Skip to content

Commit de41cd6

Browse files
Setup formatter (#17)
* apply formatter, add formatting rule * add Formatter action --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 7b261ae commit de41cd6

18 files changed

+345
-371
lines changed

.JuliaFormatter.toml

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
style = "blue"
3+
align_assignment = true
4+
align_struct_field = true
5+
align_pair_arrow = true
6+
align_matrix = true
7+
align_conditional = true

.github/workflows/Format.yml

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
name: Format suggestions
2+
3+
on:
4+
pull_request:
5+
6+
concurrency:
7+
# Skip intermediate builds: always.
8+
# Cancel intermediate builds: only if it is a pull request build.
9+
group: ${{ github.workflow }}-${{ github.ref }}
10+
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}
11+
12+
jobs:
13+
format:
14+
runs-on: ubuntu-latest
15+
steps:
16+
- uses: actions/checkout@v4
17+
- uses: julia-actions/setup-julia@v2
18+
with:
19+
version: 1
20+
- run: |
21+
julia -e 'using Pkg; Pkg.add("JuliaFormatter")'
22+
julia -e 'using JuliaFormatter; format("."; verbose=true)'
23+
- uses: reviewdog/action-suggester@v1
24+
with:
25+
tool_name: JuliaFormatter
26+
fail_on_error: true

docs/make.jl

+2-5
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,8 @@ makedocs(;
2020
"Univariate Slice Sampling" => "univariate_slice.md",
2121
"Meta Multivariate Samplers" => "meta_multivariate.md",
2222
"Latent Slice Sampling" => "latent_slice.md",
23-
"Gibbsian Polar Slice Sampling" => "gibbs_polar.md"
23+
"Gibbsian Polar Slice Sampling" => "gibbs_polar.md",
2424
],
2525
)
2626

27-
deploydocs(;
28-
repo="github.com/TuringLang/SliceSampling.jl",
29-
push_preview=true
30-
)
27+
deploydocs(; repo="github.com/TuringLang/SliceSampling.jl", push_preview=true)

ext/SliceSamplingTuringExt.jl

+30-29
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ if isdefined(Base, :get_extension)
66
using Random
77
using SliceSampling
88
using Turing
9-
# using Turing: Turing, Experimental
9+
# using Turing: Turing, Experimental
1010
else
1111
using ..LogDensityProblemsAD
1212
using ..Random
@@ -17,46 +17,47 @@ end
1717

1818
# Required for using the slice samplers as `externalsampler`s in Turing
1919
# begin
20-
Turing.Inference.getparams(
21-
::Turing.DynamicPPL.Model,
22-
sample::SliceSampling.Transition
23-
) = sample.params
20+
function Turing.Inference.getparams(
21+
::Turing.DynamicPPL.Model, sample::SliceSampling.Transition
22+
)
23+
return sample.params
24+
end
2425
# end
2526

2627
# Required for using the slice samplers as `Experimental.Gibbs` samplers in Turing
2728
# begin
28-
Turing.Inference.getparams(
29-
::Turing.DynamicPPL.Model,
30-
state::SliceSampling.UnivariateSliceState
31-
) = state.transition.params
29+
function Turing.Inference.getparams(
30+
::Turing.DynamicPPL.Model, state::SliceSampling.UnivariateSliceState
31+
)
32+
return state.transition.params
33+
end
3234

33-
Turing.Inference.getparams(
34-
::Turing.DynamicPPL.Model,
35-
state::SliceSampling.GibbsState
36-
) = state.transition.params
35+
function Turing.Inference.getparams(
36+
::Turing.DynamicPPL.Model, state::SliceSampling.GibbsState
37+
)
38+
return state.transition.params
39+
end
3740

38-
Turing.Inference.getparams(
39-
::Turing.DynamicPPL.Model,
40-
state::SliceSampling.HitAndRunState
41-
) = state.transition.params
41+
function Turing.Inference.getparams(
42+
::Turing.DynamicPPL.Model, state::SliceSampling.HitAndRunState
43+
)
44+
return state.transition.params
45+
end
4246

43-
Turing.Experimental.gibbs_requires_recompute_logprob(
47+
function Turing.Experimental.gibbs_requires_recompute_logprob(
4448
model_dst,
4549
::Turing.DynamicPPL.Sampler{
46-
<: Turing.Inference.ExternalSampler{
47-
<: SliceSampling.AbstractSliceSampling, A, U
48-
}
50+
<:Turing.Inference.ExternalSampler{<:SliceSampling.AbstractSliceSampling,A,U}
4951
},
5052
sampler_src,
5153
state_dst,
52-
state_src
53-
) where {A,U} = false
54+
state_src,
55+
) where {A,U}
56+
return false
57+
end
5458
# end
5559

