Skip to content

Add logger to Weights & Biases #607

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Apr 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added `NeptuneLogger` callback for logging experiment metadata to neptune.ai
- Add DataFrameTransformer, an sklearn compatible transformer that helps working with pandas DataFrames by transforming the DataFrame into a representation that works well with neural networks (#507)
- Added `WandbLogger` callback for logging to Weights & Biases

### Changed

Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ pytest-cov
sphinx
sphinx_rtd_theme
tensorboard>=1.14.0
wandb>=0.8.30
1 change: 1 addition & 0 deletions skorch/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,6 @@
'TrainEndCheckpoint',
'TensorBoard',
'Unfreezer',
'WandbLogger',
'WarmRestartLR',
]
86 changes: 85 additions & 1 deletion skorch/callbacks/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from contextlib import suppress
from numbers import Number
from itertools import cycle
from pathlib import Path

import numpy as np
import tqdm
Expand All @@ -14,7 +15,8 @@
from skorch.dataset import get_len
from skorch.callbacks import Callback

__all__ = ['EpochTimer', 'NeptuneLogger', 'PrintLog', 'ProgressBar', 'TensorBoard']
__all__ = ['EpochTimer', 'NeptuneLogger', 'WandbLogger', 'PrintLog', 'ProgressBar',
'TensorBoard']


def filter_log_keys(keys, keys_ignored=None):
Expand Down Expand Up @@ -205,6 +207,88 @@ def on_train_end(self, net, **kwargs):
if self.close_after_train:
self.experiment.stop()

class WandbLogger(Callback):
"""Logs best model and metrics to `Weights & Biases <https://docs.wandb.com/>`_

Use this callback to automatically log best trained model, all metrics from
your net's history, model topology and computer resources to Weights & Biases
after each epoch.

Every file saved in `wandb_run.dir` is automatically logged to W&B servers.

See `example run
<https://app.wandb.ai/borisd13/skorch/runs/s20or4ct/overview?workspace=user-borisd13>`_

Examples
--------
>>> # Install wandb
... pip install wandb

>>> import wandb
>>> from skorch.callbacks import WandbLogger

>>> # Create a wandb Run
... wandb_run = wandb.init()
>>> # Alternative: Create a wandb Run without having a W&B account
... wandb_run = wandb.init(anonymous="allow)

>>> # Log hyper-parameters (optional)
... wandb_run.config.update({"learning rate": 1e-3, "batch size": 32})

>>> net = NeuralNet(..., callbacks=[WandbLogger(wandb_run)])
>>> net.fit(X, y)

Parameters
----------
wandb_run : wandb.wandb_run.Run
wandb Run used to log data.

save_model : bool (default=True)
Whether to save a checkpoint of the best model and upload it
to your Run on W&B servers.

keys_ignored : str or list of str (default=None)
Key or list of keys that should not be logged to
tensorboard. Note that in addition to the keys provided by the
user, keys such as those starting with 'event_' or ending on
'_best' are ignored by default.
"""

def __init__(
self,
wandb_run,
save_model=True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already provide a checkpoint callback, I think this functionality is redundant.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to log the trained model to W&B.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. How does that work? In this code, I don't see any interaction with W&B:

        # save best model
        if self.save_model and hist['valid_loss_best']:
            model_path = Path(self.wandb_run.dir) / 'best_model.pth'
            with model_path.open('wb') as model_file:
                net.save_params(f_params=model_file)

Is this some code working in the background or is it simply the fact that the model parameters are stored in the wandb_run_dir?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All files stored in wandb_run.dir are automatically saved.
You can see in my example run on the "files" tab

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please leave a comment that states that the files in wandb_run.dir is automatically saved in on_epoch_end.

keys_ignored=None,
):
self.wandb_run = wandb_run
self.save_model = save_model
self.keys_ignored = keys_ignored

