Skip to content

predict_type="prob" does not work with out_features=1 #374

Closed
@tdhock

Description

@tdhock

Hi @sebffischer I expected that I should be able to use measures to evaluate predictions, for binary classification with a torch learner with out_features=1. Here is an example adapted from #373 (thanks!)

library(mlr3torch)
nn_bce_loss3 = nn_module(c("nn_bce_with_logits_loss3", "nn_loss"),
  initialize = function(weight = NULL, reduction = "mean", pos_weight = NULL) {
    self$loss = nn_bce_with_logits_loss(weight, reduction, pos_weight)
  },
  forward = function(input, target) {
    self$loss(input$reshape(-1), target$to(dtype = torch_float())-1)
  }
)
loss = nn_bce_loss3()
loss(torch_randn(10, 1), torch_randint(0, 1, 10))
task = tsk("sonar")
graph = po("torch_ingress_num") %>>%
  nn("linear", out_features = 1) %>>%
  po("torch_loss", loss = nn_bce_loss3) %>>%
  po("torch_optimizer") %>>%
  po("torch_model_classif",
     epochs = 1,
     batch_size = 32,
     predict_type="prob")
glrn = as_learner(graph)
glrn$train(task)
glrn$predict(task)

The code above has predict_type="prob" and out_features=1 so I am getting the following error on current main

if(FALSE){#broke
  remotes::install_github("mlr-org/mlr3torch@6e99e02908788275622a7b723d211f357081699a")
  glrn$predict(task)
  ## Erreur dans dimnames(x) <- dn : 
  ##   la longueur de 'dimnames' [2] n'est pas égale à l'étendue du tableau
  ## This happened PipeOp torch_model_classif's $predict()
}

The error happens because the torch model outputs only one column, but some later code assumes there are two.

I hacked a solution that fixes this (see below), and I will file a PR.

if(FALSE){#fix
  remotes::install_github("tdhock/mlr3torch@69d4adda7a71c05403d561bf3bb1ffb279978d0d")
  glrn$predict(task)
  ## <PredictionClassif> for 208 observations:
  ##  row_ids truth response
  ##        1     R        M
  ##        2     R        M
  ##        3     R        M
  ##      ---   ---      ---
  ##      206     M        M
  ##      207     M        M
  ##      208     M        M
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions