Skip to content

Commit de6a69c

Browse files
committed
compatibility with hf test case
1 parent 762b5ba commit de6a69c

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

apollo_torch/apollo.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def step(self, closure: Callable = None):
129129

130130
# APOLLO Step 1: Calculate gradient into low rank space.
131131
if "rank" in group:
132+
norm_dim = 0 if grad.shape[0] < grad.shape[1] else 1 # low-rank dimension
132133
if "projector" not in state:
133134
state["projector"] = self._initialize_projector(group, state)
134135
grad = state["projector"].project(grad, state["step"])
@@ -164,7 +165,6 @@ def step(self, closure: Callable = None):
164165
# APOLLO Step 3: Obtain approximated gradient scaling factor, channel-wise or tensor-wise.
165166
if "rank" in group:
166167
if group['scale_type'] == 'channel':
167-
norm_dim = 0 if norm_grad.shape[0] < norm_grad.shape[1] else 1
168168
grad_scaling_factor = (
169169
torch.norm(norm_grad, dim=norm_dim) /
170170
(torch.norm(grad, dim=norm_dim) + 1e-8)

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
setup(
1010
name="apollo-torch",
11-
version="1.0.2",
11+
version="1.0.3",
1212
description="APOLLO: SGD-like Memory, AdamW-level Performance",
1313
long_description=long_description,
1414
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)