Skip to content

CLEAN : rm duplicate code in fit_loop. #564

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 10, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 42 additions & 31 deletions skorch/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,43 +725,54 @@ def fit_loop(self, X, y=None, epochs=None, **fit_params):
'dataset_valid': dataset_valid,
}

y_train_is_ph = uses_placeholder_y(dataset_train)
y_valid_is_ph = uses_placeholder_y(dataset_valid)

for _ in range(epochs):
self.notify('on_epoch_begin', **on_epoch_kwargs)

train_batch_count = 0
for data in self.get_iterator(dataset_train, training=True):
Xi, yi = unpack_data(data)
yi_res = yi if not y_train_is_ph else None
self.notify('on_batch_begin', X=Xi, y=yi_res, training=True)
step = self.train_step(Xi, yi, **fit_params)
train_batch_count += 1
self.history.record_batch('train_loss', step['loss'].item())
self.history.record_batch('train_batch_size', get_len(Xi))
self.notify('on_batch_end', X=Xi, y=yi_res, training=True, **step)
self.history.record("train_batch_count", train_batch_count)

if dataset_valid is None:
self.notify('on_epoch_end', **on_epoch_kwargs)
continue
self.run_single_epoch(dataset_train, training=True, prefix="train",
step_fn=self.train_step, **fit_params)

valid_batch_count = 0
for data in self.get_iterator(dataset_valid, training=False):
Xi, yi = unpack_data(data)
yi_res = yi if not y_valid_is_ph else None
self.notify('on_batch_begin', X=Xi, y=yi_res, training=False)
step = self.validation_step(Xi, yi, **fit_params)
valid_batch_count += 1
self.history.record_batch('valid_loss', step['loss'].item())
self.history.record_batch('valid_batch_size', get_len(Xi))
self.notify('on_batch_end', X=Xi, y=yi_res, training=False, **step)
self.history.record("valid_batch_count", valid_batch_count)

self.notify('on_epoch_end', **on_epoch_kwargs)
if dataset_valid is not None:
self.run_single_epoch(dataset_valid, training=False, prefix="valid",
step_fn=self.validation_step, **fit_params)

self.notify("on_epoch_end", **on_epoch_kwargs)
return self

def run_single_epoch(self, dataset, training, prefix, step_fn, **fit_params):
"""Compute a single epoch of train or validation.

Parameters
----------
dataset : torch Dataset
The initialized dataset to loop over.

training : bool
Whether to set the module to train mode or not.

prefix : str
Prefix to use when saving to the history.

step_fn : callable
Function to call for each batch.

**fit_params : dict
Additional parameters passed to the ``step_fn``.
"""
is_placeholder_y = uses_placeholder_y(dataset)

batch_count = 0
for data in self.get_iterator(dataset, training=training):
Xi, yi = unpack_data(data)
yi_res = yi if not is_placeholder_y else None
self.notify("on_batch_begin", X=Xi, y=yi_res, training=training)
step = step_fn(Xi, yi, **fit_params)
self.history.record_batch(prefix + "_loss", step["loss"].item())
self.history.record_batch(prefix + "_batch_size", get_len(Xi))
self.notify("on_batch_end", X=Xi, y=yi_res, training=training, **step)
batch_count += 1

self.history.record(prefix + "_batch_count", batch_count)

# pylint: disable=unused-argument
def partial_fit(self, X, y=None, classes=None, **fit_params):
"""Fit the module.
Expand Down