@@ -35,16 +35,14 @@ def test_shape(sample_shape, batch_shape):
35
35
def test_sample (alpha , beta ):
36
36
num_samples = 100
37
37
d = dist .Stable (alpha , beta , coords = "S" )
38
-
39
- def sampler (size ):
40
- # Temporarily increase radius to test hole-patching logic.
41
- # Scipy doesn't handle values of alpha very close to 1.
42
- try :
43
- old = pyro .distributions .stable .RADIUS
44
- pyro .distributions .stable .RADIUS = 0.02
45
- return d .sample ([size ])
46
- finally :
47
- pyro .distributions .stable .RADIUS = old
38
+ # Temporarily increase radius to test hole-patching logic.
39
+ # Scipy doesn't handle values of alpha very close to 1.
40
+ try :
41
+ old = pyro .distributions .stable .RADIUS
42
+ pyro .distributions .stable .RADIUS = 0.02
43
+ samples = d .sample ((num_samples ,))
44
+ finally :
45
+ pyro .distributions .stable .RADIUS = old
48
46
49
47
def cdf (x ):
50
48
with warnings .catch_warnings (record = True ) as w :
@@ -56,7 +54,7 @@ def cdf(x):
56
54
pytest .xfail (reason = "scipy.stats.levy_stable.cdf is unstable" )
57
55
return result
58
56
59
- assert kstest (sampler , cdf , N = num_samples ).pvalue > 0.1
57
+ assert kstest (samples , cdf ).pvalue > 0.1
60
58
61
59
62
60
@pytest .mark .parametrize ("beta" , [- 1.0 , - 0.5 , 0.0 , 0.5 , 1.0 ])
0 commit comments