A simple JAX reproduction of the Tiny Recursive Model (TRM) trained on the sudoku task. Code here.
Training metrics included (see the "Training Metrics" tab).
-