File tree 4 files changed +51
-23
lines changed
4 files changed +51
-23
lines changed Original file line number Diff line number Diff line change 24
24
if global_args ['quantization_config' ] is not None :
25
25
global_args ['quantization_config' ].load_in_4bit = load_in_bit == 4
26
26
global_args ['quantization_config' ].load_in_8bit = load_in_bit == 8
27
+ if load_in_bit == 0 :
28
+ global_args ["quantization_config" ] = None
27
29
28
30
if enable_lora :
29
31
enable_ptv2 = False
49
51
train_info_args ['use_fast_tokenizer' ] = True
50
52
51
53
52
- def get_deepspeed_config ():
54
+ def get_deepspeed_config (precision = 'fp16' ):
53
55
'''
54
56
lora prompt finetuning 使用 deepspeed_offload.json
55
57
普通finetuning 使用deepspeed.json
56
58
'''
57
59
# 是否开启deepspeed
58
60
if not enable_deepspeed :
59
61
return None
60
-
62
+ precision = str ( precision ). lower ()
61
63
# 选择 deepspeed 配置文件
62
64
is_need_update_config = False
63
65
if enable_lora :
@@ -79,5 +81,17 @@ def get_deepspeed_config():
79
81
optimizer ['params' ]['eps' ] = train_info_args .get ('adam_epsilon' , 1e-8 )
80
82
# deepspeed_offload 优化器有效
81
83
train_info_args ['optimizer' ] = optimizer ['type' ]
84
+
85
+ if precision == 'bf16' :
86
+ if 'fp16' in deepspeed_config :
87
+ deepspeed_config ["fp16" ]["enbale" ] = False
88
+ if 'bf16' in deepspeed_config :
89
+ deepspeed_config ["bf16" ]["enbale" ] = True
90
+ else :
91
+ deepspeed_config ['bf16' ] = {"enbale" : True }
92
+ elif precision == 'fp16' :
93
+ if 'bf16' in deepspeed_config :
94
+ deepspeed_config ["bf16" ]["enbale" ] = False
95
+
82
96
return deepspeed_config
83
97
Original file line number Diff line number Diff line change 22
22
if global_args ['quantization_config' ] is not None :
23
23
global_args ['quantization_config' ].load_in_4bit = load_in_bit == 4
24
24
global_args ['quantization_config' ].load_in_8bit = load_in_bit == 8
25
+ if load_in_bit == 0 :
26
+ global_args ["quantization_config" ] = None
25
27
26
28
if enable_lora :
27
29
enable_ptv2 = False
47
49
train_info_args ['use_fast_tokenizer' ] = True
48
50
49
51
50
- def get_deepspeed_config ():
52
+ def get_deepspeed_config (precision = 'fp16' ):
51
53
'''
52
54
lora prompt finetuning 使用 deepspeed_offload.json
53
55
普通finetuning 使用deepspeed.json
54
56
'''
55
57
# 是否开启deepspeed
56
58
if not enable_deepspeed :
57
59
return None
58
-
60
+ precision = str ( precision ). lower ()
59
61
# 选择 deepspeed 配置文件
60
62
is_need_update_config = False
61
63
if enable_lora :
@@ -78,5 +80,17 @@ def get_deepspeed_config():
78
80
79
81
# deepspeed_offload 优化器有效
80
82
train_info_args ['optimizer' ] = optimizer ['type' ]
83
+
84
+ if precision == 'bf16' :
85
+ if 'fp16' in deepspeed_config :
86
+ deepspeed_config ["fp16" ]["enbale" ] = False
87
+ if 'bf16' in deepspeed_config :
88
+ deepspeed_config ["bf16" ]["enbale" ] = True
89
+ else :
90
+ deepspeed_config ['bf16' ] = {"enbale" : True }
91
+ elif precision == 'fp16' :
92
+ if 'bf16' in deepspeed_config :
93
+ deepspeed_config ["bf16" ]["enbale" ] = False
94
+
81
95
return deepspeed_config
82
96
Original file line number Diff line number Diff line change 34
34
35
35
dataHelper .make_dataset_all ()
36
36
37
+ is_bf16_supported = torch .cuda .is_bf16_supported ()
38
+ # 精度 根据实际情况做调整
39
+ if is_bf16_supported :
40
+ precision = 'bf16'
41
+ else :
42
+ precision = '16'
37
43
38
-
39
- deepspeed_config = get_deepspeed_config ()
44
+ if global_args ["quantization_config" ] is not None and global_args ["quantization_config" ].load_in_8bit :
45
+ precision = "32"
46
+ deepspeed_config = get_deepspeed_config (precision )
40
47
strategy = 'ddp' if torch .cuda .device_count () > 1 else 'auto'
41
48
if deepspeed_config is not None and len (deepspeed_config ):
42
49
strategy = DeepSpeedStrategy (config = deepspeed_config , )
56
63
if deepspeed_config is not None and len (deepspeed_config ):
57
64
strategy = DeepSpeedStrategy (config = deepspeed_config , )
58
65
59
- is_bf16_supported = torch .cuda .is_bf16_supported ()
60
- # 精度 根据实际情况做调整
61
- if is_bf16_supported :
62
- precision = 'bf16'
63
- else :
64
- precision = '16'
65
66
66
- if global_args ["quantization_config" ] is not None and global_args ["quantization_config" ].load_in_8bit :
67
- precision = "32"
68
67
69
68
trainer = Trainer (
70
69
callbacks = [checkpoint_callback , LearningRateMonitor (logging_interval = 'step' )],
Original file line number Diff line number Diff line change 36
36
37
37
dataHelper .make_dataset_all ()
38
38
39
- deepspeed_config = get_deepspeed_config ()
39
+ is_bf16_supported = torch .cuda .is_bf16_supported ()
40
+ # 精度 根据实际情况做调整
41
+ if is_bf16_supported :
42
+ precision = 'bf16'
43
+ else :
44
+ precision = '16'
45
+
46
+ if global_args ["quantization_config" ] is not None and global_args ["quantization_config" ].load_in_8bit :
47
+ precision = "32"
48
+ deepspeed_config = get_deepspeed_config (precision )
40
49
strategy = 'ddp' if torch .cuda .device_count () >= 1 else 'auto'
41
50
if deepspeed_config is not None and len (deepspeed_config ):
42
51
strategy = DeepSpeedStrategy (config = deepspeed_config , )
53
62
training_args = training_args ,
54
63
lora_args = lora_args , )
55
64
56
- is_bf16_supported = torch .cuda .is_bf16_supported ()
57
- # 精度 根据实际情况做调整
58
- if is_bf16_supported :
59
- precision = 'bf16'
60
- else :
61
- precision = '16'
62
65
63
- if global_args ["quantization_config" ] is not None and global_args ["quantization_config" ].load_in_8bit :
64
- precision = "32"
65
66
trainer = PPOTrainer (
66
67
callbacks = [ checkpoint_callback ],
67
68
max_epochs = training_args .max_epochs ,
You can’t perform that action at this time.
0 commit comments