diff --git a/README.md b/README.md index 109099fa..e8d53be2 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,8 @@ MegaBlocks dMoEs outperform MoEs trained with [Tutel](https://github.com/microso # :building_construction: Installation +Note: this assumes you have `numpy` and `torch` installed + **Training models with Megatron-LM:** We recommend using NGC's [`nvcr.io/nvidia/pytorch:23.01-py3`](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch/tags) PyTorch container. The [Dockerfile](Dockerfile) builds on this image with additional dependencies. To build the image, run `docker build . -t megablocks-dev` and then `bash docker.sh` to launch the container. Once inside the container, install MegaBlocks with `pip install .`. See [Usage](#steam_locomotive-usage) for instructions on training MoEs with MegaBlocks + Megatron-LM. **Using MegaBlocks in other packages:** To install the MegaBlocks package for use in other frameworks, run `pip install megablocks`. diff --git a/megablocks/backend/__init__.py b/megablocks/backend/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/megablocks/backend/__init__.py @@ -0,0 +1 @@ + diff --git a/megablocks/layers/common.py b/megablocks/layers/common.py index eb1cf397..c15bf02a 100644 --- a/megablocks/layers/common.py +++ b/megablocks/layers/common.py @@ -8,3 +8,15 @@ def dtype(args : Arguments): elif args.bf16: dtype = torch.bfloat16 return dtype + + +def cast_if_autocast_enabled(tensor): + if torch.is_autocast_enabled(): + if tensor.device.type == 'cuda': + dtype = torch.get_autocast_gpu_dtype() + elif tensor.device.type == 'cpu': + dtype = torch.get_autocast_cpu_dtype() + else: + raise NotImplementedError() + return tensor.to(dtype=dtype) + return tensor diff --git a/megablocks/layers/dmoe.py b/megablocks/layers/dmoe.py index 194047f0..3b31dde8 100644 --- a/megablocks/layers/dmoe.py +++ b/megablocks/layers/dmoe.py @@ -148,6 +148,7 @@ def forward_once(self, x, expert_weights, top_experts): # Perform the expert computation. x = self.mlp(x, topo) + x = common.cast_if_autocast_enabled(x) # Un-route the data for the MoE output. x = ops.padded_scatter( @@ -195,6 +196,7 @@ def permute_and_compute( # Perform the expert computation. x = self.mlp(x, topo) + x = common.cast_if_autocast_enabled(x) # Un-route the data for the MoE output. return ops.padded_scatter( diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index c32bd549..2ff01807 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -70,10 +70,9 @@ def batched_load_balancing_loss(args : Arguments): # the correct types and formats for the dot product. if args.moe_lbl_in_fp32: expert_scores = torch.cat(expert_scores, dim=1).float().mean(dim=0) - tokens_per_expert = torch.cat(tokens_per_expert).float() else: expert_scores = torch.cat(expert_scores, dim=1).mean(dim=0) - tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) + tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) expected_values = num_layers_per_pipeline_stage * args.moe_num_experts assert tokens_per_expert.numel() == expected_values @@ -147,7 +146,7 @@ def load_balancing_loss(self, tokens_per_expert, expert_scores): assert num_experts == self.num_experts scale = self.num_experts / (tokens * self.top_k) return scale * torch.dot( - tokens_per_expert.half(), + tokens_per_expert.to(expert_scores.dtype), expert_scores.mean(dim=0)) def indices_and_bins(self, top_expert): @@ -191,6 +190,7 @@ def permute_and_compute( # Perform the expert computation. Note that we don't # use biases for these linear operations. x = self.mlp(x) + x = common.cast_if_autocast_enabled(x) # Un-route the data for the MoE output. return ops.binned_scatter( @@ -387,6 +387,7 @@ def parallel_forward_once(self, x, expert_weights, top_experts): return x, tokens_per_expert.flatten() def forward(self, x): + x = common.cast_if_autocast_enabled(x) sl, bs, hs = x.size() # Compute the expert scores and assignments. diff --git a/setup.py b/setup.py index f41163ab..8388d736 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,10 @@ from setuptools import setup, find_packages +from torch import cuda from torch.utils.cpp_extension import BuildExtension, CUDAExtension +_dc = cuda.get_device_capability() +_dc = f"{_dc[0]}{_dc[1]}" ext_modules = [ CUDAExtension( "megablocks_ops", @@ -12,11 +15,22 @@ "nvcc": [ "--ptxas-options=-v", "--optimize=2", - "--generate-code=arch=compute_80,code=sm_80" + f"--generate-code=arch=compute_{_dc},code=sm_{_dc}" ] }) ] +install_requires=[ + 'stanford-stk @ git+https://github.com/stanford-futuredata/stk.git@main', +] + +extra_deps = {} + +extra_deps['dev'] = [ + 'absl-py', +] + +extra_deps['all'] = set(dep for deps in extra_deps.values() for dep in deps) setup( name="megablocks", @@ -35,10 +49,6 @@ packages=find_packages(), ext_modules=ext_modules, cmdclass={"build_ext": BuildExtension}, - install_requires=[ - "absl-py", - "numpy", - "torch", - "stanford-stk @ git+https://github.com/stanford-futuredata/stk.git@main" - ], + install_requires=install_requires, + extras_require=extra_deps, )