Skip to content

Issues with backpropagation and adjoint construction #31

@ChrisRackauckas

Description

@ChrisRackauckas

I couldn't get either methods for adjoint generation working over the @parallel stencil. For pure Zygote-based VJP calculations,

@parallel function diffusion3D_step!(T2, T, Ci, lam, dx, dy, dz)
    @inn(T2) = lam*@inn(Ci)*(@d2_xi(T)/dx^2 + @d2_yi(T)/dy^2 + @d2_zi(T)/dz^2);
    return
end

will fail because of the mutation. I was wondering if you were planning to support adjoint rules on @parallel constructors for this. That would be required for GPU usage.

Trying to avoid that issue, using ReverseDiffVJPs had an issue highlighted by the MWE:

const USE_GPU = false
using ParallelStencil, OrdinaryDiffEq
using ParallelStencil.FiniteDifferences3D
@static if USE_GPU
    @init_parallel_stencil(CUDA, Float64, 3);
else
    @init_parallel_stencil(Threads, Float64, 3);
end

@parallel function diffusion3D_step!(T2, T, Ci, lam, dx, dy, dz)
    @inn(T2) = lam*@inn(Ci)*(@d2_xi(T)/dx^2 + @d2_yi(T)/dy^2 + @d2_zi(T)/dz^2);
    return
end

function diffusion3D(lam,alg)
    # Physics
    cp_min     = 1.0;                                        # Minimal heat capacity
    lx, ly, lz = 10.0, 10.0, 10.0;                           # Length of domain in dimensions x, y and z.

    # Numerics
    nx, ny, nz = 16, 16, 16;                              # Number of gridpoints dimensions x, y and z.
    nt         = 100;                                        # Number of time steps
    dx         = lx/(nx-1);                                  # Space step in x-dimension
    dy         = ly/(ny-1);                                  # Space step in y-dimension
    dz         = lz/(nz-1);                                  # Space step in z-dimension

    # Array initializations
    T   = @zeros(nx, ny, nz);
    T2  = @zeros(nx, ny, nz);
    Ci  = @zeros(nx, ny, nz);

    # Initial conditions (heat capacity and temperature with two Gaussian anomalies each)
    Ci .= 1.0./( cp_min .+ Data.Array([5*exp(-(((ix-1)*dx-lx/1.5))^2-(((iy-1)*dy-ly/2))^2-(((iz-1)*dz-lz/1.5))^2) +
                                       5*exp(-(((ix-1)*dx-lx/3.0))^2-(((iy-1)*dy-ly/2))^2-(((iz-1)*dz-lz/1.5))^2) for ix=1:size(T,1), iy=1:size(T,2), iz=1:size(T,3)]) )
    T  .= Data.Array([100*exp(-(((ix-1)*dx-lx/2)/2)^2-(((iy-1)*dy-ly/2)/2)^2-(((iz-1)*dz-lz/3.0)/2)^2) +
                       50*exp(-(((ix-1)*dx-lx/2)/2)^2-(((iy-1)*dy-ly/2)/2)^2-(((iz-1)*dz-lz/1.5)/2)^2) for ix=1:size(T,1), iy=1:size(T,2), iz=1:size(T,3)])
    T2 .= T;                                                 # Assign also T2 to get correct boundary conditions.

    dt = min(dx^2,dy^2,dz^2)*cp_min/8.1;                 # Time step for the 3D Heat diffusion
    function f(du,u,p,t)
        @show t
        @parallel diffusion3D_step!(du, u, Ci, p[1], dx, dy, dz);
    end
    prob = ODEProblem(f, T, (0.0,nt*dt), lam)
    sol = solve(prob, alg, save_everystep = false, save_start = false)
end

sol = diffusion3D([1.0],ROCK2())

using ForwardDiff, Zygote, DiffEqSensitivity
function loss(p)
    sum(diffusion3D(p,ROCK2()))
