Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

FairScale integration #5242

Merged
merged 104 commits into from
Jul 19, 2021
Merged
Show file tree
Hide file tree
Changes from 96 commits
Commits
Show all changes
104 commits
Select commit Hold shift + click to select a range
a4d7165
start
epwalsh Apr 27, 2021
8a82679
fix up
epwalsh Apr 27, 2021
88f5206
start
epwalsh Apr 27, 2021
ba8f1ef
fix up
epwalsh Apr 27, 2021
56998d2
Merge branch 'fairscale' of github.com:allenai/allennlp into fairscale
epwalsh Apr 27, 2021
a991275
DdpWrapper
epwalsh Apr 27, 2021
fdff8bb
generalize GradScaler
epwalsh Apr 28, 2021
178c689
OSS
epwalsh Apr 29, 2021
43a86c2
fp16 params
epwalsh Apr 29, 2021
453e75a
idk
epwalsh May 2, 2021
5895bf6
undo CHANGELOG for now
epwalsh May 3, 2021
5738bb4
revert API
epwalsh May 3, 2021
4645a4b
refactor
epwalsh May 3, 2021
dbf502f
CHANGELOG
epwalsh May 3, 2021
91184c8
fixes
epwalsh May 4, 2021
2548b0f
refactor
epwalsh May 4, 2021
8a06ab7
wrap modules
epwalsh May 5, 2021
d12c017
refactor
epwalsh May 5, 2021
6d85e13
fix when no checkpointer
epwalsh May 6, 2021
c606451
fix merge conflicts
epwalsh May 18, 2021
0571b47
fix loading
epwalsh May 19, 2021
338028f
fix cicular import issue
epwalsh May 19, 2021
8e803ae
upgrade fairscale
epwalsh May 20, 2021
04dde82
Merge branch 'main' into fairscale
epwalsh May 26, 2021
3ad797d
fix race condition when extracting files with cached_path
epwalsh May 26, 2021
6d7593b
fix merge conflicts
epwalsh May 27, 2021
9d7730b
better logging
epwalsh May 27, 2021
7cc8f7d
improve logging
epwalsh May 27, 2021
c7c856b
fix
epwalsh May 27, 2021
f8c42a3
fix
epwalsh May 27, 2021
73b5e43
more logging
epwalsh May 27, 2021
f846a1b
keep state_dict tensors on CPU
epwalsh May 27, 2021
44dba38
remove annoying logging
epwalsh May 27, 2021
8756605
set gradients to none
epwalsh May 27, 2021
8ef9fa6
move params to CPU in mixed precision
epwalsh May 27, 2021
59d83a9
fixes
epwalsh May 27, 2021
bc8a819
find_unused_parameters default to False
epwalsh May 28, 2021
6993c0d
add TODO
epwalsh May 28, 2021
5784ccc
add more tests, make grad scaler configurable
epwalsh May 28, 2021
10b1b8c
update with main
epwalsh Jun 3, 2021
bc93d8a
patch models branch temporarily
epwalsh Jun 3, 2021
310503d
wow, good start
epwalsh Jun 3, 2021
09d2a38
fix Dockerfile
epwalsh Jun 3, 2021
d5cdbc1
beam search as a parameter
epwalsh Jun 3, 2021
1d9c052
ignore import error of optional dependencies
epwalsh Jun 3, 2021
ca667fc
fix
epwalsh Jun 3, 2021
e1734c2
add debugging env variables
epwalsh Jun 3, 2021
79c13b5
more debugging
epwalsh Jun 3, 2021
51555fe
update Dockerfile
epwalsh Jun 3, 2021
b4cad9e
increase shared memory for test container
epwalsh Jun 3, 2021
1ac97b3
fix
epwalsh Jun 3, 2021
326e07b
log optional import failures
epwalsh Jun 3, 2021
0ba2474
allow disabling checkpointer
epwalsh Jun 4, 2021
5893fae
fix deadlock when checkpointer disabled
epwalsh Jun 4, 2021
0a436e3
fix unbound variable
epwalsh Jun 4, 2021
01c3b5a
Merge branch 'main' into fairscale
epwalsh Jun 4, 2021
e3cab77
start using nn.Sequential to prepare for activation checkpointing
epwalsh Jun 4, 2021
47d97bc
Revert "start using nn.Sequential to prepare for activation checkpoin…
epwalsh Jun 4, 2021
3bb1287
fix merge conflicts
epwalsh Jun 7, 2021
ed39623
add chkpt wrapper class with default torch implementation
epwalsh Jun 8, 2021
055a7c9
get FairScale activation/grad checkpointing working
epwalsh Jun 9, 2021
9ebb521
pin fairscale to commit
epwalsh Jun 9, 2021
a303b5c
ignore line-too-long in setup
epwalsh Jun 9, 2021
2bceaaa
add xfail test for TorchCheckpointWrapper
epwalsh Jun 9, 2021
49ed5f5
fix bugs
epwalsh Jun 9, 2021
bc956bc
update fairscale commit
epwalsh Jun 11, 2021
747de54
use mixin class instead of flags
epwalsh Jun 11, 2021
db95c99
clean up APIs for wrapper classes
epwalsh Jun 11, 2021
313c252
clean up test
epwalsh Jun 11, 2021
048c300
Merge branch 'main' into fairscale
epwalsh Jun 11, 2021
cdb6768
fix merge conflicts
epwalsh Jun 15, 2021
a2bbfa0
fix merge conflicts
epwalsh Jun 21, 2021
5581df8
add Adafactor optimizer
epwalsh Jun 22, 2021
e799eac
implement state checkpointing
epwalsh Jun 23, 2021
d0aa97a
fix
epwalsh Jun 23, 2021
baf796b
fix test
epwalsh Jun 24, 2021
54fd6ca
Merge branch 'main' into fairscale
epwalsh Jun 24, 2021
f190030
update FairScale commit pin
epwalsh Jun 24, 2021
228b73b
fix merge conflicts
epwalsh Jun 24, 2021
5cfb722
add Module class
epwalsh Jun 24, 2021
a37b53d
Merge branch 'main' into fairscale
epwalsh Jun 28, 2021
587c228
doc fixes
epwalsh Jun 29, 2021
52fb7fd
improve repr method of IncompatibleKeys
epwalsh Jun 29, 2021
7a5fd41
clean up FSDP tests
epwalsh Jun 29, 2021
4078d46
Merge branch 'main' into fairscale
epwalsh Jun 29, 2021
d8fa9bb
changelog clean up
epwalsh Jun 29, 2021
6072a6a
Merge branch 'main' into fairscale
dirkgr Jul 1, 2021
afc81c6
make fairscale a required dependency
epwalsh Jul 7, 2021
173828f
rename 'get_grad_scaler' -> 'init_grad_scaler'
epwalsh Jul 8, 2021
0bc1d19
make hooks private methods
epwalsh Jul 8, 2021
5258dc8
make _post_load_state_dict pure
epwalsh Jul 8, 2021
7a130cb
fix comment
epwalsh Jul 8, 2021
3378a0c
use hardlink
epwalsh Jul 8, 2021
7dcd9e9
rename DdpWrapper -> DdpAccelerator
epwalsh Jul 8, 2021
984ac6c
Merge branch 'main' into fairscale
epwalsh Jul 8, 2021
33496e2
format fix
epwalsh Jul 8, 2021
90757a9
update FairScale to latest release
epwalsh Jul 14, 2021
b82f027
fix GradientDescientTrainer.get_best_weights_path
epwalsh Jul 14, 2021
50db06c
fix typo
epwalsh Jul 14, 2021
920ef23
Merge branch 'main' into fairscale
epwalsh Jul 19, 2021
a0a239e
clarify docstring
epwalsh Jul 19, 2021
b62b0c3
Merge branch 'main' into fairscale
epwalsh Jul 19, 2021
b84cf85
update CHANGELOG
epwalsh Jul 19, 2021
2436671
revert CI patch
epwalsh Jul 19, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ jobs:
. .venv/bin/activate
git clone https://github.com/allenai/allennlp-models.git
cd allennlp-models
# TODO: remove
git checkout fairscale
pip install --upgrade --upgrade-strategy eager -e . -r dev-requirements.txt

