Skip to content

Fix squeeze inside NeuralNetBinaryClassifier.infer #558

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 17, 2019

Conversation

qtux
Copy link
Contributor

@qtux qtux commented Nov 12, 2019

Previously y_infer.squeeze() removed all dimensions with single entries,
which does not correspond to the docstring description. Instead, remove
only the last dimension, if it contains a single entry.

@qtux qtux force-pushed the master branch 2 times, most recently from 05ea797 to 381229c Compare November 12, 2019 16:32
@BenjaminBossan
Copy link
Collaborator

Thanks for the PR. I think this change should be okay.

Could you please add an entry to the CHANGES.md?

@qtux
Copy link
Contributor Author

qtux commented Nov 14, 2019

Sure, I will rebase the commit.

@qtux
Copy link
Contributor Author

qtux commented Nov 14, 2019

I added a test to show the bug and replaced the squeeze by a reshape (was not correct previously).

@@ -235,6 +235,16 @@ def test_net_learns(self, net_cls, module_cls, data):
valid_acc = net.history[-1, 'valid_acc']
assert valid_acc > 0.65

def test_batch_size_one(self, net_cls, module_cls, data):
X, y = data
print("test", y.shape)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print("test", y.shape)

Previously y_infer.squeeze() removed all dimensions with single entries,
which does not correspond to the docstring description. Instead, remove
only the last dimension, if it contains a single entry (batch_size = 1).

Add a test to show the bug.
@qtux
Copy link
Contributor Author

qtux commented Nov 15, 2019 via email

@BenjaminBossan
Copy link
Collaborator

Thank you very much for the bugfix.

@BenjaminBossan BenjaminBossan merged commit f446ffd into skorch-dev:master Nov 17, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants