Skip to content

Commit 6640ebd

Browse files
committed
merge conflict
1 parent f2fef59 commit 6640ebd

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

megablocks/layers/mlp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from megablocks.layers import weight_parallel as wp
55
from megablocks.layers.arguments import Arguments, InitFn
66
from megablocks import turbo_util as turbo
7-
from megablocks import grouped_gemm_util as grouped_gemm
7+
from megablocks import grouped_gemm_util as gg
88
import stk
99
import torch
1010
import torch.nn.functional as F
@@ -522,6 +522,6 @@ def forward(self, x, tokens_per_expert):
522522
self.args.quantize_rematerialize_num_bits)
523523

524524
# Compute the MLP.
525-
x = grouped_gemm.gmm(x, w1, batch_sizes, trans_b=True)
525+
x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
526526
x = F.gelu(x, approximate="tanh")
527-
return grouped_gemm.gmm(x, w2, batch_sizes)
527+
return gg.ops.gmm(x, w2, batch_sizes)

0 commit comments

Comments
 (0)