Skip to content

Commit 2b3ede6

Browse files
authored
Merge pull request #5 from skorch-dev/master
CLEAN : rm duplicate code in fit_loop. (skorch-dev#564)
2 parents 73da8fb + d47357a commit 2b3ede6

File tree

1 file changed

+42
-31
lines changed

1 file changed

+42
-31
lines changed

skorch/net.py

+42-31
Original file line numberDiff line numberDiff line change
@@ -726,43 +726,54 @@ def fit_loop(self, X, y=None, epochs=None, **fit_params):
726726
'dataset_valid': dataset_valid,
727727
}
728728

729-
y_train_is_ph = uses_placeholder_y(dataset_train)
730-
y_valid_is_ph = uses_placeholder_y(dataset_valid)
731-
732729
for _ in range(epochs):
733730
self.notify('on_epoch_begin', **on_epoch_kwargs)
734731

735-
train_batch_count = 0
736-
for data in self.get_iterator(dataset_train, training=True):
737-
Xi, yi = unpack_data(data)
738-
yi_res = yi if not y_train_is_ph else None
739-
self.notify('on_batch_begin', X=Xi, y=yi_res, training=True)
740-
step = self.train_step(Xi, yi, **fit_params)
741-
train_batch_count += 1
742-
self.history.record_batch('train_loss', step['loss'].item())
743-
self.history.record_batch('train_batch_size', get_len(Xi))
744-
self.notify('on_batch_end', X=Xi, y=yi_res, training=True, **step)
745-
self.history.record("train_batch_count", train_batch_count)
746-
747-
if dataset_valid is None:
748-
self.notify('on_epoch_end', **on_epoch_kwargs)
749-
continue
732+
self.run_single_epoch(dataset_train, training=True, prefix="train",
733+
step_fn=self.train_step, **fit_params)
750734

751-
valid_batch_count = 0
752-
for data in self.get_iterator(dataset_valid, training=False):
753-
Xi, yi = unpack_data(data)
754-
yi_res = yi if not y_valid_is_ph else None
755-
self.notify('on_batch_begin', X=Xi, y=yi_res, training=False)
756-
step = self.validation_step(Xi, yi, **fit_params)
757-
valid_batch_count += 1
758-
self.history.record_batch('valid_loss', step['loss'].item())
759-
self.history.record_batch('valid_batch_size', get_len(Xi))
760-
self.notify('on_batch_end', X=Xi, y=yi_res, training=False, **step)
761-
self.history.record("valid_batch_count", valid_batch_count)
762-
763-
self.notify('on_epoch_end', **on_epoch_kwargs)
735+
if dataset_valid is not None:
736+
self.run_single_epoch(dataset_valid, training=False, prefix="valid",
737+
step_fn=self.validation_step, **fit_params)
738+
739+
self.notify("on_epoch_end", **on_epoch_kwargs)
764740
return self
765741

742+
def run_single_epoch(self, dataset, training, prefix, step_fn, **fit_params):
743+
"""Compute a single epoch of train or validation.
744+
745+
Parameters
746+
----------
747+
dataset : torch Dataset
748+
The initialized dataset to loop over.
749+
750+
training : bool
751+
Whether to set the module to train mode or not.
752+
753+
prefix : str
754+
Prefix to use when saving to the history.
755+
756+
step_fn : callable
757+
Function to call for each batch.
758+
759+
**fit_params : dict
760+
Additional parameters passed to the ``step_fn``.
761+
"""
762+
is_placeholder_y = uses_placeholder_y(dataset)
763+
764+
batch_count = 0
765+
for data in self.get_iterator(dataset, training=training):
766+
Xi, yi = unpack_data(data)
767+
yi_res = yi if not is_placeholder_y else None
768+
self.notify("on_batch_begin", X=Xi, y=yi_res, training=training)
769+
step = step_fn(Xi, yi, **fit_params)
770+
self.history.record_batch(prefix + "_loss", step["loss"].item())
771+
self.history.record_batch(prefix + "_batch_size", get_len(Xi))
772+
self.notify("on_batch_end", X=Xi, y=yi_res, training=training, **step)
773+
batch_count += 1
774+
775+
self.history.record(prefix + "_batch_count", batch_count)
776+
766777
# pylint: disable=unused-argument
767778
def partial_fit(self, X, y=None, classes=None, **fit_params):
768779
"""Fit the module.

0 commit comments

Comments
 (0)