Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Fix W&B callback for distributed training #5223

Merged
merged 4 commits into from
May 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- When `PretrainedTransformerIndexer` folds long sequences, it no longer loses the information from token type ids.
- Fixed documentation for `GradientDescentTrainer.cuda_device`.
- Fixed `wandb` callback to work in distributed training.


## [v2.4.0](https://github.com/allenai/allennlp/releases/tag/v2.4.0) - 2021-04-22
Expand Down
2 changes: 1 addition & 1 deletion allennlp/training/callbacks/log_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def on_batch(
batch_grad_norm: Optional[float] = None,
**kwargs,
) -> None:
if not is_training and not is_primary:
if not is_training or not is_primary:
return None
assert self.trainer is not None

Expand Down
34 changes: 23 additions & 11 deletions allennlp/training/callbacks/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,7 @@ def __init__(

self._watch_model = watch_model
self._files_to_save = files_to_save

import wandb

self.wandb = wandb
self.wandb.init(
self._wandb_kwargs: Dict[str, Any] = dict(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing wrong with this, just dict(...) is kind of an unusual way of writing {...}.

dir=os.path.abspath(serialization_dir),
project=project,
entity=entity,
Expand All @@ -105,9 +101,6 @@ def __init__(
**(wandb_kwargs or {}),
)

for fpath in self._files_to_save:
self.wandb.save(os.path.join(serialization_dir, fpath), base_path=serialization_dir)

@overrides
def log_scalars(
self,
Expand All @@ -122,7 +115,7 @@ def log_tensors(
self, tensors: Dict[str, torch.Tensor], log_prefix: str = "", epoch: Optional[int] = None
) -> None:
self._log(
{k: self.wandb.Histogram(v.cpu().data.numpy().flatten()) for k, v in tensors.items()},
{k: self.wandb.Histogram(v.cpu().data.numpy().flatten()) for k, v in tensors.items()}, # type: ignore
log_prefix=log_prefix,
epoch=epoch,
)
Expand All @@ -134,12 +127,31 @@ def _log(
dict_to_log = {f"{log_prefix}/{k}": v for k, v in dict_to_log.items()}
if epoch is not None:
dict_to_log["epoch"] = epoch
self.wandb.log(dict_to_log, step=self.trainer._batch_num_total) # type: ignore[union-attr]
self.wandb.log(dict_to_log, step=self.trainer._batch_num_total) # type: ignore

@overrides
def on_start(
self, trainer: "GradientDescentTrainer", is_primary: bool = True, **kwargs
) -> None:
super().on_start(trainer, is_primary=is_primary, **kwargs)

if not is_primary:
return None

import wandb

self.wandb = wandb
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no wandb object? wandb is always global? What if two systems want to use it at the same time?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Importing wandb may have side effects since at some point it spawns its own background worker(s).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is some unfortunate API design.

self.wandb.init(**self._wandb_kwargs)

for fpath in self._files_to_save:
self.wandb.save( # type: ignore
os.path.join(self.serialization_dir, fpath), base_path=self.serialization_dir
)

if self._watch_model:
self.wandb.watch(self.trainer.model) # type: ignore[union-attr]
self.wandb.watch(self.trainer.model) # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this about Item "None" of "Optional[Something]" has no attribute "watch"? I have been fixing that with assert self.wandb is not None.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because MyPy sees it as undefined, not Optional.


@overrides
def close(self) -> None:
super().close()
self.wandb.finish() # type: ignore