Skip to content

cuTENSOR not working with automatic differentiation #167

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
hxjz233 opened this issue Apr 14, 2024 · 5 comments · Fixed by #168
Closed

cuTENSOR not working with automatic differentiation #167

hxjz233 opened this issue Apr 14, 2024 · 5 comments · Fixed by #168
Assignees

Comments

@hxjz233
Copy link

hxjz233 commented Apr 14, 2024

I met difficulties implementing my code for tensor calculations on a GPU, and it basically amounts to the issue of backpropagating through tensor operations. Here is a simplified code.

using TensorOperations
using ChainRulesCore, Zygote
using CUDA, cuTENSOR

function QuadMin(x)
    @cutensor res = x[i,j] * x[i,j]    # for demonstrating some tensor operations with explicit index order
    return res
end

function AD4CuArray()
    initval = ones(3, 3) * 1.0
    f(x) = QuadMin(x)
    g(x) = gradient( f, x )[ 1 ]
    println(g(initval))
    return nothing
end

AD4CuArray()

The given code can run nicely if the target function had @tensor. Should I modify my code or wait for later updates? Or maybe having cuTENSOR working with back-propagation is in principle not possible to implement?

@lkdvos lkdvos self-assigned this Apr 15, 2024
@lkdvos
Copy link
Collaborator

lkdvos commented Apr 17, 2024

Hi hxjz232!

I had a look a this, and it looks like this is indeed a mistake from my end, I am assuming some default argument being filled in somewhere in the rrules, but this does not work as soon as there is a backend specified that is not the default one (this is how @cutensor functions, it implicitly inserts the cuTENSOR backend everywhere).
I think I fixed it and wrote some additional tests to prevent future failure, once the tests pass I should be able to merge this.

On a separate note, I noticed that this uses VectorInterface for some of the implementations, which by default falls back to a broadcasting operation, which is not necessarily what you want to do for CuArrays. I'll write a fix for that, and update you here once I finish it.

In any case, thanks for letting me now that this is broken, I hope to have it fixed asap, as this is definitely something that is wrong on our side of things.

@lkdvos
Copy link
Collaborator

lkdvos commented Apr 18, 2024

Jutho/VectorInterface.jl#14 should also get rid of the warning message for scalar indexing with CuArrays. Feel free to re-open an issue if things still are not working the way you expect!

@hxjz233
Copy link
Author

hxjz233 commented Apr 19, 2024

Hi Ikdvos, just FYI, the given code won't pass and gives (in fact the same as before)

