Skip to content

more grpo log #3801

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 8, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions swift/trainers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,9 @@ class GRPOArgumentsMixin:
# Dr. GRPO, https://arxiv.org/abs/2503.20783
scale_rewards: bool = True

# compatible with trl main branch(0.17.0.dev0)
wandb_log_unique_prompts: Optional[bool] = None


@dataclass
class TrainingArguments(SwiftArgumentsMixin, HfTrainingArguments):
Expand Down
122 changes: 89 additions & 33 deletions swift/trainers/rlhf_trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import re
import time
from collections import defaultdict
from collections import defaultdict, deque
from concurrent.futures import Future
from contextlib import contextmanager
from copy import copy, deepcopy
Expand All @@ -19,7 +19,9 @@
import numpy as np
import torch
import torch.nn as nn
import transformers
from accelerate.utils import gather, gather_object, is_peft_model, set_seed
from packaging import version
from torch.nn import ModuleList
from torch.utils.data import DataLoader
from transformers import PreTrainedModel, TrainerCallback
Expand All @@ -42,8 +44,8 @@
from trl.extras.profiling import profiling_decorator
except ImportError:
raise ImportError('Please install trl : `pip install -U trl`')

del HFGRPOTrainer.__init__
del HFGRPOTrainer.log

logger = get_logger()
if is_wandb_available():
Expand Down Expand Up @@ -165,7 +167,6 @@ def __init__(self,
self.temperature = args.temperature
model.warnings_issued['estimate_tokens'] = True
kwargs['data_collator'] = lambda features: features
self._metrics = {'train': defaultdict(list), 'eval': defaultdict(list)}

use_vllm = args.use_vllm
use_lmdeploy = args.use_lmdeploy
Expand All @@ -180,6 +181,19 @@ def __init__(self,

super().__init__(model, ref_model, *_args, **kwargs)

self._metrics = {'train': defaultdict(list), 'eval': defaultdict(list)}
self._total_train_tokens = 0
self.log_completions = args.log_completions
self.jsonl_writer = JsonlWriter(os.path.join(self.args.output_dir, 'completions.jsonl'))
# maxlen is set to the total number of forward passes per step. This value of `maxlen` ensures we log only the
# final optimization step.
maxlen = self.accelerator.num_processes * args.per_device_train_batch_size * args.gradient_accumulation_steps
self._textual_logs = {
'prompt': deque(maxlen=maxlen),
'completion': deque(maxlen=maxlen),
'rewards': defaultdict(lambda: deque(maxlen=maxlen)),
}

num_processes = self.accelerator.num_processes
self.global_train_batch_size = global_batch_size = args.per_device_train_batch_size * num_processes
possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
Expand Down Expand Up @@ -278,8 +292,6 @@ def __init__(self,
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, PreTrainedModel):
self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
self.log_completions = args.log_completions
self.jsonl_writer = JsonlWriter(os.path.join(self.args.output_dir, 'completions.jsonl'))

# Multi-step
self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
Expand Down Expand Up @@ -925,49 +937,59 @@ def _prepare_batch_inputs(self, inputs, rewards):
def _log_metrics(self, inputs, messages, completions, rewards, rewards_per_func):
"""Log training/evaluation metrics"""
mode = 'eval' if self.control.should_evaluate else 'train'
device = self.accelerator.device

# Calculate completion length metrics
completion_lengths = self.accelerator.gather_for_metrics(
torch.cat([inp['completion_mask'].sum(1) for inp in inputs]))
completion_lengths_mean = completion_lengths.float().mean().item()

self._metrics[mode]['completion_length'].append(completion_lengths_mean)
agg_completion_mask = gather(torch.cat([inp['completion_mask'].sum(1) for inp in inputs]))

self._metrics[mode]['completions/mean_length'].append(agg_completion_mask.float().mean().item())
self._metrics[mode]['completions/min_length'].append(agg_completion_mask.float().min().item())
self._metrics[mode]['completions/max_length'].append(agg_completion_mask.float().max().item())
# Calculate clip ratio
response_clip_ratio = (torch.gt(completion_lengths, self.args.max_completion_length).float().mean().item())

self._metrics[mode]['response_clip_ratio'].append(response_clip_ratio)

# Log rewards
reward_per_func = rewards_per_func.mean(0)
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, nn.Module):
agg_truncated_mask = gather(torch.cat([inp['truncated_mask'] for inp in inputs]).to(device))

term_completion_mask = agg_completion_mask[agg_truncated_mask]
clipped_completions_ratio = 1 - len(term_completion_mask) / len(agg_completion_mask)

if len(term_completion_mask) == 0:
# edge case where no completed sequences are found
term_completion_mask = torch.zeros(1, device=device)
self._metrics[mode]['completions/mean_terminated_length'].append(term_completion_mask.float().mean().item())
self._metrics[mode]['completions/min_terminated_length'].append(term_completion_mask.float().min().item())
self._metrics[mode]['completions/max_terminated_length'].append(term_completion_mask.float().max().item())
self._metrics[mode]['completions/clipped_ratio'].append(clipped_completions_ratio)

# Get the names of the reward functions
reward_func_names = []
for reward_func in self.reward_funcs:
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
reward_func_name = reward_func.config._name_or_path.split('/')[-1]
else:
if inspect.isfunction(reward_func):
reward_func_name = reward_func.__name__
else:
reward_func_name = reward_func.__class__.__name__
self._metrics[mode][f'rewards/{reward_func_name}'].append(reward_per_func[i].item())

reward_func_names.append(reward_func_name)
metrics_mask = ~agg_truncated_mask if self.args.overlong_filter else torch.ones(
agg_completion_mask.shape[0], dtype=torch.bool)
for i, reward_func_name in enumerate(reward_func_names):
mean_rewards = (rewards_per_func[:, i][metrics_mask]).mean().item()
self._metrics[mode][f'rewards/{reward_func_name}/mean'].append(mean_rewards)
std_rewards = (rewards_per_func[:, i][metrics_mask]).std().item()
self._metrics[mode][f'rewards/{reward_func_name}/std'].append(std_rewards)

# Log overall reward stats
grouped_rewards = rewards.view(-1, self.num_generations)
grouped_rewards = rewards[metrics_mask].view(-1, self.num_generations)
self._metrics[mode]['reward'].append(grouped_rewards.mean().item())
self._metrics[mode]['reward_std'].append(grouped_rewards.std(dim=1).mean().item())

# Log completions if enabled
if self.log_completions and self.state.global_step % self.args.logging_steps == 0:
table = {
'step': [str(self.state.global_step)] * len(rewards),
'messages': gather_object(messages),
'completion': gather_object(completions),
'reward': rewards.tolist(),
}
self.jsonl_writer.append(table)
if 'wandb' in self.args.report_to and wandb.run is not None and self.accelerator.is_main_process:
import pandas as pd
df = pd.DataFrame(table)
wandb.log({'completions': wandb.Table(dataframe=df)})
# Log prompt and completion texts
self._textual_logs['prompt'].extend(m for m, mask in zip(gather_object(messages), metrics_mask) if mask)
self._textual_logs['completion'].extend(c for c, mask in zip(gather_object(completions), metrics_mask) if mask)

for i, name in enumerate(reward_func_names):
self._textual_logs['rewards'][name].extend(rewards_per_func[:, i][metrics_mask].tolist())

@profiling_decorator
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
Expand Down Expand Up @@ -1259,3 +1281,37 @@ def seed_context(self):
dataloader_params['prefetch_factor'] = self.args.dataloader_prefetch_factor

return self.accelerator.prepare(DataLoader(resample_dataset, **dataloader_params))

def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
# compatible with trl0.16 and trl0.17.0.dev
# remove this function when next trl release(0.17.0)

mode = 'eval' if self.control.should_evaluate else 'train'
metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics

# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
if mode == 'eval':
metrics = {f'eval_{key}': val for key, val in metrics.items()}

logs = {**logs, **metrics}
if version.parse(transformers.__version__) >= version.parse('4.47.0.dev0'):
super().log(logs, start_time)
else: # transformers<=4.46
super().log(logs)
self._metrics[mode].clear()

if (self.accelerator.is_main_process and self.args.report_to and 'wandb' in self.args.report_to
and wandb.run is not None):
import pandas as pd
table = {
'step': [str(self.state.global_step)] * len(self._textual_logs['prompt']),
'prompt': self._textual_logs['prompt'],
'completion': self._textual_logs['completion'],
**self._textual_logs['rewards'],
}
df = pd.DataFrame(table)
self.jsonl_writer.append(table)
if self.args.wandb_log_unique_prompts:
df = df.drop_duplicates(subset=['prompt'])
wandb.log({'completions': wandb.Table(dataframe=df)})
Loading