Skip to content

Commit be9b5d6

Browse files
committed
Update levy-stable kstest.
1 parent 24b177e commit be9b5d6

File tree

1 file changed

+9
-11
lines changed

1 file changed

+9
-11
lines changed

tests/distributions/test_stable.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,14 @@ def test_shape(sample_shape, batch_shape):
3535
def test_sample(alpha, beta):
3636
num_samples = 100
3737
d = dist.Stable(alpha, beta, coords="S")
38-
39-
def sampler(size):
40-
# Temporarily increase radius to test hole-patching logic.
41-
# Scipy doesn't handle values of alpha very close to 1.
42-
try:
43-
old = pyro.distributions.stable.RADIUS
44-
pyro.distributions.stable.RADIUS = 0.02
45-
return d.sample([size])
46-
finally:
47-
pyro.distributions.stable.RADIUS = old
38+
# Temporarily increase radius to test hole-patching logic.
39+
# Scipy doesn't handle values of alpha very close to 1.
40+
try:
41+
old = pyro.distributions.stable.RADIUS
42+
pyro.distributions.stable.RADIUS = 0.02
43+
samples = d.sample((num_samples,))
44+
finally:
45+
pyro.distributions.stable.RADIUS = old
4846

4947
def cdf(x):
5048
with warnings.catch_warnings(record=True) as w:
@@ -56,7 +54,7 @@ def cdf(x):
5654
pytest.xfail(reason="scipy.stats.levy_stable.cdf is unstable")
5755
return result
5856

59-
assert kstest(sampler, cdf, N=num_samples).pvalue > 0.1
57+
assert kstest(samples, cdf).pvalue > 0.1
6058

6159

6260
@pytest.mark.parametrize("beta", [-1.0, -0.5, 0.0, 0.5, 1.0])

0 commit comments

Comments
 (0)