Closed
Description
I'm using Python 3.11.3, skorch==1.0.0
I find the error disappears when dropping the compile
argument. It doesn't seem to error for similar examples with NeuralNetClassifier
.
Minimal example
import numpy as np
import torch.nn.functional as F
from skorch import NeuralNetBinaryClassifier
from torch import nn
X = np.random.normal(size=(200, 100)).astype("float32")
y = np.zeros(200).astype("float32")
y[:100] = 1
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.linear = nn.Linear(100, 50)
self.output = nn.Linear(50, 1)
def forward(self, input):
out = input
out = self.linear(out)
out = F.relu(out)
out = self.output(out)
return out.squeeze(-1)
net = NeuralNetBinaryClassifier(MyNet, max_epochs=1, compile=True)
net.fit(X, y)
y_proba = net.predict_proba(X)
Raises
Traceback (most recent call last):
File "/home/jupyter/model_training_tracking/train_sklearn_single_nn.py", line 98, in <module>
classifier.fit(
File "/opt/conda/envs/python311/lib/python3.11/site-packages/skorch/classifier.py", line 348, in fit
return super().fit(X, y, **fit_params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/python311/lib/python3.11/site-packages/skorch/net.py", line 1319, in fit
self.partial_fit(X, y, **fit_params)
File "/opt/conda/envs/python311/lib/python3.11/site-packages/skorch/net.py", line 1278, in partial_fit
self.fit_loop(X, y, **fit_params)
File "/opt/conda/envs/python311/lib/python3.11/site-packages/skorch/net.py", line 1196, in fit_loop
self.notify("on_epoch_end", **on_epoch_kwargs)
File "/opt/conda/envs/python311/lib/python3.11/site-packages/skorch/net.py", line 386, in notify
getattr(cb, method_name)(self, **cb_kwargs)
File "/opt/conda/envs/python311/lib/python3.11/site-packages/skorch/callbacks/scoring.py", line 489, in on_epoch_end
current_score = self._scoring(cached_net, X_test, y_test)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/python311/lib/python3.11/site-packages/skorch/callbacks/scoring.py", line 181, in _scoring
return scorer(net, X_test, y_test)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/python311/lib/python3.11/site-packages/sklearn/metrics/_scorer.py", line 253, in __call__
return self._score(partial(_cached_call, None), estimator, X, y_true, **_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/python311/lib/python3.11/site-packages/sklearn/metrics/_scorer.py", line 345, in _score
y_pred = method_caller(
^^^^^^^^^^^^^^
File "/opt/conda/envs/python311/lib/python3.11/site-packages/sklearn/metrics/_scorer.py", line 87, in _cached_call
result, _ = _get_response_values(
^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/python311/lib/python3.11/site-packages/sklearn/utils/_response.py", line 210, in _get_response_values
y_pred = prediction_method(X)
^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/python311/lib/python3.11/site-packages/skorch/classifier.py", line 381, in predict
return (y_proba[:, 1] > self.threshold).astype('uint8')
Metadata
Metadata
Assignees
Labels
No labels