Skip to content

Commit 8950578

Browse files
committed
Signed-off-by: ssbuild <[email protected]>
1 parent 9c69177 commit 8950578

15 files changed

+56
-432
lines changed

args.MD

+1-2
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
enable_deepspeed = False
66
enable_ptv2 = False
77
enable_lora = True
8-
enable_int8 = False # qlora int8
9-
enable_int4 = False # qlora int4
8+
load_in_bit = 0 # 4 load_in_4bit, 8 load_in_8bit other 0
109

1110
### 单机多卡
1211
```text

config/reward_config/main.py

+6-10
Original file line numberDiff line numberDiff line change
@@ -11,28 +11,24 @@
1111
enable_deepspeed = False
1212
enable_ptv2 = False
1313
enable_lora = True
14-
enable_int8 = False # qlora int8
15-
enable_int4 = False # qlora int4
14+
load_in_bit = 0 # 4 load_in_4bit, 8 load_in_8bit other 0
1615

1716

1817
if enable_lora:
1918
from config.reward_config.reward_config_lora import *
19+
elif enable_ptv2:
20+
raise NotImplemented
2021
else:
2122
from config.reward_config.reward_config import *
2223

2324

2425

2526
if enable_lora:
2627
enable_ptv2 = False
27-
if enable_int4:
28-
global_args['load_in_4bit'] = True
29-
global_args['load_in_8bit'] = False
30-
31-
if enable_int8:
32-
global_args['load_in_4bit'] = False
33-
global_args['load_in_8bit'] = True
28+
global_args['load_in_4bit'] = load_in_bit == 4
29+
global_args['load_in_8bit'] = load_in_bit == 8
3430

35-
if not enable_int4:
31+
if global_args['load_in_4bit']:
3632
global_args['quantization_config'] = None
3733

3834
#检查lora adalora是否开启

models/ppo_model.py

-127
This file was deleted.

models/reward_model.py

-202
This file was deleted.

0 commit comments

Comments
 (0)