Skip to content

Commit df3c583

Browse files
committed
fix bugs
1 parent 84bf004 commit df3c583

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

examples/train/embedding/train_gte.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ swift sft \
1111
--train_type lora \
1212
--dataset 'sentence-transformers/stsb' \
1313
--torch_dtype bfloat16 \
14-
--num_train_epochs 10 \
14+
--num_train_epochs 1 \
1515
--per_device_train_batch_size 2 \
1616
--per_device_eval_batch_size 1 \
1717
--gradient_accumulation_steps $(expr 64 / $nproc_per_node) \

swift/ui/base.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -259,10 +259,15 @@ def set_lang(cls, lang):
259259
def get_choices_from_dataclass(dataclass):
260260
choice_dict = {}
261261
for f in fields(dataclass):
262+
default_value = f.default
263+
if 'MISSING_TYPE' in str(default_value):
264+
default_value = None
262265
if 'choices' in f.metadata:
263-
choice_dict[f.name] = f.metadata['choices']
266+
choice_dict[f.name] = list(f.metadata['choices'])
264267
if 'Literal' in str(f.type) and typing.get_args(f.type):
265-
choice_dict[f.name] = typing.get_args(f.type)
268+
choice_dict[f.name] = list(typing.get_args(f.type))
269+
if f.name in choice_dict and default_value not in choice_dict[f.name]:
270+
choice_dict[f.name].insert(0, default_value)
266271
return choice_dict
267272

268273
@staticmethod

0 commit comments

Comments
 (0)