Skip to content

Added Neptune logging #586

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Feb 16, 2020
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added `NeptuneLogger` callback for logging experiment metadata to neptune.ai

### Changed

- When using caching in scoring callbacks, no longer uselessly iterate over the data; this can save time if iteration is slow (#552, #557)
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ flaky
future>=0.17.1
jupyter
matplotlib>=2.0.2
neptune-client>=0.4.103
numpydoc
openpyxl
pandas
Expand Down
2 changes: 1 addition & 1 deletion skorch/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .training import *
from .lr_scheduler import *

__all__ = ['Callback', 'EpochTimer', 'PrintLog', 'ProgressBar',
__all__ = ['Callback', 'EpochTimer', 'NeptuneLogger', 'PrintLog', 'ProgressBar',
'LRScheduler', 'WarmRestartLR', 'GradientNormClipping',
'BatchScoring', 'EpochScoring', 'Checkpoint', 'EarlyStopping',
'Freezer', 'Unfreezer', 'Initializer', 'ParamMapper',
Expand Down
140 changes: 137 additions & 3 deletions skorch/callbacks/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from skorch.dataset import get_len
from skorch.callbacks import Callback


__all__ = ['EpochTimer', 'PrintLog', 'ProgressBar', 'TensorBoard']
__all__ = ['EpochTimer', 'NeptuneLogger', 'PrintLog', 'ProgressBar', 'TensorBoard']


def filter_log_keys(keys, keys_ignored=None):
Expand Down Expand Up @@ -62,6 +61,142 @@ def on_epoch_end(self, net, **kwargs):
net.history.record('dur', time.time() - self.epoch_start_time_)


class NeptuneLogger(Callback):
"""Logs results from history to Neptune

Neptune is a lightweight experiment tracking tool.
You can read more about it here: https://neptune.ai

Use this callback to automatically log all interesting values from
your net's history to Neptune.

The best way to log additional information is to log directly to the
experiment object or subclass the ``on_*`` methods.

To monitor resource consumption install psutil

>>> pip install psutil

You can view example experiment logs here:
https://ui.neptune.ai/o/shared/org/skorch-integration/e/SKOR-4/logs

Examples
--------
>>> # Install neptune
>>> pip install neptune-client
>>> # Create a neptune experiment object
>>> import neptune
...
... # We are using api token for an anonymous user.
... # For your projects use the token associated with your neptune.ai account
>>> neptune.init(api_token='eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vdWkubmVwdHVuZS5tbCIsImFwaV9rZXkiOiJiNzA2YmM4Zi03NmY5LTRjMmUtOTM5ZC00YmEwMzZmOTMyZTQifQ==',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you could add the install instruction of neptune, as well as import neptune to the code example.

... project_qualified_name='shared/skorch-integration')
...
... experiment = neptune.create_experiment(
... name='skorch-basic-example',
... params={'max_epochs': 20,
... 'lr': 0.01},
... upload_source_files=['skorch_example.py'])

>>> # Create a neptune_logger callback
>>> neptune_logger = NeptuneLogger(experiment, close_after_train=False)

>>> # Pass a logger to net callbacks argument
>>> net = NeuralNetClassifier(
... ClassifierModule,
... max_epochs=20,
... lr=0.01,
... callbacks=[neptune_logger])

>>> # Log additional metrics after training has finished
>>> from sklearn.metrics import roc_auc_score
... y_pred = net.predict_proba(X)
... auc = roc_auc_score(y, y_pred[:, 1])
...
... neptune_logger.experiment.log_metric('roc_auc_score', auc)

>>> # log charts like ROC curve
... from scikitplot.metrics import plot_roc
... import matplotlib.pyplot as plt
...
... fig, ax = plt.subplots(figsize=(16, 12))
... plot_roc(y, y_pred, ax=ax)
... neptune_logger.experiment.log_image('roc_curve', fig)

>>> # log net object after training
... net.save_params(f_params='basic_model.pkl')
... neptune_logger.experiment.log_artifact('basic_model.pkl')

>>> # close experiment
... neptune_logger.experiment.stop()

Parameters
----------
experiment : neptune.experiments.Experiment
Instantiated ``Experiment`` class.

log_on_batch_end : bool (default=False)
Whether to log loss and other metrics on batch level.

close_after_train : bool (default=True)
Whether to close the ``Experiment`` object once training
finishes. Set this parameter to False if you want to continue
logging to the same Experiment or if you use it as a context
manager.

keys_ignored : str or list of str (default=None)
Key or list of keys that should not be logged to
Neptune. Note that in addition to the keys provided by the
user, keys such as those starting with 'event_' or ending on
'_best' are ignored by default.

.. _Neptune: https://www.neptune.ai

"""

def __init__(
self,
experiment,
log_on_batch_end=False,
close_after_train=True,
keys_ignored=None,
):
self.experiment = experiment
self.log_on_batch_end = log_on_batch_end
self.close_after_train = close_after_train
self.keys_ignored = keys_ignored

def initialize(self):
self.first_batch_ = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is first_batch_ used?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is for consistency with the TensorBoard callback. It is convenient to have so that you can, e.g., log an image of the network graph exactly once. You may not be able to use on_train_begin for this because that one gets the input X, not the one that is returned by the data loader.

Copy link
Contributor Author

@jakubczakon jakubczakon Feb 14, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I simply copied it from TensorBoard (to be honest I haven't thought about it much).

Also, if I were to use it properly I should have self.first_batch_ = False on on_batch_end which is missing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added self.first_batch_ = False to on_batch_end but I can easily drop it from both as they are not used (to my understanding)

What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main reason I wanted to have it for TensorBoard was to be able to trace and add a graph of the network to TensorBoard. I think that option doesn't exist for neptune, does it? However, I think consistency is also nice, so I would leave it there.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's document the attribute in the docstring and add a quick test for first_batch_?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good idea.


keys_ignored = self.keys_ignored
if isinstance(keys_ignored, str):
keys_ignored = [keys_ignored]
self.keys_ignored_ = set(keys_ignored or [])
self.keys_ignored_.add('batches')
return self

def on_batch_end(self, net, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we really need batch level logging. Maybe logging at epoch level is sufficient? At least, I think it would make sense to allow to turn off batch level logging through a parameter.

Copy link
Contributor Author

@jakubczakon jakubczakon Feb 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I often find having batch-level logging valuable but I agree there should be an option to turn it off.
Added it.

if self.log_on_batch_end:
batch_logs = net.history[-1]['batches'][-1]

for key in filter_log_keys(batch_logs.keys(), self.keys_ignored_):
self.experiment.log_metric(key, batch_logs[key])

def on_epoch_end(self, net, **kwargs):
"""Automatically log values from the last history step."""
history = net.history
epoch_logs = history[-1]
epoch = epoch_logs['epoch']

for key in filter_log_keys(epoch_logs.keys(), self.keys_ignored_):
self.experiment.log_metric(key, x=epoch, y=epoch_logs[key])

def on_train_end(self, net, **kwargs):
if self.close_after_train:
self.experiment.stop()


class PrintLog(Callback):
"""Print useful information from the model's history as a table.

Expand Down Expand Up @@ -282,7 +417,6 @@ class ProgressBar(Callback):

>>> net.history[-1, 'batches', -1, key]
"""

def __init__(
self,
batches_per_epoch='auto',
Expand Down
162 changes: 157 additions & 5 deletions skorch/tests/callbacks/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,165 @@
from functools import partial
import os
from unittest.mock import Mock
from unittest.mock import patch
from unittest.mock import call, patch

import numpy as np
import pytest
import torch
from torch import nn

from skorch.tests.conftest import tensorboard_installed
from skorch.tests.conftest import neptune_installed, tensorboard_installed


@pytest.mark.skipif(
not neptune_installed, reason='neptune is not installed')
class TestNeptune:
@pytest.fixture
def net_cls(self):
from skorch import NeuralNetClassifier
return NeuralNetClassifier

@pytest.fixture
def data(self, classifier_data):
X, y = classifier_data
# accelerate training since we don't care for the loss
X, y = X[:40], y[:40]
return X, y

@pytest.fixture
def neptune_logger_cls(self):
from skorch.callbacks import NeptuneLogger
return NeptuneLogger

@pytest.fixture
def neptune_experiment_cls(self):
import neptune
neptune.init(project_qualified_name="tests/dry-run",
backend=neptune.OfflineBackend())
return neptune.create_experiment

@pytest.fixture
def mock_experiment(self, neptune_experiment_cls):
mock = Mock(spec=neptune_experiment_cls)
mock.log_metric = Mock()
mock.stop = Mock()
return mock

@pytest.fixture
def net_fitted(
self,
net_cls,
classifier_module,
data,
neptune_logger_cls,
mock_experiment,
):
return net_cls(
classifier_module,
callbacks=[neptune_logger_cls(mock_experiment)],
max_epochs=3,
).fit(*data)

def test_experiment_closed_automatically(self, net_fitted, mock_experiment):
assert mock_experiment.stop.call_count == 1

def test_experiment_not_closed(
self,
net_cls,
classifier_module,
data,
neptune_logger_cls,
mock_experiment,
):
net_cls(
classifier_module,
callbacks=[
neptune_logger_cls(mock_experiment, close_after_train=False)],
max_epochs=2,
).fit(*data)
assert mock_experiment.stop.call_count == 0

def test_ignore_keys(
self,
net_cls,
classifier_module,
data,
neptune_logger_cls,
mock_experiment,
):
# ignore 'dur' and 'valid_loss', 'unknown' doesn't exist but
# this should not cause a problem
npt = neptune_logger_cls(
mock_experiment, keys_ignored=['dur', 'valid_loss', 'unknown'])
net_cls(
classifier_module,
callbacks=[npt],
max_epochs=3,
).fit(*data)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we assert how many times log_metric should be called in this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure can do.


def test_keys_ignored_is_string(self, neptune_logger_cls, mock_experiment):
npt = neptune_logger_cls(mock_experiment,
keys_ignored='a-key').initialize()
expected = {'a-key', 'batches'}
assert npt.keys_ignored_ == expected

def test_fit_with_real_experiment(
self,
net_cls,
classifier_module,
data,
neptune_logger_cls,
neptune_experiment_cls,
):
net = net_cls(
classifier_module,
callbacks=[neptune_logger_cls(neptune_experiment_cls())],
max_epochs=5,
)
net.fit(*data)

def test_log_on_batch_level_on(
self,
net_cls,
classifier_module,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

argument not used (probably copied from tensorboard test that also doesn't need it). Tbh, I think this whole test can be removed here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dropped it

data,
neptune_logger_cls,
mock_experiment,
):
net = net_cls(
classifier_module,
callbacks=[neptune_logger_cls(mock_experiment, log_on_batch_end=True)],
max_epochs=5,
batch_size=4,
train_split=False
)
net.fit(*data)

# 5 epochs x (40/4 batches x 2 batch metrics + 2 epoch metrics) = 110 calls
assert mock_experiment.log_metric.call_count == 110
mock_experiment.log_metric.assert_any_call('train_batch_size', 4)

def test_log_on_batch_level_off(
self,
net_cls,
classifier_module,
data,
neptune_logger_cls,
mock_experiment,
):
net = net_cls(
classifier_module,
callbacks=[neptune_logger_cls(mock_experiment, log_on_batch_end=False)],
max_epochs=5,
batch_size=4,
train_split=False
)
net.fit(*data)

# 5 epochs x 2 epoch metrics = 10 calls
assert mock_experiment.log_metric.call_count == 10
assert call('train_batch_size', 4) \
not in mock_experiment.log_metric.call_args_list


class TestPrintLog:
Expand Down Expand Up @@ -42,6 +193,7 @@ def odd_epoch_callback(self):
class OddEpochCallback(Callback):
def on_epoch_end(self, net, **kwargs):
net.history[-1]['event_odd'] = bool(len(net.history) % 2)

return OddEpochCallback().initialize()

@pytest.fixture
Expand Down Expand Up @@ -174,8 +326,8 @@ def test_with_event_key(self, history, print_log_cls):

odd_row = print_log.sink.call_args_list[2][0][0].split()
even_row = print_log.sink.call_args_list[3][0][0].split()
assert len(odd_row) == 6 # odd row has entries in every column
assert odd_row[4] == '+' # including '+' sign for the 'event_odd'
assert len(odd_row) == 6 # odd row has entries in every column
assert odd_row[4] == '+' # including '+' sign for the 'event_odd'
assert len(even_row) == 5 # even row does not have 'event_odd' entry

def test_witout_valid_data(
Expand Down Expand Up @@ -501,7 +653,7 @@ def test_fit_with_dict_input(
X, y = data

# create a dictionary with unordered keys
X_dict = {k: X[:, i:i+4] for k, i in zip('cebad', range(0, X.shape[1], 4))}
X_dict = {k: X[:, i:i + 4] for k, i in zip('cebad', range(0, X.shape[1], 4))}

class MyModule(MLPModule):
# use different order for args here
Expand Down
Loading