Skip to content

scalar indexing of gpu array in Zygote gradient #1016

@CarloLucibello

Description

@CarloLucibello

From Flux's test suite. This is issue FluxML/Zygote.jl#1005

using MLDataDevices, CUDA, cuDNN, Zygote
CUDA.allowscalar(false)

cpu = cpu_device()
gpu = gpu_device()

gradient(x -> cpu(2 .* gpu(x))[1], Float32[1,2,3])[1] # error
gradient(x -> cpu(gpu(x) * gpu(x))[1,2], Float32[1 2 3; 4 5 6; 7 8 9])[1] # error

For the first gradient the error is

ERROR: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore should be avoided.

If you want to allow scalar iteration, use `allowscalar` or `@allowscalar`
to enable scalar iteration globally or for the operations in question.
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] errorscalar(op::String)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:155
  [3] _assertscalar(op::String, behavior::GPUArraysCore.ScalarIndexing)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:128
  [4] assertscalar(op::String)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:116
  [5] getindex
    @ ~/.julia/packages/GPUArrays/qt4ax/src/host/indexing.jl:50 [inlined]
  [6] first
    @ ./abstractarray.jl:452 [inlined]
  [7] dot(x::ChainRules.OneElement{Float32, 1, Tuple{Int64}, Tuple{Base.OneTo{…}}}, y::CuArray{Float32, 1, CUDA.DeviceMemory})
    @ LinearAlgebra ~/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/generic.jl:894
  [8] #568
    @ ~/.julia/packages/ChainRules/vdf7M/src/rulesets/Base/arraymath.jl:108 [inlined]
  [9] unthunk
    @ ~/.julia/packages/ChainRulesCore/6Pucz/src/tangent_types/thunks.jl:205 [inlined]
 [10] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/NRp5C/src/compiler/chainrules.jl:110 [inlined]
 [11] map
    @ ./tuple.jl:357 [inlined]
 [12] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/NRp5C/src/compiler/chainrules.jl:111 [inlined]
 [13] ZBack
    @ ~/.julia/packages/Zygote/NRp5C/src/compiler/chainrules.jl:212 [inlined]
 [14] (::Zygote.var"#3852#back#1233"{Zygote.ZBack{}})(Δ::ChainRules.OneElement{Float32, 1, Tuple{…}, Tuple{…}})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [15] #184
    @ ./REPL[4]:1 [inlined]
 [16] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/compiler/interface2.jl:0
 [17] (::Zygote.var"#78#79"{Zygote.Pullback{Tuple{}, Tuple{}}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/compiler/interface.jl:91
 [18] gradient(f::Function, args::Vector{Float32})
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/compiler/interface.jl:148
 [19] top-level scope
    @ REPL[4]:1
Some type information was truncated. Use `show(err)` to see complete types.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions