From 79ead54d61e673874a88769de73fe7fbabb28033 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Mon, 9 Mar 2020 19:05:31 -0500 Subject: [PATCH 01/13] feat(wandb): add logger to Weights & Biases --- skorch/callbacks/__init__.py | 2 +- skorch/callbacks/logging.py | 78 +++++++++++++++++++++++++++++++++++- 2 files changed, 78 insertions(+), 2 deletions(-) diff --git a/skorch/callbacks/__init__.py b/skorch/callbacks/__init__.py index c5830b2a2..ab119b5a9 100644 --- a/skorch/callbacks/__init__.py +++ b/skorch/callbacks/__init__.py @@ -17,4 +17,4 @@ 'LRScheduler', 'WarmRestartLR', 'GradientNormClipping', 'BatchScoring', 'EpochScoring', 'Checkpoint', 'EarlyStopping', 'Freezer', 'Unfreezer', 'Initializer', 'ParamMapper', - 'LoadInitState', 'TrainEndCheckpoint'] + 'LoadInitState', 'TrainEndCheckpoint', 'WandbLogger'] diff --git a/skorch/callbacks/logging.py b/skorch/callbacks/logging.py index 78375db81..2ffcf1045 100644 --- a/skorch/callbacks/logging.py +++ b/skorch/callbacks/logging.py @@ -5,6 +5,7 @@ from contextlib import suppress from numbers import Number from itertools import cycle +from pathlib import Path import numpy as np import tqdm @@ -14,7 +15,8 @@ from skorch.dataset import get_len from skorch.callbacks import Callback -__all__ = ['EpochTimer', 'NeptuneLogger', 'PrintLog', 'ProgressBar', 'TensorBoard'] +__all__ = ['EpochTimer', 'NeptuneLogger', 'WandbLogger', 'PrintLog', 'ProgressBar', + 'TensorBoard'] def filter_log_keys(keys, keys_ignored=None): @@ -205,6 +207,80 @@ def on_train_end(self, net, **kwargs): if self.close_after_train: self.experiment.stop() +class WandbLogger(Callback): + """Logs best model and metrics to `Weights & Biases `_ + + "Use this callback to automatically log best trained model and all metrics from + your net's history to Weights & Biases after each epoch. + + Examples + -------- + >>> import wandb + >>> from skorch.callbacks import WandbLogger + >>> wandb.init() + >>> wandb.config.update({"learning rate": 1e-3, "batch size": 32}) # optional + >>> net = NeuralNet(..., callbacks=[WandbLogger()]) + >>> net.fit(X, y) + + Parameters + ---------- + save_model : bool (default=True) + Saves best trained model. + + keys_ignored : str or list of str (default=None) + Key or list of keys that should not be logged to + tensorboard. Note that in addition to the keys provided by the + user, keys such as those starting with 'event_' or ending on + '_best' are ignored by default. + """ + + # Record if watch has been called previously (even in another instance) + _watch_called = False + + def __init__( + self, + save_model=True, + keys_ignored=None, + ): + try: + import wandb + except ImportError: + raise ImportError('Could not import wandb') + if wandb.run is None: + raise ValueError('You must call wandb.init() before WandbCallback()') + + self.save_model = save_model + self.keys_ignored = keys_ignored + self.model_path = Path(wandb.run.dir) / 'best_model.pth' + + def initialize(self): + keys_ignored = self.keys_ignored + if isinstance(keys_ignored, str): + keys_ignored = [keys_ignored] + self.keys_ignored_ = set(keys_ignored or []) + self.keys_ignored_.add('batches') + return self + + def on_train_begin(self, net, **kwargs): + """Log model topology and add a hook for gradients""" + import wandb + if not WandbLogger._watch_called: + WandbLogger._watch_called = True + wandb.watch(net.module_) + + def on_epoch_end(self, net, **kwargs): + """Automatically log values from the last history step.""" + import wandb + hist = net.history[-1] + keys_kept = filter_log_keys(hist, keys_ignored=self.keys_ignored_) + logged_vals = dict((k, hist[k]) for k in keys_kept if k in hist) + wandb.log(logged_vals) + + # save best model + if self.save_model and hist['valid_loss_best']: + with self.model_path.open('wb') as model_file: + net.save_params(f_params=model_file) + class PrintLog(Callback): """Print useful information from the model's history as a table. From 10fdefa2a5ef70deb557c4a87a4f9eb5b0389754 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Mon, 9 Mar 2020 19:06:34 -0500 Subject: [PATCH 02/13] docs(changes.md): add reference to WandbLogger --- CHANGES.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGES.md b/CHANGES.md index be82f98ca..ffff8171f 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `NeptuneLogger` callback for logging experiment metadata to neptune.ai - Add DataFrameTransformer, an sklearn compatible transformer that helps working with pandas DataFrames by transforming the DataFrame into a representation that works well with neural networks (#507) +- Added `WandbLogger` callback for logging to Weights & Biases ### Changed From d17c35aeed0786418e96f9266ea1e3b992004261 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Thu, 12 Mar 2020 20:05:20 -0500 Subject: [PATCH 03/13] feat(wandb): add run instance to callback --- skorch/callbacks/logging.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/skorch/callbacks/logging.py b/skorch/callbacks/logging.py index 2ffcf1045..0935dcecd 100644 --- a/skorch/callbacks/logging.py +++ b/skorch/callbacks/logging.py @@ -217,13 +217,16 @@ class WandbLogger(Callback): -------- >>> import wandb >>> from skorch.callbacks import WandbLogger - >>> wandb.init() + >>> wandb_run = wandb.init() >>> wandb.config.update({"learning rate": 1e-3, "batch size": 32}) # optional - >>> net = NeuralNet(..., callbacks=[WandbLogger()]) + >>> net = NeuralNet(..., callbacks=[WandbLogger(wandb_run)]) >>> net.fit(X, y) Parameters ---------- + wandb_run : wandb.wandb_run.Run + wandb Run used to log data. + save_model : bool (default=True) Saves best trained model. @@ -239,19 +242,14 @@ class WandbLogger(Callback): def __init__( self, + wandb_run, save_model=True, keys_ignored=None, ): - try: - import wandb - except ImportError: - raise ImportError('Could not import wandb') - if wandb.run is None: - raise ValueError('You must call wandb.init() before WandbCallback()') - + self.wandb_run = wandb_run self.save_model = save_model self.keys_ignored = keys_ignored - self.model_path = Path(wandb.run.dir) / 'best_model.pth' + self.model_path = Path(wandb_run.dir) / 'best_model.pth' def initialize(self): keys_ignored = self.keys_ignored @@ -263,18 +261,16 @@ def initialize(self): def on_train_begin(self, net, **kwargs): """Log model topology and add a hook for gradients""" - import wandb if not WandbLogger._watch_called: WandbLogger._watch_called = True - wandb.watch(net.module_) + self.wandb_run.watch(net.module_) def on_epoch_end(self, net, **kwargs): """Automatically log values from the last history step.""" - import wandb hist = net.history[-1] keys_kept = filter_log_keys(hist, keys_ignored=self.keys_ignored_) logged_vals = dict((k, hist[k]) for k in keys_kept if k in hist) - wandb.log(logged_vals) + self.wandb_run.log(logged_vals) # save best model if self.save_model and hist['valid_loss_best']: From 5d8f7807e242d78e68c967648c784c033ae3ab19 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Sun, 15 Mar 2020 13:34:15 -0500 Subject: [PATCH 04/13] test(wandb): added tests --- skorch/tests/callbacks/test_logging.py | 96 ++++++++++++++++++++++++++ skorch/tests/conftest.py | 9 +++ 2 files changed, 105 insertions(+) diff --git a/skorch/tests/callbacks/test_logging.py b/skorch/tests/callbacks/test_logging.py index eb86750ce..85192f99c 100644 --- a/skorch/tests/callbacks/test_logging.py +++ b/skorch/tests/callbacks/test_logging.py @@ -11,6 +11,7 @@ from torch import nn from skorch.tests.conftest import neptune_installed +from skorch.tests.conftest import wandb_installed from skorch.tests.conftest import tensorboard_installed @@ -190,6 +191,101 @@ def test_first_batch_flag( npt.on_batch_end(net) assert npt.first_batch_ is False +@pytest.mark.skipif( + not wandb_installed, reason='wandb is not installed') +class TestWandb: + @pytest.fixture + def net_cls(self): + from skorch import NeuralNetClassifier + return NeuralNetClassifier + + @pytest.fixture + def data(self, classifier_data): + X, y = classifier_data + # accelerate training since we don't care for the loss + X, y = X[:40], y[:40] + return X, y + + @pytest.fixture + def wandb_logger_cls(self): + from skorch.callbacks import WandbLogger + return WandbLogger + + @pytest.fixture + def wandb_run_cls(self): + import wandb + os.environ['WANDB_MODE'] = 'dryrun' # run offline + with wandb.init(anonymous="allow") as run: + return run + + @pytest.fixture + def mock_run(self): + mock = Mock() + mock.log = Mock() + mock.watch = Mock() + mock.dir = '.' + return mock + + @pytest.fixture + def net_fitted( + self, + net_cls, + classifier_module, + data, + wandb_logger_cls, + mock_run, + ): + return net_cls( + classifier_module, + callbacks=[wandb_logger_cls(mock_run)], + max_epochs=3, + ).fit(*data) + + def test_ignore_keys( + self, + net_cls, + classifier_module, + data, + wandb_logger_cls, + mock_run, + ): + # ignore 'dur' and 'valid_loss', 'unknown' doesn't exist but + # this should not cause a problem + wandb_cb = wandb_logger_cls( + mock_run, keys_ignored=['dur', 'valid_loss', 'unknown']) + net_cls( + classifier_module, + callbacks=[wandb_cb], + max_epochs=3, + ).fit(*data) + + # 3 epochs = 3 calls + assert mock_run.log.call_count == 3 + assert mock_run.watch.call_count == 1 + call_args = [args[0][0] for args in mock_run.log.call_args_list] + assert 'valid_loss' not in call_args + + def test_keys_ignored_is_string(self, wandb_logger_cls, mock_run): + wandb_cb = wandb_logger_cls( + mock_run, keys_ignored='a-key').initialize() + expected = {'a-key', 'batches'} + assert wandb_cb.keys_ignored_ == expected + + def test_fit_with_real_experiment( + self, + net_cls, + classifier_module, + data, + wandb_logger_cls, + wandb_run_cls, + ): + net = net_cls( + classifier_module, + callbacks=[wandb_logger_cls(wandb_run_cls)], + max_epochs=5, + ) + net.fit(*data) + class TestPrintLog: @pytest.fixture def print_log_cls(self): diff --git a/skorch/tests/conftest.py b/skorch/tests/conftest.py index e1ce233a3..d18abf504 100644 --- a/skorch/tests/conftest.py +++ b/skorch/tests/conftest.py @@ -142,6 +142,15 @@ def data(): except ImportError: pass +wandb_installed = False +try: + # pylint: disable=unused-import + import wandb + + wandb_installed = True +except ImportError: + pass + pandas_installed = False try: # pylint: disable=unused-import From 590783293301c5b01c7844c0eba13cd515ead5de Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Sun, 15 Mar 2020 13:35:24 -0500 Subject: [PATCH 05/13] feat(requirements-dev.txt): add wandb --- requirements-dev.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements-dev.txt b/requirements-dev.txt index bc43d760b..06a6e94b8 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -14,3 +14,4 @@ pytest-cov sphinx sphinx_rtd_theme tensorboard>=1.14.0 +wandb From 5ef1bde72190aa08a06b77506303b4469f3484e4 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Sun, 15 Mar 2020 17:19:10 -0500 Subject: [PATCH 06/13] docs(wandb): add documentation --- skorch/callbacks/logging.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/skorch/callbacks/logging.py b/skorch/callbacks/logging.py index 0935dcecd..c74adbce4 100644 --- a/skorch/callbacks/logging.py +++ b/skorch/callbacks/logging.py @@ -210,15 +210,21 @@ def on_train_end(self, net, **kwargs): class WandbLogger(Callback): """Logs best model and metrics to `Weights & Biases `_ - "Use this callback to automatically log best trained model and all metrics from - your net's history to Weights & Biases after each epoch. + Use this callback to automatically log best trained model, all metrics from + your net's history, model topology and computer resources to Weights & Biases + after each epoch. + + See `example run `_ Examples -------- + >>> # Install wandb + ... pip install wandb >>> import wandb >>> from skorch.callbacks import WandbLogger >>> wandb_run = wandb.init() - >>> wandb.config.update({"learning rate": 1e-3, "batch size": 32}) # optional + >>> # Log hyper-parameters (optional) + ... wandb.config.update({"learning rate": 1e-3, "batch size": 32}) >>> net = NeuralNet(..., callbacks=[WandbLogger(wandb_run)]) >>> net.fit(X, y) @@ -228,7 +234,7 @@ class WandbLogger(Callback): wandb Run used to log data. save_model : bool (default=True) - Saves best trained model. + Whether to save a checkpoint of the best model. keys_ignored : str or list of str (default=None) Key or list of keys that should not be logged to @@ -266,7 +272,7 @@ def on_train_begin(self, net, **kwargs): self.wandb_run.watch(net.module_) def on_epoch_end(self, net, **kwargs): - """Automatically log values from the last history step.""" + """Log values from the last history step and save best model""" hist = net.history[-1] keys_kept = filter_log_keys(hist, keys_ignored=self.keys_ignored_) logged_vals = dict((k, hist[k]) for k in keys_kept if k in hist) From e54fe594c964c17313a0afd6266389271d54d230 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Sun, 15 Mar 2020 17:34:05 -0500 Subject: [PATCH 07/13] feat(wandb): update doc --- skorch/callbacks/logging.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/skorch/callbacks/logging.py b/skorch/callbacks/logging.py index c74adbce4..79e01b37d 100644 --- a/skorch/callbacks/logging.py +++ b/skorch/callbacks/logging.py @@ -214,17 +214,22 @@ class WandbLogger(Callback): your net's history, model topology and computer resources to Weights & Biases after each epoch. - See `example run `_ + See `example run + `_ Examples -------- >>> # Install wandb ... pip install wandb + >>> import wandb >>> from skorch.callbacks import WandbLogger - >>> wandb_run = wandb.init() + + >>> # Create a wandb Run + ... wandb_run = wandb.init() >>> # Log hyper-parameters (optional) ... wandb.config.update({"learning rate": 1e-3, "batch size": 32}) + >>> net = NeuralNet(..., callbacks=[WandbLogger(wandb_run)]) >>> net.fit(X, y) From 2991565b75a109c2371870af43907d761f5e9e6f Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Sun, 15 Mar 2020 17:42:30 -0500 Subject: [PATCH 08/13] refactor(wandb): address comments --- skorch/callbacks/logging.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/skorch/callbacks/logging.py b/skorch/callbacks/logging.py index 79e01b37d..fec6f0192 100644 --- a/skorch/callbacks/logging.py +++ b/skorch/callbacks/logging.py @@ -260,7 +260,6 @@ def __init__( self.wandb_run = wandb_run self.save_model = save_model self.keys_ignored = keys_ignored - self.model_path = Path(wandb_run.dir) / 'best_model.pth' def initialize(self): keys_ignored = self.keys_ignored @@ -280,12 +279,13 @@ def on_epoch_end(self, net, **kwargs): """Log values from the last history step and save best model""" hist = net.history[-1] keys_kept = filter_log_keys(hist, keys_ignored=self.keys_ignored_) - logged_vals = dict((k, hist[k]) for k in keys_kept if k in hist) + logged_vals = {k: hist[k] for k in keys_kept if k in hist} self.wandb_run.log(logged_vals) # save best model if self.save_model and hist['valid_loss_best']: - with self.model_path.open('wb') as model_file: + model_path = Path(self.wandb_run.dir) / 'best_model.pth' + with model_path.open('wb') as model_file: net.save_params(f_params=model_file) From 91980111a6e4261b3063099088f0b6d52dbe65b5 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Mon, 16 Mar 2020 20:35:52 -0500 Subject: [PATCH 09/13] feat(wandb): remove ref to _watch_called --- skorch/callbacks/logging.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/skorch/callbacks/logging.py b/skorch/callbacks/logging.py index fec6f0192..50d189016 100644 --- a/skorch/callbacks/logging.py +++ b/skorch/callbacks/logging.py @@ -248,9 +248,6 @@ class WandbLogger(Callback): '_best' are ignored by default. """ - # Record if watch has been called previously (even in another instance) - _watch_called = False - def __init__( self, wandb_run, @@ -271,9 +268,7 @@ def initialize(self): def on_train_begin(self, net, **kwargs): """Log model topology and add a hook for gradients""" - if not WandbLogger._watch_called: - WandbLogger._watch_called = True - self.wandb_run.watch(net.module_) + self.wandb_run.watch(net.module_) def on_epoch_end(self, net, **kwargs): """Log values from the last history step and save best model""" From 3ca574141461b8b351c62f125359e94529800bb5 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Thu, 19 Mar 2020 19:20:28 -0500 Subject: [PATCH 10/13] feat(wandb): set minimum version --- requirements-dev.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 06a6e94b8..ee0f8de20 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -14,4 +14,4 @@ pytest-cov sphinx sphinx_rtd_theme tensorboard>=1.14.0 -wandb +wandb>=0.8.30 From 5785ac0ca6569cfe5456224fa44f61142c3999f9 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Sun, 22 Mar 2020 11:35:59 -0500 Subject: [PATCH 11/13] docs(wandb): log anonymous + upload model --- skorch/callbacks/logging.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/skorch/callbacks/logging.py b/skorch/callbacks/logging.py index 50d189016..0563efb06 100644 --- a/skorch/callbacks/logging.py +++ b/skorch/callbacks/logging.py @@ -227,6 +227,9 @@ class WandbLogger(Callback): >>> # Create a wandb Run ... wandb_run = wandb.init() + >>> # Alternative: Create a wandb Run without having a W&B account + ... wandb_run = wandb.init(anonymous="allow) + >>> # Log hyper-parameters (optional) ... wandb.config.update({"learning rate": 1e-3, "batch size": 32}) @@ -239,7 +242,8 @@ class WandbLogger(Callback): wandb Run used to log data. save_model : bool (default=True) - Whether to save a checkpoint of the best model. + Whether to save a checkpoint of the best model and upload it + to your Run on W&B servers. keys_ignored : str or list of str (default=None) Key or list of keys that should not be logged to From 7d3a77572e2ae03edc34a80de0d72767f78ff97c Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Sun, 22 Mar 2020 21:08:15 -0500 Subject: [PATCH 12/13] feat(wandb): simplify logged_vals Co-Authored-By: Thomas J Fan --- skorch/callbacks/logging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skorch/callbacks/logging.py b/skorch/callbacks/logging.py index 0563efb06..c6c3a18c9 100644 --- a/skorch/callbacks/logging.py +++ b/skorch/callbacks/logging.py @@ -278,7 +278,7 @@ def on_epoch_end(self, net, **kwargs): """Log values from the last history step and save best model""" hist = net.history[-1] keys_kept = filter_log_keys(hist, keys_ignored=self.keys_ignored_) - logged_vals = {k: hist[k] for k in keys_kept if k in hist} + logged_vals = {k: hist[k] for k in keys_kept} self.wandb_run.log(logged_vals) # save best model From 878b6195bde611762e41635e2857c0896dd1a39a Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Sun, 22 Mar 2020 21:13:36 -0500 Subject: [PATCH 13/13] feat(wandb): implement comments --- skorch/callbacks/logging.py | 4 +++- skorch/tests/callbacks/test_logging.py | 15 --------------- 2 files changed, 3 insertions(+), 16 deletions(-) diff --git a/skorch/callbacks/logging.py b/skorch/callbacks/logging.py index c6c3a18c9..25adca0f0 100644 --- a/skorch/callbacks/logging.py +++ b/skorch/callbacks/logging.py @@ -214,6 +214,8 @@ class WandbLogger(Callback): your net's history, model topology and computer resources to Weights & Biases after each epoch. + Every file saved in `wandb_run.dir` is automatically logged to W&B servers. + See `example run `_ @@ -231,7 +233,7 @@ class WandbLogger(Callback): ... wandb_run = wandb.init(anonymous="allow) >>> # Log hyper-parameters (optional) - ... wandb.config.update({"learning rate": 1e-3, "batch size": 32}) + ... wandb_run.config.update({"learning rate": 1e-3, "batch size": 32}) >>> net = NeuralNet(..., callbacks=[WandbLogger(wandb_run)]) >>> net.fit(X, y) diff --git a/skorch/tests/callbacks/test_logging.py b/skorch/tests/callbacks/test_logging.py index 85192f99c..c20d70973 100644 --- a/skorch/tests/callbacks/test_logging.py +++ b/skorch/tests/callbacks/test_logging.py @@ -226,21 +226,6 @@ def mock_run(self): mock.dir = '.' return mock - @pytest.fixture - def net_fitted( - self, - net_cls, - classifier_module, - data, - wandb_logger_cls, - mock_run, - ): - return net_cls( - classifier_module, - callbacks=[wandb_logger_cls(mock_run)], - max_epochs=3, - ).fit(*data) - def test_ignore_keys( self, net_cls,