Skip to content

Commit 6e7001e

Browse files
authored
Merge pull request #273 from stan-dev/avoid-under-and-overflows-in-stacking
avoid under and overflows in stacking
2 parents b1f7a5a + 568f29b commit 6e7001e

File tree

1 file changed

+6
-9
lines changed

1 file changed

+6
-9
lines changed

R/loo_model_weights.R

+6-9
Original file line numberDiff line numberDiff line change
@@ -257,15 +257,12 @@ stacking_weights <-
257257
stop("At least two models are required for stacking weights.")
258258
}
259259

260-
exp_lpd_point <- exp(lpd_point)
261260
negative_log_score_loo <- function(w) {
262261
# objective function: log score
263262
stopifnot(length(w) == K - 1)
264263
w_full <- c(w, 1 - sum(w))
265-
sum <- 0
266-
for (i in 1:N) {
267-
sum <- sum + log(exp(lpd_point[i, ]) %*% w_full)
268-
}
264+
# avoid over- and underflows using log weights and rowLogSumExps
265+
sum <- sum(matrixStats::rowLogSumExps(sweep(lpd_point[1:N,], 2, log(w_full), '+')))
269266
return(-as.numeric(sum))
270267
}
271268

@@ -274,11 +271,11 @@ stacking_weights <-
274271
stopifnot(length(w) == K - 1)
275272
w_full <- c(w, 1 - sum(w))
276273
grad <- rep(0, K - 1)
274+
# avoid over- and underflows using log weights, rowLogSumExps,
275+
# and by subtracting the row maximum of lpd_point
276+
mlpd <- matrixStats::rowMaxs(lpd_point)
277277
for (k in 1:(K - 1)) {
278-
for (i in 1:N) {
279-
grad[k] <- grad[k] +
280-
(exp_lpd_point[i, k] - exp_lpd_point[i, K]) / (exp_lpd_point[i,] %*% w_full)
281-
}
278+
grad[k] <- sum((exp(lpd_point[, k] - mlpd) - exp(lpd_point[, K] - mlpd)) / exp(matrixStats::rowLogSumExps(sweep(lpd_point, 2, log(w_full), '+')) - mlpd))
282279
}
283280
return(-grad)
284281
}

0 commit comments

Comments
 (0)