Large language models like ChatGPT or Claude weren’t created solely by training on large datasets. RLHF plays a crucial role - a process where models learn from human feedback and become more helpful and safe.
Was ist RLHF und warum ist es wichtig¶
Reinforcement Learning from Human Feedback (RLHF) is a key technique that enables large language models (LLMs) to provide helpful and safe responses. While traditional LLM training involves predicting the next token based on huge text corpora, RLHF adds another layer - it teaches the model to recognize what humans actually consider a quality response.
The problem is that a model trained only on text prediction can generate technically correct but unhelpful or even harmful responses. RLHF solves this problem by incorporating human evaluation and reinforcement learning into the process.
Drei Phasen des RLHF-Prozesses¶
1. Supervised Fine-tuning (SFT)¶
The first step involves fine-tuning a pre-trained model on a curated dataset of demonstrations. Humans create quality examples of responses to various prompts:
# Example SFT dataset
{
"prompt": "Explain quantum mechanics simply",
"completion": "Quantum mechanics describes the behavior of particles at the atomic level.
Key principles are: superposition (particles can be in multiple
states simultaneously), uncertainty (we cannot precisely know both
position and velocity of a particle) and quantum entanglement..."
}
This step creates the foundation - the model learns basic patterns of useful responses.
2. Training des Reward-Modells¶
The second step is the most critical. A reward model is created that can automatically evaluate response quality. The process looks like this:
- The model generates several responses to the same prompt
- Humans rank these responses by quality
- A separate reward model learns to predict these preferences
# Pseudocode for training reward model
class RewardModel(nn.Module):
def __init__(self, base_model):
super().__init__()
self.transformer = base_model
self.reward_head = nn.Linear(hidden_size, 1)
def forward(self, input_ids):
hidden_states = self.transformer(input_ids).last_hidden_state
# Get reward score for entire sequence
rewards = self.reward_head(hidden_states[:, -1, :])
return rewards
# Loss function for pairwise ranking
def reward_loss(reward_chosen, reward_rejected):
return -torch.log(torch.sigmoid(reward_chosen - reward_rejected))
The reward model learns to recognize patterns in responses that humans prefer - such as factual accuracy, helpfulness, safety, or communication style.
3. Reinforcement Learning mit PPO¶
The final phase uses Proximal Policy Optimization (PPO) to optimize the original model against the trained reward model:
# Simplified PPO implementation for RLHF
class PPOTrainer:
def __init__(self, policy_model, reward_model, ref_model):
self.policy = policy_model
self.reward_model = reward_model
self.ref_model = ref_model # Frozen reference model
self.kl_coef = 0.1 # KL divergence coefficient
def compute_rewards(self, queries, responses):
# Reward from reward model
rewards = self.reward_model(responses)
# KL penalty against reference model
kl_penalty = self.compute_kl_penalty(queries, responses)
return rewards - self.kl_coef * kl_penalty
def train_step(self, batch):
# Generate responses with current model
responses = self.policy.generate(batch['queries'])
# Compute rewards
rewards = self.compute_rewards(batch['queries'], responses)
# PPO update
policy_loss = self.compute_policy_loss(responses, rewards)
policy_loss.backward()
self.optimizer.step()
Praktische Herausforderungen und Loesungen¶
KL-Divergenz-Regularisierung¶
A critical problem is “reward hacking” - the model can learn to generate responses that fool the reward model but aren’t actually quality. Therefore, KL divergence penalty is used to limit how much the model can deviate from the original distribution:
def kl_divergence_penalty(policy_logprobs, ref_logprobs):
"""Penalty for too much change from reference model"""
kl = policy_logprobs - ref_logprobs
return torch.mean(kl)
# In loss function
total_reward = base_reward - kl_coef * kl_penalty
Trainingsstabiilitaet¶
PPO is notoriously unstable. Key stabilization techniques include:
- Gradient clipping - limiting gradient magnitude
- Learning rate scheduling - gradual learning rate reduction
- Value function baseline - subtracting baseline to reduce variance
- Multiple epochs - reusing the same data multiple times
Metriken und Evaluierung¶
Measuring RLHF success is a complex task. Both automatic metrics and human evaluation are used:
# Automatic metrics
metrics = {
'reward_score': torch.mean(rewards),
'kl_divergence': compute_kl(policy_probs, ref_probs),
'response_length': torch.mean(response_lengths),
'perplexity': compute_perplexity(responses)
}
# Human evaluation (A/B testing)
def human_evaluation(model_a_responses, model_b_responses):
preferences = []
for resp_a, resp_b in zip(model_a_responses, model_b_responses):
# Humans evaluate which response is better
preference = human_judge(resp_a, resp_b)
preferences.append(preference)
return np.mean(preferences) # Win rate of model A
Aktuelle Trends und Verbesserungen¶
The RLHF field is rapidly evolving. Latest trends include:
- Constitutional AI - model learns self-regulation according to explicit principles
- RLAIF - using AI instead of humans to generate feedback
- Multi-objective optimization - simultaneous optimization of multiple goals (helpfulness, safety, truthfulness)
- Online RLHF - continuous learning from real user interaction
Zusammenfassung¶
RLHF stellt einen grundlegenden Fortschritt beim AI-Alignment-Problem dar. Durch die Kombination von Supervised Learning, Reward Modeling und Reinforcement Learning kann ein rohes Sprachmodell in einen hilfreichen Assistenten verwandelt werden. Obwohl die Implementierung erhebliche technische Herausforderungen mit sich bringt - von PPO-Instabilitaet bis hin zu Reward Hacking - zeigen die Ergebnisse klar den Weg zu sichereren und nuetzlicheren KI-Systemen. Fuer den praktischen Einsatz sind sorgfaeltiges Hyperparameter-Tuning, qualitativ hochwertige Daten fuer das Reward-Modell und gruendliches Testen an realen Anwendungsfaellen entscheidend.