-
Notifications
You must be signed in to change notification settings - Fork 398
Add logger to Weights & Biases #607
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
79ead54
10fdefa
d17c35a
3864919
5d8f780
5907832
5ef1bde
e54fe59
2991565
9198011
3ca5741
5785ac0
7d3a775
878b619
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,5 +33,6 @@ | |
'TrainEndCheckpoint', | ||
'TensorBoard', | ||
'Unfreezer', | ||
'WandbLogger', | ||
'WarmRestartLR', | ||
] |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -5,6 +5,7 @@ | |||||
from contextlib import suppress | ||||||
from numbers import Number | ||||||
from itertools import cycle | ||||||
from pathlib import Path | ||||||
|
||||||
import numpy as np | ||||||
import tqdm | ||||||
|
@@ -14,7 +15,8 @@ | |||||
from skorch.dataset import get_len | ||||||
from skorch.callbacks import Callback | ||||||
|
||||||
__all__ = ['EpochTimer', 'NeptuneLogger', 'PrintLog', 'ProgressBar', 'TensorBoard'] | ||||||
__all__ = ['EpochTimer', 'NeptuneLogger', 'WandbLogger', 'PrintLog', 'ProgressBar', | ||||||
'TensorBoard'] | ||||||
|
||||||
|
||||||
def filter_log_keys(keys, keys_ignored=None): | ||||||
|
@@ -205,6 +207,76 @@ def on_train_end(self, net, **kwargs): | |||||
if self.close_after_train: | ||||||
self.experiment.stop() | ||||||
|
||||||
class WandbLogger(Callback): | ||||||
"""Logs best model and metrics to `Weights & Biases <https://docs.wandb.com/>`_ | ||||||
|
||||||
"Use this callback to automatically log best trained model and all metrics from | ||||||
your net's history to Weights & Biases after each epoch. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you should use the docstring to help skorch users who are unfamiliar with W&B to get started quickly. E.g., you could specify what package they must install for this to work, i.e. a pip (or conda) instruction. You should also indicate what kind of setup they need to make beforehand (say, starting a local server). For a nice example, look at the docstring for NeptuneLogger. |
||||||
|
||||||
Examples | ||||||
-------- | ||||||
>>> import wandb | ||||||
>>> from skorch.callbacks import WandbLogger | ||||||
>>> wandb_run = wandb.init() | ||||||
>>> wandb.config.update({"learning rate": 1e-3, "batch size": 32}) # optional | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you indicate what this config update does? |
||||||
>>> net = NeuralNet(..., callbacks=[WandbLogger(wandb_run)]) | ||||||
>>> net.fit(X, y) | ||||||
|
||||||
Parameters | ||||||
---------- | ||||||
wandb_run : wandb.wandb_run.Run | ||||||
wandb Run used to log data. | ||||||
|
||||||
save_model : bool (default=True) | ||||||
Saves best trained model. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
keys_ignored : str or list of str (default=None) | ||||||
Key or list of keys that should not be logged to | ||||||
tensorboard. Note that in addition to the keys provided by the | ||||||
user, keys such as those starting with 'event_' or ending on | ||||||
'_best' are ignored by default. | ||||||
""" | ||||||
|
||||||
# Record if watch has been called previously (even in another instance) | ||||||
_watch_called = False | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this really used anywhere? If it is, please move this inside |
||||||
|
||||||
def __init__( | ||||||
self, | ||||||
wandb_run, | ||||||
save_model=True, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We already provide a checkpoint callback, I think this functionality is redundant. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is to log the trained model to W&B. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting. How does that work? In this code, I don't see any interaction with W&B: # save best model
if self.save_model and hist['valid_loss_best']:
model_path = Path(self.wandb_run.dir) / 'best_model.pth'
with model_path.open('wb') as model_file:
net.save_params(f_params=model_file) Is this some code working in the background or is it simply the fact that the model parameters are stored in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All files stored in wandb_run.dir are automatically saved. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please leave a comment that states that the files in |
||||||
keys_ignored=None, | ||||||
): | ||||||
self.wandb_run = wandb_run | ||||||
self.save_model = save_model | ||||||
self.keys_ignored = keys_ignored | ||||||
self.model_path = Path(wandb_run.dir) / 'best_model.pth' | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please don't set any arguments in |
||||||
|
||||||
def initialize(self): | ||||||
keys_ignored = self.keys_ignored | ||||||
if isinstance(keys_ignored, str): | ||||||
keys_ignored = [keys_ignored] | ||||||
self.keys_ignored_ = set(keys_ignored or []) | ||||||
self.keys_ignored_.add('batches') | ||||||
return self | ||||||
|
||||||
def on_train_begin(self, net, **kwargs): | ||||||
"""Log model topology and add a hook for gradients""" | ||||||
if not WandbLogger._watch_called: | ||||||
WandbLogger._watch_called = True | ||||||
self.wandb_run.watch(net.module_) | ||||||
|
||||||
def on_epoch_end(self, net, **kwargs): | ||||||
"""Automatically log values from the last history step.""" | ||||||
hist = net.history[-1] | ||||||
keys_kept = filter_log_keys(hist, keys_ignored=self.keys_ignored_) | ||||||
logged_vals = dict((k, hist[k]) for k in keys_kept if k in hist) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
self.wandb_run.log(logged_vals) | ||||||
|
||||||
# save best model | ||||||
if self.save_model and hist['valid_loss_best']: | ||||||
with self.model_path.open('wb') as model_file: | ||||||
net.save_params(f_params=model_file) | ||||||
|
||||||
|
||||||
class PrintLog(Callback): | ||||||
"""Print useful information from the model's history as a table. | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.