Skip to content

Commit 12fdcbf

Browse files
author
ssbuild
committed
deepspeed precision
Signed-off-by: ssbuild <[email protected]>
1 parent 1a47a80 commit 12fdcbf

File tree

4 files changed

+51
-23
lines changed

4 files changed

+51
-23
lines changed

config/reward_config/main.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
if global_args['quantization_config'] is not None:
2525
global_args['quantization_config'].load_in_4bit = load_in_bit == 4
2626
global_args['quantization_config'].load_in_8bit = load_in_bit == 8
27+
if load_in_bit == 0:
28+
global_args["quantization_config"] = None
2729

2830
if enable_lora:
2931
enable_ptv2 = False
@@ -49,15 +51,15 @@
4951
train_info_args['use_fast_tokenizer'] = True
5052

5153

52-
def get_deepspeed_config():
54+
def get_deepspeed_config(precision='fp16'):
5355
'''
5456
lora prompt finetuning 使用 deepspeed_offload.json
5557
普通finetuning 使用deepspeed.json
5658
'''
5759
# 是否开启deepspeed
5860
if not enable_deepspeed:
5961
return None
60-
62+
precision = str(precision).lower()
6163
# 选择 deepspeed 配置文件
6264
is_need_update_config = False
6365
if enable_lora:
@@ -79,5 +81,17 @@ def get_deepspeed_config():
7981
optimizer['params']['eps'] = train_info_args.get('adam_epsilon', 1e-8)
8082
# deepspeed_offload 优化器有效
8183
train_info_args['optimizer'] = optimizer['type']
84+
85+
if precision == 'bf16':
86+
if 'fp16' in deepspeed_config:
87+
deepspeed_config["fp16"]["enbale"] = False
88+
if 'bf16' in deepspeed_config:
89+
deepspeed_config["bf16"]["enbale"] = True
90+
else:
91+
deepspeed_config['bf16'] = {"enbale": True}
92+
elif precision == 'fp16':
93+
if 'bf16' in deepspeed_config:
94+
deepspeed_config["bf16"]["enbale"] = False
95+
8296
return deepspeed_config
8397

config/rlhf_config/main.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
if global_args['quantization_config'] is not None:
2323
global_args['quantization_config'].load_in_4bit = load_in_bit == 4
2424
global_args['quantization_config'].load_in_8bit = load_in_bit == 8
25+
if load_in_bit == 0:
26+
global_args["quantization_config"] = None
2527

2628
if enable_lora:
2729
enable_ptv2 = False
@@ -47,15 +49,15 @@
4749
train_info_args['use_fast_tokenizer'] = True
4850

4951

50-
def get_deepspeed_config():
52+
def get_deepspeed_config(precision='fp16'):
5153
'''
5254
lora prompt finetuning 使用 deepspeed_offload.json
5355
普通finetuning 使用deepspeed.json
5456
'''
5557
# 是否开启deepspeed
5658
if not enable_deepspeed:
5759
return None
58-
60+
precision = str(precision).lower()
5961
# 选择 deepspeed 配置文件
6062
is_need_update_config = False
6163
if enable_lora:
@@ -78,5 +80,17 @@ def get_deepspeed_config():
7880

7981
# deepspeed_offload 优化器有效
8082
train_info_args['optimizer'] = optimizer['type']
83+
84+
if precision == 'bf16':
85+
if 'fp16' in deepspeed_config:
86+
deepspeed_config["fp16"]["enbale"] = False
87+
if 'bf16' in deepspeed_config:
88+
deepspeed_config["bf16"]["enbale"] = True
89+
else:
90+
deepspeed_config['bf16'] = {"enbale": True}
91+
elif precision == 'fp16':
92+
if 'bf16' in deepspeed_config:
93+
deepspeed_config["bf16"]["enbale"] = False
94+
8195
return deepspeed_config
8296

stage2_reward/train.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,16 @@
3434

3535
dataHelper.make_dataset_all()
3636

37+
is_bf16_supported = torch.cuda.is_bf16_supported()
38+
# 精度 根据实际情况做调整
39+
if is_bf16_supported:
40+
precision = 'bf16'
41+
else:
42+
precision = '16'
3743

38-
39-
deepspeed_config = get_deepspeed_config()
44+
if global_args["quantization_config"] is not None and global_args["quantization_config"].load_in_8bit:
45+
precision = "32"
46+
deepspeed_config = get_deepspeed_config(precision)
4047
strategy = 'ddp' if torch.cuda.device_count() > 1 else 'auto'
4148
if deepspeed_config is not None and len(deepspeed_config):
4249
strategy = DeepSpeedStrategy(config=deepspeed_config, )
@@ -56,15 +63,7 @@
5663
if deepspeed_config is not None and len(deepspeed_config):
5764
strategy = DeepSpeedStrategy(config=deepspeed_config, )
5865

59-
is_bf16_supported = torch.cuda.is_bf16_supported()
60-
# 精度 根据实际情况做调整
61-
if is_bf16_supported:
62-
precision = 'bf16'
63-
else:
64-
precision = '16'
6566

66-
if global_args["quantization_config"] is not None and global_args["quantization_config"].load_in_8bit:
67-
precision = "32"
6867

6968
trainer = Trainer(
7069
callbacks=[checkpoint_callback, LearningRateMonitor(logging_interval='step')],

stage3_rlhf/train.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,16 @@
3636

3737
dataHelper.make_dataset_all()
3838

39-
deepspeed_config = get_deepspeed_config()
39+
is_bf16_supported = torch.cuda.is_bf16_supported()
40+
# 精度 根据实际情况做调整
41+
if is_bf16_supported:
42+
precision = 'bf16'
43+
else:
44+
precision = '16'
45+
46+
if global_args["quantization_config"] is not None and global_args["quantization_config"].load_in_8bit:
47+
precision = "32"
48+
deepspeed_config = get_deepspeed_config(precision)
4049
strategy = 'ddp' if torch.cuda.device_count() >= 1 else 'auto'
4150
if deepspeed_config is not None and len(deepspeed_config):
4251
strategy = DeepSpeedStrategy(config=deepspeed_config, )
@@ -53,15 +62,7 @@
5362
training_args=training_args,
5463
lora_args=lora_args, )
5564

56-
is_bf16_supported = torch.cuda.is_bf16_supported()
57-
# 精度 根据实际情况做调整
58-
if is_bf16_supported:
59-
precision = 'bf16'
60-
else:
61-
precision = '16'
6265

63-
if global_args["quantization_config"] is not None and global_args["quantization_config"].load_in_8bit:
64-
precision = "32"
6566
trainer = PPOTrainer(
6667
callbacks=[ checkpoint_callback],
6768
max_epochs=training_args.max_epochs,

0 commit comments

Comments
 (0)