@@ -726,43 +726,54 @@ def fit_loop(self, X, y=None, epochs=None, **fit_params):
726
726
'dataset_valid' : dataset_valid ,
727
727
}
728
728
729
- y_train_is_ph = uses_placeholder_y (dataset_train )
730
- y_valid_is_ph = uses_placeholder_y (dataset_valid )
731
-
732
729
for _ in range (epochs ):
733
730
self .notify ('on_epoch_begin' , ** on_epoch_kwargs )
734
731
735
- train_batch_count = 0
736
- for data in self .get_iterator (dataset_train , training = True ):
737
- Xi , yi = unpack_data (data )
738
- yi_res = yi if not y_train_is_ph else None
739
- self .notify ('on_batch_begin' , X = Xi , y = yi_res , training = True )
740
- step = self .train_step (Xi , yi , ** fit_params )
741
- train_batch_count += 1
742
- self .history .record_batch ('train_loss' , step ['loss' ].item ())
743
- self .history .record_batch ('train_batch_size' , get_len (Xi ))
744
- self .notify ('on_batch_end' , X = Xi , y = yi_res , training = True , ** step )
745
- self .history .record ("train_batch_count" , train_batch_count )
746
-
747
- if dataset_valid is None :
748
- self .notify ('on_epoch_end' , ** on_epoch_kwargs )
749
- continue
732
+ self .run_single_epoch (dataset_train , training = True , prefix = "train" ,
733
+ step_fn = self .train_step , ** fit_params )
750
734
751
- valid_batch_count = 0
752
- for data in self .get_iterator (dataset_valid , training = False ):
753
- Xi , yi = unpack_data (data )
754
- yi_res = yi if not y_valid_is_ph else None
755
- self .notify ('on_batch_begin' , X = Xi , y = yi_res , training = False )
756
- step = self .validation_step (Xi , yi , ** fit_params )
757
- valid_batch_count += 1
758
- self .history .record_batch ('valid_loss' , step ['loss' ].item ())
759
- self .history .record_batch ('valid_batch_size' , get_len (Xi ))
760
- self .notify ('on_batch_end' , X = Xi , y = yi_res , training = False , ** step )
761
- self .history .record ("valid_batch_count" , valid_batch_count )
762
-
763
- self .notify ('on_epoch_end' , ** on_epoch_kwargs )
735
+ if dataset_valid is not None :
736
+ self .run_single_epoch (dataset_valid , training = False , prefix = "valid" ,
737
+ step_fn = self .validation_step , ** fit_params )
738
+
739
+ self .notify ("on_epoch_end" , ** on_epoch_kwargs )
764
740
return self
765
741
742
+ def run_single_epoch (self , dataset , training , prefix , step_fn , ** fit_params ):
743
+ """Compute a single epoch of train or validation.
744
+
745
+ Parameters
746
+ ----------
747
+ dataset : torch Dataset
748
+ The initialized dataset to loop over.
749
+
750
+ training : bool
751
+ Whether to set the module to train mode or not.
752
+
753
+ prefix : str
754
+ Prefix to use when saving to the history.
755
+
756
+ step_fn : callable
757
+ Function to call for each batch.
758
+
759
+ **fit_params : dict
760
+ Additional parameters passed to the ``step_fn``.
761
+ """
762
+ is_placeholder_y = uses_placeholder_y (dataset )
763
+
764
+ batch_count = 0
765
+ for data in self .get_iterator (dataset , training = training ):
766
+ Xi , yi = unpack_data (data )
767
+ yi_res = yi if not is_placeholder_y else None
768
+ self .notify ("on_batch_begin" , X = Xi , y = yi_res , training = training )
769
+ step = step_fn (Xi , yi , ** fit_params )
770
+ self .history .record_batch (prefix + "_loss" , step ["loss" ].item ())
771
+ self .history .record_batch (prefix + "_batch_size" , get_len (Xi ))
772
+ self .notify ("on_batch_end" , X = Xi , y = yi_res , training = training , ** step )
773
+ batch_count += 1
774
+
775
+ self .history .record (prefix + "_batch_count" , batch_count )
776
+
766
777
# pylint: disable=unused-argument
767
778
def partial_fit (self , X , y = None , classes = None , ** fit_params ):
768
779
"""Fit the module.
0 commit comments