Skip to content

Commit 7a15595

Browse files
authored
Add 2D plotting method for Cstf method (#106)
- Add 2D plotting method for Cstf method (#106)
1 parent e969c28 commit 7a15595

17 files changed

+8063
-126
lines changed

NAMESPACE

+1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ importFrom(mlr3,TaskClassif)
5252
importFrom(mlr3,TaskRegr)
5353
importFrom(mlr3,as_data_backend)
5454
importFrom(mlr3,assert_task)
55+
importFrom(mlr3,lrn)
5556
importFrom(mlr3,rsmp)
5657
importFrom(mlr3,rsmps)
5758
importFrom(mlr3,tsk)

R/autoplot_spcv_cstf.R

+167-108
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,16 @@
1111
#' other methods and plotting these without a train/test indicator would not
1212
#' make sense.
1313
#'
14+
#' @section 2D vs 3D plotting:
15+
#' This method has both a 2D and a 3D plotting method.
16+
#' The 2D method returns a \pkg{ggplot} with x and y axes representing the spatial
17+
#' coordinates.
18+
#' The 3D method uses \pkg{plotly} to create an interactive 3D plot.
19+
#' Set `plot3D = TRUE` to use the 3D method.
20+
#'
21+
#' Note that spatiotemporal datasets usually suffer from overplotting in 2D
22+
#' mode.
23+
#'
1424
#' @name autoplot.ResamplingSptCVCstf
1525
#' @inheritParams autoplot.ResamplingSpCVBlock
1626
#'
@@ -42,6 +52,9 @@
4252
#' working directory.
4353
#' @param show_omitted `[logical]`\cr
4454
#' Whether to show points not used in train or test set for the current fold.
55+
#' @param plot3D `[logical]`\cr
56+
#' Whether to create a 2D image via \pkg{ggplot2} or a 3D plot via
57+
#' \pkg{plotly}.
4558
#' @param ... Passed down to `plotly::orca()`. Only effective when
4659
#' `static_image = TRUE`.
4760
#' @export
@@ -85,6 +98,7 @@ autoplot.ResamplingSptCVCstf = function( # nolint
8598
axis_label_fontsize = 11,
8699
static_image = FALSE,
87100
show_omitted = FALSE,
101+
plot3D = NULL,
88102
...) {
89103

90104
dots = list(...)
@@ -108,149 +122,193 @@ autoplot.ResamplingSptCVCstf = function( # nolint
108122
resampling_sub$instance = resampling_sub$instance[[repeats_id]]
109123
}
110124

111-
if (!is.null(fold_id)) {
112-
113-
if (length(fold_id) == 1) {
114-
### only one fold
115-
116-
data_coords = prepare_autoplot_cstf(task, resampling_sub)
117-
118-
# suppress undefined global variables note
119-
data_coords$indicator = ""
120-
121-
row_id_test = resampling$instance$test[[fold_id]]
122-
row_id_train = resampling$instance$train[[fold_id]]
123-
124-
data_coords[row_id %in% row_id_test, indicator := "Test"]
125-
data_coords[row_id %in% row_id_train, indicator := "Train"]
126-
127-
if (show_omitted && nrow(data_coords[indicator == ""]) > 0) {
128-
data_coords[indicator == "", indicator := "Omitted"]
129-
130-
plot_single_plotly = plotly::plot_ly(data_coords,
131-
x = ~x, y = ~y, z = ~Date,
132-
color = ~indicator, colors = c(
133-
"grey", "#E18727", "#0072B5"
134-
),
135-
sizes = c(20, 100)
136-
)
137-
} else {
138-
data_coords = data_coords[indicator != ""]
139-
plot_single_plotly = plotly::plot_ly(data_coords,
140-
x = ~x, y = ~y, z = ~Date,
141-
color = ~indicator, colors = c(
142-
"#E18727", "#0072B5"
143-
),
144-
sizes = c(20, 100)
145-
)
146-
}
125+
# check if we are in a 2D or 3D scenario
126+
if (is.null(plot3D)) {
127+
if (!is.null(resampling_sub$space_var) &&
128+
!is.null(resampling_sub$time_var)) {
129+
plot3D = TRUE
130+
} else {
131+
plot3D = FALSE
132+
}
133+
}
147134

148-
plot_single_plotly = plotly::add_markers(plot_single_plotly,
149-
marker = list(size = point_size))
150-
plot_single_plotly = plotly::layout(plot_single_plotly,
151-
title = sprintf(
152-
"Fold %s, Repetition %s", fold_id,
153-
repeats_id),
154-
autosize = TRUE,
155-
scene = list(
156-
xaxis = list(title = "Lat", nticks = nticks_x),
157-
yaxis = list(title = "Lon", nticks = nticks_y),
158-
zaxis = list(
159-
title = "Time",
160-
type = "date",
161-
tickformat = tickformat_date,
162-
tickfont = list(size = axis_label_fontsize)
163-
),
164-
camera = list(eye = list(z = 1.50))
165-
)
166-
)
135+
# 2D -------------------------------------------------------------------------
136+
137+
if (!plot3D) {
138+
139+
# bring into correct format (complicated alternative to reshape::melt)
140+
resampling_sub$instance = data.table::rbindlist(
141+
lapply(resampling_sub$instance$test, as.data.table),
142+
idcol = "fold")
143+
setnames(resampling_sub$instance, c("fold", "row_id"))
144+
145+
plot = autoplot_spatial(
146+
resampling = resampling_sub,
147+
task = task,
148+
fold_id = fold_id,
149+
# repeats_id = repeats_id,
150+
plot_as_grid = plot_as_grid,
151+
train_color = train_color,
152+
test_color = test_color,
153+
crs = crs,
154+
show_blocks = FALSE,
155+
show_labels = FALSE,
156+
...)
157+
return(invisible(plot))
158+
}
167159

168-
if (static_image) {
169-
plotly::orca(plot_single_plotly, ...)
170-
}
160+
# 3D -------------------------------------------------------------------------
171161

172-
print(plot_single_plotly)
173-
return(invisible(plot_single_plotly))
174-
} else {
162+
if (plot3D) {
175163

176-
### Multiplot of multiple partitions with train and test set
164+
if (!is.null(fold_id)) {
177165

178-
plot = mlr3misc::map(fold_id, function(.x) {
166+
if (length(fold_id) == 1) {
167+
### only one fold
179168

180169
data_coords = prepare_autoplot_cstf(task, resampling_sub)
181170

182-
# get test and train indices
183-
row_id_test = resampling$instance$test[[.x]]
184-
row_id_train = resampling$instance$train[[.x]]
171+
# suppress undefined global variables note
172+
data_coords$indicator = ""
173+
174+
row_id_test = resampling_sub$instance$test[[fold_id]]
175+
row_id_train = resampling_sub$instance$train[[fold_id]]
185176

186-
# assign test or train to columns matching the respective row ids
187177
data_coords[row_id %in% row_id_test, indicator := "Test"]
188178
data_coords[row_id %in% row_id_train, indicator := "Train"]
189179

190-
data_coords$Date = as.Date(data_coords$Date)
191-
192-
if (show_omitted) {
180+
if (show_omitted && nrow(data_coords[indicator == ""]) > 0) {
193181
data_coords[indicator == "", indicator := "Omitted"]
194182

195-
pl = plotly::plot_ly(data_coords,
183+
plot_single_plotly = plotly::plot_ly(data_coords,
196184
x = ~x, y = ~y, z = ~Date,
197185
color = ~indicator, colors = c(
198186
"grey", "#E18727", "#0072B5"
199187
),
200-
# # this is needed for later when doing 3D subplots
201-
scene = paste0("scene", .x),
202-
showlegend = ifelse(.x == 1, TRUE, FALSE)
188+
sizes = c(20, 100)
203189
)
204190
} else {
205191
data_coords = data_coords[indicator != ""]
206-
pl = plotly::plot_ly(data_coords,
192+
plot_single_plotly = plotly::plot_ly(data_coords,
207193
x = ~x, y = ~y, z = ~Date,
208194
color = ~indicator, colors = c(
209195
"#E18727", "#0072B5"
210196
),
211-
# # this is needed for later when doing 3D subplots
212-
scene = paste0("scene", .x),
213-
showlegend = ifelse(.x == 1, TRUE, FALSE)
214-
# sizes = c(20, 100)
197+
sizes = c(20, 100)
215198
)
216199
}
217200

218-
pl = plotly::add_markers(pl, marker = list(size = point_size))
219-
layout_args = list(pl,
220-
"title" = sprintf("Fold #%s", .x),
221-
list(
222-
xaxis = list(
223-
title = "Lat",
224-
nticks = nticks_x,
225-
tickfont = list(size = axis_label_fontsize)),
226-
yaxis = list(
227-
title = "Lon",
228-
nticks = nticks_y,
229-
tickfont = list(size = axis_label_fontsize)),
201+
plot_single_plotly = plotly::add_markers(plot_single_plotly,
202+
marker = list(size = point_size))
203+
plot_single_plotly = plotly::layout(plot_single_plotly,
204+
title = sprintf(
205+
"Fold %s, Repetition %s", fold_id,
206+
repeats_id),
207+
autosize = TRUE,
208+
scene = list(
209+
xaxis = list(title = "Lat", nticks = nticks_x),
210+
yaxis = list(title = "Lon", nticks = nticks_y),
230211
zaxis = list(
231212
title = "Time",
232213
type = "date",
233214
tickformat = tickformat_date,
234215
tickfont = list(size = axis_label_fontsize)
235-
# sets size of axis titles
236-
# titlefont = list(size = 5)
237216
),
238217
camera = list(eye = list(z = 1.50))
239218
)
240219
)
241-
# -`p` is the name of the plotly object.
242-
# - title sets the title of the plot
243-
# - the "scene" name is dynamically generated and refers to the scene
244-
# name in the `plot_ly()` call
245-
names(layout_args) = c(
246-
"p",
247-
"title",
248-
paste0("scene", .x)
249-
)
250220

251-
pl = mlr3misc::invoke(plotly::layout, .args = layout_args)
221+
if (static_image) {
222+
plotly::orca(plot_single_plotly, ...)
223+
}
224+
225+
print(plot_single_plotly)
226+
return(invisible(plot_single_plotly))
227+
} else {
228+
229+
### Multiplot of multiple partitions with train and test set
230+
231+
plot = mlr3misc::map(fold_id, function(.x) {
232+
233+
data_coords = prepare_autoplot_cstf(task, resampling_sub)
234+
235+
# get test and train indices
236+
row_id_test = resampling_sub$instance$test[[.x]]
237+
row_id_train = resampling_sub$instance$train[[.x]]
238+
239+
# assign test or train to columns matching the respective row ids
240+
data_coords[row_id %in% row_id_test, indicator := "Test"]
241+
data_coords[row_id %in% row_id_train, indicator := "Train"]
242+
243+
data_coords$Date = as.Date(data_coords$Date)
244+
245+
if (show_omitted) {
246+
data_coords[indicator == "", indicator := "Omitted"]
247+
248+
pl = plotly::plot_ly(data_coords,
249+
x = ~x, y = ~y, z = ~Date,
250+
color = ~indicator, colors = c(
251+
"grey", "#E18727", "#0072B5"
252+
),
253+
# # this is needed for later when doing 3D subplots
254+
scene = paste0("scene", .x),
255+
showlegend = ifelse(.x == 1, TRUE, FALSE)
256+
)
257+
} else {
258+
data_coords = data_coords[indicator != ""]
259+
pl = plotly::plot_ly(data_coords,
260+
x = ~x, y = ~y, z = ~Date,
261+
color = ~indicator, colors = c(
262+
"#E18727", "#0072B5"
263+
),
264+
# # this is needed for later when doing 3D subplots
265+
scene = paste0("scene", .x),
266+
showlegend = ifelse(.x == 1, TRUE, FALSE)
267+
# sizes = c(20, 100)
268+
)
269+
}
270+
271+
pl = plotly::add_markers(pl, marker = list(size = point_size))
272+
layout_args = list(pl,
273+
"title" = sprintf("Fold #%s", .x),
274+
list(
275+
xaxis = list(
276+
title = "Lat",
277+
nticks = nticks_x,
278+
tickfont = list(size = axis_label_fontsize)),
279+
yaxis = list(
280+
title = "Lon",
281+
nticks = nticks_y,
282+
tickfont = list(size = axis_label_fontsize)),
283+
zaxis = list(
284+
title = "Time",
285+
type = "date",
286+
tickformat = tickformat_date,
287+
tickfont = list(size = axis_label_fontsize)
288+
# sets size of axis titles
289+
# titlefont = list(size = 5)
290+
),
291+
camera = list(eye = list(z = 1.50))
292+
)
293+
)
294+
# -`p` is the name of the plotly object.
295+
# - title sets the title of the plot
296+
# - the "scene" name is dynamically generated and refers to the scene
297+
# name in the `plot_ly()` call
298+
names(layout_args) = c(
299+
"p",
300+
"title",
301+
paste0("scene", .x)
302+
)
252303

253-
})
304+
pl = mlr3misc::invoke(plotly::layout, .args = layout_args)
305+
306+
})
307+
}
308+
}
309+
310+
else {
311+
stop("This method requires to set argument 'fold_id'. See ?autoplot.ResamplingSptCVCstf for more information.") # nolint
254312
}
255313

256314
# is a grid requested?
@@ -274,9 +332,8 @@ autoplot.ResamplingSptCVCstf = function( # nolint
274332
return(plot)
275333
return(invisible(plot))
276334
}
277-
} else {
278-
stop("This method requires to set argument 'fold_id'. See ?autoplot.ResamplingSptCVCstf for more information.") # nolint
279335
}
336+
280337
}
281338

282339
#' @rdname autoplot.ResamplingSptCVCstf
@@ -295,6 +352,7 @@ autoplot.ResamplingRepeatedSptCVCstf = function( # nolint
295352
nticks_y = 3,
296353
point_size = 3,
297354
axis_label_fontsize = 11,
355+
plot3D = NULL,
298356
...) {
299357

300358
autoplot.ResamplingSptCVCstf(
@@ -310,9 +368,10 @@ autoplot.ResamplingRepeatedSptCVCstf = function( # nolint
310368
nticks_y = nticks_y,
311369
point_size = point_size,
312370
axis_label_fontsize = axis_label_fontsize,
313-
...
371+
plot3D = plot3D,
372+
...,
314373
# ellipsis
315-
# repeats_id = repeats_id
374+
repeats_id = repeats_id
316375
)
317376
}
318377

R/zzz.R

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#' @rawNamespace import(data.table, except = transpose)
22
#' @importFrom R6 R6Class
3-
#' @importFrom mlr3 TaskClassif TaskRegr Resampling as_data_backend assert_task rsmp tsk rsmps
3+
#' @importFrom mlr3 TaskClassif TaskRegr Resampling as_data_backend assert_task rsmp tsk rsmps lrn
44
#' @import mlr3misc
55
#' @import checkmate
66
#' @import paradox
@@ -93,13 +93,13 @@ register_mlr3 = function() { # nocov start
9393
register_mlr3()
9494
setHook(packageEvent("mlr3", "onLoad"), function(...) register_mlr3(),
9595
action = "append")
96-
} # nocov end
96+
}
9797

9898
.onUnload = function(libpath) { # nolint
9999
event = packageEvent("mlr3", "onLoad")
100100
hooks = getHook(event)
101101
pkgname = vapply(hooks, function(x) environment(x)$pkgname, NA_character_)
102102
setHook(event, hooks[pkgname != "mlr3spatiotempcv"], action = "replace")
103-
} # nocov end
103+
}
104104

105-
leanify_package()
105+
leanify_package() # nocov end

0 commit comments

Comments
 (0)