Error Message
  [1] error(s::String)
    @ Base .\error.jl:35
  [2] assertscalar(op::String)
    @ GPUArraysCore D:\Julia\depot\packages\GPUArraysCore\uOYfN\src\GPUArraysCore.jl:103
  [3] getindex
    @ D:\Julia\depot\packages\GPUArrays\dAUOE\src\host\indexing.jl:48 [inlined]
  [4] scalar_getindex
    @ D:\Julia\depot\packages\GPUArrays\dAUOE\src\host\indexing.jl:34 [inlined]
  [5] _getindex
    @ D:\Julia\depot\packages\GPUArrays\dAUOE\src\host\indexing.jl:17 [inlined]
  [6] getindex
    @ D:\Julia\depot\packages\GPUArrays\dAUOE\src\host\indexing.jl:15 [inlined]
  [7] scale(x::CuArray{Float64, 0, CUDA.Mem.DeviceBuffer}, α::VectorInterface.Zero)
    @ VectorInterface D:\Julia\depot\packages\VectorInterface\TAlcJ\src\abstractarray.jl:39
  [8] #61
    @ D:\Julia\depot\packages\TensorOperations\dNaBM\ext\TensorOperationsChainRulesCoreExt.jl:93 [inlined]
  [9] unthunk
    @ D:\Julia\depot\packages\ChainRulesCore\zgT0R\src\tangent_types\thunks.jl:204 [inlined]
 [10] wrap_chainrules_output
    @ D:\Julia\depot\packages\Zygote\jxHJc\src\compiler\chainrules.jl:110 [inlined]
 [11] map (repeats 2 times)
    @ .\tuple.jl:276 [inlined]
 [12] wrap_chainrules_output
    @ D:\Julia\depot\packages\Zygote\jxHJc\src\compiler\chainrules.jl:111 [inlined]
 [13] ZBack
    @ D:\Julia\depot\packages\Zygote\jxHJc\src\compiler\chainrules.jl:211 [inlined]
 [14] Pullback
    @ D:\MagBEC\juliatest\t_adjulia\cuTensorAD.jl:6 [inlined]
 [15] (::Zygote.Pullback{Tuple{typeof(QuadMin), Matrix{Float64}}, Tuple{Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#tensorfree!_pullback#47"{Tuple{CuArray{Float64, 0, CUDA.Mem.DeviceBuffer}, TensorOperations.Backend{:cuTENSOR}}}}, Zygote.Pullback{Tuple{typeof(scalartype), Matrix{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Matrix{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}, Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#tensorscalar_pullback#49"{CuArray{Float64, 0, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{typeof(scalartype), Matrix{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Matrix{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}, Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#tensoralloc_contract_pullback#41"{Tuple{DataType, Tuple{Tuple{}, Tuple{}}, Matrix{Float64}, Tuple{Tuple{}, Tuple{Int64, Int64}}, Symbol, Matrix{Float64}, Tuple{Tuple{Int64, Int64}, Tuple{}}, Symbol, Bool, TensorOperations.Backend{:cuTENSOR}}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#pullback#67"{CuArray{Float64, 0, CUDA.Mem.DeviceBuffer}, Tuple{Tuple{}, Tuple{}}, Matrix{Float64}, Tuple{Tuple{}, Tuple{Int64, Int64}}, Symbol, Matrix{Float64}, Tuple{Tuple{Int64, Int64}, Tuple{}}, Symbol, VectorInterface.One, VectorInterface.Zero, Tuple{TensorOperations.Backend{:cuTENSOR}}, ProjectTo{Number, NamedTuple{(), Tuple{}}}, ProjectTo{Number, NamedTuple{(), Tuple{}}}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{}}}}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}, Zygote.Pullback{Tuple{typeof(TensorOperations.promote_contract), Type{Float64}, Type{Float64}}, Tuple{Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.Pullback{Tuple{typeof(Base.promote_op), typeof(TensorOperations.tensorop), Type{Float64}, Type{Float64}}, Tuple{Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.ZBack{ChainRules.var"#apply_type_pullback#42"{Tuple{DataType, DataType}}}}}, Zygote.Pullback{Tuple{typeof(Core.Compiler.return_type), typeof(TensorOperations.tensorop), Type{Tuple{Float64, Float64}}}, Tuple{typeof(Core.Compiler.return_type)}}, Zygote.var"#2017#back#204"{typeof(identity)}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}}}}})(Δ::Float64)
    @ Zygote D:\Julia\depot\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [16] Pullback
    @ D:\MagBEC\juliatest\t_adjulia\cuTensorAD.jl:12 [inlined]
 [17] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#f#3", Matrix{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(QuadMin), Matrix{Float64}}, Tuple{Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#tensorfree!_pullback#47"{Tuple{CuArray{Float64, 0, CUDA.Mem.DeviceBuffer}, TensorOperations.Backend{:cuTENSOR}}}}, Zygote.Pullback{Tuple{typeof(scalartype), Matrix{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Matrix{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}, Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#tensorscalar_pullback#49"{CuArray{Float64, 0, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{typeof(scalartype), Matrix{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Matrix{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}, Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#tensoralloc_contract_pullback#41"{Tuple{DataType, Tuple{Tuple{}, Tuple{}}, Matrix{Float64}, Tuple{Tuple{}, Tuple{Int64, Int64}}, Symbol, Matrix{Float64}, Tuple{Tuple{Int64, Int64}, Tuple{}}, Symbol, Bool, TensorOperations.Backend{:cuTENSOR}}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#pullback#67"{CuArray{Float64, 0, CUDA.Mem.DeviceBuffer}, Tuple{Tuple{}, Tuple{}}, Matrix{Float64}, Tuple{Tuple{}, Tuple{Int64, Int64}}, Symbol, Matrix{Float64}, Tuple{Tuple{Int64, Int64}, Tuple{}}, Symbol, VectorInterface.One, VectorInterface.Zero, Tuple{TensorOperations.Backend{:cuTENSOR}}, ProjectTo{Number, NamedTuple{(), Tuple{}}}, ProjectTo{Number, NamedTuple{(), 
Tuple{}}}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{}}}}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}, Zygote.Pullback{Tuple{typeof(TensorOperations.promote_contract), 
Type{Float64}, Type{Float64}}, Tuple{Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.Pullback{Tuple{typeof(Base.promote_op), typeof(TensorOperations.tensorop), Type{Float64}, Type{Float64}}, Tuple{Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.ZBack{ChainRules.var"#apply_type_pullback#42"{Tuple{DataType, DataType}}}}}, Zygote.Pullback{Tuple{typeof(Core.Compiler.return_type), typeof(TensorOperations.tensorop), Type{Tuple{Float64, Float64}}}, Tuple{typeof(Core.Compiler.return_type)}}, Zygote.var"#2017#back#204"{typeof(identity)}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}}}}}}}})(Δ::Float64)
    @ Zygote D:\Julia\depot\packages\Zygote\jxHJc\src\compiler\interface.jl:91
 [18] gradient(f::Function, args::Matrix{Float64})
    @ Zygote D:\Julia\depot\packages\Zygote\jxHJc\src\compiler\interface.jl:148
 [19] g
    @ D:\MagBEC\juliatest\t_adjulia\cuTensorAD.jl:13 [inlined]
 [20] AD4CuArray()
    @ Main D:\MagBEC\juliatest\t_adjulia\cuTensorAD.jl:14
 [21] top-level scope
    @ D:\MagBEC\juliatest\t_adjulia\cuTensorAD.jl:18
Version Info
julia> Pkg.status(["TensorOperations","ChainRulesCore","Zygote","Yota","CUDA","cuTENSOR"])
Status `D:\Julia\depot\environments\v1.9\Project.toml`
⌅ [052768ef] CUDA v5.1.2
  [d360d2e6] ChainRulesCore v1.23.0
  [6aa20fa7] TensorOperations v4.1.1
  [cd998857] Yota v0.8.5
  [e88e6eb3] Zygote v0.6.69
⌃ [011b41b2] cuTENSOR v1.2.1

But if you switch to using Yota and g(x) = grad(f, x) it does its work, so there might be something to be checked on the Zygote side. Nevertheless, there is a solution for AD+cuTENSOR after all and that is already quite cool!

@lkdvos
Copy link
Collaborator

lkdvos commented Apr 19, 2024

The changes in VectorInterface were not yet tagged, but this should be resolved once this is merged: JuliaRegistries/General#105225
Would you mind trying again with version v0.4.5 of VectorInterface? I am hoping that fixes it.

@hxjz233
Copy link
Author

hxjz233 commented Apr 19, 2024

Yes it solves the issue! Thanks for the effort! :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants