-
Notifications
You must be signed in to change notification settings - Fork 397
[MRG] Adds Batch Count in History #445
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Adds Batch Count in History #445
Conversation
Why are we not leveraging Edit: for i, data in enumerate(self.get_dataloader(...)):
# ...
self.history.record_batch('train_batch_count', idx + 1) but this would not solve the counting problem in the callbacks since you would still need to count each batch manually or you would suffer the O(n) complexity introduced by repeatedly evaluating def new_batch(self, is_training):
# ...
key = 'train_batch_count' if is_training else 'valid_batch_count'
self.record(key, self[-1, key] + 1) This would give use live updates and there would be no noisy addition in the fit loop Ideally you would be able to easily, i.e. without much overhead, retrieve
Currently we are only able to obtain (3): |
I haven't look at the whole PR but wanted to comment on some parts:
We do have
The proposed change does not look all that elegant and requires to change the signature of Lines 316 to 318 in 72a3c87
(and we should really fix the line break there!) |
Oh well, I completely overlooked that we track the batch size for each batch. So the total sample count could be easily computed via
If we do it key = 'train_batch_count' if training else 'valid_batch_count'
self.history.record(key, self[-1, key] + 1) I think I prefer the fit-loop solution over this one.
+1 |
Maybe |
If you mean that the LR scheduler should write the batch count to the history I would object since this would simply move the 'bad coupling' argument to the LR scheduler.
The best solution for me would be to have everything that is necessary to resume training at the same point for the In my opinion we should just take the solution as is, think about the batch count being updated during looping instead of at the end and go on. |
batch_size = 100 | ||
max_epochs = 2 | ||
max_epochs = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't max_epochs == 2
be a better test for this since it also covers the case of multiple epochs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original intention was to use fit
and partial_fit
to simulate two epochs.
This PR has been updated to used max_epochs=2
and use one fit. The test_lr_callback_batch_steps_correctly_fallback
tests for partial fit and that the fallback works.
lr_policy = LRScheduler(policy, **kwargs) | ||
net = NeuralNetClassifier(classifier_module(), max_epochs=max_epochs, | ||
batch_size=batch_size, callbacks=[lr_policy]) | ||
net.fit(X, y) | ||
expected = (num_examples // batch_size) * max_epochs - 1 | ||
net.partial_fit(X, y) | ||
expected = int(1.6 * num_examples / batch_size) * max_epochs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you clarify this expression? I find it non-obvious (where does the 1.6 come from?) and having max_epochs
here with max_epochs=1
doesn't make sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR was updated to break the steps up when calculating the expected batch_idx.
kwargs, | ||
): | ||
batch_size = 100 | ||
max_epochs = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same argument as above.
@thomasjpfan Are you still on this? |
@BenjaminBossan Thanks for the ping! I have updated this PR to address the comments. |
skorch/callbacks/lr_scheduler.py
Outdated
try: | ||
self.batch_idx_ = sum(net.history[:, 'train_batch_count']) | ||
except KeyError: | ||
self.batch_idx_ = sum(len(b) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very minor nitpick: I don't like the indentation here, this should fit in one line (88 chars is okay).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have gotten too use to sklearn's 80 characters limit. :)
This PR has been updated.
Apart from the minor nitpick concerning the line break, this LGTM. @ottonemo what do you think? |
net.fit(X, y) | ||
|
||
# Removes batch count information in the last two epochs | ||
for i in range(2): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't 2 be max_epochs
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup! PR updated to reflect this.
LGTM @BenjaminBossan can you merge this? |
Fixes #418