Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Bump fairscale from 0.4.0 to 0.4.2 #5461

Merged
merged 4 commits into from
Nov 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 1 addition & 18 deletions allennlp/nn/checkpoint/fairscale_checkpoint_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,10 @@ class FairScaleCheckpointWrapper(CheckpointWrapper):
:class:`allennlp.nn.parallel.fairscale_fsdp_accelerator.FairScaleFsdpAccelerator`.
See the [T5 implementation](/api/modules/transformer/t5/) for an example
of how to use the two together.

!!! Note
If using the `FairScaleFsdpAccelerator`, you need to set `maintain_forward_counter` to `True`.
For convenience, if `maintain_forward_counter` is not set, internally it will be
set to `True` if training in a distributed setup, or `False` otherwise.
"""

def __init__(
self, offload_to_cpu: Optional[bool] = True, maintain_forward_counter: Optional[bool] = None
) -> None:
def __init__(self, offload_to_cpu: Optional[bool] = True) -> None:
self._offload_to_cpu = offload_to_cpu
if maintain_forward_counter is None:
from allennlp.common.util import is_distributed

# Better to assume we need this in the distributed case, since we definitely
# need this when the model is wrapped with FairScale's FSDP.
self._maintain_forward_counter = is_distributed()
else:
self._maintain_forward_counter = maintain_forward_counter

@overrides
def wrap_module(
Expand All @@ -50,6 +35,4 @@ def wrap_module(
) -> nn.Module:
if "offload_to_cpu" not in kwargs and self._offload_to_cpu is not None:
kwargs["offload_to_cpu"] = self._offload_to_cpu
if "maintain_forward_counter" not in kwargs and self._maintain_forward_counter is not None:
kwargs["maintain_forward_counter"] = self._maintain_forward_counter
return checkpoint_wrapper(module, **kwargs)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
"torch>=1.6.0,<1.11.0",
"torchvision>=0.8.1,<0.12.0",
"cached-path>=0.3.1,<0.4.0",
"fairscale==0.4.0",
"fairscale==0.4.2",
"jsonnet>=0.10.0 ; sys.platform != 'win32'",
"overrides==3.1.0",
"nltk",
Expand Down