Skip to content

Commit edbbbfa

Browse files
wsmosesglwagnersimone-silvestri
authored
Reactant: add make_tracer for grid (#4242)
* Reactant: add make_tracer for grid * tmp * wip * fixup * cleanup * cleanup * attempt fix * fix * Update src/ImmersedBoundaries/immersed_boundary_interface.jl * Update inactive_node.jl * Update src/ImmersedBoundaries/immersed_boundary_interface.jl * Extend _immersed_cell * Update src/ImmersedBoundaries/immersed_reductions.jl Co-authored-by: Simone Silvestri <[email protected]> * Update src/Operators/spacings_and_areas_and_volumes.jl Co-authored-by: Simone Silvestri <[email protected]> --------- Co-authored-by: Gregory L. Wagner <[email protected]> Co-authored-by: Gregory Wagner <[email protected]> Co-authored-by: Simone Silvestri <[email protected]>
1 parent d607af6 commit edbbbfa

File tree

7 files changed

+163
-12
lines changed

7 files changed

+163
-12
lines changed

ext/OceananigansReactantExt/OceananigansReactantExt.jl

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,13 @@ Base.@nospecializeinfer function Reactant.traced_type_inner(
104104
return Oceananigans.Grids.OrthogonalSphericalShellGrid{FT2, TX2, TY2, TZ2, Z2, Map2, CC2, FC2, CF2, FF2, Arch}
105105
end
106106

107+
@inline Reactant.make_tracer(
108+
seen,
109+
@nospecialize(prev::Oceananigans.Grids.OrthogonalSphericalShellGrid),
110+
args...;
111+
kwargs...
112+
) = Reactant.make_tracer_via_immutable_constructor(seen, prev, args...; kwargs...)
113+
107114
# https://github.com/CliMA/Oceananigans.jl/blob/d9b3b142d8252e8e11382d1b3118ac2a092b38a2/src/ImmersedBoundaries/immersed_boundary_grid.jl#L8
108115
Base.@nospecializeinfer function Reactant.traced_type_inner(
109116
@nospecialize(OA::Type{ImmersedBoundaryGrid{FT, TX, TY, TZ, G, I, M, S, Arch}}),
@@ -124,6 +131,84 @@ Base.@nospecializeinfer function Reactant.traced_type_inner(
124131
return Oceananigans.Grids.ImmersedBoundaryGrid{FT2, TX2, TY2, TZ2, G2, I2, M2, S2, Arch}
125132
end
126133

134+
struct Fix1v2{F,T}
135+
f::F
136+
t::T
137+
end
138+
139+
@inline function (s::Fix1v2)(args...)
140+
s.f(s.t, args...)
141+
end
142+
143+
function evalcond(c, i, j, k)
144+
Oceananigans.AbstractOperations.evaluate_condition(c.condition, i, j, k, c.grid, c)
145+
end
146+
147+
@inline function Reactant.TracedUtils.broadcast_to_size(c::Oceananigans.AbstractOperations.ConditionalOperation, rsize)
148+
if c == rsize
149+
return Reactant.TracedUtils.materialize_traced_array(c)
150+
end
151+
return c
152+
end
153+
154+
@inline function Reactant.TracedUtils.materialize_traced_array(c::Oceananigans.AbstractOperations.ConditionalOperation)
155+
N = ndims(c)
156+
axes2 = ntuple(Val(N)) do i
157+
reshape(Base.OneTo(size(c, i)), (ntuple(Val(N)) do j
158+
if i == j
159+
size(c, i)
160+
else
161+
1
162+
end
163+
end)...)
164+
end
165+
166+
tracedidxs = axes(c)
167+
tracedidxs = axes2
168+
169+
conds = Reactant.TracedUtils.materialize_traced_array(Reactant.call_with_reactant(Oceananigans.AbstractOperations.evaluate_condition, c.condition, tracedidxs..., c.grid, c))
170+
171+
172+
tvals = Reactant.Ops.fill(zero(Reactant.unwrapped_eltype(Base.eltype(c))), size(c))
173+
174+
gf = Reactant.call_with_reactant(getindex, c.operand, axes2...)
175+
Reactant.TracedRArrayOverrides._copyto!(tvals, Base.broadcasted(c.func, gf))
176+
177+
return Reactant.Ops.select(
178+
conds,
179+
tvals,
180+
Reactant.TracedUtils.broadcast_to_size(c.mask, size(c))
181+
)
182+
end
183+
184+
function evalkern(kern, i, j, k)
185+
kern.kernel_function(i, j, k, kern.grid, kern.arguments...)
186+
end
187+
188+
@inline function Reactant.TracedUtils.materialize_traced_array(c::Oceananigans.AbstractOperations.KernelFunctionOperation)
189+
N = ndims(c)
190+
axes2 = ntuple(Val(N)) do i
191+
reshape(Base.OneTo(size(c, i)), (ntuple(Val(N)) do j
192+
if i == j
193+
size(c, i)
194+
else
195+
1
196+
end
197+
end)...)
198+
end
199+
200+
tvals = Reactant.Ops.fill(Reactant.unwrapped_eltype(Base.eltype(c)), size(c))
201+
Reactant.TracedRArrayOverrides._copyto!(tvals, Base.broadcasted(Fix1v2(evalkern, c), axes2...))
202+
return tvals
203+
end
204+
205+
@inline function Reactant.TracedUtils.broadcast_to_size(c::Oceananigans.AbstractOperations.KernelFunctionOperation, rsize)
206+
if c == rsize
207+
return Reactant.TracedUtils.materialize_traced_array(c)
208+
end
209+
return c
210+
end
211+
127212
# These are additional modules that may need to be Reactantified in the future:
128213
#
129214
# include("Utils.jl")

src/Grids/inactive_node.jl

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,25 @@
22
const c = Center()
33
const f = Face()
44

5-
function build_condition(Topo, side, dim)
5+
function build_condition(Topo, side, dim, array::Bool)
66
if Topo == :Bounded
7-
return :(($side < 1) | ($side > grid.$dim))
7+
if array
8+
return :(($side .< 1) .| ($side .> grid.$dim))
9+
else
10+
return :(($side < 1) | ($side > grid.$dim))
11+
end
812
elseif Topo == :LeftConnected
9-
return :(($side > grid.$dim))
13+
if array
14+
return :(($side .> grid.$dim))
15+
else
16+
return :(($side > grid.$dim))
17+
end
1018
else # RightConnected
11-
return :(($side < 1))
19+
if array
20+
return :(($side .< 1))
21+
else
22+
return :(($side < 1))
23+
end
1224
end
1325
end
1426

@@ -39,43 +51,59 @@ Topos = (:Bounded, :LeftConnected, :RightConnected)
3951

4052
for PrimaryTopo in Topos
4153

42-
xcondition = build_condition(PrimaryTopo, :i, :Nx)
43-
ycondition = build_condition(PrimaryTopo, :j, :Ny)
44-
zcondition = build_condition(PrimaryTopo, :k, :Nz)
54+
xcondition = build_condition(PrimaryTopo, :i, :Nx, false)
55+
ycondition = build_condition(PrimaryTopo, :j, :Ny, false)
56+
zcondition = build_condition(PrimaryTopo, :k, :Nz, false)
57+
58+
xcondition_ar = build_condition(PrimaryTopo, :i, :Nx, true)
59+
ycondition_ar = build_condition(PrimaryTopo, :j, :Ny, true)
60+
zcondition_ar = build_condition(PrimaryTopo, :k, :Nz, true)
4561

4662
@eval begin
4763
XBoundedGrid = AbstractGrid{<:Any, <:$PrimaryTopo}
4864
YBoundedGrid = AbstractGrid{<:Any, <:Any, <:$PrimaryTopo}
4965
ZBoundedGrid = AbstractGrid{<:Any, <:Any, <:Any, <:$PrimaryTopo}
5066

5167
@inline inactive_cell(i, j, k, grid::XBoundedGrid) = $xcondition
68+
@inline inactive_cell(i::AbstractArray, j::AbstractArray, k::AbstractArray, grid::XBoundedGrid) = $xcondition_ar
5269
@inline inactive_cell(i, j, k, grid::YBoundedGrid) = $ycondition
70+
@inline inactive_cell(i::AbstractArray, j::AbstractArray, k::AbstractArray, grid::YBoundedGrid) = $ycondition_ar
5371
@inline inactive_cell(i, j, k, grid::ZBoundedGrid) = $zcondition
72+
@inline inactive_cell(i::AbstractArray, j::AbstractArray, k::AbstractArray, grid::ZBoundedGrid) = $zcondition_ar
5473
end
5574

5675
for SecondaryTopo in Topos
5776

58-
xycondition = :( $xcondition | $(build_condition(SecondaryTopo, :j, :Ny)))
59-
xzcondition = :( $xcondition | $(build_condition(SecondaryTopo, :k, :Nz)))
60-
yzcondition = :( $ycondition | $(build_condition(SecondaryTopo, :k, :Nz)))
77+
xycondition = :( $xcondition | $(build_condition(SecondaryTopo, :j, :Ny, false)))
78+
xzcondition = :( $xcondition | $(build_condition(SecondaryTopo, :k, :Nz, false)))
79+
yzcondition = :( $ycondition | $(build_condition(SecondaryTopo, :k, :Nz, false)))
80+
81+
xycondition_ar = :( $xcondition_ar .| $(build_condition(SecondaryTopo, :j, :Ny, true)))
82+
xzcondition_ar = :( $xcondition_ar .| $(build_condition(SecondaryTopo, :k, :Nz, true)))
83+
yzcondition_ar = :( $ycondition_ar .| $(build_condition(SecondaryTopo, :k, :Nz, true)))
6184

6285
@eval begin
6386
XYBoundedGrid = AbstractGrid{<:Any, <:$PrimaryTopo, <:$SecondaryTopo}
6487
XZBoundedGrid = AbstractGrid{<:Any, <:$PrimaryTopo, <:Any, <:$SecondaryTopo}
6588
YZBoundedGrid = AbstractGrid{<:Any, <:Any, <:$PrimaryTopo, <:$SecondaryTopo}
6689

6790
@inline inactive_cell(i, j, k, grid::XYBoundedGrid) = $xycondition
91+
@inline inactive_cell(i::AbstractArray, j::AbstractArray, k::AbstractArray, grid::XYBoundedGrid) = $xycondition_ar
6892
@inline inactive_cell(i, j, k, grid::XZBoundedGrid) = $xzcondition
93+
@inline inactive_cell(i::AbstractArray, j::AbstractArray, k::AbstractArray, grid::XZBoundedGrid) = $xzcondition_ar
6994
@inline inactive_cell(i, j, k, grid::YZBoundedGrid) = $yzcondition
95+
@inline inactive_cell(i::AbstractArray, j::AbstractArray, k::AbstractArray, grid::YZBoundedGrid) = $yzcondition_ar
7096
end
7197

7298
for TertiaryTopo in Topos
73-
xyzcondition = :( $xycondition | $(build_condition(TertiaryTopo, :k, :Nz)))
99+
xyzcondition = :( $xycondition | $(build_condition(TertiaryTopo, :k, :Nz, false)))
100+
xyzcondition_ar = :( $xyzcondition .| $(build_condition(TertiaryTopo, :k, :Nz, true)))
74101

75102
@eval begin
76103
XYZBoundedGrid = AbstractGrid{<:Any, <:$PrimaryTopo, <:$SecondaryTopo, <:$TertiaryTopo}
77104

78105
@inline inactive_cell(i, j, k, grid::XYZBoundedGrid) = $xyzcondition
106+
@inline inactive_cell(i::AbstractArray, j::AbstractArray, k::AbstractArray, grid::XYZBoundedGrid) = $xyzcondition_ar
79107
end
80108
end
81109
end

src/Grids/vertical_discretization.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ on_architecture(arch, coord::MutableVerticalDiscretization) =
115115
AUG = AbstractUnderlyingGrid
116116

117117
@inline rnode(i, j, k, grid, ℓx, ℓy, ℓz) = rnode(k, grid, ℓz)
118+
@inline rnode(i::AbstractArray, j::AbstractArray, k, grid, ℓx, ℓy, ℓz) = permutedims(Base.stack(collect(Base.stack(collect(rnode(k, grid, ℓz) for _ in j)) for _ in i)), (3, 2, 1))
118119
@inline rnode(k, grid, ::Center) = getnode(grid.z.cᵃᵃᶜ, k)
119120
@inline rnode(k, grid, ::Face) = getnode(grid.z.cᵃᵃᶠ, k)
120121

src/ImmersedBoundaries/grid_fitted_bottom.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,15 @@ end
129129
return z zb
130130
end
131131

132+
@inline function _immersed_cell(i, j, k::AbstractArray, underlying_grid, ib::GridFittedBottom)
133+
# We use `rnode` for the `immersed_cell` because we do not want to have
134+
# wetting or drying that could happen for a moving grid if we use znode
135+
z = rnode(i, j, k, underlying_grid, c, c, c)
136+
zb = @inbounds ib.bottom_height[i, j, 1]
137+
zb = Base.stack(collect(zb for _ in k))
138+
return z .≤ zb
139+
end
140+
132141
#####
133142
##### Static column depth
134143
#####

src/ImmersedBoundaries/immersed_boundary_interface.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,17 @@ As well as
4747
* `inactive_node(i-1, 1, 1, grid, f, c, c) = true`
4848
"""
4949
@inline inactive_cell(i, j, k, ibg::IBG) = immersed_cell(i, j, k, ibg) | inactive_cell(i, j, k, ibg.underlying_grid)
50+
@inline inactive_cell(i::AbstractArray, j::AbstractArray, k::AbstractArray, ibg::IBG) = immersed_cell(i, j, k, ibg) .| inactive_cell(i, j, k, ibg.underlying_grid)
5051

5152
# Isolate periphery of the immersed boundary
5253
@inline immersed_peripheral_node(i, j, k, ibg::IBG, LX, LY, LZ) = peripheral_node(i, j, k, ibg, LX, LY, LZ) &
5354
!peripheral_node(i, j, k, ibg.underlying_grid, LX, LY, LZ)
5455

55-
@inline immersed_inactive_node(i, j, k, ibg::IBG, LX, LY, LZ) = inactive_node(i, j, k, ibg, LX, LY, LZ) &
56+
@inline immersed_peripheral_node(i::AbstractArray, j::AbstractArray, k::AbstractArray, ibg::IBG, LX, LY, LZ) = peripheral_node(i, j, k, ibg, LX, LY, LZ) .&
57+
Base.broadcast(!, peripheral_node(i, j, k, ibg.underlying_grid, LX, LY, LZ))
58+
59+
@inline immersed_inactive_node(i, j, k, ibg::IBG, LX, LY, LZ) = inactive_node(i, j, k, ibg, LX, LY, LZ) &
5660
!inactive_node(i, j, k, ibg.underlying_grid, LX, LY, LZ)
61+
62+
@inline immersed_inactive_node(i::AbstractArray, j::AbstractArray, k::AbstractArray, ibg::IBG, LX, LY, LZ) = inactive_node(i, j, k, ibg, LX, LY, LZ) .&
63+
Base.broadcast(!, inactive_node(i, j, k, ibg.underlying_grid, LX, LY, LZ))

src/ImmersedBoundaries/immersed_reductions.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@ end
3737
return !immersed & evaluate_condition(condition.func, i, j, k, ibg, args...)
3838
end
3939

40+
@inline function evaluate_condition(condition::NotImmersed, i::AbstractArray, j::AbstractArray, k::AbstractArray, ibg, co::ConditionalOperation, args...)
41+
ℓx, ℓy, ℓz = map(instantiate, location(co))
42+
immersed = immersed_peripheral_node(i, j, k, ibg, ℓx, ℓy, ℓz) .| inactive_node(i, j, k, ibg, ℓx, ℓy, ℓz)
43+
return Base.broadcast(!, immersed) .& evaluate_condition(condition.func, i, j, k, ibg, args...)
44+
end
45+
4046
#####
4147
##### Reduction operations on Reduced Fields test the immersed condition on the entirety of the immersed direction
4248
#####

src/Operators/spacings_and_areas_and_volumes.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,14 +228,29 @@ end
228228
### Linear spacings
229229

230230
@inline Δxᶜᶜᵃ(i, j, k, grid::OSSG) = @inbounds grid.Δxᶜᶜᵃ[i, j]
231+
232+
@inline Δxᶜᶜᵃ(i::AbstractArray, j::AbstractArray, k::AbstractArray, grid::OSSG) = Base.stack(collect(Δxᶜᶜᵃ(i, j, 1, grid) for _ in k))
233+
231234
@inline Δxᶠᶜᵃ(i, j, k, grid::OSSG) = @inbounds grid.Δxᶠᶜᵃ[i, j]
235+
@inline Δxᶠᶜᵃ(i::AbstractArray, j::AbstractArray, k::AbstractArray, grid::OSSG) = Base.stack(collect(Δxᶠᶜᵃ(i, j, 1, grid) for _ in k))
236+
232237
@inline Δxᶜᶠᵃ(i, j, k, grid::OSSG) = @inbounds grid.Δxᶜᶠᵃ[i, j]
238+
@inline Δxᶜᶠᵃ(i::AbstractArray, j::AbstractArray, k::AbstractArray, grid::OSSG) = Base.stack(collect(Δxᶜᶠᵃ(i, j, 1, grid) for _ in k))
239+
233240
@inline Δxᶠᶠᵃ(i, j, k, grid::OSSG) = @inbounds grid.Δxᶠᶠᵃ[i, j]
241+
@inline Δxᶠᶠᵃ(i::AbstractArray, j::AbstractArray, k::AbstractArray, grid::OSSG) = Base.stack(collect(Δxᶠᶠᵃ(i, j, 1, grid) for _ in k))
234242

235243
@inline Δyᶜᶜᵃ(i, j, k, grid::OSSG) = @inbounds grid.Δyᶜᶜᵃ[i, j]
244+
@inline Δyᶜᶜᵃ(i::AbstractArray, j::AbstractArray, k::AbstractArray, grid::OSSG) = Base.stack(collect(Δyᶜᶜᵃ(i, j, 1, grid) for _ in k))
245+
236246
@inline Δyᶠᶜᵃ(i, j, k, grid::OSSG) = @inbounds grid.Δyᶠᶜᵃ[i, j]
247+
@inline Δyᶠᶜᵃ(i::AbstractArray, j::AbstractArray, k::AbstractArray, grid::OSSG) = Base.stack(collect(Δyᶠᶜᵃ(i, j, 1, grid) for _ in k))
248+
237249
@inline Δyᶜᶠᵃ(i, j, k, grid::OSSG) = @inbounds grid.Δyᶜᶠᵃ[i, j]
250+
@inline Δyᶜᶠᵃ(i::AbstractArray, j::AbstractArray, k::AbstractArray, grid::OSSG) = Base.stack(collect(Δyᶜᶠᵃ(i, j, 1, grid) for _ in k))
251+
238252
@inline Δyᶠᶠᵃ(i, j, k, grid::OSSG) = @inbounds grid.Δyᶠᶠᵃ[i, j]
253+
@inline Δyᶠᶠᵃ(i::AbstractArray, j::AbstractArray, k::AbstractArray, grid::OSSG) = Base.stack(collect(Δyᶠᶠᵃ(i, j, 1, grid) for _ in k))
239254

240255
#####
241256
#####

0 commit comments

Comments
 (0)