Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

predict_type="prob" does not work with out_features=1 #374

Open
tdhock opened this issue Apr 3, 2025 · 3 comments · May be fixed by #375
Open

predict_type="prob" does not work with out_features=1 #374

tdhock opened this issue Apr 3, 2025 · 3 comments · May be fixed by #375

Comments

@tdhock
Copy link
Contributor

tdhock commented Apr 3, 2025

Hi @sebffischer I expected that I should be able to use measures to evaluate predictions, for binary classification with a torch learner with out_features=1. Here is an example adapted from #373 (thanks!)

library(mlr3torch)
nn_bce_loss3 = nn_module(c("nn_bce_with_logits_loss3", "nn_loss"),
  initialize = function(weight = NULL, reduction = "mean", pos_weight = NULL) {
    self$loss = nn_bce_with_logits_loss(weight, reduction, pos_weight)
  },
  forward = function(input, target) {
    self$loss(input$reshape(-1), target$to(dtype = torch_float())-1)
  }
)
loss = nn_bce_loss3()
loss(torch_randn(10, 1), torch_randint(0, 1, 10))
task = tsk("sonar")
graph = po("torch_ingress_num") %>>%
  nn("linear", out_features = 1) %>>%
  po("torch_loss", loss = nn_bce_loss3) %>>%
  po("torch_optimizer") %>>%
  po("torch_model_classif",
     epochs = 1,
     batch_size = 32,
     predict_type="prob")
glrn = as_learner(graph)
glrn$train(task)
glrn$predict(task)

The code above has predict_type="prob" and out_features=1 so I am getting the following error on current main

if(FALSE){#broke
  remotes::install_github("mlr-org/mlr3torch@6e99e02908788275622a7b723d211f357081699a")
  glrn$predict(task)
  ## Erreur dans dimnames(x) <- dn : 
  ##   la longueur de 'dimnames' [2] n'est pas égale à l'étendue du tableau
  ## This happened PipeOp torch_model_classif's $predict()
}

The error happens because the torch model outputs only one column, but some later code assumes there are two.

I hacked a solution that fixes this (see below), and I will file a PR.

if(FALSE){#fix
  remotes::install_github("tdhock/mlr3torch@69d4adda7a71c05403d561bf3bb1ffb279978d0d")
  glrn$predict(task)
  ## <PredictionClassif> for 208 observations:
  ##  row_ids truth response
  ##        1     R        M
  ##        2     R        M
  ##        3     R        M
  ##      ---   ---      ---
  ##      206     M        M
  ##      207     M        M
  ##      208     M        M
}
@tdhock tdhock linked a pull request Apr 3, 2025 that will close this issue
@sebffischer
Copy link
Member

Thanks again, I missed that we make this assumption here as well!

Maybe to understand a bit better where you are coming from:
What exactly is the advantage of encoding the features this way?
Using cross-entropy loss with output dim p = 2 is equivalent, right?
I would also expect the performance advantage to be minimal.

@sebffischer
Copy link
Member

sebffischer commented Apr 3, 2025

But I think you are right that we should have different output heads for binary classification.
I think that nn("head") should ideally do this correctly, which means:

  1. The module it creates should output only a single score, not two
  2. It should register the correct way to load the labels, i.e. [0, 1] floats and not [1, 2] ints + the correct shape.

@tdhock
Copy link
Contributor Author

tdhock commented Apr 3, 2025

using cross-entropy loss with output dim p=2 is not equivalent because in that case you have 2x the number of parameters to learn in the last layer.
Advantage: faster to learn with fewer parameters in last layer. The folklore would suggest that it is easier to learn too (fewer parameters => less data required)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants