-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Fix W&B callback for distributed training #5223
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -88,11 +88,7 @@ def __init__( | |
|
||
self._watch_model = watch_model | ||
self._files_to_save = files_to_save | ||
|
||
import wandb | ||
|
||
self.wandb = wandb | ||
self.wandb.init( | ||
self._wandb_kwargs: Dict[str, Any] = dict( | ||
dir=os.path.abspath(serialization_dir), | ||
project=project, | ||
entity=entity, | ||
|
@@ -105,9 +101,6 @@ def __init__( | |
**(wandb_kwargs or {}), | ||
) | ||
|
||
for fpath in self._files_to_save: | ||
self.wandb.save(os.path.join(serialization_dir, fpath), base_path=serialization_dir) | ||
|
||
@overrides | ||
def log_scalars( | ||
self, | ||
|
@@ -122,7 +115,7 @@ def log_tensors( | |
self, tensors: Dict[str, torch.Tensor], log_prefix: str = "", epoch: Optional[int] = None | ||
) -> None: | ||
self._log( | ||
{k: self.wandb.Histogram(v.cpu().data.numpy().flatten()) for k, v in tensors.items()}, | ||
{k: self.wandb.Histogram(v.cpu().data.numpy().flatten()) for k, v in tensors.items()}, # type: ignore | ||
log_prefix=log_prefix, | ||
epoch=epoch, | ||
) | ||
|
@@ -134,12 +127,31 @@ def _log( | |
dict_to_log = {f"{log_prefix}/{k}": v for k, v in dict_to_log.items()} | ||
if epoch is not None: | ||
dict_to_log["epoch"] = epoch | ||
self.wandb.log(dict_to_log, step=self.trainer._batch_num_total) # type: ignore[union-attr] | ||
self.wandb.log(dict_to_log, step=self.trainer._batch_num_total) # type: ignore | ||
|
||
@overrides | ||
def on_start( | ||
self, trainer: "GradientDescentTrainer", is_primary: bool = True, **kwargs | ||
) -> None: | ||
super().on_start(trainer, is_primary=is_primary, **kwargs) | ||
|
||
if not is_primary: | ||
return None | ||
|
||
import wandb | ||
|
||
self.wandb = wandb | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is no There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Importing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is some unfortunate API design. |
||
self.wandb.init(**self._wandb_kwargs) | ||
|
||
for fpath in self._files_to_save: | ||
self.wandb.save( # type: ignore | ||
os.path.join(self.serialization_dir, fpath), base_path=self.serialization_dir | ||
) | ||
|
||
if self._watch_model: | ||
self.wandb.watch(self.trainer.model) # type: ignore[union-attr] | ||
self.wandb.watch(self.trainer.model) # type: ignore | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this about There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's because MyPy sees it as undefined, not |
||
|
||
@overrides | ||
def close(self) -> None: | ||
super().close() | ||
self.wandb.finish() # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nothing wrong with this, just
dict(...)
is kind of an unusual way of writing{...}
.