@@ -165,7 +165,7 @@ def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[A
165
165
knapsacks = greedy_knapsack (lengths , self .data_args .cutoff_len )
166
166
for knapsack in knapsacks :
167
167
packed_input_ids , packed_attention_masks , packed_labels = [], [], []
168
- packed_images , packed_videos , packed_audios = [], [], []
168
+ packed_images , packed_videos , packed_audios , packed_position_ids = [], [], [], []
169
169
for i , length in enumerate (knapsack ):
170
170
index = length2indexes [length ].pop ()
171
171
packed_input_ids += batch_input_ids [index ]
@@ -175,6 +175,7 @@ def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[A
175
175
packed_audios += batch_audios [index ]
176
176
if self .data_args .neat_packing :
177
177
packed_attention_masks += [i + 1 ] * len (batch_input_ids [index ]) # start from 1
178
+ packed_position_ids += list (range (len (batch_input_ids [index ])))
178
179
else :
179
180
packed_attention_masks += [1 ] * len (batch_input_ids [index ])
180
181
@@ -184,6 +185,7 @@ def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[A
184
185
packed_labels += [IGNORE_INDEX ] * pad_length
185
186
if self .data_args .neat_packing :
186
187
packed_attention_masks += [0 ] * pad_length
188
+ packed_position_ids += [0 ] * pad_length
187
189
else :
188
190
packed_attention_masks += [1 ] * pad_length # more efficient flash_attn
189
191
@@ -196,5 +198,6 @@ def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[A
196
198
model_inputs ["images" ].append (packed_images or None )
197
199
model_inputs ["videos" ].append (packed_videos or None )
198
200
model_inputs ["audios" ].append (packed_audios or None )
201
+ model_inputs ["position_ids" ].append (packed_position_ids or None )
199
202
200
203
return model_inputs
0 commit comments