Skip to content

mdda/getting-to-aha-with-tpus

Repository files navigation

Getting to Aha

With TPU(s) using JAX nnx

  • Reasoning-from-Zero using TPUs for compute
    • Following the release of DeepSeek's R1 model, there was a nice follow-up from a group at Berkeley with a 'Countdown task reasoner' that can be trained from scratch for "$30 of H100s" (https://github.com/Jiayi-Pan/TinyZero)
    • The aim of this project is to replicate that same task, but using a gemma2 model, and TPU infrastructure
    • This will make it far, far more likely that TPUs could become an experimentation platform for the curious : The current barriers to entry are very high

The Plan

  • Use gemma2-2B-base on:
    • Kaggle TPU v3-8; and
    • Colab TPU v2-8 (potentially - it would be very tight)
  • Reasoning task : Countdown task from TinyZero
    • RL objective : GRPO
  • Goal : Get to "Aha!" using $free TPU resources
    • with codebase that is:
      • Plain and Readable (i.e. not simply dressing up a call to trl)
      • Hackable (i.e. can implement more than the demo case)

Decision : Which framework?

  • JAX flax.nnx examples/gemma (i.e. new style)
    • Positives:
      • Framework being promoted as PyTorch-user-friendly
    • Negatives:
      • Early days (PROVEN)
      • gemma example in nnx documentation does not work
      • nnx.jit of Transformer forward pass proven to take >60Gb RAM during compilation
        • (it would only not crash the VM if the instance had <70Gb available RAM)
        • Therefore impractical for use on Colab/Kaggle == DEAD END
  • Google-DeepMind gemma library in JAX flax.linen (i.e. old style)
    • Positives:
      • The library actually works with Gemma2
        • And consumes <1Gb RAM doing jit on forward pass / sampling
      • Library has LoRA and sharding
    • Negatives:
      • Flax/linen is (according to the nnx docs) backward-looking
      • Heavy dependency on kauldron for training (and LoRA, sharding, etc)
        • Undermines the goal of using plain, readable code
      • GDM gemma library transformer Sampler is greedy-only
        • Monkey-patching this functionality (which is deep inside the Sampler class) would smell bad
        • So adding library features would have to be done before beginning
  • pytorch-gemma library for PyTorch/XLA
    • Positives:
      • Library appears ready for CPU, GPU and TPU
      • Includes distribution code (with Column-wise and Row-wise Linear implementations)
      • Includes 8-bit quantised code
    • Negatives:
      • Does not appear to include LoRA
        • Though may be compatible with PEFT (needs testing)
        • How does auto-LoRA interact with sharding? Eeek
      • While PyTorch XLA is clearly 'real' ...
        • Need to test whether XLA code can get 'compiled' in a similar way to JAX jit
  • Keras gemma implementation using JAX backend
    • Positives:
      • Ecosystem appears ready for CPU, GPU and TPU
      • Includes LoRA, more sophisicated sampling and distribution over TPUs
      • Actually proven to work on TPUs via Colab (in this Repo)
    • Negatives:
      • IMHO, Keras is perceived as being somewhat lame vs other frameworks
      • Still need to test whether fancy sampling, fancy distribution strategy, and custom training step (GRPO) can be implemented at the same time

So far:

  • nnx has suceeded in:
    • causing me to labouriously debug and fix the example library
    • wasting many GPU hours frustratedly trying to nnx.jit things without crashing the VM
  • gemma (GDM library)
    • only has a greedy Sampler - which would need fixing
    • relies very heavily on kauldron to do fancy things
  • PyTorch/XLA pytorch_gemma looks interesting, though would need:
    • LoRA to be added (ideally using PEFT)
    • actual benchmarking on TPUs vs JAX (time-consuming)
  • Keras.JAX seems likely to be a good basis,
    • and has started to show signs of life
    • though it remains to be seen whether it works as advertised as the model/RL gets more complex

Installation / Running the code

sudo snap install astral-uv --classic
uv venv env_flax.nnx
. ./env_flax.nnx/bin/activate
uv pip install jupyterlab ipywidgets jupytext OmegaConf
  • Run jupyterlab notebook enviroment:
jupytext --set-formats cache-notebooks//ipynb,py:light *.py
#...
jupyter-lab --port 8282 --no-browser
  • Test the countdown puzzle generation:
pushd ./aha_dataset/countdown/
python generator.py 
popd

RL-related Resources

Post-R1 GRPO demos

Contrarian Ideas

GRPO expositions

GRPO Hints

GRPO libraries

R1 Notes

  • https://x.com/Guodaya/status/1886635010251518330 (now deleted)
    • =Researcher at DeepSeek
    • The 660B R1-Zero and R1 began running after the release of V3, with training taking approximately 2-3 weeks
    • The R1 model prior to this time (e.g., in the V3 tech report) was the R1-Lite or the R1-Lite-Zero

Miscellaneous


Potential next ideas

RL on Deepseek 'hard distilled' models

Agentic RAG

Datasets

Agent RL

Task Vectors


TPU Resources

See this page for:

  • JAX Resources
    • JAX (generic)
    • Gemma Models
    • Keras (JAX backend)
    • LoRA for JAX
  • TPU training
    • TPU training (Node-style TPUs = old, including Colab)
    • TPU training (VM-style TPUs = modern)

About

Reasoning-from-Zero using gemma.JAX.nnx on TPUs

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages