Skip to content

[MRG] Adds Batch Count in History #445

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 1, 2019

Conversation

thomasjpfan
Copy link
Member

Fixes #418

@thomasjpfan thomasjpfan changed the title [MRG] Adds Batch count in history [MRG] Adds Batch Count in History Mar 17, 2019
@ottonemo
Copy link
Member

ottonemo commented Mar 19, 2019

Why are we not leveraging new_batch in History for this? It is basically the entry point for registering a new batch, we could just count from there.

Edit:
To elaborate a bit more on this, I'm generally OK with extending the fit loop but in this case I'm not sure about the way it is done: I feel that there is not much return for what is added. For example, now you still need to count in your callback since there's no live update. You could work around this by counting in the loop, e.g.:

for i, data in enumerate(self.get_dataloader(...)):
    # ...
    self.history.record_batch('train_batch_count', idx + 1)

but this would not solve the counting problem in the callbacks since you would still need to count each batch manually or you would suffer the O(n) complexity introduced by repeatedly evaluating train_batch_count = sum(history[:, 'train_batch_count']). This would also leave the validation batch count undefined if there's no validation set (might be OK but inconvenient at times).
Another solution would be to extend history.new_batch:

def new_batch(self, is_training):
    # ...
    key = 'train_batch_count' if is_training else 'valid_batch_count'
    self.record(key, self[-1, key] + 1)

This would give use live updates and there would be no noisy addition in the fit loop
but there would still be the issue of summing everything to get the count.

Ideally you would be able to easily, i.e. without much overhead, retrieve

  1. total amount of seen samples
  2. total amount of seen batches
  3. total amount of seen epochs

Currently we are only able to obtain (3): len(history).
This PR would introduce a solution for (2), namely, sum(history[:, 'train_batch_count']).
The total number of samples is another beast. Since we cannot guarantee that each
batch is equal in size we would need to track this value as well. If we are going to touch
this maybe we should think about adding a sample counter as well?

@BenjaminBossan
Copy link
Collaborator

I haven't look at the whole PR but wanted to comment on some parts:

The total number of samples is another beast

We do have train_batch_size and valid_batch_size, which can be used to sum the total number of samples seen.

Why are we not leveraging new_batch in History for this? It is basically the entry point for registering a new batch, we could just count from there.

The proposed change does not look all that elegant and requires to change the signature of new_batch. I understand the desire to keep the fit loop lean, but I'm not convinced this is the best approach. Perhaps we can somehow use net.on_batch_end, since that already receives training=False/True:

skorch/skorch/net.py

Lines 316 to 318 in 72a3c87

def on_batch_end(self, net,
Xi=None, yi=None, training=False, **kwargs):
pass

(and we should really fix the line break there!)

@ottonemo
Copy link
Member

The total number of samples is another beast

We do have train_batch_size and valid_batch_size, which can be used to sum the total number of samples seen.

Oh well, I completely overlooked that we track the batch size for each batch. So the total sample count could be easily computed via training_samples = sum(history[:, 'batches', :, 'train_batch_size']). Never mind me, then :)

Why are we not leveraging new_batch in History for this? It is basically the entry point for registering a new batch, we could just count from there.

The proposed change does not look all that elegant and requires to change the signature of new_batch. I understand the desire to keep the fit loop lean, but I'm not convinced this is the best approach. Perhaps we can somehow use net.on_batch_end, since that already receives training=False/True:

If we do it on_batch_end, the code would look like something like this, I presume:

key = 'train_batch_count' if training else 'valid_batch_count'
self.history.record(key, self[-1, key] + 1)

I think I prefer the fit-loop solution over this one.

(and we should really fix the line break there!)

+1

@thomasjpfan
Copy link
Member Author

Maybe LRScheduler should be recording the batch counts, since it is the only one using it.

@ottonemo
Copy link
Member

If you mean that the LR scheduler should write the batch count to the history I would object since this would simply move the 'bad coupling' argument to the LR scheduler.

NeuralNet is a good place to do this since the batch count is a very general thing and NeuralNet is a very general class. A specific callback for this would work as well but the case is too small to justify this, I think. Integrating it to LR scheduler only makes sense if it doesn't "leak" to the outside, e.g. by writing to the history (except for what it is supposed to do, namely learning rate scheduling of course) or else we risk having callbacks depend on LRScheduler just for the batch count.

The best solution for me would be to have everything that is necessary to resume training at the same point for the LRScheduler callback to be stored locally in the callback. This would work for pickled networks but not for freshly initialized networks that continue training from loading a non-pickle checkpoint.

In my opinion we should just take the solution as is, think about the batch count being updated during looping instead of at the end and go on.

batch_size = 100
max_epochs = 2
max_epochs = 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't max_epochs == 2 be a better test for this since it also covers the case of multiple epochs?

Copy link
Member Author

@thomasjpfan thomasjpfan Apr 9, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original intention was to use fit and partial_fit to simulate two epochs.

This PR has been updated to used max_epochs=2 and use one fit. The test_lr_callback_batch_steps_correctly_fallback tests for partial fit and that the fallback works.

lr_policy = LRScheduler(policy, **kwargs)
net = NeuralNetClassifier(classifier_module(), max_epochs=max_epochs,
batch_size=batch_size, callbacks=[lr_policy])
net.fit(X, y)
expected = (num_examples // batch_size) * max_epochs - 1
net.partial_fit(X, y)
expected = int(1.6 * num_examples / batch_size) * max_epochs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify this expression? I find it non-obvious (where does the 1.6 come from?) and having max_epochs here with max_epochs=1 doesn't make sense.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR was updated to break the steps up when calculating the expected batch_idx.

kwargs,
):
batch_size = 100
max_epochs = 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same argument as above.

@BenjaminBossan
Copy link
Collaborator

@thomasjpfan Are you still on this?

@thomasjpfan
Copy link
Member Author

@BenjaminBossan Thanks for the ping! I have updated this PR to address the comments.

try:
self.batch_idx_ = sum(net.history[:, 'train_batch_count'])
except KeyError:
self.batch_idx_ = sum(len(b)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very minor nitpick: I don't like the indentation here, this should fit in one line (88 chars is okay).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have gotten too use to sklearn's 80 characters limit. :)

This PR has been updated.

@BenjaminBossan
Copy link
Collaborator

Apart from the minor nitpick concerning the line break, this LGTM. @ottonemo what do you think?

net.fit(X, y)

# Removes batch count information in the last two epochs
for i in range(2):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't 2 be max_epochs here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup! PR updated to reflect this.

@githubnemo
Copy link
Collaborator

LGTM @BenjaminBossan can you merge this?

@BenjaminBossan BenjaminBossan merged commit e1dc86d into skorch-dev:master May 1, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

LRScheduler's batch_idx_ includes validation batches
4 participants