Skip to content

Commit 48f590e

Browse files
committed
avoid error for 1-D unit_vector
the unit_vector isn't actually used anywhere in the situation when it errors (when K isn't >1) so we just make it size 2 in that case to avoid Stan's error for 1-D unit vectors. fixes #603
1 parent 92a877c commit 48f590e

File tree

4 files changed

+17
-2
lines changed

4 files changed

+17
-2
lines changed

src/stan_files/lm.stan

+2-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ transformed data {
5555
}
5656
parameters {
5757
// must not call with init="0"
58-
array[K > 1 ? J : 0] unit_vector[K] u; // primitives for coefficients
58+
// https://github.com/stan-dev/rstanarm/issues/603#issuecomment-1785928224
59+
array[K > 1 ? J : 0] unit_vector[K > 1 ? K : 2] u; // primitives for coefficients
5960
array[J * has_intercept] real z_alpha; // primitives for intercepts
6061
array[J] real<lower=(K > 1 ? 0 : -1), upper=1> R2; // proportions of variance explained
6162
vector[J * (1 - prior_PD)] log_omega; // under/overfitting factors

src/stan_files/polr.stan

+3-1
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,9 @@ transformed data {
175175
}
176176
parameters {
177177
simplex[J] pi;
178-
array[K > 1] unit_vector[K] u;
178+
// avoid error by making unit_vector have 2 elements when K <= 1
179+
// https://github.com/stan-dev/rstanarm/issues/603#issuecomment-1785928224
180+
array[K > 1] unit_vector[K > 1 ? K : 2] u;
179181
real<lower=(K > 1 ? 0 : -1), upper=1> R2;
180182
array[is_skewed] real<lower=0> alpha;
181183
}

tests/testthat/test_stan_lm.R

+7
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,13 @@ test_that("stan_lm doesn't break with vb algorithms", {
119119
expect_stanreg(fit2)
120120
})
121121

122+
test_that("stan_lm works with 1 predictor", {
123+
SW(fit <- stan_lm(mpg ~ wt, data = mtcars,
124+
prior = R2(0.5, "mean"), refresh = 0,
125+
seed = SEED))
126+
expect_stanreg(fit)
127+
})
128+
122129
test_that("stan_lm throws error if only intercept", {
123130
expect_error(stan_lm(mpg ~ 1, data = mtcars, prior = R2(location = 0.75)),
124131
regexp = "not suitable for estimating a mean")

tests/testthat/test_stan_polr.R

+5
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ test_that("stan_polr runs for esoph example", {
5656
expect_stanreg(fit2vb)
5757
})
5858

59+
test_that("stan_polr runs with 1 predictor", {
60+
esoph$x1 <- rnorm(nrow(esoph))
61+
expect_stanreg(stan_polr(tobgp ~ x1, data = esoph, prior = R2(0.5, "mean")))
62+
})
63+
5964
test_that("stan_polr throws error if formula excludes intercept", {
6065
expect_error(stan_polr(tobgp ~ 0 + agegp + alcgp, data = esoph,
6166
method = "loglog", prior = R2(0.4, "median")),

0 commit comments

Comments
 (0)