Skip to content

Error predicting when using compile=True with NeuralNetBinaryClassifier #1057

Closed
@foster999

Description

@foster999

I'm using Python 3.11.3, skorch==1.0.0

I find the error disappears when dropping the compile argument. It doesn't seem to error for similar examples with NeuralNetClassifier.

Minimal example

import numpy as np
import torch.nn.functional as F
from skorch import NeuralNetBinaryClassifier
from torch import nn

X = np.random.normal(size=(200, 100)).astype("float32")
y = np.zeros(200).astype("float32")
y[:100] = 1


class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.linear = nn.Linear(100, 50)
        self.output = nn.Linear(50, 1)

    def forward(self, input):
        out = input
        out = self.linear(out)
        out = F.relu(out)
        out = self.output(out)
        return out.squeeze(-1)


net = NeuralNetBinaryClassifier(MyNet, max_epochs=1, compile=True)

net.fit(X, y)
y_proba = net.predict_proba(X)

Raises

Traceback (most recent call last):
  File "/home/jupyter/model_training_tracking/train_sklearn_single_nn.py", line 98, in <module>
    classifier.fit(
  File "/opt/conda/envs/python311/lib/python3.11/site-packages/skorch/classifier.py", line 348, in fit
    return super().fit(X, y, **fit_params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/python311/lib/python3.11/site-packages/skorch/net.py", line 1319, in fit
    self.partial_fit(X, y, **fit_params)
  File "/opt/conda/envs/python311/lib/python3.11/site-packages/skorch/net.py", line 1278, in partial_fit
    self.fit_loop(X, y, **fit_params)
  File "/opt/conda/envs/python311/lib/python3.11/site-packages/skorch/net.py", line 1196, in fit_loop
    self.notify("on_epoch_end", **on_epoch_kwargs)
  File "/opt/conda/envs/python311/lib/python3.11/site-packages/skorch/net.py", line 386, in notify
    getattr(cb, method_name)(self, **cb_kwargs)
  File "/opt/conda/envs/python311/lib/python3.11/site-packages/skorch/callbacks/scoring.py", line 489, in on_epoch_end
    current_score = self._scoring(cached_net, X_test, y_test)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/python311/lib/python3.11/site-packages/skorch/callbacks/scoring.py", line 181, in _scoring
    return scorer(net, X_test, y_test)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/python311/lib/python3.11/site-packages/sklearn/metrics/_scorer.py", line 253, in __call__
    return self._score(partial(_cached_call, None), estimator, X, y_true, **_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/python311/lib/python3.11/site-packages/sklearn/metrics/_scorer.py", line 345, in _score
    y_pred = method_caller(
             ^^^^^^^^^^^^^^
  File "/opt/conda/envs/python311/lib/python3.11/site-packages/sklearn/metrics/_scorer.py", line 87, in _cached_call
    result, _ = _get_response_values(
                ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/python311/lib/python3.11/site-packages/sklearn/utils/_response.py", line 210, in _get_response_values
    y_pred = prediction_method(X)
             ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/python311/lib/python3.11/site-packages/skorch/classifier.py", line 381, in predict
    return (y_proba[:, 1] > self.threshold).astype('uint8')

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions