Skip to content

Commit 851de6f

Browse files
authored
Try to fix rasterLayer test in spcv_block() (#105)
1 parent f512993 commit 851de6f

23 files changed

+64
-54
lines changed

NAMESPACE

+10-1
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,20 @@ export(autoplot)
4343
import(checkmate)
4444
import(data.table, except = transpose)
4545
import(ggplot2)
46-
import(mlr3)
4746
import(mlr3misc)
4847
import(paradox)
4948
importFrom(R6,R6Class)
5049
importFrom(graphics,plot)
50+
importFrom(mlr3,Resampling)
51+
importFrom(mlr3,TaskClassif)
52+
importFrom(mlr3,TaskRegr)
53+
importFrom(mlr3,as_data_backend)
54+
importFrom(mlr3,assert_task)
55+
importFrom(mlr3,rsmp)
56+
importFrom(mlr3,rsmps)
57+
importFrom(mlr3,tsk)
58+
importFrom(stats,kmeans)
5159
importFrom(stats,quantile)
5260
importFrom(utils,bibentry)
61+
importFrom(utils,capture.output)
5362
importFrom(utils,globalVariables)

R/ResamplingRepeatedSpCVBlock.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ ResamplingRepeatedSpCVBlock = R6Class("ResamplingRepeatedSpCVBlock",
9494
#' A task to instantiate.
9595
instantiate = function(task) {
9696

97-
assert_task(task)
97+
mlr3::assert_task(task)
9898
checkmate::assert_multi_class(task, c("TaskClassifST", "TaskRegrST"))
9999
pv = self$param_set$values
100100
assert_numeric(pv$repeats, min.len = 1)

R/ResamplingRepeatedSpCVCoords.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ ResamplingRepeatedSpCVCoords = R6Class("ResamplingRepeatedSpCVCoords",
7070
#' A task to instantiate.
7171
instantiate = function(task) {
7272

73-
assert_task(task)
73+
mlr3::assert_task(task)
7474
checkmate::assert_multi_class(task, c("TaskClassifST", "TaskRegrST"))
7575
groups = task$groups
7676
if (!is.null(groups)) {

R/ResamplingRepeatedSpCVEnv.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ ResamplingRepeatedSpCVEnv = R6Class("ResamplingRepeatedSpCVEnv",
6767
#' A task to instantiate.
6868
instantiate = function(task) {
6969

70-
assert_task(task)
70+
mlr3::assert_task(task)
7171
checkmate::assert_multi_class(task, c("TaskClassifST", "TaskRegrST"))
7272
pv = self$param_set$values
7373

R/ResamplingRepeatedSptCVCluto.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ ResamplingRepeatedSptCVCluto = R6Class("ResamplingRepeatedSptCVCluto",
135135

136136
mlr3misc::require_namespaces("skmeans", quietly = TRUE)
137137

138-
assert_task(task)
138+
mlr3::assert_task(task)
139139
checkmate::assert_multi_class(task, c("TaskClassifST", "TaskRegrST"))
140140
checkmate::assert_subset(self$time_var, choices = task$feature_names)
141141
groups = task$groups

R/ResamplingRepeatedSptCVCstf.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ ResamplingRepeatedSptCVCstf = R6Class("ResamplingRepeatedSptCVCstf",
9797
#' A task to instantiate.
9898
instantiate = function(task) {
9999

100-
assert_task(task)
100+
mlr3::assert_task(task)
101101
checkmate::assert_multi_class(task, c("TaskClassifST", "TaskRegrST"))
102102
checkmate::assert_subset(self$time_var,
103103
choices = task$feature_names,

R/ResamplingSpCVBlock.R

+12-11
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#' @references
66
#' `r format_bib("valavi2018")`
77
#'
8+
#' @importFrom utils capture.output
9+
#'
810
#' @export
911
#' @examples
1012
#' if (mlr3misc::require_namespaces(c("sf", "blockCV"), quietly = TRUE)) {
@@ -48,7 +50,10 @@ ResamplingSpCVBlock = R6Class("ResamplingSpCVBlock",
4850
"checkerboard"), default = "random"),
4951
ParamUty$new("rasterLayer",
5052
default = NULL,
51-
custom_check = function(x) checkmate::check_class(x, "RasterLayer", null.ok = TRUE))
53+
custom_check = function(x) {
54+
checkmate::check_class(x, "RasterLayer",
55+
null.ok = TRUE)
56+
})
5257
))
5358
ps$values = list(folds = 10L)
5459
super$initialize(
@@ -64,7 +69,7 @@ ResamplingSpCVBlock = R6Class("ResamplingSpCVBlock",
6469
#' A task to instantiate.
6570
instantiate = function(task) {
6671

67-
assert_task(task)
72+
mlr3::assert_task(task)
6873
checkmate::assert_multi_class(task, c("TaskClassifST", "TaskRegrST"))
6974
pv = self$param_set$values
7075

@@ -107,8 +112,10 @@ ResamplingSpCVBlock = R6Class("ResamplingSpCVBlock",
107112
stopf("Grouping is not supported for spatial resampling methods.")
108113
}
109114
instance = private$.sample(
110-
task$row_ids, task$coordinates(),
111-
task$extra_args$crs)
115+
task$row_ids,
116+
task$coordinates(),
117+
task$extra_args$crs
118+
)
112119

113120
self$instance = instance
114121
self$task_hash = task$hash
@@ -160,14 +167,8 @@ ResamplingSpCVBlock = R6Class("ResamplingSpCVBlock",
160167
row_id = ids,
161168
fold = inds$foldID
162169
)
163-
# list(
164-
# resampling = data.table(
165-
# row_id = ids,
166-
# fold = inds$foldID
167-
# ),
168-
# blocks = blocks_sf
169-
# )
170170
},
171+
171172
# private get funs for train and test which are used by
172173
# Resampling$.get_set()
173174
.get_train = function(i) {

R/ResamplingSpCVBuffer.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ ResamplingSpCVBuffer = R6Class("ResamplingSpCVBuffer",
5151
#' A task to instantiate.
5252
instantiate = function(task) {
5353

54-
assert_task(task)
54+
mlr3::assert_task(task)
5555
checkmate::assert_multi_class(task, c("TaskClassifST", "TaskRegrST"))
5656
groups = task$groups
5757

R/ResamplingSpCVCoords.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ ResamplingSpCVCoords = R6Class("ResamplingSpCVCoords",
4747
#' A task to instantiate.
4848
instantiate = function(task) {
4949

50-
assert_task(task)
50+
mlr3::assert_task(task)
5151
checkmate::assert_multi_class(task, c("TaskClassifST", "TaskRegrST"))
5252
groups = task$groups
5353

R/ResamplingSpCVEnv.R

+3-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#' @references
66
#' `r format_bib("valavi2018")`
77
#'
8+
#' @importFrom stats kmeans
89
#' @export
910
#' @examples
1011
#' if (mlr3misc::require_namespaces(c("sf", "blockCV"), quietly = TRUE)) {
@@ -49,7 +50,7 @@ ResamplingSpCVEnv = R6Class("ResamplingSpCVEnv",
4950
#' A task to instantiate.
5051
instantiate = function(task) {
5152

52-
assert_task(task)
53+
mlr3::assert_task(task)
5354
checkmate::assert_multi_class(task, c("TaskClassifST", "TaskRegrST"))
5455
pv = self$param_set$values
5556

@@ -101,7 +102,7 @@ ResamplingSpCVEnv = R6Class("ResamplingSpCVEnv",
101102

102103
private = list(
103104
.sample = function(ids, data) {
104-
inds = kmeans(data, centers = self$param_set$values$folds)
105+
inds = stats::kmeans(data, centers = self$param_set$values$folds)
105106

106107
data.table(
107108
row_id = ids,

R/ResamplingSptCVCluto.R

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#' @title Spatioemporal Cluster Resampling
22
#'
3-
#' @import mlr3
43
#' @template rox_sptcv_cluto
54
#'
65
#' @references
@@ -101,7 +100,7 @@ ResamplingSptCVCluto = R6Class("ResamplingSptCVCluto",
101100

102101
mlr3misc::require_namespaces("skmeans", quietly = TRUE)
103102

104-
assert_task(task)
103+
mlr3::assert_task(task)
105104
checkmate::assert_multi_class(task, c("TaskClassifST", "TaskRegrST"))
106105
checkmate::assert_subset(self$time_var, choices = task$feature_names)
107106
groups = task$groups

R/ResamplingSptCVCstf.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ ResamplingSptCVCstf = R6Class("ResamplingSptCVCstf",
8484
#' Column name identifying a class unit (e.g. land cover).
8585
instantiate = function(task) {
8686

87-
assert_task(task)
87+
mlr3::assert_task(task)
8888
checkmate::assert_multi_class(task, c("TaskClassifST", "TaskRegrST"))
8989
checkmate::assert_subset(self$time_var,
9090
choices = task$feature_names,

R/TaskClassifST.R

-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
#' @title SpatioTemporal Classification Task
22
#'
3-
#' @import mlr3
4-
#'
53
#' @description This task specializes [Task] and [TaskSupervised] for
64
#' spatiotemporal classification problems. The target column is assumed to be a
75
#' factor. The `task_type` is set to `"classif"` and `"spatiotemporal"`.

R/TaskRegrST.R

-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
#' @title SpatioTemporal Regression Task
22
#'
3-
#' @import mlr3
4-
#'
53
#' @description
64
#' This task specializes [Task] and [TaskSupervised] for spatiotemporal
75
#' classification problems.

R/zzz.R

+13-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#' @rawNamespace import(data.table, except = transpose)
22
#' @importFrom R6 R6Class
3-
#' @import mlr3
3+
#' @importFrom mlr3 TaskClassif TaskRegr Resampling as_data_backend assert_task rsmp tsk rsmps
44
#' @import mlr3misc
55
#' @import checkmate
66
#' @import paradox
@@ -60,15 +60,15 @@ register_mlr3 = function() { # nocov start
6060

6161
# tasks --------------------------------------------------------------------
6262

63-
x = utils::getFromNamespace("mlr_tasks", ns = "mlr3")
63+
mlr_tasks = utils::getFromNamespace("mlr_tasks", ns = "mlr3")
6464

6565
mlr_tasks$add("ecuador", load_task_ecuador)
6666
mlr_tasks$add("diplodia", load_task_diplodia)
6767
mlr_tasks$add("cookfarm", load_task_cookfarm)
6868

6969
# resampling methods ---------------------------------------------------------
7070

71-
x = utils::getFromNamespace("mlr_resamplings", ns = "mlr3")
71+
mlr_resamplings = utils::getFromNamespace("mlr_resamplings", ns = "mlr3")
7272
mlr_resamplings$add("spcv_block", ResamplingSpCVBlock)
7373
mlr_resamplings$add("spcv_buffer", ResamplingSpCVBuffer)
7474
mlr_resamplings$add("sptcv_cstf", ResamplingSptCVCstf)
@@ -85,7 +85,7 @@ register_mlr3 = function() { # nocov start
8585

8686
utils::globalVariables(c(
8787
"row_id", "cookfarm_sample", "ecuador", "diplodia",
88-
"resampling", "task", "indicator", "fold"))
88+
"resampling", "task", "indicator", "fold", "id", "type"))
8989

9090
}
9191

@@ -94,3 +94,12 @@ register_mlr3 = function() { # nocov start
9494
setHook(packageEvent("mlr3", "onLoad"), function(...) register_mlr3(),
9595
action = "append")
9696
} # nocov end
97+
98+
.onUnload = function(libpath) { # nolint
99+
event = packageEvent("mlr3", "onLoad")
100+
hooks = getHook(event)
101+
pkgname = vapply(hooks, function(x) environment(x)$pkgname, NA_character_)
102+
setHook(event, hooks[pkgname != "mlr3spatiotempcv"], action = "replace")
103+
} # nocov end
104+
105+
leanify_package()

data-raw/cookfarm_sample.R

+5-5
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
library(GSIF)
55
# moved to package https://github.com/envirometrix/landmap after GSIF was archived on CRAN in 2021-03
66
data(cookfarm)
7-
saveRDS(cookfarm, "R/sysdata.rda", version = 2)
8-
cookfarm = readRDS("R/sysdata.rda")
7+
# saveRDS(cookfarm, "R/sysdata.rda", version = 2)
8+
# cookfarm = readRDS("R/sysdata.rda")
99
set.seed(42)
1010

1111
cookfarm_profiles = cookfarm$profiles
@@ -16,12 +16,12 @@ cookfarm_mlr3_sf = cookfarm_profiles %>%
1616
sf::st_as_sf(coords = c("Easting", "Northing"), crs = 26911) %>%
1717
dplyr::mutate(x = sf::st_coordinates(.)[, "X"]) %>%
1818
dplyr::mutate(y = sf::st_coordinates(.)[, "Y"]) %>%
19-
dplyr::mutate(Date = as.character(Date)) %>%
20-
dplyr::sample_n(500)
19+
dplyr::mutate(Date = as.character(Date))
2120

2221
cookfarm_sample = cookfarm_mlr3_sf %>%
2322
sf::st_set_geometry(NULL) %>%
24-
na.omit()
23+
na.omit() %>%
24+
dplyr::sample_n(500)
2525

2626
# mapview::mapview(cookfarm_sample)
2727

data/cookfarm_sample.rda

337 KB
Binary file not shown.

man-roxygen/rox_spcv_block.R

-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
#' @import mlr3
2-
#'
31
#' @description Spatial Block Cross validation implemented by the `blockCV`
42
#' package.
53
#'

man-roxygen/rox_spcv_buffer.R

-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
#' @import mlr3
2-
#'
31
#' @description Spatial Buffer Cross validation implemented by the `blockCV`
42
#' package.
53
#'

man-roxygen/rox_spcv_coords.R

-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,2 @@
1-
#' @import mlr3
2-
#'
31
#' @description Spatial Cross validation following the "k-means" approach after
42
#' Brenning 2012.

man-roxygen/rox_spcv_env.R

-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
#' @import mlr3
2-
#'
31
#' @description Environmental Block Cross Validation. This strategy uses k-means
42
#' clustering to specify blocks of similar environmental conditions. Only
53
#' numeric features can be used. The `features` used for building blocks can

tests/testthat/test-1-autoplot.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ test_that("plot() works for 'repeated_spcv_cstf'", {
312312

313313
test_that("autoplot time + space", {
314314
# special data with five temporal levels
315-
data = cookfarm
315+
data = cookfarm_sample
316316
data$Date = rep(c(
317317
"2020-01-01", "2020-02-01", "2020-03-01", "2020-04-01",
318318
"2020-05-01"), times = 1, each = 100)

tests/testthat/test-ResamplingSpCVBlock.R

+11-8
Original file line numberDiff line numberDiff line change
@@ -95,32 +95,35 @@ test_that("mlr3spatiotempcv indices are the same as blockCV indices: rasterLayer
9595
# same issue on macOS 3.6
9696
testthat::skip_if(as.numeric(R.version$major) < 4)
9797

98+
# unclear why this tests only errors on GHA
99+
testthat::skip_on_ci()
100+
98101
set.seed(42)
99102

100103
task = test_make_blockCV_test_task()
101104
testSF = test_make_blockCV_test_df()
102105

103-
r <- raster::raster(raster::extent(testSF), crs = "EPSG:4326")
104-
r[] <- 10
106+
rl <- raster::raster(raster::extent(testSF), crs = sf::st_crs(testSF)$wkt)
107+
vals <- seq_len(raster::ncell(rl))
108+
rl = raster::setValues(rl, vals)
105109

106-
rsmp <- rsmp("spcv_block",
110+
rsmp1 <- rsmp("spcv_block",
107111
range = 50000L,
108112
selection = "checkerboard",
109-
rasterLayer = r)
110-
rsmp$instantiate(task)
111-
113+
rasterLayer = rl)
114+
rsmp1$instantiate(task)
112115

113116
# blockCV
114117
capture.output(testBlock <- suppressMessages(
115118
blockCV::spatialBlock(
116119
speciesData = testSF,
117120
theRange = 50000L,
118121
selection = "checkerboard",
119-
rasterLayer = r,
122+
rasterLayer = rl,
120123
showBlocks = FALSE,
121124
verbose = FALSE,
122125
progress = FALSE)
123126
))
124127

125-
expect_equal(rsmp$instance$fold, testBlock$foldID)
128+
expect_equal(rsmp1$instance$fold, testBlock$foldID)
126129
})

0 commit comments

Comments
 (0)