- Reasoning-from-Zero using TPUs for compute
- Following the release of DeepSeek's R1 model, there was a nice follow-up from a group at Berkeley with a 'Countdown task reasoner' that can be trained from scratch for "$30 of H100s" (https://github.com/Jiayi-Pan/TinyZero)
- The aim of this project is to replicate that same task, but using a gemma2 model, and TPU infrastructure
- This will make it far, far more likely that TPUs could become an experimentation platform for the curious : The current barriers to entry are very high
- Use
gemma2-2B-base
on:- Kaggle TPU v3-8; and
- Colab TPU v2-8 (potentially - it would be very tight)
- Reasoning task : Countdown task from TinyZero
- RL objective : GRPO
- Goal : Get to "Aha!" using $free TPU resources
- with codebase that is:
- Plain and Readable (i.e. not simply dressing up a call to
trl
) - Hackable (i.e. can implement more than the demo case)
- Plain and Readable (i.e. not simply dressing up a call to
- with codebase that is:
- JAX
flax.nnx
examples/gemma (i.e. new style)- Positives:
- Framework being promoted as PyTorch-user-friendly
- Negatives:
- Early days (PROVEN)
gemma
example innnx
documentation does not worknnx.jit
of Transformer forward pass proven to take >60Gb RAM during compilation- (it would only not crash the VM if the instance had <70Gb available RAM)
- Therefore impractical for use on Colab/Kaggle == DEAD END
- Positives:
- Google-DeepMind
gemma
library in JAXflax.linen
(i.e. old style)- Positives:
- The library actually works with Gemma2
- And consumes <1Gb RAM doing
jit
on forward pass / sampling
- And consumes <1Gb RAM doing
- Library has LoRA and sharding
- The library actually works with Gemma2
- Negatives:
- Flax/linen is (according to the
nnx
docs) backward-looking - Heavy dependency on
kauldron
for training (and LoRA, sharding, etc)- Undermines the goal of using plain, readable code
- GDM
gemma
library transformer Sampler is greedy-only- Monkey-patching this functionality (which is deep inside the
Sampler
class) would smell bad - So adding library features would have to be done before beginning
- Monkey-patching this functionality (which is deep inside the
- Flax/linen is (according to the
- Positives:
pytorch-gemma
library for PyTorch/XLA- Positives:
- Library appears ready for CPU, GPU and TPU
- Includes distribution code (with Column-wise and Row-wise Linear implementations)
- Includes 8-bit quantised code
- Negatives:
- Does not appear to include LoRA
- Though may be compatible with PEFT (needs testing)
- How does auto-LoRA interact with sharding? Eeek
- While PyTorch XLA is clearly 'real' ...
- Need to test whether XLA code can get 'compiled' in a similar way to JAX
jit
- Need to test whether XLA code can get 'compiled' in a similar way to JAX
- Does not appear to include LoRA
- Positives:
- Keras gemma implementation using JAX backend
- Positives:
- Ecosystem appears ready for CPU, GPU and TPU
- Includes LoRA, more sophisicated sampling and distribution over TPUs
- Actually proven to work on TPUs via Colab (in this Repo)
- Negatives:
- IMHO, Keras is perceived as being somewhat lame vs other frameworks
- Still need to test whether fancy sampling, fancy distribution strategy, and custom training step (GRPO) can be implemented at the same time
- Positives:
So far:
nnx
has suceeded in:- causing me to labouriously debug and fix the example library
- wasting many GPU hours frustratedly trying to
nnx.jit
things without crashing the VM
gemma
(GDM library)- only has a greedy Sampler - which would need fixing
- relies very heavily on
kauldron
to do fancy things
- PyTorch/XLA
pytorch_gemma
looks interesting, though would need:- LoRA to be added (ideally using PEFT)
- actual benchmarking on TPUs vs JAX (time-consuming)
- Keras.JAX seems likely to be a good basis,
- and has started to show signs of life
- though it remains to be seen whether it works as advertised as the model/RL gets more complex
sudo snap install astral-uv --classic
uv venv env_flax.nnx
. ./env_flax.nnx/bin/activate
uv pip install jupyterlab ipywidgets jupytext OmegaConf
- Run jupyterlab notebook enviroment:
jupytext --set-formats cache-notebooks//ipynb,py:light *.py
#...
jupyter-lab --port 8282 --no-browser
- Test the countdown puzzle generation:
pushd ./aha_dataset/countdown/
python generator.py
popd
-
Experience the Ahah moment yourself for <$30
- Berkeley : Jiayi Pan=@jiayi_pirate, @JunjieZhang12, @xingyaow_, @lifan__yuan
- Author Twitter thread
- TinyZero is a reproduction of DeepSeek R1 Zero in countdown and multiplication tasks
= We built upon
veRL
- CountDown: a game where players combine numbers with basic arithmetic to reach a target number
- Scoring Function, Dataset with correct answers loaded here
- This is a bit strange, since the rows of the dataset does not include
{+,-,*,/}
answers- ... presumably there is a way to get the
target
from thenums
... - Maybe: generated by exhuastive search
- ... presumably there is a way to get the
- Tried : Qwen-2.5-Base 0.5B, 1.5B, 3B, 7B
- >1.5B models start learning to search, to self-verify and to revise solutions
- Either base or instruct model works
- Converge to same performance (instruct learns more quickly)
- Instruct model's output are more structured and readable
- PPO, GRPO and PRIME all worked
-
Mini-R1: Reproduce Deepseek R1 "aha moment" - an RL tutorial
- = same as above
- "This blog is inspired by Jiayi Pan who initially explored the idea and proofed it with a small model."
-
- Will Brown = @willccbb
- Code gist, with comments from implementers / testers
- Llama 1B, GSM8k
from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer
- beta: (float, optional, defaults to 0.04) — KL coefficient
- Commenter had success with beta=0.01
- https://huggingface.co/docs/trl/main/en/grpo_trainer#trl.GRPOConfig.beta
- (updated code, running smooth now on Qwen-1.5B w/ longer max_completion_length + higher num_generations)
- "TRL GRPO has vLLM now btw + it's soooo much faster wow"
- Next version (?) uses TRL_GRPOTrainer
- Colab version with Qwen 0.5B
- Runs
vLLM
on Colab GPU too - Decent looking code, but Aha is not directly visible...
- Runs
- This used by @anton
- Qwen2.5-0.5B (base model) directly goes into step by step breakdown with zero prompting (and Llama doesn't produce step-wise thinking of its own accord)
- when reward starts going up at step >100 it's either hacking it or discovered something
- See below...
-
@anton experiments
- "Perfect Reward Function"
- "Finished a run (R1 style) GRPO on Qwen-2.5-0.5B (base model) yield +10 accuracy points on GSM8K. Literally just works"
- Two tricks I found work well :
- use a good system prompt, and
- try lower beta (KL coefficient).
- 3 rewards: int reward, final_answer tags, and correctness reward
- has commented on original
willccbb/grpo_demo.py
gist- Has own gist of GRPOTrainer to run gsm8k eval during training
- Two tricks I found work well :
- "Got a better result on qwen2.5-0.5b (base) → 56% gsm8k"
-
Full GRPO fine-tuning Qwen2.5 0.5B on a single T4
- @qunash on GitHub = https://huggingface.co/anzorq
- Fork of the TRL repo by GitHub @andyl98 - with more optimisations
Qwen2.5-0.5B-Instruct
gsm8k eval result from 22.4% to 48.6%- in just ~150 steps (~30 minutes) on a single T4 GPU
-
Train your own R1 reasoning model with Unsloth
- Daniel Han (unsloth) thread
- We removed double memory usage during vLLM serving and finetuning
- 70% less VRAM finetuning and 20x faster inference all in one package!
- LoRA / QLoRA also originally did not work for people when doing GRPO in the starter script
- unsloth thread
- GRPO with unsloth on free colab
- "it's painfully slow; but works :p"
- Exposes code from TRL training loop a little...
model="Qwen/Qwen2-0.5B-Instruct", reward_funcs="weqweasdas/RM-Gemma-2B",
... reward model?
- Commentary
- GRPO is now optimized to use 80% less VRAM
- GRPO now with LoRA and QLoRA
- Qwen2.5(1.5B) can be trained with just 7GB!
- Llama3.1(8B) training with 15GB
- Daniel Han (unsloth) thread
-
- Weihao Zeng, Yuzhen Huang, Wei Liu, Keqing He, Qian Liu, Zejun Ma, Junxian He=@junxian_he
: hkust
We reproduce the training of DeepSeek-R1-Zero and DeepSeek-R1 for complex mathematical reasoning, starting from Qwen-2.5-Math-7B (base model), and only using 8K (query, final answer) examples from the original MATH dataset.
- Code on GitHub
- We are working on the paper and will release it very soon
- Uses OpenRLHF
- Weihao Zeng, Yuzhen Huang, Wei Liu, Keqing He, Qian Liu, Zejun Ma, Junxian He=@junxian_he
: hkust
-
"R1-V: Reinforcing Super Generalization Ability in Vision Language Models"
- Liang Chen = @liangchen5518
- https://github.com/Deep-Agent/R1-V
- Cost<$3 : 8 A100 GPUs for 30 minutes
- 100 training steps
-
The Thought Process Behind Kimi k1.5
-
- Add 'Wait!' when model wants to do '</think>' to extend thought process
- SFT on thought traces from ...?
- s1: The $6 R1 Competitor?
- Entropix Tie In - in entropix, extra 'encouragement' tokens were added in... So: similar idea
- repo on GitHub
- Project Page
- Frugality:
- Sifted their dataset of 56K examples down to just the best 1K,
- the core 1K is all that's needed to achieve o1-preview performance on a 32B model.
- Adding data didn't raise performance at all.
- s1.1 : trained on same 1K questions
- DeepSeek answers, rather than Gemini generations
- As it is just 1K examples, training is extremely cheap and took just 26 minutes
- To control test-time compute, we develop “budget forcing”:
- We either force the model to end its thinking or
- extend it by appending Wait when the model tries to stop
- This simple method improves our model
- GDE Blogpost : s1 and s1.1
- There May Not be Aha Moment in R1-Zero-like Training — A Pilot Study
- SEA AI Labs in SG (!)
- OAT-Zero code on GitHub
- Key points:
- "We found Aha moment (such as self-reflection patterns) appears at epoch 0, namely base models"
- "Superficial Self-Reflection (SSR) from base models' responses" - leading to wrong answer
- "increasing response length phenomenon not emergence .. but RL optimizing rule-based reward"
- OAT RL library - A research-friendly framework for LLM online alignment
-
GRPO with Verifiable (Binary) Rewards Is an Adaptive Weighted Contrastive Loss
- IBM researcher : Breaks down whitening into the factors
-
Fine-tuning 20B LLMs with RLHF on a 24GB consumer GPU
- Includes PEFT and
trl
(9-March-2023)
- Includes PEFT and
-
GRPO also works very well for Llama 2 7B, with an impressive +15 accuracy point increase in GSM8K
- "There's nothing magical about recent model families. If the model can perform some task with sufficient accuracy, then RL with verifiable rewards will likely boost performance"
- Run it yourself:
RicardoDominguez/grpo_llama2-7b-chat_gsm8k.sh
- Seems like unrolled code from TRL ... everything is there
-
- Somewhat ranty
-
Trellis video series:
- 1: Reinforcement Learning for LLMs in 2025
- Set-up of training, with curation of SFT data (mostly)
- 2: How does GRPO work?
- 32mins : TODO:WATCH!
- 1: Reinforcement Learning for LLMs in 2025
-
- Fixing up the implementation in AllenAI RL library
- Other comments:
- When directly minimizing the KL loss, kl3 just appears much more numerically stable.
- And the >0 guarantee here is also really nice (kl1 could go negative).
- John Schulman's Homepage : Approximating KL Divergence
- BUT ... LMs with GRPO etc with KL penalty = 0 works
- "These are from experiments and this is not official training advice."
- GRPO from DeepSeek-R1 is now available in Hugging Face
trl
libraryGRPOTrainer
Docs- KL divergence is estimated using the approximator introduced by Schulman et al. (2020)
- The approximator is defined as follows:
p_ratio - log(p_ratio) - 1
- The approximator is defined as follows:
- Has a
use_vllm=True
parameter to do generations usingvllm
- "just a reminder : trl grpo is not same as same as described in deepseek paper"
- No clipping objective (though does have KL term) (may not be important at all)
- Also "Joey (e/λ)" has comments about gradient / loss and removing constants...
- Claim : "loss = advantage*log_softmax(logits) works, same gradients"
- (Makes sense at first glance, but not clear whether there's something else going on)
- OpenRLHF
- High-performance RLHF framework built on Ray, DeepSpeed and HF Transformers
- veRL
- Volcano Engine Reinforcement Learning for LLM
- https://x.com/Guodaya/status/1886635010251518330 (now deleted)
- =Researcher at DeepSeek
- The 660B R1-Zero and R1 began running after the release of V3, with training taking approximately 2-3 weeks
- The R1 model prior to this time (e.g., in the V3 tech report) was the R1-Lite or the R1-Lite-Zero
-
GRPO VRAM Requirements For the GPU Poor
- Points out RAM requirements (with potential torch ideas)
- GRPO explanation not very useful
-
The N Implementation Details of RLHF with PPO
- 2023-10-24
-
The 37 Implementation Details of Proximal Policy Optimization
- 2022-03-25
- DeepScaleR: Surpassing O1-Preview with a 1.5B Model by Scaling RL
- Berkeley Sky Computing Lab (not the same authors as original $30 one, AFAICT)
- "1.5B model beats o1-preview on math by RL"
- Cost:
- Overall, our training run consists of ~1,750 steps.
- The initial 8K context phase was trained on 8 A100 GPUs,
- while the 16K and 24K phases scaled up training to 32 A100 GPUs.
- In total, the training took ~3,800 A100 hours = roughly 5 days on 32 A100s
- $4500 in compute cost
- Reddit discussion DeepScaleR-1.5B-Preview
- Model on HF
- Project on GitHub
- uses their own veRL
- "DeepScaleR is by far the most sophisticated and impressive thing built on R1 this far"
- Maximizing intelligence per FLOP is a natural step after test time unlock
-
Eliciting Critical Reasoning in Retrieval-Augmented Language Models via Contrastive Explanations
-
Self-RAG: Learning to Retrieve, Generate, and Critique through Self-Reflection
-
Chain-of-Retrieval Augmented Generation
- Microsoft
- More than 10 points improvement in EM score compared to strong baseline
- Establishes a new SotA performance across a diverse range of knowledge-intensive tasks
- BERGEN: A Benchmarking Library for Retrieval-Augmented Generation
- Benchmarking Large Language Models in Retrieval-Augmented Generation
- LegalBench-RAG: A Benchmark for Retrieval-Augmented Generation in the Legal Domain
- Fact, Fetch, and Reason: A Unified Evaluation of Retrieval-Augmented Generation
- CRAG: Comprehensive RAG Benchmark
- Natural Questions: A Benchmark for Question Answering Research
- Grounding Large Language Models in Interactive Environments with Online Reinforcement Learning
- T5 (in 2023-02)
- RAGEN: A General-Purpose Reasoning Agent Training Framework
- Code on GitHub
- Author Thread
- We run RAGEN on the Gym-Sokoban task:
- Qwen-2.5-{0.5B, 3B}-{Instruct, None}
- DeepSeek-R1-Distill-Qwen-1.5B
- Scaled Cognition: "first ever models trained specifically for agentic applications"
- "APT-1, is now #1 on agentic benchmarks" ...
- https://x.com/chrisbarber/status/1885047105741611507
- Shannon Sands (@max_paperclips) from @NousResearch
- backtracking vector
- "caused the chain of thought to backtrack much more often, and when suppressed caused it to be a linear and much shorter CoT"
See this page for:
- JAX Resources
- JAX (generic)
- Gemma Models
- Keras (JAX backend)
- LoRA for JAX
- TPU training
- TPU training (Node-style TPUs = old, including Colab)
- TPU training (VM-style TPUs = modern)