Skip to content

Commit ec687ff

Browse files
authored
Merge pull request #278 from stan-dev/fix-kfold_split_stratified-1-obs
Fix `kfold_split_stratified()` when a group has 1 observation
2 parents 93cdae8 + 117f030 commit ec687ff

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

R/kfold-helpers.R

+6-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,12 @@ kfold_split_stratified <- function(K = 10, x = NULL) {
8484
N <- length(x)
8585
xids <- numeric()
8686
for (l in 1:Nlev) {
87-
xids <- c(xids, sample(which(x==l)))
87+
idx <- which(x == l)
88+
if (length(idx) > 1) {
89+
xids <- c(xids, sample(idx))
90+
} else {
91+
xids <- c(xids, idx)
92+
}
8893
}
8994
bins <- rep(NA, N)
9095
bins[xids] <- rep(1:K, ceiling(N/K))[1:N]

tests/testthat/test_kfold_helpers.R

+8
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@ test_that("kfold_split_stratified works", {
2222
y <- mtcars$cyl
2323
fold_strat <- kfold_split_stratified(10, y)
2424
expect_equal(range(table(fold_strat)), c(3, 4))
25+
26+
# test when a group has 1 observation
27+
# https://github.com/stan-dev/loo/issues/277
28+
y <- rep(c(1, 2, 3), times = c(20, 40, 1))
29+
expect_silent(fold_strat <- kfold_split_stratified(5, y)) # used to be a warning before fixing issue #277
30+
tab <- table(fold_strat, y)
31+
expect_equal(tab[1, ], c("1" = 4, "2" = 8, "3" = 1))
32+
for (i in 2:nrow(tab)) expect_equal(tab[i, ], c("1" = 4, "2" = 8, "3" = 0))
2533
})
2634

2735
test_that("kfold_split_grouped works", {

0 commit comments

Comments
 (0)