- name: Debug info
Expand Down
20 changes: 20 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,26 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `TransformerTextField`, for cases where you don't care about AllenNLP's advanced text handling capabilities.
- Added `TransformerModule._post_load_pretrained_state_dict_hook()` method. Can be used to modify `missing_keys` and `unexpected_keys` after
loading a pretrained state dictionary. This is useful when tying weights, for example.
- Added a module `allennlp.nn.parallel` with a new base class, `DdpAccelerator`, which generalizes
PyTorch's `DistributedDataParallel` wrapper to support other implementations. Two implementations of
this class are provided. The default is `TorchDdpAccelerator` (registered at "torch"), which is just a thin wrapper around
`DistributedDataParallel`. The other is `FairScaleFsdpAccelerator`, which wraps FairScale's
[`FullyShardedDataParallel`](https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html).
You can specify the `DdpAccelerator` in the "distributed" section of a configuration file under the key "ddp_accelerator".
- Added a module `allennlp.nn.checkpoint` with a new base class, `CheckpointWrapper`, for implementations
of activation/gradient checkpointing. Two implentations are provided. The default implementation is `TorchCheckpointWrapper` (registered as "torch"),
which exposes [PyTorch's checkpoint functionality](https://pytorch.org/docs/stable/checkpoint.html).
The other is `FairScaleCheckpointWrapper` which exposes the more flexible
[checkpointing funtionality from FairScale](https://fairscale.readthedocs.io/en/latest/api/nn/checkpoint/checkpoint_activations.html).
- The `Model` base class now takes a `ddp_accelerator` parameter (an instance of `DdpAccelerator`) which will be available as
`self.ddp_accelerator` during distributed training. This is useful when, for example, instantiating submodules in your
model's `__init__()` method by wrapping them with `self.ddp_accelerator.wrap_module()`. See the `allennlp.modules.transformer.t5`
for an example.
- Added an end-to-end test for the Transformer Toolkit.

### Fixed

- Fixed a mispelling: the parameter `contructor_extras` in `Lazy()` is now correctly called `constructor_extras`.
- Fixed Broken link in `allennlp.fairness.fairness_metrics.Separation` docs
- Ensured all `allennlp` submodules are imported with `allennlp.common.plugins.import_plugins()`.
- Fixed `IndexOutOfBoundsException` in `MultiOptimizer` when checking if optimizer received any parameters.
Expand All @@ -38,6 +54,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed

- Changed behavior of `MultiOptimizer` so that while a default optimizer is still required, an error is not thrown if the default optimizer receives no parameters.
- The type of the `grad_norm` parameter of `GradientDescentTrainer` is now `Union[float, bool]`,
with a default value of `False`. `False` means gradients are not rescaled and the gradient
norm is never even calculated. `True` means the gradients are still not rescaled but the gradient
norm is calculated and passed on to callbacks. A `float` value means gradients are rescaled.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't seen the code yet, but I'm not too wild about this API. That means you have to know whether some other component needs the gradient norm or not. I'd rather provide a function called get_grad_norm() or something like that, which calculates it lazily.

- Made the epsilon parameter for the layer normalization in token embeddings configurable.

### Removed
Expand Down
5 changes: 5 additions & 0 deletions Dockerfile.test
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ ENV NVIDIA_VISIBLE_DEVICES all
ENV NVIDIA_DRIVER_CAPABILITIES compute,utility
LABEL com.nvidia.volumes.needed="nvidia_driver"

# These environment variables are helpful for debugging.
# See https://pytorch.org/docs/stable/distributed.html#common-environment-variables for more info.
ENV NCCL_DEBUG INFO
ENV NCCL_DEBUG_SUBSYS ALL

WORKDIR /stage/allennlp

# Install torch ecosystem first. This build arg should be in the form of a version requirement,
Expand Down
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ install :
pip install --upgrade --upgrade-strategy eager -e . -r dev-requirements.txt
# These nltk packages are used by the 'checklist' module.
$(NLTK_DOWNLOAD_CMD)

#
# Documention helpers.
#
Expand Down Expand Up @@ -175,4 +176,4 @@ docker-test-image :

.PHONY : docker-test-run
docker-test-run :
$(DOCKER_RUN_CMD) $(DOCKER_GPUS) $(DOCKER_TEST_IMAGE_NAME) $(ARGS)
$(DOCKER_RUN_CMD) --shm-size 2G $(DOCKER_GPUS) $(DOCKER_TEST_IMAGE_NAME) $(ARGS)
78 changes: 60 additions & 18 deletions allennlp/commands/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from allennlp.data import DatasetReader, Vocabulary
from allennlp.data import DataLoader
from allennlp.models.archival import archive_model, CONFIG_NAME, verify_include_in_archive
from allennlp.models.model import _DEFAULT_WEIGHTS, Model
from allennlp.models.model import Model
from allennlp.nn.parallel import DdpAccelerator
from allennlp.training.trainer import Trainer
from allennlp.training import util as training_util

Expand Down Expand Up @@ -131,6 +132,7 @@ def train_model_from_file(
include_package: List[str] = None,
dry_run: bool = False,
file_friendly_logging: bool = False,
return_model: Optional[bool] = None,
) -> Optional[Model]:
"""
A wrapper around [`train_model`](#train_model) which loads the params from a file.
Expand Down Expand Up @@ -160,11 +162,16 @@ def train_model_from_file(
file_friendly_logging : `bool`, optional (default=`False`)
If `True`, we add newlines to tqdm output, even on an interactive terminal, and we slow
down tqdm's output to only once every 10 seconds.
return_model : `Optional[bool]`, optional (default = `None`)
Whether or not to return the final model. If not specified, this defaults to `False` for
distributed training and `True` otherwise.

# Returns

best_model : `Optional[str]`
The path to the archived model with the best weights or `None` if in dry run.
best_model : `Optional[Model]`
The model with the best epoch weights or `None` if in dry run.
The model with the best epoch weights or `None`, depending on the value of `return_model` and `dry_run`.
"""
# Load the experiment config from a file and pass it to `train_model`.
params = Params.from_file(parameter_filename, overrides)
Expand All @@ -177,6 +184,7 @@ def train_model_from_file(
include_package=include_package,
dry_run=dry_run,
file_friendly_logging=file_friendly_logging,
return_model=return_model,
)


Expand All @@ -189,6 +197,7 @@ def train_model(
include_package: List[str] = None,
dry_run: bool = False,
file_friendly_logging: bool = False,
return_model: Optional[bool] = None,
) -> Optional[Model]:
"""
Trains the model specified in the given [`Params`](../common/params.md#params) object, using the data
Expand Down Expand Up @@ -216,11 +225,14 @@ def train_model(
file_friendly_logging : `bool`, optional (default=`False`)
If `True`, we add newlines to tqdm output, even on an interactive terminal, and we slow
down tqdm's output to only once every 10 seconds.
return_model : `Optional[bool]`, optional (default = `None`)
Whether or not to return the final model. If not specified, this defaults to `False` for
distributed training and `True` otherwise.

# Returns

best_model : `Optional[Model]`
The model with the best epoch weights or `None` if in dry run.
The model with the best epoch weights or `None`, depending on the value of `return_model` and `dry_run`.
"""
common_logging.FILE_FRIENDLY_LOGGING = file_friendly_logging

Expand All @@ -233,6 +245,8 @@ def train_model(
include_in_archive = params.pop("include_in_archive", None)
verify_include_in_archive(include_in_archive)

model: Optional[Model] = None

distributed_params = params.params.pop("distributed", None)
# If distributed isn't in the config and the config contains strictly
# one cuda device, we just run a single training process.
Expand All @@ -245,11 +259,6 @@ def train_model(
dry_run=dry_run,
file_friendly_logging=file_friendly_logging,
)

if not dry_run:
archive_model(serialization_dir, include_in_archive=include_in_archive)
return model

# Otherwise, we are running multiple processes for training.
else:
common_logging.prepare_global_logging(
Expand Down Expand Up @@ -323,15 +332,22 @@ def train_model(
device_ids,
file_friendly_logging,
include_in_archive,
Params(distributed_params),
),
nprocs=num_procs,
)
if dry_run:
return None
else:
archive_model(serialization_dir, include_in_archive=include_in_archive)
model = Model.load(params, serialization_dir)
return model

if not dry_run:
archive_model(serialization_dir, include_in_archive=include_in_archive)
else:
return None

if return_model is None:
return model # model may or may not be `None`.
elif return_model is True:
return model if model is not None else Model.load(params, serialization_dir)
else:
return None


def _train_worker(
Expand All @@ -347,6 +363,7 @@ def _train_worker(
distributed_device_ids: List[int] = None,
file_friendly_logging: bool = False,
include_in_archive: List[str] = None,
distributed_params: Optional[Params] = None,
) -> Optional[Model]:
"""
Helper to train the configured model/experiment. In distributed mode, this is spawned as a
Expand Down Expand Up @@ -383,6 +400,8 @@ def _train_worker(
down tqdm's output to only once every 10 seconds.
include_in_archive : `List[str]`, optional
Paths relative to `serialization_dir` that should be archived in addition to the default ones.
distributed_params : `Optional[Params]`, optional
Additional distributed params.

# Returns

Expand All @@ -404,8 +423,11 @@ def _train_worker(

include_package = include_package or []

ddp_accelerator: Optional[DdpAccelerator] = None

if distributed:
assert distributed_device_ids is not None
assert distributed_params is not None

# Since the worker is spawned and not forked, the extra imports need to be done again.
# Both the ones from the plugins and the ones from `include_package`.
Expand All @@ -426,16 +448,17 @@ def _train_worker(
# In distributed training, the configured device is always going to be a list.
# The corresponding gpu id for the particular worker is obtained by picking the id
# from the device list with the rank as index
gpu_id = distributed_device_ids[process_rank] # type: ignore
gpu_id = int(distributed_device_ids[process_rank]) # type: ignore

# Till now, "cuda_device" might not be set in the trainer params.
# But a worker trainer needs to only know about its specific GPU id.
params["trainer"]["local_rank"] = process_rank
params["trainer"]["cuda_device"] = gpu_id
params["trainer"]["world_size"] = world_size
params["trainer"]["distributed"] = True

if gpu_id >= 0:
torch.cuda.set_device(int(gpu_id))
torch.cuda.set_device(gpu_id)
dist.init_process_group(
backend="nccl",
init_method=f"tcp://{primary_addr}:{primary_port}",
Expand All @@ -449,6 +472,16 @@ def _train_worker(
world_size=world_size,
rank=global_rank,
)

if "ddp_accelerator" in distributed_params:
ddp_accelerator_params = distributed_params.pop("ddp_accelerator")
ddp_accelerator = DdpAccelerator.from_params(
ddp_accelerator_params,
local_rank=process_rank,
world_size=world_size,
cuda_device=gpu_id,
)

logging.info(
f"Process group of world size {world_size} initialized "
f"for distributed training in worker {global_rank}"
Expand All @@ -458,6 +491,7 @@ def _train_worker(
params=params,
serialization_dir=serialization_dir,
local_rank=process_rank,
ddp_accelerator=ddp_accelerator,
)

if dry_run:
Expand All @@ -470,7 +504,7 @@ def _train_worker(
metrics = train_loop.run()
except KeyboardInterrupt:
# if we have completed an epoch, try to create a model archive.
if primary and os.path.exists(os.path.join(serialization_dir, _DEFAULT_WEIGHTS)):
if primary:
best_weights_path = train_loop.trainer.get_best_weights_path()
if best_weights_path is None:
logging.info(
Expand Down Expand Up @@ -581,6 +615,7 @@ def from_partial_objects(
test_data_path: Any = None,
evaluate_on_test: bool = False,
batch_weight_key: str = "",
ddp_accelerator: Optional[DdpAccelerator] = None,
) -> "TrainModel":
"""
This method is intended for use with our `FromParams` logic, to construct a `TrainModel`
Expand Down Expand Up @@ -667,6 +702,10 @@ def from_partial_objects(
batch_weight_key: `str`, optional (default=`""`)
The name of metric used to weight the loss on a per-batch basis. This is only used
during evaluation on final test data, if you've specified `evaluate_on_test=True`.

ddp_accelerator : `Optional[DdpAccelerator]`, optional (default = `None`)
A `DdpAccelerator` to use in distributed trainer. Passed to the model and the trainer.

"""
# Train data loader.
data_loaders: Dict[str, DataLoader] = {
Expand Down Expand Up @@ -724,7 +763,9 @@ def from_partial_objects(

vocabulary_ = vocabulary.construct(instances=instance_generator)

model_ = model.construct(vocab=vocabulary_, serialization_dir=serialization_dir)
model_ = model.construct(
vocab=vocabulary_, serialization_dir=serialization_dir, ddp_accelerator=ddp_accelerator
)

# Initializing the model can have side effect of expanding the vocabulary.
# Save the vocab only in the primary. In the degenerate non-distributed
Expand All @@ -744,6 +785,7 @@ def from_partial_objects(
data_loader=data_loaders["train"],
validation_data_loader=data_loaders.get("validation"),
local_rank=local_rank,
ddp_accelerator=ddp_accelerator,
)
assert trainer_ is not None

Expand Down
2 changes: 1 addition & 1 deletion allennlp/common/from_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def construct_arg(

value_cls = args[0]
subextras = create_extras(value_cls, extras)
return Lazy(value_cls, params=deepcopy(popped_params), contructor_extras=subextras) # type: ignore
return Lazy(value_cls, params=deepcopy(popped_params), constructor_extras=subextras) # type: ignore

# For any other kind of iterable, we will just assume that a list is good enough, and treat
# it the same as List. This condition needs to be at the end, so we don't catch other kinds
Expand Down
6 changes: 4 additions & 2 deletions allennlp/common/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,13 @@ def __init__(
self,
constructor: Union[Type[T], Callable[..., T]],
params: Optional[Params] = None,
contructor_extras: Optional[Dict[str, Any]] = None,
constructor_extras: Optional[Dict[str, Any]] = None,
**kwargs,
) -> None:
self._constructor = constructor
self._params = params or Params({})
self._constructor_extras = contructor_extras or {}
self._constructor_extras = constructor_extras or {}
self._constructor_extras.update(kwargs)

@property
def constructor(self) -> Callable[..., T]:
Expand Down
2 changes: 1 addition & 1 deletion allennlp/common/testing/model_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def ensure_model_can_train_save_and_load(
"""
save_dir = self.TEST_DIR / "save_and_load_test"
archive_file = save_dir / "model.tar.gz"
model = train_model_from_file(param_file, save_dir, overrides=overrides)
model = train_model_from_file(param_file, save_dir, overrides=overrides, return_model=True)
assert model is not None

metrics_file = save_dir / "metrics.json"
Expand Down
14 changes: 9 additions & 5 deletions allennlp/common/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,12 +418,14 @@ def peak_gpu_memory() -> Dict[int, int]:
if not torch.cuda.is_available():
return {}

device = torch.cuda.current_device()

results_dict: Dict[int, int] = {}
if is_distributed():
# If the backend is not 'nccl', we're training on CPU.
if dist.get_backend() != "nccl":
return {}

device = torch.cuda.current_device()
global_rank = dist.get_rank()
world_size = dist.get_world_size()
peak_bytes = torch.cuda.max_memory_allocated(device)
Expand All @@ -433,13 +435,15 @@ def peak_gpu_memory() -> Dict[int, int]:

dist.all_gather(gather_results, peak_bytes_tensor)

results_dict: Dict[int, int] = {}
for peak_bytes_tensor in gather_results:
results_dict[int(peak_bytes_tensor[0])] = int(peak_bytes_tensor[1])

return results_dict
else:
return {0: torch.cuda.max_memory_allocated()}
results_dict = {0: torch.cuda.max_memory_allocated()}

# Reset peak stats.
torch.cuda.reset_max_memory_allocated(device)

return results_dict


def ensure_list(iterable: Iterable[A]) -> List[A]:
Expand Down
Loading