Skip to content

Commit 494dce1

Browse files
committed
fix Turing interface
1 parent 7d4c967 commit 494dce1

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

ext/SliceSamplingTuringExt.jl

+7-7
Original file line numberDiff line numberDiff line change
@@ -50,20 +50,20 @@ end
5050
# end
5151

5252
function SliceSampling.initial_sample(rng::Random.AbstractRNG, ℓ::Turing.LogDensityFunction)
53-
model =.model
54-
spl = Turing.SampleFromUniform()
55-
vi = Turing.VarInfo(rng, model, spl)
56-
θ = vi[spl]
53+
model =.model
54+
vi = Turing.VarInfo(rng, model, Turing.SampleFromUniform())
55+
vi_spl = last(Turing.DynamicPPL.evaluate!!(model, rng, vi, Turing.SampleFromUniform()))
56+
θ = vi_spl[:]
5757

5858
init_attempt_count = 1
59-
while !isfinite)
59+
while !all(isfinite.(θ))
6060
if init_attempt_count == 10
6161
@warn "failed to find valid initial parameters in $(init_attempt_count) tries; consider providing explicit initial parameters using the `initial_params` keyword"
6262
end
6363

6464
# NOTE: This will sample in the unconstrained space.
65-
vi = last(DynamicPPL.evaluate!!(model, rng, vi, SampleFromUniform()))
66-
θ = vi[spl]
65+
vi_spl = last(Turing.DynamicPPL.evaluate!!(model, rng, vi, Turing.SampleFromUniform()))
66+
θ = vi_spl[:]
6767

6868
init_attempt_count += 1
6969
end

0 commit comments

Comments
 (0)