def initialize(self):
keys_ignored = self.keys_ignored
if isinstance(keys_ignored, str):
keys_ignored = [keys_ignored]
self.keys_ignored_ = set(keys_ignored or [])
self.keys_ignored_.add('batches')
return self

def on_train_begin(self, net, **kwargs):
"""Log model topology and add a hook for gradients"""
self.wandb_run.watch(net.module_)

def on_epoch_end(self, net, **kwargs):
"""Log values from the last history step and save best model"""
hist = net.history[-1]
keys_kept = filter_log_keys(hist, keys_ignored=self.keys_ignored_)
logged_vals = {k: hist[k] for k in keys_kept}
self.wandb_run.log(logged_vals)

# save best model
if self.save_model and hist['valid_loss_best']:
model_path = Path(self.wandb_run.dir) / 'best_model.pth'
with model_path.open('wb') as model_file:
net.save_params(f_params=model_file)


class PrintLog(Callback):
"""Print useful information from the model's history as a table.
Expand Down
81 changes: 81 additions & 0 deletions skorch/tests/callbacks/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torch import nn

from skorch.tests.conftest import neptune_installed
from skorch.tests.conftest import wandb_installed
from skorch.tests.conftest import tensorboard_installed


Expand Down Expand Up @@ -190,6 +191,86 @@ def test_first_batch_flag(
npt.on_batch_end(net)
assert npt.first_batch_ is False

@pytest.mark.skipif(
not wandb_installed, reason='wandb is not installed')
class TestWandb:
@pytest.fixture
def net_cls(self):
from skorch import NeuralNetClassifier
return NeuralNetClassifier

@pytest.fixture
def data(self, classifier_data):
X, y = classifier_data
# accelerate training since we don't care for the loss
X, y = X[:40], y[:40]
return X, y

@pytest.fixture
def wandb_logger_cls(self):
from skorch.callbacks import WandbLogger
return WandbLogger

@pytest.fixture
def wandb_run_cls(self):
import wandb
os.environ['WANDB_MODE'] = 'dryrun' # run offline
with wandb.init(anonymous="allow") as run:
return run

@pytest.fixture
def mock_run(self):
mock = Mock()
mock.log = Mock()
mock.watch = Mock()
mock.dir = '.'
return mock

def test_ignore_keys(
self,
net_cls,
classifier_module,
data,
wandb_logger_cls,
mock_run,
):
# ignore 'dur' and 'valid_loss', 'unknown' doesn't exist but
# this should not cause a problem
wandb_cb = wandb_logger_cls(
mock_run, keys_ignored=['dur', 'valid_loss', 'unknown'])
net_cls(
classifier_module,
callbacks=[wandb_cb],
max_epochs=3,
).fit(*data)

# 3 epochs = 3 calls
assert mock_run.log.call_count == 3
assert mock_run.watch.call_count == 1
call_args = [args[0][0] for args in mock_run.log.call_args_list]
assert 'valid_loss' not in call_args

def test_keys_ignored_is_string(self, wandb_logger_cls, mock_run):
wandb_cb = wandb_logger_cls(
mock_run, keys_ignored='a-key').initialize()
expected = {'a-key', 'batches'}
assert wandb_cb.keys_ignored_ == expected

def test_fit_with_real_experiment(
self,
net_cls,
classifier_module,
data,
wandb_logger_cls,
wandb_run_cls,
):
net = net_cls(
classifier_module,
callbacks=[wandb_logger_cls(wandb_run_cls)],
max_epochs=5,
)
net.fit(*data)

class TestPrintLog:
@pytest.fixture
def print_log_cls(self):
Expand Down
9 changes: 9 additions & 0 deletions skorch/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,15 @@ def data():
except ImportError:
pass

wandb_installed = False
try:
# pylint: disable=unused-import
import wandb

wandb_installed = True
except ImportError:
pass

pandas_installed = False
try:
# pylint: disable=unused-import
Expand Down