diff --git a/bagua/torch_api/algorithms/async_model_average.py b/bagua/torch_api/algorithms/async_model_average.py index e76f6ff9e..f3dcfabe4 100644 --- a/bagua/torch_api/algorithms/async_model_average.py +++ b/bagua/torch_api/algorithms/async_model_average.py @@ -78,15 +78,19 @@ def __init__( process_ranks, stream=torch.cuda.Stream(priority=-1) ) - def tensors_to_buckets(self, tensors: List[List[BaguaTensor]]) -> List[BaguaBucket]: + def tensors_to_buckets( + self, tensors: List[List[BaguaTensor]], do_flatten: bool + ) -> List[BaguaBucket]: + # TODO: async algorithm conflict with fused optimizer, can only support flattened inplace bucket. + assert do_flatten, "async does not support `do_flatten=False` at present." if self.step_id < self.warmup_steps: - return super().tensors_to_buckets(tensors) + return super().tensors_to_buckets(tensors, do_flatten) all_tensors = [] for idx, bucket in enumerate(tensors): all_tensors.extend(bucket) - bagua_bucket = BaguaBucket(all_tensors, flatten=True, name=str(0)) + bagua_bucket = BaguaBucket(all_tensors, flatten=do_flatten, name=str(0)) return [bagua_bucket] diff --git a/bagua/torch_api/algorithms/base.py b/bagua/torch_api/algorithms/base.py index 2c8ddd79d..5f3b60694 100644 --- a/bagua/torch_api/algorithms/base.py +++ b/bagua/torch_api/algorithms/base.py @@ -73,7 +73,9 @@ def init_tensors(self, bagua_module: BaguaModule) -> List[BaguaTensor]: ), "tensor names should be unique" return tensors - def tensors_to_buckets(self, tensors: List[List[BaguaTensor]]) -> List[BaguaBucket]: + def tensors_to_buckets( + self, tensors: List[List[BaguaTensor]], do_flatten: bool + ) -> List[BaguaBucket]: """ Given the bucketing suggestion from Bagua, return the actual Bagua buckets. The default implementation follows the suggestion to do the bucketing. @@ -82,6 +84,7 @@ def tensors_to_buckets(self, tensors: List[List[BaguaTensor]]) -> List[BaguaBuck tensors: Bagua tensors grouped in different lists, representing Bagua's suggestion on how to bucketing the tensors. + do_flatten: Whether to flatten the Bagua buckets. Returns: A list of Bagua buckets. @@ -89,7 +92,7 @@ def tensors_to_buckets(self, tensors: List[List[BaguaTensor]]) -> List[BaguaBuck bagua_buckets = [] for idx, bucket in enumerate(tensors): bagua_bucket = BaguaBucket( - bucket, flatten=True, name=str(idx) + bucket, flatten=do_flatten, name=str(idx) ) # TODO: check duplicated names bagua_buckets.append(bagua_bucket) return bagua_buckets diff --git a/bagua/torch_api/algorithms/bytegrad.py b/bagua/torch_api/algorithms/bytegrad.py index 0378632b8..a9f14a8cc 100644 --- a/bagua/torch_api/algorithms/bytegrad.py +++ b/bagua/torch_api/algorithms/bytegrad.py @@ -30,24 +30,14 @@ def __init__( self.hierarchical = hierarchical self.average = average - def tensors_to_buckets(self, tensors: List[List[BaguaTensor]]) -> List[BaguaBucket]: - """ - Given the bucketing suggestion from Bagua, return the actual Bagua buckets. - The default implementation follows the suggestion to do the bucketing. - - Args: - tensors: Bagua tensors grouped in different - lists, representing Bagua's suggestion on how to bucketing the - tensors. - - Returns: - A list of Bagua buckets. - """ + def tensors_to_buckets( + self, tensors: List[List[BaguaTensor]], do_flatten: bool + ) -> List[BaguaBucket]: bagua_buckets = [] for idx, bucket in enumerate(tensors): bagua_bucket = BaguaBucket( bucket, - flatten=True, + flatten=do_flatten, name=str(idx), alignment=self.process_group.get_global_communicator().nranks(), ) diff --git a/bagua/torch_api/algorithms/decentralized.py b/bagua/torch_api/algorithms/decentralized.py index fc2130392..b9c5c5ec1 100644 --- a/bagua/torch_api/algorithms/decentralized.py +++ b/bagua/torch_api/algorithms/decentralized.py @@ -48,12 +48,14 @@ def init_tensors(self, bagua_module: BaguaModule) -> List[BaguaTensor]: ] return self.tensors - def tensors_to_buckets(self, tensors: List[List[BaguaTensor]]) -> List[BaguaBucket]: + def tensors_to_buckets( + self, tensors: List[List[BaguaTensor]], do_flatten: bool + ) -> List[BaguaBucket]: all_tensors = [] for idx, bucket in enumerate(tensors): all_tensors.extend(bucket) - bagua_bucket = BaguaBucket(all_tensors, flatten=True, name=str(0)) + bagua_bucket = BaguaBucket(all_tensors, flatten=do_flatten, name=str(0)) return [bagua_bucket] diff --git a/bagua/torch_api/algorithms/q_adam.py b/bagua/torch_api/algorithms/q_adam.py index fafd43bb9..022941389 100644 --- a/bagua/torch_api/algorithms/q_adam.py +++ b/bagua/torch_api/algorithms/q_adam.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 from bagua.torch_api.bucket import BaguaBucket from bagua.torch_api.tensor import BaguaTensor -from bagua.torch_api import get_world_size from bagua.torch_api.distributed import BaguaModule from bagua.torch_api.algorithms import Algorithm, AlgorithmImpl from bagua.torch_api.communication import BaguaProcessGroup @@ -45,7 +44,7 @@ def __init__( raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) super(QAdamOptimizer, self).__init__(params, defaults) - + # TODO: qadam optimizer maintain `step_id` in its state self.step_id = 0 self.warmup_steps = warmup_steps @@ -162,12 +161,14 @@ def set_momentum_fn(param, t): tensor_groups.sort(key=lambda x: x._q_adam_idx) return tensor_groups - def tensors_to_buckets(self, tensors: List[List[BaguaTensor]]) -> List[BaguaBucket]: + def tensors_to_buckets( + self, tensors: List[List[BaguaTensor]], do_flatten: bool + ) -> List[BaguaBucket]: bagua_buckets = [] for idx, bucket in enumerate(tensors): bagua_bucket = BaguaBucket( bucket, - flatten=True, + flatten=do_flatten, name=str(idx), alignment=self.process_group.get_global_communicator().nranks(), ) diff --git a/bagua/torch_api/bucket.py b/bagua/torch_api/bucket.py index cec70b29f..3a850a810 100644 --- a/bagua/torch_api/bucket.py +++ b/bagua/torch_api/bucket.py @@ -8,7 +8,7 @@ import torch from bagua.torch_api.tensor import BaguaTensor -from bagua.torch_api.utils import check_contiguous +from bagua.torch_api.utils import check_contiguous, get_flattened_tensor from bagua.torch_api.communication import ( BaguaProcessGroup, _bagua_backend_comm, @@ -87,25 +87,10 @@ def flattened_tensor(self) -> torch.Tensor: :attr:`self` tensors and padding tensor (if exists). """ - all_registered_tensors = [ + all_effective_tensors = [ tensor.bagua_getter_closure() for tensor in self._all_tensors ] - total_size = 0 - for tensor in all_registered_tensors: - total_size += tensor.numel() - - flatten_tensor = torch.zeros( - total_size, - dtype=all_registered_tensors[0].dtype, - device=all_registered_tensors[0].device, - ) - - offset = 0 - for tensor in all_registered_tensors: - # copy data - flatten_tensor[offset : offset + tensor.numel()] = tensor.reshape(-1) - offset += tensor.numel() - return flatten_tensor + return get_flattened_tensor(all_effective_tensors) def _flatten_(self): """ @@ -372,7 +357,7 @@ def clear_ops(self) -> BaguaBucket: def bytes(self) -> int: """Returns the total number of bytes occupied by the bucket.""" - registered_tensors = [tensor.bagua_getter_closure() for tensor in self.tensors] + effective_tensors = [tensor.bagua_getter_closure() for tensor in self.tensors] return sum( - tensor.numel() * tensor.element_size() for tensor in registered_tensors + tensor.numel() * tensor.element_size() for tensor in effective_tensors ) diff --git a/bagua/torch_api/contrib/__init__.py b/bagua/torch_api/contrib/__init__.py index 40c4c9de9..84aebb257 100644 --- a/bagua/torch_api/contrib/__init__.py +++ b/bagua/torch_api/contrib/__init__.py @@ -1,4 +1,4 @@ -from .fused_optimizer import FusedOptimizer # noqa: F401 +from .fuse.optimizer import fuse_optimizer # noqa: F401 from .load_balancing_data_loader import ( # noqa: F401 LoadBalancingDistributedSampler, LoadBalancingDistributedBatchSampler, diff --git a/bagua/torch_api/contrib/fuse/__init__.py b/bagua/torch_api/contrib/fuse/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bagua/torch_api/contrib/fuse/optimizer.py b/bagua/torch_api/contrib/fuse/optimizer.py new file mode 100644 index 000000000..3ff1b89a3 --- /dev/null +++ b/bagua/torch_api/contrib/fuse/optimizer.py @@ -0,0 +1,527 @@ +import torch +from typing import List, Dict, Optional, Any +import copy +import logging +from functools import reduce +from bagua.torch_api.utils import check_contiguous, get_flattened_tensor +import gorilla + + +__all__ = ["fuse_optimizer", "fuse_step", "is_fused_optimizer"] + + +def flatten_params_and_states(optimizer: torch.optim.Optimizer): + """ + Flatten parameter tensors in the sampe group into contiguous ones. + """ + + type_params = {} + for group in optimizer.param_groups: + for param in group["params"]: + + params_of_type = type_params.get(param.type(), []) + params_of_type.append(param) + type_params[param.type()] = params_of_type + + for param_type, params in type_params.items(): + grads = [p.bagua_ensure_grad().grad for p in params] + state_tensors, state_scalars = get_optimizer_param_states(optimizer, params) + + if state_tensors is None: + continue + + flatten_tensors(params) + flatten_tensors_with_closure( + grads, + params, + getter_closure=lambda p: p.grad, + setter_closure=lambda p, new_grad: setattr(p, "grad", new_grad), + ) + + for name, tensors in state_tensors.items(): + + def set_state_fn(p, t): + optimizer.state[p][name] = t + + flatten_tensors_with_closure( + tensors, + params, + getter_closure=lambda p: optimizer.state[p][name], + setter_closure=set_state_fn, + ) + + +def flatten_tensors(tensors: List[torch.Tensor]): + """ + Flatten :attr:`tensors` into contiguous one. + """ + if len(tensors) == 0: + return + + if check_contiguous(tensors): + return + + flatten_tensor = get_flattened_tensor(tensors) + flatten_storage = flatten_tensor.storage() + + offset = 0 + for tensor in tensors: + with torch.no_grad(): + tensor.set_(flatten_storage, offset, tensor.shape) + + offset += tensor.numel() + logging.debug(f"flatten done {offset}") + + check_contiguous(tensors) + + +def flatten_tensors_with_closure(tensors, params, getter_closure, setter_closure): + if len(tensors) == 0: + return + + if check_contiguous(tensors): + return + + flatten_tensor = get_flattened_tensor(tensors) + flatten_storage = flatten_tensor.storage() + + offset = 0 + for tensor, param in zip(tensors, params): + with torch.no_grad(): + z = torch.zeros_like(getter_closure(param)) + z.set_(flatten_storage, offset, z.shape) + setter_closure(param, z) + + offset += tensor.numel() + logging.debug(f"flatten with closure done {offset}") + + check_contiguous([getter_closure(p) for p in params]) + + +def _is_contiguous_tensor(a: torch.Tensor, b: torch.Tensor): + """ + Checking if tensor :attr:`a` and tensor :attr:`b` are contiguous. + """ + size_a = a.numel() * a.element_size() + size_b = b.numel() * b.element_size() + + return (a.data_ptr() == b.data_ptr() + size_b) or ( + b.data_ptr() == a.data_ptr() + size_a + ) + + +def _find_continuous_tensors(tensors: List[torch.Tensor]): + tensor_list = zip(tensors, list(range(len(tensors)))) + sorted_tensor_list = sorted(tensor_list, key=lambda x: x[0].data_ptr()) + + grouped_indices = [] + tmp_tensors = [] + tmp_indices = [] + + for tensor, idx in sorted_tensor_list: + if len(tmp_tensors) > 0 and not _is_contiguous_tensor(tensor, tmp_tensors[-1]): + if len(tmp_tensors) > 1: + grouped_indices.append(tmp_indices) + tmp_tensors = [] + tmp_indices = [] + + tmp_tensors.append(tensor) + tmp_indices.append(idx) + + if len(tmp_tensors) > 1: + grouped_indices.append(tmp_indices) + + return grouped_indices + + +def calculate_mutual_groups(tensors_list: List[List[torch.Tensor]]): + constraints = [] + + size = len(tensors_list[0]) + for tensors in tensors_list: + assert size == len( + tensors + ), "Tensors to calculate mutual groups must have equal size." + + grouped_indices = _find_continuous_tensors(tensors) + constraints.append(grouped_indices) + + if len(constraints) == 0: + return constraints + + grouped_indices = constraints[0] + for i in range(1, len(constraints)): + grouped_indices = _intersect(grouped_indices, constraints[i]) + + logging.debug( + f"calculate mutual groups: {grouped_indices}, constraints: {constraints}" + ) + return grouped_indices + + +def _intersect(a: List[List[int]], b: List[List[int]]): + c = [value for value in a if value in b] + return c + + +def group_tensors(tensors: List[torch.Tensor], indices: List[int]) -> torch.Tensor: + if len(indices) == 0: + return + + to_group = [tensors[idx] for idx in indices] + assert check_contiguous(to_group), "tensors grouped must be contiguous" + + total_size = sum([t.numel() for t in to_group]) + with torch.no_grad(): + tensor_view = torch.zeros( + total_size, dtype=to_group[0].dtype, device=to_group[0].device + ) + tensor_view.set_(to_group[0].storage(), 0, tensor_view.shape) + + return tensor_view + + +def ungroup_tensor( + tensor_view: torch.Tensor, tensors: List[torch.Tensor] +) -> Optional[List[torch.Tensor]]: + """ + Ungroup :attr:`tensor_view` to a list of tensors that have same data types and sizes with :attr:`tensors`. + """ + + offset = 0 + ungrouped = [] + for tensor in tensors: + if tensor_view.dtype != tensor.dtype: + logging.warning( + "Fused optimizer failed to recover parameter state from fused parameter state, due to dismatch of datatype between parameter and parameter state." + ) + return + + z = torch.zeros_like(tensor) + z.set_(tensor_view.storage(), offset, tensor.shape) + + offset += tensor.numel() + ungrouped.append(z) + + if offset != tensor_view.numel(): + logging.warning( + "Fused optimizer failed to recover parameter state from fused parameter state, due to dismatch of size between parameter and parameter state." + ) + return + + return ungrouped + + +def fuse_optimizer( + optimizer: torch.optim.Optimizer, + do_flatten: bool = True, + check_flatten: bool = True, +): + """ + Convert any optimizer into a fused optimizer. + + A fused optimizer can fuse multiple parameter updates into one or a few updates. To achieve this, users need to: + + | 1) flatten multiple parameters in the same group into fused parameter by setting :attr:`do_flatten=True`, + which is also the default behavior of a fused optimizer; + | 2) perform a fused parameter update by calling :meth:`fuse_step`. + + This fused optimizer is implemented for general use. It can be used used in conjunction with + a :class:`~bagua.torch_api.distributed.BaguaModule` as well as a + `torch.nn.parallel.DistributedDataParallel `_ + wrapped module, or some other cases (not listed here). + + Args: + optimizer (torch.optim.Optimizer): Any PyTorch optimizer. + do_flatten (bool): Whether to flatten the parameters. The flatten operation will reset data pointers of + parameter tensors so that they can be fused together. Default: ``True``. + check_flatten (bool): When setting to ``True``, it enables fused optimizer to automatically check if + parameter tensors are contiguous as they are flattened to. Can only work with :attr:`do_flatten=True`. + Default: ``True``. + + Returns: + A Fused optimizer. + + Example:: + >>> optimizer = torch.optim.Adadelta(model.parameters(), ....) + >>> optimizer = bagua.torch_api.contrib.fuse_optimizer(optimizer, do_flatten=True) + >>> + >>> optimizer.fuse_step() + + When use in conjunction with a :class:`~bagua.torch_api.distributed.BaguaModule`, set :attr:`do_flatten=False` + in :meth:`~bagua.torch_api.distributed.BaguaModule.with_bagua` explicitly: + + >>> optimizer = bagua.torch_api.contrib.fuse_optimizer(optimizer, do_flatten=True) + >>> model = model.with_bagua([optimizer], GradientAllReduceAlgorithm(), do_flatten=False) + >>> + >>> optimizer.fuse_step() + + .. note:: + This function and :meth:`~bagua.torch_api.distributed.BaguaModule.with_bagua` method both will reset data + pointers of module parameters by default. In order to perform a more effective fused parameter update, + users need to disable bucket flattening in :meth:`~bagua.torch_api.distributed.BaguaModule.with_bagua` + by setting its :attr:`do_flatten` to ``False``. + + .. note:: + A fuse optimizer does not change the original behaviors of :attr:`optimizer`, but enabling it to perform a + fused parameter update through :meth:`fuse_step`. Users can still perform a normal parameter update through + :meth:`step`. + """ + + if is_fused_optimizer(optimizer): + raise RuntimeError("trying to fuse an optimizer twice!") + + optimizer._bagua_check_flatten = do_flatten and check_flatten + optimizer._bagua_fused_count = 0 + optimizer._bagua_cloned_attrs = {} + optimizer._bagua_fused_optimizer = make_optimizer_instance(optimizer) + + if do_flatten: + flatten_params_and_states(optimizer) + + if not hasattr(optimizer, "fuse_step"): + patch = gorilla.Patch(optimizer.__class__, "fuse_step", fuse_step) + gorilla.apply(patch) + + return optimizer + + +def is_fused_optimizer(optimizer: torch.optim.Optimizer): + """ + Checking if :attr:`optimizer` is a fused optimizer or not. + """ + return hasattr(optimizer, "_bagua_fused_optimizer") + + +def make_optimizer_instance(optimizer: torch.optim.Optimizer): + ignore_attrs = [ + "_bagua_check_flatten", + "_bagua_fused_count", + "_bagua_cloned_attrs", + ] + new_optimizer = copy.copy(optimizer) + + for attr in dir(optimizer): + if attr not in ignore_attrs and attr not in dir(new_optimizer): + logging.warning( + f"Clone attribute {attr} to fused optimizer, should not modify it in `optimizer.step()`." + ) + setattr(new_optimizer, attr, getattr(optimizer, attr)) + optimizer._bagua_cloned_attrs[attr] = getattr(optimizer, attr) + + # new_optimizer.param_groups = [] + # for group in optimizer.param_groups: + # new_group = {"params": list(group["params"])} + # new_optimizer.add_param_group(new_group) + + return new_optimizer + + +def fuse_step(optimizer: torch.optim.Optimizer, closure=None): + r"""Perform a fused parameter update. + + This operation will fuse multiple contiguous parameters into a fused parameter, by creating a tensor + view sharing the same underlying storage with them, and then perform parameter update on fused parameters. + If none of the parameter tensors are contiguous, this operation is equivalent to :meth:`step`. + + Args: + optimizer: A fused optimizer. + closure (Callable): A closure that reevaluates the model and + returns the loss. Optional for most optimizers. + + .. note:: + This function will not modify the storage of parameter tensors. + """ + + assert is_fused_optimizer( + optimizer + ), "Should init fused optimizer by calling `fuse_optimizer`." + + do_fuse(optimizer) + optimizer._bagua_fused_optimizer.step(closure) + check_optimizer(optimizer) + sync_optimizer_state(optimizer) + + +def do_fuse(optimizer: torch.optim.Optimizer): + _fused_optimizer = optimizer._bagua_fused_optimizer + + # Note: optimizer and fused optimizer share the same state, but different param groups + _fused_optimizer.param_groups = [] + for index, group in enumerate(optimizer.param_groups): + params = group["params"] + + weights = [p.data for p in params] + grads = [p.grad for p in params] + + state_tensors, state_scalars = get_optimizer_param_states(optimizer, params) + + if state_tensors is None: + continue + + check_flatten = optimizer._bagua_check_flatten + if check_flatten and not check_contiguous(weights): + logging.warning( + "Parameter weights storage changed after flattened in fused optimizer, may degrade performance." + ) + check_flatten = False + + if check_flatten and not check_contiguous(grads): + logging.warning( + "Parameter gradients storage changed after flattened in fused optimizer, may degrade performance." + ) + check_flatten = False + + for name, tensors in state_tensors.items(): + if check_flatten and not check_contiguous(tensors): + logging.warning( + "Parameter state {} storage changed after flattened in fused optimizer, may degrade performance.".format( + name + ) + ) + check_flatten = False + + grouped_indices = calculate_mutual_groups( + [weights, grads] + list(state_tensors.values()) + ) + + if len(grouped_indices) == 0: + _fused_optimizer.add_param_group(group) + continue + + optimizer._bagua_fused_count += 1 + + new_params = [] + + for indices in grouped_indices: + grouped_weight = group_tensors(weights, indices) + grouped_grad = group_tensors(grads, indices) + + grouped_states = {} + for name, tensors in state_tensors.items(): + ts = group_tensors(tensors, indices) + grouped_states[name] = ts + + with torch.no_grad(): + p = torch.nn.Parameter(grouped_weight, requires_grad=False) + p.grad = grouped_grad + p._bagua_fused_param_ids = indices + + # sync original param state to fused param state + for name, ts in grouped_states.items(): + optimizer.state[p][name] = ts + + for name, v in state_scalars.items(): + optimizer.state[p][name] = v + + new_params.append(p) + + grouped_indices_flat = list(reduce(lambda x, y: x + y, grouped_indices)) + for idx, param in enumerate(params): + if idx not in grouped_indices_flat: + new_params.append(param) + + new_group = {"params": new_params} + _fused_optimizer.add_param_group(new_group) + + +def check_optimizer(optimizer): + # make sure cloned attributes are not modified + for attr in optimizer._bagua_cloned_attrs: + if getattr(optimizer, attr) != getattr(optimizer._bagua_fused_optimizer, attr): + logging.error( + f"Should not change attribute {attr} in `optimizer.step(), maintain it in optimizer state.`" + ) + + +def sync_optimizer_state(optimizer): + # write back state for original params + # Note: we should make sure every module parameter in original params groups has the right state + for group, fused_group in zip( + optimizer.param_groups, optimizer._bagua_fused_optimizer.param_groups + ): + + params = group["params"] + fused_params = fused_group["params"] + + fused_state_tensors, fused_state_scalars = get_optimizer_param_states( + optimizer, fused_params + ) + + for fp in fused_params: + if not hasattr(fp, "_bagua_fused_param_ids"): + continue + + original_params = [params[i] for i in fp._bagua_fused_param_ids] + + for name in fused_state_tensors.keys(): + state_tensors = ungroup_tensor( + optimizer.state[fp][name], original_params + ) + + if state_tensors is not None: + for p, state in zip(original_params, state_tensors): + optimizer.state[p][name] = state + + for name, v in fused_state_scalars.items(): + for p in original_params: + optimizer.state[p][name] = v + + # clear outdated state for fused param + logging.debug("delete outdated params state") + del optimizer.state[fp] + + +def get_optimizer_param_states(optimizer, params): + state_tensors = {} # type: Dict[str, List[torch.Tensor]] + state_scalars = {} # type: Dict[str, Any] + + state_tensor_names = set( + [ + k + for p in params + for k, v in optimizer.state[p].items() + if isinstance(v, torch.Tensor) + ] + ) + state_scalar_names = set( + [ + k + for p in params + for k, v in optimizer.state[p].items() + if not isinstance(v, torch.Tensor) + ] + ) + + for name in state_tensor_names: + tensors = [] + for p in params: + if name not in optimizer.state[p]: + logging.error( + f"Unexpected parameter state {name}, failed not fuse optimizer." + ) + return None, None + + tensors.append(optimizer.state[p][name]) + + state_tensors[name] = tensors + + for name in state_scalar_names: + scalar = None + + for p in params: + if name not in optimizer.state[p]: + logging.error( + f"Unexpected parameter state {name}, failed not fuse optimizer." + ) + return None, None + + if scalar is not None and scalar != optimizer.state[p][name]: + logging.error( + f"Parameter state '{name}' does not match, failed not fuse optimizer." + ) + return None, None + + state_scalars[name] = optimizer.state[p][name] + + return state_tensors, state_scalars diff --git a/bagua/torch_api/contrib/fused_optimizer.py b/bagua/torch_api/contrib/fused_optimizer.py deleted file mode 100644 index 80a9c28d4..000000000 --- a/bagua/torch_api/contrib/fused_optimizer.py +++ /dev/null @@ -1,134 +0,0 @@ -import torch -import copy -from bagua.torch_api.utils import collocate_params, flatten_module_params - -__all__ = ["FusedOptimizer"] - - -class FusedOptimizer(torch.optim.Optimizer): - """Convert any optimizer into a fused optimizer. - - This fused optimizer fuses multiple module parameter update kernel launches - into one or a few, by flattening parameter tensors into one or more - contiguous buckets. - - It can be used in conjunction with :meth:`~bagua.torch_api.distributed.BaguaModule.with_bagua` method. In this case, - Bagua will do the fusions automatically, otherwise, you need to explicitly - set :attr:`do_flatten=True`. - - Args: - optimizer (torch.optim.Optimizer): Any PyTorch optimizer. - do_flatten (bool): Whether to flatten the parameters. Default: ``False``. - - Returns: - Fused optimizer. - - - Example:: - To use in conjunction with :meth:`~bagua.torch_api.distributed.BaguaModule.with_bagua` method: - - >>> optimizer = torch.optim.Adadelta(model.parameters(), ....) - >>> optimizer = bagua.torch_api.contrib.FusedOptimizer(optimizer) - >>> model = model.with_bagua([optimizer], GradientAllReduceAlgorithm()) - - To use alone or with `torch.nn.parallel.DistributedDataParallel `_, - set :attr:`do_flatten=True`: - - >>> optimizer = torch.optim.Adadelta(model.parameters(), ....) - >>> optimizer = bagua.torch_api.contrib.FusedOptimizer(optimizer, do_flatten=True) - """ - - def __init__(self, optimizer: torch.optim.Optimizer, do_flatten: bool = False): - self.optimizer = copy.copy(optimizer) - super(FusedOptimizer, self).__init__(optimizer.param_groups, optimizer.defaults) - - if do_flatten: - f32_params = [ - param - for group in self.optimizer.param_groups - for param in group["params"] - if param.type() == "torch.cuda.FloatTensor" - ] - f16_params = [ - param - for group in self.optimizer.param_groups - for param in group["params"] - if param.type() == "torch.cuda.HalfTensor" - ] - - flatten_module_params(f32_params, align_bytes=1) - flatten_module_params(f16_params, align_bytes=1) - - def step(self, closure=None): - r"""Performs a single optimization step (parameter update). - - Args: - closure (Callable): A closure that reevaluates the model and - returns the loss. Optional for most optimizers. - - .. note:: - Unless otherwise specified, this function should not modify the - ``.grad`` field of the parameters. - """ - for group in self.optimizer.param_groups: - params = group["params"] - grouped_params = group_params_by_storage(params) - - fused_params = [] - - for _, group_p in grouped_params.items(): - fused_params.extend(reorder_params(group_p)) - - group["params"] = fused_params - - return self.optimizer.step(closure) - - -def reorder_params(params): - """Input params share same storage, reorder them by their storage offset""" - - sorted_params = sorted(params, key=lambda x: x.storage_offset()) - - grouped = [] - tmp_params = [] - - for p in sorted_params: - if len(tmp_params) > 0 and not is_contiguous_param(p, tmp_params[-1]): - grouped.append(collocate_params(tmp_params)) - tmp_params = [] - - tmp_params.append(p) - - if len(tmp_params) > 0: - grouped.append(collocate_params(tmp_params)) # FIXME: potential OOM - - return grouped - - -def is_contiguous_param(a, b): - allocate_size_a = ( - a.bagua_tensor.num_elem_allocated() if hasattr(a, "bagua_tensor") else a.numel() - ) - allocate_size_b = ( - b.bagua_tensor.num_elem_allocated() if hasattr(b, "bagua_tensor") else b.numel() - ) - return ( - a.data.storage_offset() == b.data.storage_offset() + allocate_size_b - and a.grad.data.storage_offset() - == b.grad.data.storage_offset() + allocate_size_b - ) or ( - b.data.storage_offset() == a.data.storage_offset() + allocate_size_a - and b.grad.data.storage_offset() - == a.grad.data.storage_offset() + allocate_size_a - ) - - -def group_params_by_storage(params): - grouped_params = {} - for p in params: - weight_storage = p.data.storage().data_ptr() - param_list = grouped_params.get(weight_storage, []) - param_list.append(p) - grouped_params[weight_storage] = param_list - - return grouped_params diff --git a/bagua/torch_api/distributed.py b/bagua/torch_api/distributed.py index 42e611cb2..829ae3f0e 100644 --- a/bagua/torch_api/distributed.py +++ b/bagua/torch_api/distributed.py @@ -140,7 +140,7 @@ def _bagua_broadcast_optimizer_state(self, optimizer): for group in optimizer.param_groups: for p in group["params"]: if p.requires_grad and id(p) not in optimizer_state_dict["state"]: - p.grad = p.data.new(p.size()).zero_() + p.bagua_ensure_grad() if isinstance(optimizer, torch.optim.SparseAdam): p.grad = p.grad.to_sparse() optimizer_state_dict = optimizer.state_dict() @@ -263,6 +263,7 @@ def with_bagua( # pytype: disable=module-attr optimizers: List[torch.optim.Optimizer], algorithm: "bagua.torch_api.algorithms.Algorithm", process_group: Optional[BaguaProcessGroup] = None, + do_flatten: bool = True, ) -> BaguaModule: r"""``with_bagua`` enables easy distributed data parallel training on a `torch.nn.Module `_. @@ -274,6 +275,8 @@ def with_bagua( # pytype: disable=module-attr used to do the actual communication and update. process_group: The process group to be used for distributed data all-reduction. If ``None``, the default process group, which is created by :func:`bagua.torch_api.init_process_group`, will be used. (default: ``None``) + do_flatten: Whether to flatten the Bagua buckets. The flatten operation will reset data pointer of bucket + tensors so that they can use faster code paths. Default: ``True``. Returns: The original module, with Bagua related environments initialized. @@ -328,6 +331,8 @@ def with_bagua( # pytype: disable=module-attr ): # for compatibility with PyTorch DDP self.parameters_to_ignore.extend(self._ddp_params_and_buffers_to_ignore) + self._bagua_do_flatten = do_flatten + self.bagua_train_step_counter = 0 """ @@ -476,7 +481,9 @@ def _bagua_reset_module(self): def _bagua_reset_algorithm_buckets(self): self._bagua_cleanup_algorithm() raw_buckets = self._bagua_autotune_get_buckets() - self.bagua_buckets.extend(self.bagua_algorithm.tensors_to_buckets(raw_buckets)) + self.bagua_buckets.extend( + self.bagua_algorithm.tensors_to_buckets(raw_buckets, self._bagua_do_flatten) + ) for name, param in self.named_parameters(): @@ -514,6 +521,7 @@ def real_post_backward_hook(*unused): if not hasattr(optimizer, "_bagua_original_step"): optimizer._bagua_original_step = optimizer.step + # TODO: `fused_step` may miss `init_post_optimizer_step_hook` def new_step_factory(optimizer): def new_step(self, *args, **kwargs): result = self._bagua_original_step(*args, **kwargs) diff --git a/bagua/torch_api/tensor.py b/bagua/torch_api/tensor.py index be436fc81..32eb3ec0e 100644 --- a/bagua/torch_api/tensor.py +++ b/bagua/torch_api/tensor.py @@ -247,7 +247,9 @@ def bagua_set_storage( if self._bagua_setter_closure is None: # set directly with torch.no_grad(): - self.bagua_getter_closure().set_(storage, storage_offset, self.shape) + self.bagua_getter_closure().set_( + storage, storage_offset, self.bagua_getter_closure().shape + ) return with torch.no_grad(): diff --git a/bagua/torch_api/utils.py b/bagua/torch_api/utils.py index 9b3aa79af..58afcc3e1 100644 --- a/bagua/torch_api/utils.py +++ b/bagua/torch_api/utils.py @@ -48,10 +48,6 @@ def apply_flattened_call_all(tensors, call): apply_flattened_call(tensors, call) -def align_size(size, align): - return int((size + align - 1) / align) * align - - def check_contiguous(tensors): data_ptr = None for t in tensors: @@ -61,145 +57,25 @@ def check_contiguous(tensors): return True -def _get_params_flattened_aligned_size(params, align_bytes): - assert align_bytes == 1 or ( - align_bytes % params[0].element_size() == 0 - ), "align bytes must be multiples of element size" - - sizes = [p.numel() for p in params] - - total_size = sum(sizes) - aligned_total_size = ( - align_size(total_size * params[0].element_size(), align_bytes) - // params[0].element_size() - ) - - # padding to the last param - sizes[-1] += aligned_total_size - total_size - - for p, sz in zip(params, sizes): - p.allocated_size = sz - - return aligned_total_size - - -def flatten_module_params(params_list, align_bytes: int): - if len(params_list) == 0: +def get_flattened_tensor(tensors: List[torch.Tensor]) -> torch.Tensor: + if len(tensors) == 0: return - if not isinstance(params_list[0], list): - params_list = [params_list] - total_size = 0 - for params in params_list: - total_size += _get_params_flattened_aligned_size(params, align_bytes) - - logging.debug( - f"flatten {str(params_list[0][0].dtype).partition('.')[-1]} params aligned to {align_bytes} bytes, total numels: {total_size}" - ) + for tensor in tensors: + total_size += tensor.numel() - flatten_weights_tensor = torch.zeros(total_size, dtype=params_list[0][0].dtype).to( - params_list[0][0].device + flatten_tensor = torch.zeros( + total_size, dtype=tensors[0].dtype, device=tensors[0].device ) - flatten_grads_tensor = torch.zeros(total_size, dtype=params_list[0][0].dtype).to( - params_list[0][0].device - ) - - flatten_weights_storage = flatten_weights_tensor.storage() - flatten_grads_storage = flatten_grads_tensor.storage() - - def set_storage(param, weight_storage, grad_storage, storage_offset): - with torch.no_grad(): - z = torch.zeros_like(param.data) - z.set_(weight_storage, storage_offset, param.shape) - param.data = z - - t = torch.zeros_like(param.data) - t.set_(grad_storage, storage_offset, param.shape) - param.grad = t - - offset = 0 - for params in params_list: - for p in params: - # copy data - flatten_weights_tensor[offset : offset + p.numel()] = p.data.reshape(-1) - - if p.grad is not None: - flatten_grads_tensor[offset : offset + p.numel()] = p.grad.data.reshape( - -1 - ) - else: - logging.debug(f"grad is none, {offset}") - # flatten - set_storage(p, flatten_weights_storage, flatten_grads_storage, offset) - offset += p.allocated_size - logging.debug(f"flatten param done {offset}") - - # # check - for params in params_list: - weight_tensors = [p.data for p in params] - grad_tensors = [p.grad.data for p in params] - - assert check_contiguous(weight_tensors) - assert check_contiguous(grad_tensors) - - return new_param(flatten_weights_tensor, flatten_grads_tensor) - - -def collocate_params(params): - """ - `tensors` share the same storage - """ - if len(params) == 1: - return params[0] - - logging.debug(f"fuse {len(params)} params") - - sorted_params = sorted(params, key=lambda x: x.storage_offset()) - - start = None offset = 0 - for p in sorted_params: - weight = p.data - grad = p.grad.data - - assert ( - weight.storage_offset() == grad.storage_offset() - ), "collocated weights and grads must have consistent storage offset" - - if start is None: - start = offset = weight.storage_offset() - else: - assert ( - offset == weight.storage_offset() - ), "params collocated must be contiguous" - - offset += ( - p.bagua_tensor.num_elem_allocated() - if hasattr(p, "bagua_tensor") - else p.numel() - ) - - with torch.no_grad(): - weight_tensor = torch.zeros(offset - start, dtype=params[0].dtype).to( - params[0].device - ) - weight_tensor.set_(params[0].data.storage(), start, weight_tensor.shape) - - grad_tensor = torch.zeros(offset - start, dtype=params[0].dtype).to( - params[0].device - ) - grad_tensor.set_(params[0].grad.data.storage(), start, grad_tensor.shape) - - return new_param(weight_tensor, grad_tensor) - + for tensor in tensors: + # copy data + flatten_tensor[offset : offset + tensor.numel()] = tensor.reshape(-1) + offset += tensor.numel() -def new_param(weight, grad): - with torch.no_grad(): - p = torch.nn.Parameter(weight, requires_grad=False) - p.grad = grad - return p + return flatten_tensor def to_bagua_datatype(datatype): diff --git a/examples/mnist/main.py b/examples/mnist/main.py index 2862589d6..3c0828d3d 100644 --- a/examples/mnist/main.py +++ b/examples/mnist/main.py @@ -41,12 +41,16 @@ def forward(self, x): def train(args, model, train_loader, optimizer, epoch): model.train() for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.cuda(), target.cuda() optimizer.zero_grad() output = model(data) loss = F.nll_loss(output, target) loss.backward() - optimizer.step() + if args.fuse_optimizer: + optimizer.fuse_step() + else: + optimizer.step() if batch_idx % args.log_interval == 0: logging.info( "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( @@ -154,7 +158,13 @@ def main(): "--set-deterministic", action="store_true", default=False, - help="whether set deterministic", + help="set deterministic or not", + ) + parser.add_argument( + "--fuse-optimizer", + action="store_true", + default=False, + help="fuse optimizer or not", ) args = parser.parse_args() @@ -213,6 +223,9 @@ def main(): model = Net().cuda() optimizer = optim.Adadelta(model.parameters(), lr=args.lr) + if args.fuse_optimizer: + optimizer = bagua.contrib.fuse_optimizer(optimizer) + if args.algorithm == "gradient_allreduce": from bagua.torch_api.algorithms import gradient_allreduce @@ -248,6 +261,7 @@ def main(): model = model.with_bagua( [optimizer], algorithm, + do_flatten=not args.fuse_optimizer, ) scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) diff --git a/examples/squad/main.py b/examples/squad/main.py index df4e55091..70f6a3cd7 100644 --- a/examples/squad/main.py +++ b/examples/squad/main.py @@ -134,6 +134,9 @@ def train(args, train_dataset, model, tokenizer): optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon ) + if args.fuse_optimizer: + optimizer = bagua.contrib.fuse_optimizer(optimizer) + if args.algorithm == "gradient_allreduce": from bagua.torch_api.algorithms import gradient_allreduce @@ -204,7 +207,9 @@ def train(args, train_dataset, model, tokenizer): # Distributed training (should be after apex fp16 initialization) if args.distributed: - model = model.with_bagua([optimizer], algorithm) + model = model.with_bagua( + [optimizer], algorithm, do_flatten=not args.fuse_optimizer + ) # Train! logger.info("***** Running training *****") @@ -258,8 +263,6 @@ def train(args, train_dataset, model, tokenizer): desc="Epoch", disable=bagua.get_rank() != 0, ) - # Added here for reproductibility - set_seed(args) for _ in train_iterator: epoch_iterator = tqdm( @@ -335,7 +338,10 @@ def train(args, train_dataset, model, tokenizer): model.parameters(), args.max_grad_norm ) - optimizer.step() + if args.fuse_optimizer: + optimizer.fuse_step() + else: + optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 @@ -907,6 +913,12 @@ def main(): default=1, help="multiple threads for converting example to features", ) + parser.add_argument( + "--set-deterministic", + action="store_true", + default=False, + help="set deterministic or not", + ) parser.add_argument( "--algorithm", type=str, @@ -925,6 +937,13 @@ def main(): type=int, help="Warmup(allreduce) steps for async algorithm", ) + parser.add_argument( + "--fuse-optimizer", + action="store_true", + default=False, + help="fuse optimizer or not", + ) + args = parser.parse_args() if args.doc_stride >= args.max_seq_length - args.max_query_length: @@ -994,8 +1013,13 @@ def main(): transformers.utils.logging.set_verbosity_info() transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_explicit_format() - # Set seed - set_seed(args) + + # Added here for reproductibility + if args.set_deterministic: + # Set seed + set_seed(args) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True # Load pretrained model and tokenizer if args.distributed and bagua.get_rank() != 0: diff --git a/tests/contrib/test_fused_optimizer.py b/tests/contrib/test_fused_optimizer.py index 16c918db2..49fe0624c 100644 --- a/tests/contrib/test_fused_optimizer.py +++ b/tests/contrib/test_fused_optimizer.py @@ -6,8 +6,12 @@ from tests.internal.common_utils import find_free_port from tests import skip_if_cuda_available, skip_if_cuda_not_available +import logging -def run_step(opt, flag_param, fuse, wrap, device): +logging.getLogger().setLevel(logging.INFO) + + +def construct_model_and_optimizer(opt, flag_param, device): weight = torch.tensor( [[-0.2109, -0.4976], [-0.1413, -0.3420], [-0.2524, 0.6976]], requires_grad=True, @@ -39,31 +43,186 @@ def run_step(opt, flag_param, fuse, wrap, device): model = model.to(device) optimizer = opt(model.parameters(), **flag_param) - if fuse: - bagua.contrib.FusedOptimizer(optimizer, do_flatten=not wrap) + return model, optimizer - if wrap: - model.with_bagua( - [optimizer], - bagua.algorithms.gradient_allreduce.GradientAllReduceAlgorithm(), - ) +def train_model(model, optimizer, device, num_epochs): input = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], device=device).reshape(3, 2) - for _ in range(1001): + for epoch in range(num_epochs): optimizer.zero_grad() output = model(input) loss = output.sum() loss.backward() optimizer.step() + # logging.debug(f"#train model#{epoch} params: {optimizer.param_groups}") + # logging.debug(f"#train model#{epoch} state: {optimizer.state}") + + +def train_model_fused(model, optimizer, device, num_epochs): + input = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], device=device).reshape(3, 2) + + for epoch in range(num_epochs): + optimizer.zero_grad() + output = model(input) + loss = output.sum() + loss.backward() + + if epoch % 2 == 0: + optimizer.fuse_step() + else: + optimizer.step() + # logging.debug(f"#train model fused#{epoch} params: {optimizer._bagua_fused_optimizer.param_groups}") + # logging.debug(f"#train model fused#{epoch} state: {optimizer.state}") + + +def bagua_init(model, optimizer, algorithm, do_flatten): + # wrap model + if algorithm == "gradient_allreduce": + from bagua.torch_api.algorithms import gradient_allreduce + + bagua_algorithm = gradient_allreduce.GradientAllReduceAlgorithm() + elif algorithm == "bytegrad": + from bagua.torch_api.algorithms import bytegrad + + bagua_algorithm = bytegrad.ByteGradAlgorithm() + elif algorithm == "decentralized": + from bagua.torch_api.algorithms import decentralized + + bagua_algorithm = decentralized.DecentralizedAlgorithm(hierarchical=False) + elif algorithm == "async": + from bagua.torch_api.algorithms import async_model_average + + bagua_algorithm = async_model_average.AsyncModelAverageAlgorithm( + sync_interval_ms=10, + ) + elif algorithm == "low_prec_decentralized": + from bagua.torch_api.algorithms import decentralized + + bagua_algorithm = decentralized.LowPrecisionDecentralizedAlgorithm( + hierarchical=False + ) + elif algorithm == "qadam": + from bagua.torch_api.algorithms.q_adam import QAdamAlgorithm, QAdamOptimizer + + optimizer = QAdamOptimizer(model.parameters(), warmup_steps=1) + bagua_algorithm = QAdamAlgorithm(optimizer, hierarchical=False) + else: + raise ValueError("unsupported algorithm") + + model = model.with_bagua([optimizer], bagua_algorithm, do_flatten=do_flatten) + + return model, optimizer + + +def setup_bagua_env(): + # init env + os.environ["WORLD_SIZE"] = "1" + os.environ["LOCAL_WORLD_SIZE"] = "1" + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(find_free_port(8000, 8100)) + os.environ["BAGUA_SERVICE_PORT"] = str(find_free_port(9000, 9100)) + os.environ["RANK"] = "0" + os.environ["LOCAL_RANK"] = "0" + + # init bagua distributed process group + torch.cuda.set_device(0) + # TODO: remove this after process group destroy supported + if not bagua.communication.is_initialized(): + bagua.init_process_group() + + +def run(opt, flag_param, device, num_epochs): + model, optimizer = construct_model_and_optimizer(opt, flag_param, device) + + train_model(model, optimizer, device, num_epochs=num_epochs) + return model.parameters() + + +def run_fused(opt, flag_param, device, num_epochs): + model, optimizer = construct_model_and_optimizer(opt, flag_param, device) + optimizer = bagua.contrib.fuse_optimizer(optimizer, do_flatten=True) + + train_model_fused(model, optimizer, device, num_epochs=num_epochs) + return model.parameters(), optimizer._bagua_fused_count + + +def run_with_bagua(opt, flag_param, device, num_epochs, algorithm): + model, optimizer = construct_model_and_optimizer(opt, flag_param, device) + + model, optimizer = bagua_init(model, optimizer, algorithm, do_flatten=True) + + train_model(model, optimizer, device, num_epochs=num_epochs) + + if algorithm == "async": + model.bagua_algorithm.abort(model) return model.parameters() +def run_fused_with_bagua( + opt, flag_param, device, num_epochs, algorithm, optimizer_flatten, bagua_flatten +): + model, optimizer = construct_model_and_optimizer(opt, flag_param, device) + + # First fuse optimizer, then wrap module + optimizer = bagua.contrib.fuse_optimizer(optimizer, do_flatten=optimizer_flatten) + model, optimizer = bagua_init(model, optimizer, algorithm, bagua_flatten) + + train_model_fused(model, optimizer, device, num_epochs=num_epochs) + # torch.cuda.current_stream().synchronize() + if algorithm == "async": + model.bagua_algorithm.abort(model) + # torch.cuda.synchronize() + return model.parameters(), optimizer._bagua_fused_count + + +def run_fused_with_bagua_v2( + opt, flag_param, device, num_epochs, algorithm, optimizer_flatten, bagua_flatten +): + model, optimizer = construct_model_and_optimizer(opt, flag_param, device) + + # First wrap module, then fuse optimizer + model, optimizer = bagua_init(model, optimizer, algorithm, bagua_flatten) + optimizer = bagua.contrib.fuse_optimizer(optimizer, do_flatten=optimizer_flatten) + + train_model_fused(model, optimizer, device, num_epochs=num_epochs) + + if algorithm == "async": + model.bagua_algorithm.abort(model) + return model.parameters(), optimizer._bagua_fused_count + + class TestFusedOptimizer(unittest.TestCase): - def run_all_optimizers_once(self, wrap, device): + def run_qadam( + self, device, num_epochs, fused_count, optimizer_flatten, bagua_flatten + ): + res1 = run_with_bagua( + optim.SGD, + dict(lr=0.01), + device=device, + num_epochs=num_epochs, + algorithm="qadam", + ) + res2, cnt2 = run_fused_with_bagua_v2( + optim.SGD, + dict(lr=0.01), + device=device, + num_epochs=num_epochs, + algorithm="qadam", + optimizer_flatten=optimizer_flatten, + bagua_flatten=bagua_flatten, + ) + + for p1, p2 in zip(res1, res2): + self.assertTrue(torch.equal(p1, p2)) + self.assertTrue(cnt2 == fused_count) + + def run_all_optimizers_once(self, fn1, fn2, device, num_epochs, fused_count): optimizer_list = [ + optim.SGD, + optim.SGD, optim.Adam, optim.Adam, optim.Adam, @@ -72,8 +231,6 @@ def run_all_optimizers_once(self, wrap, device): optim.AdamW, optim.AdamW, optim.AdamW, - optim.SGD, - optim.SGD, optim.RMSprop, optim.RMSprop, optim.RMSprop, @@ -88,6 +245,10 @@ def run_all_optimizers_once(self, wrap, device): ] flag_params = [ + dict(lr=0.2, momentum=1, dampening=0, weight_decay=1, nesterov=True), # SGD + dict( + lr=0.2, momentum=1, dampening=0.5, weight_decay=1, nesterov=False + ), # SGD dict(weight_decay=1.0, amsgrad=True), # Adam dict(weight_decay=1.0, amsgrad=False), # Adam dict(weight_decay=0.0, amsgrad=True), # Adam @@ -96,10 +257,6 @@ def run_all_optimizers_once(self, wrap, device): dict(weight_decay=1.0, amsgrad=False), # AdamW dict(weight_decay=0.0, amsgrad=True), # AdamW dict(weight_decay=0.0, amsgrad=False), # AdamW - dict(lr=0.2, momentum=1, dampening=0, weight_decay=1, nesterov=True), # SGD - dict( - lr=0.2, momentum=1, dampening=0.5, weight_decay=1, nesterov=False - ), # SGD dict(weight_decay=1, momentum=1, centered=True), # RMSprop dict(weight_decay=1, momentum=0, centered=True), # RMSprop dict(weight_decay=1, momentum=1, centered=False), # RMSprop @@ -113,34 +270,210 @@ def run_all_optimizers_once(self, wrap, device): dict(weight_decay=1), # Adadelta ] + count = 0 for opt, flag_param in zip(optimizer_list, flag_params): - res1 = run_step(opt, flag_param, fuse=True, wrap=wrap, device=device) - res2 = run_step(opt, flag_param, fuse=False, wrap=wrap, device=device) + res1 = fn1(opt, flag_param, device=device, num_epochs=num_epochs) + res2, cnt2 = fn2(opt, flag_param, device=device, num_epochs=num_epochs) for p1, p2 in zip(res1, res2): self.assertTrue(torch.equal(p1, p2)) + self.assertTrue(cnt2 == fused_count) + + count += 1 + if count % 5 == 0: + logging.info(f"Tests Passed [{count}/{len(optimizer_list)}]") + + def run_fused_with_bagua_wrapper(self, fn1, fn2, num_epochs, fused_count): + self.run_all_optimizers_once(fn1, fn2, "cuda:0", num_epochs, fused_count) @skip_if_cuda_available() def test_fused_optimizer(self): - self.run_all_optimizers_once(device="cpu", wrap=False) + self.run_all_optimizers_once( + fn1=run, fn2=run_fused, device="cpu", num_epochs=101, fused_count=51 + ) @skip_if_cuda_not_available() - def test_fused_optimizer_with_bagua_wrapper(self): - # init env - os.environ["WORLD_SIZE"] = "1" - os.environ["LOCAL_WORLD_SIZE"] = "1" - os.environ["MASTER_ADDR"] = "127.0.0.1" - os.environ["MASTER_PORT"] = str(find_free_port(8000, 8100)) - os.environ["BAGUA_SERVICE_PORT"] = str(find_free_port(9000, 9100)) - - os.environ["RANK"] = "0" - os.environ["LOCAL_RANK"] = "0" - - # init bagua distributed process group - torch.cuda.set_device(0) - bagua.init_process_group() + def test_gradient_allreduce(self): + setup_bagua_env() + # check: optimizer param groups is flattened, should fuse + self.run_fused_with_bagua_wrapper( + fn1=run, + fn2=lambda p1, p2, device, num_epochs: run_fused_with_bagua( + p1, p2, device, num_epochs, "gradient_allreduce", True, False + ), + num_epochs=101, + fused_count=51, + ) + # check: both are falttened, should not fuse + self.run_fused_with_bagua_wrapper( + fn1=run, + fn2=lambda p1, p2, device, num_epochs: run_fused_with_bagua( + p1, p2, device, num_epochs, "gradient_allreduce", True, True + ), + num_epochs=101, + fused_count=0, + ) + # check: bagua module is falttened, should not fuse + self.run_fused_with_bagua_wrapper( + fn1=run, + fn2=lambda p1, p2, device, num_epochs: run_fused_with_bagua( + p1, p2, device, num_epochs, "gradient_allreduce", False, True + ), + num_epochs=101, + fused_count=0, + ) + + @skip_if_cuda_not_available() + def test_bytegrad(self): + setup_bagua_env() + # check: optimizer param groups is flattened, should fuse + self.run_fused_with_bagua_wrapper( + fn1=lambda p1, p2, device, num_epochs: run_with_bagua( + p1, p2, device, num_epochs, "bytegrad" + ), + fn2=lambda p1, p2, device, num_epochs: run_fused_with_bagua( + p1, p2, device, num_epochs, "bytegrad", True, False + ), + num_epochs=101, + fused_count=51, + ) + + @skip_if_cuda_not_available() + def test_decentralized(self): + setup_bagua_env() + # check: optimizer param groups is flattened, should fuse + self.run_fused_with_bagua_wrapper( + fn1=run, + fn2=lambda p1, p2, device, num_epochs: run_fused_with_bagua( + p1, p2, device, num_epochs, "decentralized", True, False + ), + num_epochs=101, + fused_count=51, + ) + self.run_fused_with_bagua_wrapper( + fn1=run, + fn2=lambda p1, p2, device, num_epochs: run_fused_with_bagua( + p1, p2, device, num_epochs, "decentralized", True, True + ), + num_epochs=101, + fused_count=0, + ) + self.run_fused_with_bagua_wrapper( + fn1=run, + fn2=lambda p1, p2, device, num_epochs: run_fused_with_bagua( + p1, p2, device, num_epochs, "decentralized", False, True + ), + num_epochs=101, + fused_count=0, + ) + + @skip_if_cuda_not_available() + def test_async(self): + return + setup_bagua_env() + self.run_fused_with_bagua_wrapper( + fn1=run, + fn2=lambda p1, p2, device, num_epochs: run_fused_with_bagua( + p1, p2, device, num_epochs, "async", True, False + ), + num_epochs=101, + fused_count=51, + ) + self.run_fused_with_bagua_wrapper( + fn1=run, + fn2=lambda p1, p2, device, num_epochs: run_fused_with_bagua( + p1, p2, device, num_epochs, "async", False, True + ), + num_epochs=101, + fused_count=0, + ) + + @skip_if_cuda_not_available() + def test_low_prec_decentralized(self): + return + setup_bagua_env() + self.run_fused_with_bagua_wrapper( + fn1=lambda p1, p2, device, num_epochs: run_with_bagua( + p1, p2, device, num_epochs, "low_prec_decentralized" + ), + fn2=lambda p1, p2, device, num_epochs: run_fused_with_bagua( + p1, p2, device, num_epochs, "low_prec_decentralized", True, False + ), + num_epochs=101, + fused_count=51, + ) + + @skip_if_cuda_not_available() + def test_qadam(self): + return + setup_bagua_env() + self.run_qadam( + device="cuda:0", + num_epochs=101, + fused_count=51, + optimizer_flatten=True, + bagua_flatten=False, + ) + + @skip_if_cuda_available() + def test_calculate_mutual_groups(self): + from bagua.torch_api.contrib.fuse.optimizer import calculate_mutual_groups + + tensor = torch.rand(100) + + tensor_pieces = [] + for i in range(10): + tensor_pieces.append(tensor[i * 10 : (i + 1) * 10]) + + g1 = [ + tensor_pieces[3], + tensor_pieces[1], + tensor_pieces[2], + tensor_pieces[0], + tensor_pieces[8], + tensor_pieces[9], + torch.rand(10), + ] + g2 = [torch.rand(10) for _ in range(len(g1))] + g3 = [ + tensor_pieces[3], + tensor_pieces[1], + tensor_pieces[2], + tensor_pieces[0], + torch.rand(10), + torch.rand(10), + torch.rand(10), + ] + g4 = [ + torch.rand(10), + tensor_pieces[1], + tensor_pieces[2], + tensor_pieces[0], + tensor_pieces[8], + tensor_pieces[9], + torch.rand(10), + ] + g5 = [ + tensor_pieces[3], + tensor_pieces[1], + tensor_pieces[2], + tensor_pieces[0], + tensor_pieces[8], + tensor_pieces[9], + torch.rand(10), + ] + + ret = calculate_mutual_groups([g1, g2]) + self.assertTrue(ret == []) + + ret = calculate_mutual_groups([g1, g3]) + self.assertTrue(ret == [[3, 1, 2, 0]]) + + ret = calculate_mutual_groups([g1, g4]) + self.assertTrue(ret == [[4, 5]]) - self.run_all_optimizers_once(device="cuda:0", wrap=True) + ret = calculate_mutual_groups([g1, g5]) + self.assertTrue(ret == [[3, 1, 2, 0], [4, 5]]) if __name__ == "__main__":