end
ForwardDiff.gradient(loss,[1.0])
Zygote.gradient(loss,[1.0])
TaskFailedException

    nested task error: BoundsError: attempt to access 22080-element Vector{ReverseDiff.AbstractInstruction} at index [474]
    Stacktrace:
      [1] setindex!
        @ .\array.jl:839 [inlined]
      [2] push!
        @ .\array.jl:930 [inlined]
      [3] record!
        @ C:\Users\accou\.julia\packages\ReverseDiff\E4Tzn\src\tape.jl:10 [inlined]
      [4] ForwardOptimize
        @ C:\Users\accou\.julia\packages\ReverseDiff\E4Tzn\src\macros.jl:105 [inlined]
      [5] -
        @ C:\Users\accou\.julia\packages\ReverseDiff\E4Tzn\src\derivatives\scalars.jl:9 [inlined]
      [6] macro expansion
        @ C:\Users\accou\.julia\packages\ParallelStencil\0VULM\src\parallel.jl:103 [inlined]
      [7] macro expansion
        @ D:\OneDrive\Computer\Desktop\test.jl:11 [inlined]
      [8] macro expansion
        @ C:\Users\accou\.julia\packages\ParallelStencil\0VULM\src\ParallelKernel\parallel.jl:330 [inlined]
      [9] (::var"#567#threadsfor_fun#7"{Array{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 3, Array{Float64, 3}, Array{Float64, 3}}}, 3}, ReverseDiff.TrackedArray{Float64, Float64, 3, Array{Float64, 3}, Array{Float64, 3}}, Array{Float64, 3}, ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}, Float64, Float64, Float64, Tuple{UnitRange{Int64}, UnitRange{Int64}, UnitRange{Int64}}, UnitRange{Int64}})(onethread::Bool)
        @ Main .\threadingconstructs.jl:81
     [10] (::var"#567#threadsfor_fun#7"{Array{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 3, Array{Float64, 3}, Array{Float64, 3}}}, 3}, ReverseDiff.TrackedArray{Float64, Float64, 3, Array{Float64, 3}, Array{Float64, 3}}, Array{Float64, 3}, ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}, Float64, Float64, Float64, Tuple{UnitRange{Int64}, UnitRange{Int64}, UnitRange{Int64}}, UnitRange{Int64}})()
        @ Main .\threadingconstructs.jl:48
wait at task.jl:322 [inlined]
threading_run(func::Function) at threadingconstructs.jl:34
macro expansion at threadingconstructs.jl:93 [inlined]
macro expansion at parallel.jl:328 [inlined]
diffusion3D_step! at shared.jl:77 [inlined]
(::var"#f#12"{Array{Float64, 3}, Float64, Float64, Float64})(du::Array{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 3, Array{Float64, 3}, Array{Float64, 3}}}, 3}, u::ReverseDiff.TrackedArray{Float64, Float64, 3, Array{Float64, 3}, Array{Float64, 3}}, p::ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, t::ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}) at test.jl:42
ODEFunction at scimlfunctions.jl:334 [inlined]
(::DiffEqSensitivity.var"#33#37"{ODEFunction{true, var"#f#12"{Array{Float64, 3}, Float64, Float64, Float64}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}})(u::ReverseDiff.TrackedArray{Float64, Float64, 3, Array{Float64, 3}, Array{Float64, 3}}, p::ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, t::ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}) at derivative_wrappers.jl:345
ReverseDiff.GradientTape(f::Function, input::Tuple{Array{Float64, 3}, Vector{Float64}, Vector{Float64}}, cfg::ReverseDiff.GradientConfig{Tuple{ReverseDiff.TrackedArray{Float64, Float64, 3, Array{Float64, 3}, Array{Float64, 3}}, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}) at tape.jl:207
GradientTape at tape.jl:204 [inlined]
_vecjacobian!(dλ::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, y::Array{Float64, 3}, λ::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, p::Vector{Float64}, t::Float64, S::DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, ReverseDiff.GradientTape{DiffEqSensitivity.var"#93#104"{ODEFunction{true, var"#f#12"{Array{Float64, 3}, Float64, Float64, Float64}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}}, Tuple{ReverseDiff.TrackedArray{Float64, Float64, 3, Array{Float64, 3}, Array{Float64, 3}}, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}, Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}}, Nothing, Noth...

while forward-mode runs, it looks like the tracing required for adjoints causes issues with the threading constructs.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions