Skip to content

Add logger to Weights & Biases #607

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Apr 5, 2020
Merged
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added `NeptuneLogger` callback for logging experiment metadata to neptune.ai
- Add DataFrameTransformer, an sklearn compatible transformer that helps working with pandas DataFrames by transforming the DataFrame into a representation that works well with neural networks (#507)
- Added `WandbLogger` callback for logging to Weights & Biases

### Changed

Expand Down
1 change: 1 addition & 0 deletions skorch/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,6 @@
'TrainEndCheckpoint',
'TensorBoard',
'Unfreezer',
'WandbLogger',
'WarmRestartLR',
]
74 changes: 73 additions & 1 deletion skorch/callbacks/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from contextlib import suppress
from numbers import Number
from itertools import cycle
from pathlib import Path

import numpy as np
import tqdm
Expand All @@ -14,7 +15,8 @@
from skorch.dataset import get_len
from skorch.callbacks import Callback

__all__ = ['EpochTimer', 'NeptuneLogger', 'PrintLog', 'ProgressBar', 'TensorBoard']
__all__ = ['EpochTimer', 'NeptuneLogger', 'WandbLogger', 'PrintLog', 'ProgressBar',
'TensorBoard']


def filter_log_keys(keys, keys_ignored=None):
Expand Down Expand Up @@ -205,6 +207,76 @@ def on_train_end(self, net, **kwargs):
if self.close_after_train:
self.experiment.stop()

class WandbLogger(Callback):
"""Logs best model and metrics to `Weights & Biases <https://docs.wandb.com/>`_

"Use this callback to automatically log best trained model and all metrics from
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"Use this callback to automatically log best trained model and all metrics from
Use this callback to automatically log best trained model and all metrics from

your net's history to Weights & Biases after each epoch.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should use the docstring to help skorch users who are unfamiliar with W&B to get started quickly. E.g., you could specify what package they must install for this to work, i.e. a pip (or conda) instruction. You should also indicate what kind of setup they need to make beforehand (say, starting a local server).

For a nice example, look at the docstring for NeptuneLogger.


Examples
--------
>>> import wandb
>>> from skorch.callbacks import WandbLogger
>>> wandb_run = wandb.init()
>>> wandb.config.update({"learning rate": 1e-3, "batch size": 32}) # optional
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you indicate what this config update does?

>>> net = NeuralNet(..., callbacks=[WandbLogger(wandb_run)])
>>> net.fit(X, y)

Parameters
----------
wandb_run : wandb.wandb_run.Run
wandb Run used to log data.

save_model : bool (default=True)
Saves best trained model.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Saves best trained model.
Whether to save a checkpoint of the best model.


keys_ignored : str or list of str (default=None)
Key or list of keys that should not be logged to
tensorboard. Note that in addition to the keys provided by the
user, keys such as those starting with 'event_' or ending on
'_best' are ignored by default.
"""

# Record if watch has been called previously (even in another instance)
_watch_called = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really used anywhere? If it is, please move this inside initialize and call it watch_called_.


def __init__(
self,
wandb_run,
save_model=True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already provide a checkpoint callback, I think this functionality is redundant.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to log the trained model to W&B.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. How does that work? In this code, I don't see any interaction with W&B:

        # save best model
        if self.save_model and hist['valid_loss_best']:
            model_path = Path(self.wandb_run.dir) / 'best_model.pth'
            with model_path.open('wb') as model_file:
                net.save_params(f_params=model_file)

Is this some code working in the background or is it simply the fact that the model parameters are stored in the wandb_run_dir?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All files stored in wandb_run.dir are automatically saved.
You can see in my example run on the "files" tab

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please leave a comment that states that the files in wandb_run.dir is automatically saved in on_epoch_end.

keys_ignored=None,
):
self.wandb_run = wandb_run
self.save_model = save_model
self.keys_ignored = keys_ignored
self.model_path = Path(wandb_run.dir) / 'best_model.pth'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't set any arguments in __init__ that are not passed by the user. So either allow them to pass the model_path argument (if that makes any sense) or instead set the model_path inside initialize (and call it model_path_).


def initialize(self):
keys_ignored = self.keys_ignored
if isinstance(keys_ignored, str):
keys_ignored = [keys_ignored]
self.keys_ignored_ = set(keys_ignored or [])
self.keys_ignored_.add('batches')
return self

def on_train_begin(self, net, **kwargs):
"""Log model topology and add a hook for gradients"""
if not WandbLogger._watch_called:
WandbLogger._watch_called = True
self.wandb_run.watch(net.module_)

def on_epoch_end(self, net, **kwargs):
"""Automatically log values from the last history step."""
hist = net.history[-1]
keys_kept = filter_log_keys(hist, keys_ignored=self.keys_ignored_)
logged_vals = dict((k, hist[k]) for k in keys_kept if k in hist)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logged_vals = dict((k, hist[k]) for k in keys_kept if k in hist)
logged_vals = {k: hist[k] for k in keys_kept if k in hist}

self.wandb_run.log(logged_vals)

# save best model
if self.save_model and hist['valid_loss_best']:
with self.model_path.open('wb') as model_file:
net.save_params(f_params=model_file)


class PrintLog(Callback):
"""Print useful information from the model's history as a table.
Expand Down