File tree 2 files changed +2
-2
lines changed
2 files changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -129,6 +129,7 @@ def step(self, closure: Callable = None):
129
129
130
130
# APOLLO Step 1: Calculate gradient into low rank space.
131
131
if "rank" in group :
132
+ norm_dim = 0 if grad .shape [0 ] < grad .shape [1 ] else 1 # low-rank dimension
132
133
if "projector" not in state :
133
134
state ["projector" ] = self ._initialize_projector (group , state )
134
135
grad = state ["projector" ].project (grad , state ["step" ])
@@ -164,7 +165,6 @@ def step(self, closure: Callable = None):
164
165
# APOLLO Step 3: Obtain approximated gradient scaling factor, channel-wise or tensor-wise.
165
166
if "rank" in group :
166
167
if group ['scale_type' ] == 'channel' :
167
- norm_dim = 0 if norm_grad .shape [0 ] < norm_grad .shape [1 ] else 1
168
168
grad_scaling_factor = (
169
169
torch .norm (norm_grad , dim = norm_dim ) /
170
170
(torch .norm (grad , dim = norm_dim ) + 1e-8 )
Original file line number Diff line number Diff line change 8
8
9
9
setup (
10
10
name = "apollo-torch" ,
11
- version = "1.0.2 " ,
11
+ version = "1.0.3 " ,
12
12
description = "APOLLO: SGD-like Memory, AdamW-level Performance" ,
13
13
long_description = long_description ,
14
14
long_description_content_type = "text/markdown" ,
You can’t perform that action at this time.
0 commit comments