Skip to content

Commit 52b3991

Browse files
committed
set all2all dtype using amp precision
1 parent bc47c6d commit 52b3991

File tree

3 files changed

+16
-0
lines changed

3 files changed

+16
-0
lines changed

megablocks/layers/common.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,15 @@ def dtype(args : Arguments):
88
elif args.bf16:
99
dtype = torch.bfloat16
1010
return dtype
11+
12+
13+
def cast_if_autocast_enabled(tensor):
14+
if torch.is_autocast_enabled():
15+
if tensor.device.type == 'cuda':
16+
dtype = torch.get_autocast_gpu_dtype()
17+
elif tensor.device.type == 'cpu':
18+
dtype = torch.get_autocast_cpu_dtype()
19+
else:
20+
raise NotImplementedError()
21+
return tensor.to(dtype=dtype)
22+
return tensor

megablocks/layers/dmoe.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def forward_once(self, x, top_expert):
136136

137137
# Perform the expert computation.
138138
x = self.mlp(x, topo)
139+
x = common.cast_if_autocast_enabled(x)
139140

140141
# Un-route the data for the MoE output.
141142
x = ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
@@ -162,6 +163,7 @@ def permute_and_compute(
162163

163164
# Perform the expert computation.
164165
x = self.mlp(x, topo)
166+
x = common.cast_if_autocast_enabled(x)
165167

166168
# Un-route the data for the MoE output.
167169
return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)

megablocks/layers/moe.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def permute_and_compute(
185185
# Perform the expert computation. Note that we don't
186186
# use biases for these linear operations.
187187
x = self.mlp(x)
188+
x = common.cast_if_autocast_enabled(x)
188189

189190
# Un-route the data for the MoE output.
190191
return ops.binned_scatter(x, indices, bins)
@@ -344,6 +345,7 @@ def parallel_forward_once(self, x, top_expert):
344345
return x, tokens_per_expert.flatten()
345346

346347
def forward(self, x):
348+
x = common.cast_if_autocast_enabled(x)
347349
sl, bs, hs = x.size()
348350

349351
# Compute the top-1 expert routing.

0 commit comments

Comments
 (0)