Skip to content

Commit 1199759

Browse files
authored
[trainer] fix batch processing in PPO trainer (#7576)
1 parent 903db09 commit 1199759

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

src/llamafactory/train/ppo/trainer.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -241,9 +241,11 @@ def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
241241
self.tokenizer.padding_side = "right" # change padding side
242242
queries, responses, rewards = [], [], []
243243
for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
244-
mini_batch_queries, mini_batch_responses = self.get_inputs(
245-
batch[idx : idx + self.config.mini_batch_size]
246-
)
244+
mini_batch = {
245+
"input_ids": batch["input_ids"][idx : idx + self.config.mini_batch_size],
246+
"attention_mask": batch["attention_mask"][idx : idx + self.config.mini_batch_size]
247+
}
248+
mini_batch_queries, mini_batch_responses = self.get_inputs(mini_batch)
247249
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses)
248250
queries.extend(mini_batch_queries)
249251
responses.extend(mini_batch_responses)

0 commit comments

Comments
 (0)