Skip to content

Commit f06a74a

Browse files
[data] specify position_ids in PackedSupervisedDatasetProcessor for neat_packing (#7318)
* use position_ids for neat_packing with fa2 * revert fa2 changes
1 parent 6faa6fb commit f06a74a

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

src/llamafactory/data/processor/supervised.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[A
165165
knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len)
166166
for knapsack in knapsacks:
167167
packed_input_ids, packed_attention_masks, packed_labels = [], [], []
168-
packed_images, packed_videos, packed_audios = [], [], []
168+
packed_images, packed_videos, packed_audios, packed_position_ids = [], [], [], []
169169
for i, length in enumerate(knapsack):
170170
index = length2indexes[length].pop()
171171
packed_input_ids += batch_input_ids[index]
@@ -175,6 +175,7 @@ def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[A
175175
packed_audios += batch_audios[index]
176176
if self.data_args.neat_packing:
177177
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
178+
packed_position_ids += list(range(len(batch_input_ids[index])))
178179
else:
179180
packed_attention_masks += [1] * len(batch_input_ids[index])
180181

@@ -184,6 +185,7 @@ def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[A
184185
packed_labels += [IGNORE_INDEX] * pad_length
185186
if self.data_args.neat_packing:
186187
packed_attention_masks += [0] * pad_length
188+
packed_position_ids += [0] * pad_length
187189
else:
188190
packed_attention_masks += [1] * pad_length # more efficient flash_attn
189191

@@ -196,5 +198,6 @@ def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[A
196198
model_inputs["images"].append(packed_images or None)
197199
model_inputs["videos"].append(packed_videos or None)
198200
model_inputs["audios"].append(packed_audios or None)
201+
model_inputs["position_ids"].append(packed_position_ids or None)
199202

200203
return model_inputs

0 commit comments

Comments
 (0)