Skip to content

FIX NeuralNetBinaryClassifier with torch.compile #1058

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
### Fixed

- Fix an issue with using `NeuralNetBinaryClassifier` with `torch.compile` (#1058)

## [1.0.0] - 2024-05-27

The 1.0.0 release of skorch is here. We think that skorch is at a very stable point, which is why a 1.0.0 release is appropriate. There are no plans to add any breaking changes or major revisions in the future. Instead, our focus now is to keep skorch up-to-date with the latest versions of PyTorch and scikit-learn, and to fix any bugs that may arise.
Expand Down
33 changes: 33 additions & 0 deletions skorch/tests/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -4159,3 +4159,36 @@ def test_fit_and_predict_with_compile(self, net_cls, module_cls, data):
# compiled, we rely here on torch keeping this public attribute
assert hasattr(net.module_, 'dynamo_ctx')
assert hasattr(net.criterion_, 'dynamo_ctx')

def test_binary_classifier_with_compile(self, data):
# issue 1057 the problem was that compile would wrap the optimizer,
# resulting in _infer_predict_nonlinearity to return the wrong result
# because of a failing isinstance check
from skorch import NeuralNetBinaryClassifier

X, y = data[0], data[1].astype(np.float32)

class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.linear = nn.Linear(20, 10)
self.output = nn.Linear(10, 1)

def forward(self, input):
out = self.linear(input)
out = nn.functional.relu(out)
out = self.output(out)
return out.squeeze(-1)

net = NeuralNetBinaryClassifier(
MyNet,
max_epochs=3,
compile=True,
)
# check that no error is raised
net.fit(X, y)

y_proba = net.predict_proba(X)
y_pred = net.predict(X)
assert y_proba.shape == (X.shape[0], 2)
assert y_pred.shape == (X.shape[0],)
2 changes: 2 additions & 0 deletions skorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,8 @@ def _infer_predict_nonlinearity(net):
return _identity

criterion = getattr(net, net._criteria[0] + '_')
# unwrap optimizer in case of torch.compile being used
criterion = getattr(criterion, '_orig_mod', criterion)

if isinstance(criterion, CrossEntropyLoss):
return partial(torch.softmax, dim=-1)
Expand Down
Loading