Skip to content

Commit 2b0bed5

Browse files
Merge pull request #178 from ChrisRackauckas/fix-formatting
Apply JuliaFormatter to fix code formatting
2 parents 877afc6 + 9557ba9 commit 2b0bed5

File tree

5 files changed

+43
-46
lines changed

5 files changed

+43
-46
lines changed

docs/src/tutorials/basic_mnist_deq.md

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ Construct the Lux Neural Network containing a DEQ layer.
4848

4949
```@example basic_mnist_deq
5050
function construct_model(solver; model_type::Symbol=:deq)
51-
down = Chain(Conv((3, 3), 1 => 64, gelu; stride=1), GroupNorm(64, 64),
52-
Conv((4, 4), 64 => 64; stride=2, pad=1))
51+
down = Chain(
52+
Conv((3, 3), 1 => 64, gelu; stride=1), GroupNorm(64, 64), Conv((4, 4), 64 => 64; stride=2, pad=1))
5353
5454
# The input layer of the DEQ
5555
deq_model = Chain(
@@ -72,8 +72,7 @@ function construct_model(solver; model_type::Symbol=:deq)
7272
deq = DeepEquilibriumNetwork(deq_model, solver; init, verbose=false,
7373
linsolve_kwargs=(; maxiters=10), maxiters=10)
7474
75-
classifier = Chain(
76-
GroupNorm(64, 64, relu), GlobalMeanPool(), FlattenLayer(), Dense(64, 10))
75+
classifier = Chain(GroupNorm(64, 64, relu), GlobalMeanPool(), FlattenLayer(), Dense(64, 10))
7776
7877
model = Chain(; down, deq, classifier)
7978
@@ -133,8 +132,9 @@ function train_model(solver, model_type)
133132
@set! tstate.states = Lux.update_state(tstate.states, :fixed_depth, Val(5))
134133
135134
for _ in 1:2, (i, (x, y)) in enumerate(train_dataloader)
136-
_, loss, _, tstate = Training.single_train_step!(
137-
AutoZygote(), loss_function, (x, y), tstate)
135+
136+
_, loss,
137+
_, tstate = Training.single_train_step!(AutoZygote(), loss_function, (x, y), tstate)
138138
if i % 10 == 1
139139
@printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(train_dataloader) loss
140140
end
@@ -147,8 +147,9 @@ function train_model(solver, model_type)
147147
148148
for epoch in 1:3
149149
for (i, (x, y)) in enumerate(train_dataloader)
150-
_, loss, _, tstate = Training.single_train_step!(
151-
AutoZygote(), loss_function, (x, y), tstate)
150+
_, loss,
151+
_,
152+
tstate = Training.single_train_step!(AutoZygote(), loss_function, (x, y), tstate)
152153
if i % 10 == 1
153154
@printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(train_dataloader) loss
154155
end

docs/src/tutorials/reduced_dim_deq.md

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,23 +46,22 @@ function construct_model(solver; model_type::Symbol=:regdeq)
4646
# The input layer of the DEQ
4747
deq_model = Chain(
4848
Parallel(+,
49-
Dense(
50-
128 => 64, tanh; use_bias=false, init_weight=truncated_normal(; std=0.01)), # Reduced dim of `128`
51-
Dense(
52-
512 => 64, tanh; use_bias=false, init_weight=truncated_normal(; std=0.01))), # Original dim of `512`
49+
Dense(128 => 64, tanh; use_bias=false, init_weight=truncated_normal(;
50+
std=0.01)), # Reduced dim of `128`
51+
Dense(512 => 64, tanh; use_bias=false, init_weight=truncated_normal(;
52+
std=0.01))), # Original dim of `512`
5353
Dense(64 => 64, tanh; use_bias=false, init_weight=truncated_normal(; std=0.01)),
5454
Dense(64 => 128; use_bias=false, init_weight=truncated_normal(; std=0.01))) # Return the reduced dim of `128`
5555
5656
if model_type === :skipdeq
57-
init = Dense(
58-
512 => 128, tanh; use_bias=false, init_weight=truncated_normal(; std=0.01))
57+
init = Dense(512 => 128, tanh; use_bias=false, init_weight=truncated_normal(;
58+
std=0.01))
5959
elseif model_type === :regdeq
6060
error(":regdeq is not supported for reduced dim models")
6161
else
6262
# This should preferably done via `ChainRulesCore.@ignore_derivatives`. But here
6363
# we are only using Zygote so this is fine.
64-
init = WrappedFunction(x -> Zygote.@ignore(fill!(
65-
similar(x, 128, size(x, 2)), false)))
64+
init = WrappedFunction(x -> Zygote.@ignore(fill!(similar(x, 128, size(x, 2)), false)))
6665
end
6766
6867
deq = DeepEquilibriumNetwork(deq_model, solver; init, verbose=false,
@@ -128,8 +127,9 @@ function train_model(solver, model_type)
128127
@set! tstate.states = Lux.update_state(tstate.states, :fixed_depth, Val(5))
129128
130129
for _ in 1:2, (i, (x, y)) in enumerate(train_dataloader)
131-
_, loss, _, tstate = Training.single_train_step!(
132-
AutoZygote(), loss_function, (x, y), tstate)
130+
131+
_, loss,
132+
_, tstate = Training.single_train_step!(AutoZygote(), loss_function, (x, y), tstate)
133133
if i % 10 == 1
134134
@printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(train_dataloader) loss
135135
end
@@ -142,8 +142,9 @@ function train_model(solver, model_type)
142142
143143
for epoch in 1:3
144144
for (i, (x, y)) in enumerate(train_dataloader)
145-
_, loss, _, tstate = Training.single_train_step!(
146-
AutoZygote(), loss_function, (x, y), tstate)
145+
_, loss,
146+
_,
147+
tstate = Training.single_train_step!(AutoZygote(), loss_function, (x, y), tstate)
147148
if i % 10 == 1
148149
@printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(train_dataloader) loss
149150
end

src/layers.jl

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,12 @@ end
3939

4040
function Base.show(io::IO, sol::DeepEquilibriumSolution)
4141
println(io, "DeepEquilibriumSolution")
42-
println(io, " * Initial Guess: ",
43-
sprint(print, sol.u0; context=(:compact => true, :limit => true)))
44-
println(io, " * Steady State: ",
45-
sprint(print, sol.z_star; context=(:compact => true, :limit => true)))
46-
println(io, " * Residual: ",
47-
sprint(print, sol.residual; context=(:compact => true, :limit => true)))
42+
println(io, " * Initial Guess: ", sprint(print, sol.u0; context=(
43+
:compact => true, :limit => true)))
44+
println(io, " * Steady State: ", sprint(print, sol.z_star; context=(
45+
:compact => true, :limit => true)))
46+
println(io, " * Residual: ", sprint(print, sol.residual; context=(
47+
:compact => true, :limit => true)))
4848
println(io, " * Jacobian Loss: ",
4949
sprint(print, sol.jacobian_loss; context=(:compact => true, :limit => true)))
5050
print(io, " * NFE: ", sol.nfe)
@@ -171,8 +171,7 @@ function DeepEquilibriumNetwork(
171171
model, solver; init=missing, jacobian_regularization=nothing,
172172
problem_type::Type=SteadyStateProblem{false}, kwargs...)
173173
if init === missing # Regular DEQ
174-
init = WrappedFunction(Base.Fix1(
175-
zeros_init, LuxOps.getproperty(model, Val(:scales))))
174+
init = WrappedFunction(Base.Fix1(zeros_init, LuxOps.getproperty(model, Val(:scales))))
176175
elseif init === nothing # SkipRegDEQ
177176
init = NoOpLayer()
178177
elseif !(init isa AbstractLuxLayer)
@@ -254,8 +253,7 @@ function MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Ma
254253
if post_fuse_layer === nothing
255254
model = MultiScaleInputLayer(Chain(l1, l2), split_idxs, scales)
256255
else
257-
model = MultiScaleInputLayer(
258-
Chain(l1, l2, Parallel(nothing, post_fuse_layer...)), split_idxs, scales)
256+
model = MultiScaleInputLayer(Chain(l1, l2, Parallel(nothing, post_fuse_layer...)), split_idxs, scales)
259257
end
260258

261259
return DeepEquilibriumNetwork(model, solver; kwargs...)
@@ -291,8 +289,7 @@ Same arguments as [`MultiScaleDeepEquilibriumNetwork`](@ref) but sets `problem_t
291289
`ODEProblem{false}`.
292290
"""
293291
function MultiScaleNeuralODE(args...; kwargs...)
294-
return MultiScaleDeepEquilibriumNetwork(
295-
args...; kwargs..., problem_type=ODEProblem{false})
292+
return MultiScaleDeepEquilibriumNetwork(args...; kwargs..., problem_type=ODEProblem{false})
296293
end
297294

298295
## Generate Initial Condition

src/utils.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
@generated function split_and_reshape(
2-
x::AbstractMatrix, ::Val{idxs}, ::Val{shapes}) where {idxs, shapes}
1+
@generated function split_and_reshape(x::AbstractMatrix, ::Val{idxs}, ::Val{shapes}) where {
2+
idxs, shapes}
33
dims = [reshape((idxs[i] + 1):idxs[i + 1], shapes[i]...) for i in 1:(length(idxs) - 1)]
44
varnames = map(_ -> gensym("x_view"), dims)
55
calls = [:($(varnames[i]) = x[$(dims[i]), :]) for i in eachindex(dims)]
@@ -15,8 +15,7 @@ function split_and_reshape(y::AbstractMatrix, x)
1515
szs = [prod(size(xᵢ)[1:(end - 1)]) for xᵢ in x]
1616
counters = vcat(0, cumsum(szs)[1:(end - 1)])
1717
# Make the data contiguous
18-
return map((sz, c, xᵢ) -> copy(reshape(view(y, (c + 1):(c + sz), :), size(xᵢ))),
19-
szs, counters, x)
18+
return map((sz, c, xᵢ) -> copy(reshape(view(y, (c + 1):(c + sz), :), size(xᵢ))), szs, counters, x)
2019
end
2120

2221
flatten(x::AbstractVector) = reshape(x, length(x), 1)

test/layers_tests.jl

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,16 @@ end
3434
jacobian_regularizations = ongpu ? _jacobian_regularizations[1:(end - 1)] :
3535
_jacobian_regularizations
3636

37-
@testset "Solver: $(nameof(typeof(solver))) | Model Type: $(mtype) | Jac. Reg: $(jacobian_regularization)" for solver in SOLVERS,
38-
mtype in model_type,
39-
jacobian_regularization in jacobian_regularizations
37+
@testset "Solver: $(nameof(typeof(solver))) | Model Type: $(mtype) | Jac. Reg: $(jacobian_regularization)" for solver in
38+
SOLVERS,
39+
mtype in model_type, jacobian_regularization in jacobian_regularizations
4040

41-
@testset "x_size: $(x_size)" for (base_model, init_model, x_size) in zip(
42-
base_models, init_models, x_sizes)
41+
@testset "x_size: $(x_size)" for (base_model, init_model, x_size) in
42+
zip(base_models, init_models, x_sizes)
4343
model = if mtype === :deq
4444
DeepEquilibriumNetwork(base_model, solver; jacobian_regularization)
4545
elseif mtype === :skipdeq
46-
SkipDeepEquilibriumNetwork(
47-
base_model, init_model, solver; jacobian_regularization)
46+
SkipDeepEquilibriumNetwork(base_model, init_model, solver; jacobian_regularization)
4847
elseif mtype === :skipregdeq
4948
SkipDeepEquilibriumNetwork(base_model, solver; jacobian_regularization)
5049
end
@@ -112,10 +111,10 @@ end
112111

113112
@testset "$mode" for (mode, aType, dev, ongpu) in MODES
114113
@testset "Solver: $(nameof(typeof(solver)))" for solver in SOLVERS,
115-
mtype in model_type,
116-
jacobian_regularization in jacobian_regularizations
114+
mtype in model_type, jacobian_regularization in jacobian_regularizations
117115

118-
@testset "x_size: $(x_size)" for (main_layer, mapping_layer, init_layer, x_size, scale) in zip(
116+
@testset "x_size: $(x_size)" for (
117+
main_layer, mapping_layer, init_layer, x_size, scale) in zip(
119118
main_layers, mapping_layers, init_layers, x_sizes, scales)
120119
model = if mtype === :deq
121120
MultiScaleDeepEquilibriumNetwork(main_layer, mapping_layer, nothing,

0 commit comments

Comments
 (0)