File tree 2 files changed +8
-3
lines changed
2 files changed +8
-3
lines changed Original file line number Diff line number Diff line change @@ -11,7 +11,7 @@ swift sft \
11
11
--train_type lora \
12
12
--dataset ' sentence-transformers/stsb' \
13
13
--torch_dtype bfloat16 \
14
- --num_train_epochs 10 \
14
+ --num_train_epochs 1 \
15
15
--per_device_train_batch_size 2 \
16
16
--per_device_eval_batch_size 1 \
17
17
--gradient_accumulation_steps $( expr 64 / $nproc_per_node ) \
Original file line number Diff line number Diff line change @@ -259,10 +259,15 @@ def set_lang(cls, lang):
259
259
def get_choices_from_dataclass (dataclass ):
260
260
choice_dict = {}
261
261
for f in fields (dataclass ):
262
+ default_value = f .default
263
+ if 'MISSING_TYPE' in str (default_value ):
264
+ default_value = None
262
265
if 'choices' in f .metadata :
263
- choice_dict [f .name ] = f .metadata ['choices' ]
266
+ choice_dict [f .name ] = list ( f .metadata ['choices' ])
264
267
if 'Literal' in str (f .type ) and typing .get_args (f .type ):
265
- choice_dict [f .name ] = typing .get_args (f .type )
268
+ choice_dict [f .name ] = list (typing .get_args (f .type ))
269
+ if f .name in choice_dict and default_value not in choice_dict [f .name ]:
270
+ choice_dict [f .name ].insert (0 , default_value )
266
271
return choice_dict
267
272
268
273
@staticmethod
You can’t perform that action at this time.
0 commit comments