Skip to content

Commit b6d2186

Browse files
committed
fix main_train minor issue
1 parent 188f1b6 commit b6d2186

8 files changed

+24
-8
lines changed

README.md

+15-4
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,21 @@ For APOLLO and APOLLO-Mini, we have the following arguments
110110
- **`channel`**: Applies gradient scaling at the channel level (APOLLO)
111111
- **`tensor`**: Applies gradient scaling at the tensor level (APOLLO-Mini).
112112

113-
#### `scale`
114-
- Governs the scaling factor for gradient updates. Can be tuned for better performance.
115-
- `1` for APOLLO by default (validated on A100).
116-
- `128` for APOLLO-Mini by default. You can scale it larger, especially when the model is large.
113+
#### **`scale`**
114+
The `scale` parameter plays a crucial role in heuristically adjusting gradient updates to compensate for scaling factor approximation errors arising from the use of a lower rank. Proper tuning of this parameter can significantly improve performance:
115+
- **`1`**: Default value for APOLLO (validated on A100 GPUs).
116+
- **`128`**: Default value for APOLLO-Mini. For larger models, experimenting with higher values is recommended.
117+
118+
#### `--scale_front`
119+
120+
To stabilize training, we adopt the **Norm-Growth Limiter (NL)** from [Fira](https://github.com/xichen-fy/Fira), which has shown to be slightly more effective than traditional gradient clipping.
121+
122+
There are two ways to apply the Norm-Growth Limiter based on when it's used relative to the heuristical (`scale`):
123+
1. **After Scaling**: NL is applied after the gradient is multiplied by the `scale`.
124+
- Recommended for smaller models or when training involves fewer warmup steps.
125+
- Enable this by setting `--scale_front`.
126+
2. **Before Scaling**: NL is applied before the gradient is scaled.
127+
- With sufficient warmup steps, both methods yield similar performance for large models.
117128

118129
---
119130

main_pretrain.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def main(args):
289289
total_svd_count = 0
290290

291291
for batch_idx, batch in enumerate(dataloader):
292-
if update_step != 0 and batch_idx <= args.gradient_accumulation * update_step:
292+
if update_step != 0 and batch_idx < args.gradient_accumulation * update_step:
293293
continue # skipping learned data when resuming from checkpointing
294294

295295
global_step += 1

scripts/pretrain_c4/llama_130m_apollo.sh

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# LLaMA-130M, APOLLO, 4 A100, 1 Node
2-
num_rank=256
2+
num_rank=192 # use exact 1/4 of llama 130M model dimension
33
scale_type=channel
44
proj_type=random
55
apollo_scale=1 # A6000 uses a smaller one to avoid loss spikes
@@ -14,6 +14,7 @@ torchrun --standalone --nproc_per_node 4 main_pretrain.py \
1414
--warmup_steps 2000 \
1515
--num_training_steps 20000 \
1616
--optimizer apollo_adamw \
17+
--scale_front \
1718
--apollo_scale ${apollo_scale} \
1819
--rank ${num_rank} \
1920
--scale_type ${scale_type} \

scripts/pretrain_c4/llama_130m_apollo_mini.sh

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
num_rank=1
33
scale_type=tensor
44
proj_type=random
5-
apollo_scale=128
5+
apollo_scale=192.0 # exact 1/4 of llama model dimension
66

77
torchrun --standalone --nproc_per_node 4 main_pretrain.py \
88
--model_config configs/llama_130m.json \
@@ -14,6 +14,7 @@ torchrun --standalone --nproc_per_node 4 main_pretrain.py \
1414
--warmup_steps 2000 \
1515
--num_training_steps 20000 \
1616
--optimizer apollo_adamw \
17+
--scale_front \
1718
--apollo_scale ${apollo_scale} \
1819
--rank ${num_rank} \
1920
--scale_type ${scale_type} \

scripts/pretrain_c4/llama_60m_apollo.sh

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ torchrun --standalone --nproc_per_node 1 main_pretrain.py \
1414
--warmup_steps 1000 \
1515
--num_training_steps 10000 \
1616
--optimizer apollo_adamw \
17+
--scale_front \
1718
--apollo_scale ${apollo_scale} \
1819
--rank ${num_rank} \
1920
--scale_type ${scale_type} \

scripts/pretrain_c4/llama_60m_apollo_mini.sh

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ torchrun --standalone --nproc_per_node 1 main_pretrain.py \
1616
--warmup_steps 1000 \
1717
--num_training_steps 10000 \
1818
--optimizer apollo_adamw \
19+
--scale_front \
1920
--apollo_scale ${apollo_scale} \
2021
--rank ${num_rank} \
2122
--scale_type ${scale_type} \

utils/argparse.py

+1
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def parse_args(args):
7777
parser.add_argument("--proj", type=str, default="random") # "random" or "svd"
7878
parser.add_argument("--scale_type", type=str, default="tensor") # "tensor" or "channel"
7979
parser.add_argument("--apollo_scale", type=float, default=1.0) # scale for gradient scaling factor
80+
parser.add_argument("--scale_front", action='store_true') # put the nl before or after scale the gradient with the apollo_scale
8081

8182
args = parser.parse_args(args)
8283
args = check_args_torchrun_main(args)

utils/setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def setup_optimization(args, model, trainable_params, param_groups, id_lowrank_p
147147
optimizer = GaLoreAdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay)
148148

149149
elif args.optimizer.lower() == "apollo_adamw":
150-
optimizer = APOLLOAdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay)
150+
optimizer = APOLLOAdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay, scale_front=args.scale_front)
151151

152152
elif args.optimizer.lower() == "q_apollo":
153153
optimizer = QAPOLLOAdamW(

0 commit comments

Comments
 (0)