56-
function SliceSampling.initial_sample(
57-
rng::Random.AbstractRNG,
58-
::Turing.LogDensityFunction
59-
)
60+
function SliceSampling.initial_sample(rng::Random.AbstractRNG, ℓ::Turing.LogDensityFunction)
6061
model =.model
6162
spl = Turing.SampleFromUniform()
6263
vi = Turing.VarInfo(rng, model, spl)
@@ -67,14 +68,14 @@ function SliceSampling.initial_sample(
6768
if init_attempt_count == 10
6869
@warn "failed to find valid initial parameters in $(init_attempt_count) tries; consider providing explicit initial parameters using the `initial_params` keyword"
6970
end
70-
71+
7172
# NOTE: This will sample in the unconstrained space.
7273
vi = last(DynamicPPL.evaluate!!(model, rng, vi, SampleFromUniform()))
7374
θ = vi[spl]
7475

7576
init_attempt_count += 1
7677
end
77-
θ
78+
return θ
7879
end
7980

8081
end

src/SliceSampling.jl

+26-29
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ Struct containing the results of the transition.
3232
- `lp::Real`: Log-target density of the samples.
3333
- `info::NamedTuple`: Named tuple containing information about the transition.
3434
"""
35-
struct Transition{P, L <: Real, I <: NamedTuple}
35+
struct Transition{P,L<:Real,I<:NamedTuple}
3636
"current state of the slice sampling chain"
3737
params::P
3838

@@ -53,47 +53,44 @@ Return the initial sample for the `model` using the random number generator `rng
5353
- `model`: The target `LogDensityProblem`.
5454
"""
5555
function initial_sample(::Random.AbstractRNG, ::Any)
56-
error(
56+
return error(
5757
"`initial_sample` is not implemented but an initialization wasn't provided. ",
58-
"Consider supplying an initialization to `initial_params`."
58+
"Consider supplying an initialization to `initial_params`.",
5959
)
6060
end
6161

6262
# If target is from `LogDensityProblemsAD`, unwrap target before calling `initial_sample`.
6363
# This is necessary since Turing wraps `DynamicPPL.Model`s when passed to an `externalsampler`.
64-
initial_sample(
65-
rng::Random.AbstractRNG,
66-
wrap::LogDensityProblemsAD.ADGradientWrapper
67-
) = initial_sample(rng, parent(wrap))
64+
function initial_sample(
65+
rng::Random.AbstractRNG, wrap::LogDensityProblemsAD.ADGradientWrapper
66+
)
67+
return initial_sample(rng, parent(wrap))
68+
end
6869

6970
function exceeded_max_prop(max_prop::Int)
70-
error("Exceeded maximum number of proposal $(max_prop), ",
71-
"which indicates an acceptance rate less than $(1/max_prop*100)%. ",
72-
"A quick fix is to increase `max_prop`, ",
73-
"but an acceptance rate that is too low often indicates that there is a problem. ",
74-
"Here are some possible causes:\n",
75-
"- The model might be broken or degenerate (most likely cause).\n",
76-
"- The tunable parameters of the sampler are suboptimal.\n",
77-
"- The initialization is pathologic. (try supplying a (different) `initial_params`)\n",
78-
"- There might be a bug in the sampler. (if this is suspected, file an issue to `SliceSampling`)\n"
79-
)
71+
return error(
72+
"Exceeded maximum number of proposal $(max_prop), ",
73+
"which indicates an acceptance rate less than $(1/max_prop*100)%. ",
74+
"A quick fix is to increase `max_prop`, ",
75+
"but an acceptance rate that is too low often indicates that there is a problem. ",
76+
"Here are some possible causes:\n",
77+
"- The model might be broken or degenerate (most likely cause).\n",
78+
"- The tunable parameters of the sampler are suboptimal.\n",
79+
"- The initialization is pathologic. (try supplying a (different) `initial_params`)\n",
80+
"- There might be a bug in the sampler. (if this is suspected, file an issue to `SliceSampling`)\n",
81+
)
8082
end
8183

8284
## Univariate Slice Sampling Algorithms
8385
export Slice, SliceSteppingOut, SliceDoublingOut
8486

85-
abstract type AbstractUnivariateSliceSampling <: AbstractSliceSampling end
87+
abstract type AbstractUnivariateSliceSampling <: AbstractSliceSampling end
8688

87-
accept_slice_proposal(
88-
::AbstractSliceSampling,
89-
::Any,
90-
::Real,
91-
::Real,
92-
::Real,
93-
::Real,
94-
::Real,
95-
::Real,
96-
) = true
89+
function accept_slice_proposal(
90+
::AbstractSliceSampling, ::Any, ::Real, ::Real, ::Real, ::Real, ::Real, ::Real
91+
)
92+
return true
93+
end
9794

9895
function find_interval end
9996

@@ -103,7 +100,7 @@ include("univariate/steppingout.jl")
103100
include("univariate/doublingout.jl")
104101

105102
## Multivariate slice sampling algorithms
106-
abstract type AbstractMultivariateSliceSampling <: AbstractSliceSampling end
103+
abstract type AbstractMultivariateSliceSampling <: AbstractSliceSampling end
107104

108105
# Meta Multivariate Samplers
109106
export RandPermGibbs, HitAndRun

0 commit comments

Comments
 (0)