# [BONUS 3] Group Relative Policy Optimization in vanilla torch - [GRPO](https://colab.research.google.com/#fileId=https%3A//huggingface.co/datasets/nanochat-students/notebooks/blob/main/grpo.ipynb) This chapter demonstrates Group Relative Policy Optimization (GRPO) training for the NanoChat model—a reinforcement learning approach for improving model responses based on reward signals. ## Import model and tokenizer ```python import torch from torch.utils.data import DataLoader from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup model_id = "karpathy/nanochat-d32" revision = "refs/pr/1" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision) model = AutoModelForCausalLM.from_pretrained( model_id, revision=revision, torch_dtype=torch.bfloat16 if device.type == "cuda" else torch.float32, ).to(device) tokenizer.pad_token = tokenizer.eos_token model.config.pad_token_id = tokenizer.pad_token_id ``` ## Setup LoRA ```python from peft import LoraConfig, get_peft_model lora_config = LoraConfig( r=1, lora_alpha=2, lora_dropout=0.00, task_type="CAUSAL_LM", target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "fc1", "fc2"] ) model = get_peft_model(model, lora_config) model.print_trainable_parameters() ``` ``` trainable params: 1,179,648 || all params: 1,880,227,840 || trainable%: 0.0627 ``` ## Demo the model Test with a plain autoregressive prompt: ```python print("=" * 80) print("TEST 1: Plain Autoregressive Prompt") print("=" * 80) prompt = "The Eiffel Tower stands in Paris and" test_inputs = tokenizer(prompt, return_tensors="pt").to(device) with torch.no_grad(): test_outputs = model.generate( **test_inputs, max_new_tokens=64, do_sample=False, pad_token_id=tokenizer.pad_token_id, ) generated_tokens = test_outputs[0, test_inputs["input_ids"].shape[1] :] print(f"Prompt: {prompt}") print(f"\nGenerated: {tokenizer.decode(generated_tokens, skip_special_tokens=True)}") print("=" * 80) ``` ``` ================================================================================ TEST 1: Plain Autoregressive Prompt ================================================================================ Prompt: The Eiffel Tower stands in Paris and Generated: is one of the most famous landmarks in the world. It is located on the Champ de Mars in the heart of the city. The tower was built for the 1889 World's Fair. It was designed by the French engineer Gustave Eiffel and took 2 years to build. The Eiffel Tower stands 324 meters ================================================================================ ``` And with the chat template: ```python print("=" * 80) print("TEST 2: Chat Template") print("="*80) conversation = [ {"role": "user", "content": "What is the capital of France?"}, ] inputs = tokenizer.apply_chat_template( conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to(device) print(f"Formatted prompt: {tokenizer.decode(inputs['input_ids'][0])}") print(f"Input IDs: {inputs['input_ids'][0].tolist()}") with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=64, do_sample=False ) generated_tokens = outputs[0, inputs["input_ids"].shape[1] :] print(f"\nGenerated: {tokenizer.decode(generated_tokens)}") print("=" * 80) ``` ``` ================================================================================ TEST 2: Chat Template ================================================================================ Formatted prompt: <|bos|><|user_start|>What is the capital of France?<|user_end|><|assistant_start|> Input IDs: [65527, 65528, 1442, 309, 261, 3429, 281, 4215, 63, 65529, 65530] Generated: The capital of France is Paris.<|assistant_end|> ================================================================================ ``` ## Dataset We use the OpenR1-Math dataset for math reasoning tasks: ```python raw_dataset = load_dataset("HuggingFaceH4/OpenR1-Math-220k-default-verified", split="train") splits = raw_dataset.train_test_split(test_size=0.1, seed=42) train_dataset = splits["train"] eval_dataset = splits["test"] ``` ## Training Configuration ```python max_train_steps = 50 prompt_batch_size = 1 num_generations = 4 max_new_tokens = 128 temperature = 1.0 top_k = 50 learning_rate = 5e-6 weight_decay = 0.0 epsilon = 0.2 gradient_accumulation_steps = 1 warmup_ratio = 0.1 logging_frequency = 5 max_train_samples = 1000 max_eval_samples = 100 ``` ## Reward Functions GRPO requires reward functions to guide the policy optimization. We define several: ```python import re import numpy as np import torch.nn.functional as F from contextlib import nullcontext def think_format_reward(completions): """ Reward function that checks if the reasoning process is enclosed within and tags. Returns 1.0 if the format is correct, otherwise 0.0. """ pattern = r"^(?!.*)(.*?).*$" matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completions] return [1.0 if match else 0.0 for match in matches] def accuracy_reward(completions, solutions): """ Reward function that checks if the completion matches the solution. For simplicity, we'll do basic string matching here. """ rewards = [] for completion, solution in zip(completions, solutions): # Simple string matching (normalized) reward = 1.0 if solution.strip().lower() in completion.strip().lower() else 0.0 rewards.append(reward) return rewards def min_length_reward(completions, min_length=10): """ Reward function that checks if the completion is at least a certain length. Returns 1.0 if the length is greater than or equal to the minimum length, otherwise 0.0. """ return [1.0 if len(completion) >= min_length else 0.0 for completion in completions] def combined_reward(completions, solutions): """ Combines format and accuracy rewards with equal weight. """ format_rewards = think_format_reward(completions) accuracy_rewards = accuracy_reward(completions, solutions) min_length_rewards = min_length_reward(completions) return [np.mean([f, a, m]) for f, a, m in zip(format_rewards, accuracy_rewards, min_length_rewards)] ``` ## Helper Functions ```python def per_token_log_probs(logits, labels): logits = logits.float() log_probs = F.log_softmax(logits, dim=-1) return log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) def prepare_prompt(example, problem_key="problem", solution_key="solution"): # Extract the messages (should be a list of dicts with 'role' and 'content') prompt = example.get(problem_key, "") messages = [{"role": "user", "content": prompt}] formatted = tokenizer.apply_chat_template( messages, add_generation_prompt=True, truncation=True, max_length=2048, padding=False, return_dict=True, return_tensors="pt", ) return formatted["input_ids"], formatted["attention_mask"] if device.type == "cuda": autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) else: autocast_ctx = nullcontext() ``` ## Optimizer and Scheduler ```python optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) total_update_steps = max_train_steps // gradient_accumulation_steps warmup_steps = max(1, int(total_update_steps * warmup_ratio)) scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_update_steps) ``` ## The Training Loop The GRPO training loop generates multiple completions per prompt, computes rewards, and updates the policy using a clipped objective similar to PPO: ```python # Sample dataset if needed if max_train_samples is not None and len(train_dataset) > max_train_samples: train_dataset = train_dataset.select(range(max_train_samples)) if max_eval_samples is not None and len(eval_dataset) > max_eval_samples: eval_dataset = eval_dataset.select(range(max_eval_samples)) model.train() train_index = 0 global_step = 0 running_reward = 0.0 running_loss = 0.0 for step in range(1, max_train_steps + 1): example = train_dataset[train_index % len(train_dataset)] train_index += 1 prompt_ids, prompt_mask = prepare_prompt(example) prompt_ids = prompt_ids.to(device) prompt_mask = prompt_mask.to(device) prompt_length = prompt_ids.shape[1] prompt_repeat = prompt_ids.repeat(num_generations, 1) mask_repeat = prompt_mask.repeat(num_generations, 1) # Generate completions model.eval() with torch.no_grad(): generated = model.generate( input_ids=prompt_repeat, attention_mask=mask_repeat, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_k=top_k, pad_token_id=tokenizer.pad_token_id, ) model.train() sequences = generated attention_mask = (sequences != tokenizer.pad_token_id).long() completion_mask = attention_mask.clone() completion_mask[:, :prompt_length] = 0 completion_tokens = sequences[:, prompt_length:] completion_texts = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True) # Get solution solution = example.get("solution", example.get("answer", "")) solutions = [solution] * num_generations # Compute rewards rewards = combined_reward(completion_texts, solutions) rewards = torch.tensor(rewards, dtype=torch.float32, device=device) running_reward += rewards.mean().item() rewards_view = rewards.view(prompt_batch_size, num_generations) mean_rewards = rewards_view.mean(dim=1, keepdim=True) std_rewards = rewards_view.std(dim=1, keepdim=True) std_rewards = torch.where(std_rewards > 0, std_rewards, torch.ones_like(std_rewards)) advantages = ((rewards_view - mean_rewards) / std_rewards).view(-1) labels = sequences[:, 1:].clone() labels[attention_mask[:, 1:] == 0] = tokenizer.pad_token_id # Compute old log probs with torch.no_grad(): with (autocast_ctx if device.type == "cuda" else nullcontext()): old_outputs = model( input_ids=sequences, attention_mask=attention_mask, use_cache=False, ) old_log_probs = per_token_log_probs(old_outputs.logits[:, :-1], labels) valid_mask = (completion_mask[:, 1:] == 1) & (labels != tokenizer.pad_token_id) # Compute loss optimizer.zero_grad(set_to_none=True) with (autocast_ctx if device.type == "cuda" else nullcontext()): outputs = model( input_ids=sequences, attention_mask=attention_mask, use_cache=False, ) log_probs = per_token_log_probs(outputs.logits[:, :-1], labels) ratio = (log_probs - old_log_probs).exp() ratio = torch.where(valid_mask, ratio, torch.ones_like(ratio)) clipped_ratio = ratio.clamp(1.0 - epsilon, 1.0 + epsilon) adv = advantages.unsqueeze(1) loss_unclipped = ratio * adv loss_clipped = clipped_ratio * adv per_token_loss = -torch.min(loss_unclipped, loss_clipped) per_token_loss = torch.where(valid_mask, per_token_loss, torch.zeros_like(per_token_loss)) denom = valid_mask.sum().clamp(min=1) loss = per_token_loss.sum() / denom loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() scheduler.step() global_step += 1 running_loss += loss.item() if step % logging_frequency == 0: avg_reward = running_reward / logging_frequency avg_loss = running_loss / logging_frequency current_lr = scheduler.get_last_lr()[0] print( f"step={step:04d} | loss={avg_loss:.4f} | avg_reward={avg_reward:.4f} | lr={current_lr:.2e}" ) running_reward = 0.0 running_loss = 0.0 # Sample evaluation model.eval() eval_example = eval_dataset[0] prompt_ids, prompt_mask = prepare_prompt(eval_example) with torch.no_grad(): eval_sequences = model.generate( input_ids=prompt_ids.to(device), attention_mask=prompt_mask.to(device), max_new_tokens=max_new_tokens, do_sample=True, top_k=top_k, temperature=temperature, pad_token_id=tokenizer.pad_token_id, ) model.train() completion = eval_sequences[0, prompt_ids.shape[1] :] print("Sample eval completion:", tokenizer.decode(completion, skip_special_tokens=True)[:100]) print("Training complete.") ``` ``` step=0005 | loss=0.0000 | avg_reward=0.4000 | lr=0.00e+00 Sample eval completion: 3^4 - 11 and 3^6 - 17 step=0010 | loss=0.0000 | avg_reward=0.3333 | lr=0.00e+00 Sample eval completion: 11. This statement refers to an optimization problem where we seek to find the smallest prime \( p step=0015 | loss=0.0000 | avg_reward=0.4667 | lr=0.00e+00 Sample eval completion: What number has two prime factors, 1 and itself, without additional restrictions? One possible combi step=0020 | loss=-0.0983 | avg_reward=0.4500 | lr=0.00e+00 ... Training complete. ```