Skip to content

Reactant: add make_tracer for grid #4242

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Mar 22, 2025
Merged
85 changes: 85 additions & 0 deletions ext/OceananigansReactantExt/OceananigansReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,13 @@ Base.@nospecializeinfer function Reactant.traced_type_inner(
return Oceananigans.Grids.OrthogonalSphericalShellGrid{FT2, TX2, TY2, TZ2, Z2, Map2, CC2, FC2, CF2, FF2, Arch}
end

@inline Reactant.make_tracer(
seen,
@nospecialize(prev::Oceananigans.Grids.OrthogonalSphericalShellGrid),
args...;
kwargs...
) = Reactant.make_tracer_via_immutable_constructor(seen, prev, args...; kwargs...)

# https://github.com/CliMA/Oceananigans.jl/blob/d9b3b142d8252e8e11382d1b3118ac2a092b38a2/src/ImmersedBoundaries/immersed_boundary_grid.jl#L8
Base.@nospecializeinfer function Reactant.traced_type_inner(
@nospecialize(OA::Type{ImmersedBoundaryGrid{FT, TX, TY, TZ, G, I, M, S, Arch}}),
Expand All @@ -124,6 +131,84 @@ Base.@nospecializeinfer function Reactant.traced_type_inner(
return Oceananigans.Grids.ImmersedBoundaryGrid{FT2, TX2, TY2, TZ2, G2, I2, M2, S2, Arch}
end

struct Fix1v2{F,T}
f::F
t::T
end

@inline function (s::Fix1v2)(args...)
s.f(s.t, args...)
end

function evalcond(c, i, j, k)
Oceananigans.AbstractOperations.evaluate_condition(c.condition, i, j, k, c.grid, c)
end

@inline function Reactant.TracedUtils.broadcast_to_size(c::Oceananigans.AbstractOperations.ConditionalOperation, rsize)
if c == rsize
return Reactant.TracedUtils.materialize_traced_array(c)
end
return c
end

@inline function Reactant.TracedUtils.materialize_traced_array(c::Oceananigans.AbstractOperations.ConditionalOperation)
N = ndims(c)
axes2 = ntuple(Val(N)) do i
reshape(Base.OneTo(size(c, i)), (ntuple(Val(N)) do j
if i == j
size(c, i)
else
1
end
end)...)
end

tracedidxs = axes(c)
tracedidxs = axes2

conds = Reactant.TracedUtils.materialize_traced_array(Reactant.call_with_reactant(Oceananigans.AbstractOperations.evaluate_condition, c.condition, tracedidxs..., c.grid, c))


tvals = Reactant.Ops.fill(zero(Reactant.unwrapped_eltype(Base.eltype(c))), size(c))

gf = Reactant.call_with_reactant(getindex, c.operand, axes2...)
Reactant.TracedRArrayOverrides._copyto!(tvals, Base.broadcasted(c.func, gf))

return Reactant.Ops.select(
conds,
tvals,
Reactant.TracedUtils.broadcast_to_size(c.mask, size(c))
)
end

function evalkern(kern, i, j, k)
kern.kernel_function(i, j, k, kern.grid, kern.arguments...)
end

@inline function Reactant.TracedUtils.materialize_traced_array(c::Oceananigans.AbstractOperations.KernelFunctionOperation)
N = ndims(c)
axes2 = ntuple(Val(N)) do i
reshape(Base.OneTo(size(c, i)), (ntuple(Val(N)) do j
if i == j
size(c, i)
else
1
end
end)...)
end

tvals = Reactant.Ops.fill(Reactant.unwrapped_eltype(Base.eltype(c)), size(c))
Reactant.TracedRArrayOverrides._copyto!(tvals, Base.broadcasted(Fix1v2(evalkern, c), axes2...))
return tvals
end

@inline function Reactant.TracedUtils.broadcast_to_size(c::Oceananigans.AbstractOperations.KernelFunctionOperation, rsize)
if c == rsize
return Reactant.TracedUtils.materialize_traced_array(c)
end
return c
end

# These are additional modules that may need to be Reactantified in the future:
#
# include("Utils.jl")
Expand Down
50 changes: 39 additions & 11 deletions src/Grids/inactive_node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,25 @@
const c = Center()
const f = Face()

function build_condition(Topo, side, dim)
function build_condition(Topo, side, dim, array::Bool)
if Topo == :Bounded
return :(($side < 1) | ($side > grid.$dim))
if array
return :(($side .< 1) .| ($side .> grid.$dim))
else
return :(($side < 1) | ($side > grid.$dim))
end
elseif Topo == :LeftConnected
return :(($side > grid.$dim))
if array
return :(($side .> grid.$dim))
else
return :(($side > grid.$dim))
end
else # RightConnected
return :(($side < 1))
if array
return :(($side .< 1))
else
return :(($side < 1))
end
end
end

Expand Down Expand Up @@ -39,43 +51,59 @@ Topos = (:Bounded, :LeftConnected, :RightConnected)

for PrimaryTopo in Topos

xcondition = build_condition(PrimaryTopo, :i, :Nx)
ycondition = build_condition(PrimaryTopo, :j, :Ny)
zcondition = build_condition(PrimaryTopo, :k, :Nz)
xcondition = build_condition(PrimaryTopo, :i, :Nx, false)
ycondition = build_condition(PrimaryTopo, :j, :Ny, false)
zcondition = build_condition(PrimaryTopo, :k, :Nz, false)

xcondition_ar = build_condition(PrimaryTopo, :i, :Nx, true)
ycondition_ar = build_condition(PrimaryTopo, :j, :Ny, true)
zcondition_ar = build_condition(PrimaryTopo, :k, :Nz, true)

@eval begin
XBoundedGrid = AbstractGrid{<:Any, <:$PrimaryTopo}
YBoundedGrid = AbstractGrid{<:Any, <:Any, <:$PrimaryTopo}
ZBoundedGrid = AbstractGrid{<:Any, <:Any, <:Any, <:$PrimaryTopo}

@inline inactive_cell(i, j, k, grid::XBoundedGrid) = $xcondition
@inline inactive_cell(i::AbstractArray, j::AbstractArray, k::AbstractArray, grid::XBoundedGrid) = $xcondition_ar
@inline inactive_cell(i, j, k, grid::YBoundedGrid) = $ycondition
@inline inactive_cell(i::AbstractArray, j::AbstractArray, k::AbstractArray, grid::YBoundedGrid) = $ycondition_ar
@inline inactive_cell(i, j, k, grid::ZBoundedGrid) = $zcondition
@inline inactive_cell(i::AbstractArray, j::AbstractArray, k::AbstractArray, grid::ZBoundedGrid) = $zcondition_ar
end

for SecondaryTopo in Topos

xycondition = :( $xcondition | $(build_condition(SecondaryTopo, :j, :Ny)))
xzcondition = :( $xcondition | $(build_condition(SecondaryTopo, :k, :Nz)))
yzcondition = :( $ycondition | $(build_condition(SecondaryTopo, :k, :Nz)))
xycondition = :( $xcondition | $(build_condition(SecondaryTopo, :j, :Ny, false)))
xzcondition = :( $xcondition | $(build_condition(SecondaryTopo, :k, :Nz, false)))
yzcondition = :( $ycondition | $(build_condition(SecondaryTopo, :k, :Nz, false)))

xycondition_ar = :( $xcondition_ar .| $(build_condition(SecondaryTopo, :j, :Ny, true)))
xzcondition_ar = :( $xcondition_ar .| $(build_condition(SecondaryTopo, :k, :Nz, true)))
yzcondition_ar = :( $ycondition_ar .| $(build_condition(SecondaryTopo, :k, :Nz, true)))

@eval begin
XYBoundedGrid = AbstractGrid{<:Any, <:$PrimaryTopo, <:$SecondaryTopo}
XZBoundedGrid = AbstractGrid{<:Any, <:$PrimaryTopo, <:Any, <:$SecondaryTopo}
YZBoundedGrid = AbstractGrid{<:Any, <:Any, <:$PrimaryTopo, <:$SecondaryTopo}

@inline inactive_cell(i, j, k, grid::XYBoundedGrid) = $xycondition
@inline inactive_cell(i::AbstractArray, j::AbstractArray, k::AbstractArray, grid::XYBoundedGrid) = $xycondition_ar
@inline inactive_cell(i, j, k, grid::XZBoundedGrid) = $xzcondition
@inline inactive_cell(i::AbstractArray, j::AbstractArray, k::AbstractArray, grid::XZBoundedGrid) = $xzcondition_ar
@inline inactive_cell(i, j, k, grid::YZBoundedGrid) = $yzcondition
@inline inactive_cell(i::AbstractArray, j::AbstractArray, k::AbstractArray, grid::YZBoundedGrid) = $yzcondition_ar
end

for TertiaryTopo in Topos
xyzcondition = :( $xycondition | $(build_condition(TertiaryTopo, :k, :Nz)))
xyzcondition = :( $xycondition | $(build_condition(TertiaryTopo, :k, :Nz, false)))
xyzcondition_ar = :( $xyzcondition .| $(build_condition(TertiaryTopo, :k, :Nz, true)))

@eval begin
XYZBoundedGrid = AbstractGrid{<:Any, <:$PrimaryTopo, <:$SecondaryTopo, <:$TertiaryTopo}

@inline inactive_cell(i, j, k, grid::XYZBoundedGrid) = $xyzcondition
@inline inactive_cell(i::AbstractArray, j::AbstractArray, k::AbstractArray, grid::XYZBoundedGrid) = $xyzcondition_ar
end
end
end
Expand Down
1 change: 1 addition & 0 deletions src/Grids/vertical_discretization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ on_architecture(arch, coord::MutableVerticalDiscretization) =
AUG = AbstractUnderlyingGrid

@inline rnode(i, j, k, grid, ℓx, ℓy, ℓz) = rnode(k, grid, ℓz)
@inline rnode(i::AbstractArray, j::AbstractArray, k, grid, ℓx, ℓy, ℓz) = permutedims(Base.stack(collect(Base.stack(collect(rnode(k, grid, ℓz) for _ in j)) for _ in i)), (3, 2, 1))
@inline rnode(k, grid, ::Center) = getnode(grid.z.cᵃᵃᶜ, k)
@inline rnode(k, grid, ::Face) = getnode(grid.z.cᵃᵃᶠ, k)

Expand Down
9 changes: 9 additions & 0 deletions src/ImmersedBoundaries/grid_fitted_bottom.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,15 @@ end
return z ≤ zb
end

@inline function _immersed_cell(i, j, k::AbstractArray, underlying_grid, ib::GridFittedBottom)
# We use `rnode` for the `immersed_cell` because we do not want to have
# wetting or drying that could happen for a moving grid if we use znode
z = rnode(i, j, k, underlying_grid, c, c, c)
zb = @inbounds ib.bottom_height[i, j, 1]
zb = Base.stack(collect(zb for _ in k))
return z .≤ zb
end

#####
##### Static column depth
#####
Expand Down
9 changes: 8 additions & 1 deletion src/ImmersedBoundaries/immersed_boundary_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,17 @@ As well as
* `inactive_node(i-1, 1, 1, grid, f, c, c) = true`
"""
@inline inactive_cell(i, j, k, ibg::IBG) = immersed_cell(i, j, k, ibg) | inactive_cell(i, j, k, ibg.underlying_grid)
@inline inactive_cell(i::AbstractArray, j::AbstractArray, k::AbstractArray, ibg::IBG) = immersed_cell(i, j, k, ibg) .| inactive_cell(i, j, k, ibg.underlying_grid)

# Isolate periphery of the immersed boundary
@inline immersed_peripheral_node(i, j, k, ibg::IBG, LX, LY, LZ) = peripheral_node(i, j, k, ibg, LX, LY, LZ) &
!peripheral_node(i, j, k, ibg.underlying_grid, LX, LY, LZ)

@inline immersed_inactive_node(i, j, k, ibg::IBG, LX, LY, LZ) = inactive_node(i, j, k, ibg, LX, LY, LZ) &
@inline immersed_peripheral_node(i::AbstractArray, j::AbstractArray, k::AbstractArray, ibg::IBG, LX, LY, LZ) = peripheral_node(i, j, k, ibg, LX, LY, LZ) .&
Base.broadcast(!, peripheral_node(i, j, k, ibg.underlying_grid, LX, LY, LZ))

@inline immersed_inactive_node(i, j, k, ibg::IBG, LX, LY, LZ) = inactive_node(i, j, k, ibg, LX, LY, LZ) &
!inactive_node(i, j, k, ibg.underlying_grid, LX, LY, LZ)

@inline immersed_inactive_node(i::AbstractArray, j::AbstractArray, k::AbstractArray, ibg::IBG, LX, LY, LZ) = inactive_node(i, j, k, ibg, LX, LY, LZ) .&
Base.broadcast(!, inactive_node(i, j, k, ibg.underlying_grid, LX, LY, LZ))
6 changes: 6 additions & 0 deletions src/ImmersedBoundaries/immersed_reductions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ end
return !immersed & evaluate_condition(condition.func, i, j, k, ibg, args...)
end

@inline function evaluate_condition(condition::NotImmersed, i::AbstractArray, j::AbstractArray, k::AbstractArray, ibg, co::ConditionalOperation, args...)
ℓx, ℓy, ℓz = map(instantiate, location(co))
immersed = immersed_peripheral_node(i, j, k, ibg, ℓx, ℓy, ℓz) .| inactive_node(i, j, k, ibg, ℓx, ℓy, ℓz)
return Base.broadcast(!, immersed) .& evaluate_condition(condition.func, i, j, k, ibg, args...)
end

#####
##### Reduction operations on Reduced Fields test the immersed condition on the entirety of the immersed direction
#####
Expand Down
15 changes: 15 additions & 0 deletions src/Operators/spacings_and_areas_and_volumes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,14 +228,29 @@ end
### Linear spacings

@inline Δxᶜᶜᵃ(i, j, k, grid::OSSG) = @inbounds grid.Δxᶜᶜᵃ[i, j]

@inline Δxᶜᶜᵃ(i::AbstractArray, j::AbstractArray, k::AbstractArray, grid::OSSG) = Base.stack(collect(Δxᶜᶜᵃ(i, j, 1, grid) for _ in k))

@inline Δxᶠᶜᵃ(i, j, k, grid::OSSG) = @inbounds grid.Δxᶠᶜᵃ[i, j]
@inline Δxᶠᶜᵃ(i::AbstractArray, j::AbstractArray, k::AbstractArray, grid::OSSG) = Base.stack(collect(Δxᶠᶜᵃ(i, j, 1, grid) for _ in k))

@inline Δxᶜᶠᵃ(i, j, k, grid::OSSG) = @inbounds grid.Δxᶜᶠᵃ[i, j]
@inline Δxᶜᶠᵃ(i::AbstractArray, j::AbstractArray, k::AbstractArray, grid::OSSG) = Base.stack(collect(Δxᶜᶠᵃ(i, j, 1, grid) for _ in k))

@inline Δxᶠᶠᵃ(i, j, k, grid::OSSG) = @inbounds grid.Δxᶠᶠᵃ[i, j]
@inline Δxᶠᶠᵃ(i::AbstractArray, j::AbstractArray, k::AbstractArray, grid::OSSG) = Base.stack(collect(Δxᶠᶠᵃ(i, j, 1, grid) for _ in k))

@inline Δyᶜᶜᵃ(i, j, k, grid::OSSG) = @inbounds grid.Δyᶜᶜᵃ[i, j]
@inline Δyᶜᶜᵃ(i::AbstractArray, j::AbstractArray, k::AbstractArray, grid::OSSG) = Base.stack(collect(Δyᶜᶜᵃ(i, j, 1, grid) for _ in k))

@inline Δyᶠᶜᵃ(i, j, k, grid::OSSG) = @inbounds grid.Δyᶠᶜᵃ[i, j]
@inline Δyᶠᶜᵃ(i::AbstractArray, j::AbstractArray, k::AbstractArray, grid::OSSG) = Base.stack(collect(Δyᶠᶜᵃ(i, j, 1, grid) for _ in k))

@inline Δyᶜᶠᵃ(i, j, k, grid::OSSG) = @inbounds grid.Δyᶜᶠᵃ[i, j]
@inline Δyᶜᶠᵃ(i::AbstractArray, j::AbstractArray, k::AbstractArray, grid::OSSG) = Base.stack(collect(Δyᶜᶠᵃ(i, j, 1, grid) for _ in k))

@inline Δyᶠᶠᵃ(i, j, k, grid::OSSG) = @inbounds grid.Δyᶠᶠᵃ[i, j]
@inline Δyᶠᶠᵃ(i::AbstractArray, j::AbstractArray, k::AbstractArray, grid::OSSG) = Base.stack(collect(Δyᶠᶠᵃ(i, j, 1, grid) for _ in k))

#####
#####
Expand Down