Skip to content

Commit e697c63

Browse files
authored
Fix missing non-optional argument in rrules with backend. (#168)
* Add non-optional arguments in chainrules implementation `tensorcontract(args..., alfa, backend)` only has a default value for `alfa` when no backend is supplied. * Add tests for rrules with backend specified
1 parent 4743131 commit e697c63

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

ext/TensorOperationsChainRulesCoreExt.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,13 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
5555
_dα = tensorscalar(tensorcontract(((), ()), A, ((), linearize(pC)),
5656
_conj(conjA), ΔC,
5757
(trivtuple(numind(pC)),
58-
()), :N, backend...))
58+
()), :N, One(), backend...))
5959
return projectα(_dα)
6060
end
6161
= @thunk begin
6262
_dβ = tensorscalar(tensorcontract(((), ()), C,
6363
((), trivtuple(numind(pC))), :C, ΔC,
64-
(trivtuple(numind(pC)), ()), :N,
64+
(trivtuple(numind(pC)), ()), :N, One(),
6565
backend...))
6666
return projectβ(_dβ)
6767
end
@@ -116,17 +116,17 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
116116
= @thunk begin
117117
_dα = tensorscalar(tensorcontract(((), ()),
118118
tensorcontract(pC, A, pA, conjA, B, pB,
119-
conjB),
119+
conjB, One(), backend...),
120120
((), trivtuple(numind(pC))),
121121
:C, ΔC,
122-
(trivtuple(numind(pC)), ()), :N,
122+
(trivtuple(numind(pC)), ()), :N, One(),
123123
backend...))
124124
return projectα(_dα)
125125
end
126126
= @thunk begin
127127
_dβ = tensorscalar(tensorcontract(((), ()), C,
128128
((), trivtuple(numind(pC))), :C, ΔC,
129-
(trivtuple(numind(pC)), ()), :N,
129+
(trivtuple(numind(pC)), ()), :N, One(),
130130
backend...))
131131
return projectβ(_dβ)
132132
end
@@ -172,14 +172,14 @@ function ChainRulesCore.rrule(::typeof(tensortrace!), C, pC::Index2Tuple, A,
172172
tensortrace(pC, A, pA),
173173
((), trivtuple(numind(pC))),
174174
_conj(conjA), ΔC,
175-
(trivtuple(numind(pC)), ()), :N,
175+
(trivtuple(numind(pC)), ()), :N, One(),
176176
backend...))
177177
return projectα(_dα)
178178
end
179179
= @thunk begin
180180
_dβ = tensorscalar(tensorcontract(((), ()), C,
181181
((), trivtuple(numind(pC))), :C, ΔC,
182-
(trivtuple(numind(pC)), ()), :N,
182+
(trivtuple(numind(pC)), ()), :N, One(),
183183
backend...))
184184
return projectβ(_dβ)
185185
end

test/ad.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using TensorOperations
2+
using TensorOperations: StridedBLAS, StridedNative
23
using Test
34
using ChainRulesTestUtils
45

@@ -19,7 +20,12 @@ precision(::Type{<:Union{Float64,Complex{Float64}}}) = 1e-8
1920
β = rand(T)
2021
A = rand(T₁, (2, 3, 4, 2, 5))
2122
C = rand(T₂, size.(Ref(A), pC[1]))
23+
2224
test_rrule(tensortrace!, C, pC, A, pA, :N, α, β; atol, rtol)
25+
test_rrule(tensortrace!, C, pC, A, pA, :C, α, β; atol, rtol)
26+
27+
test_rrule(tensortrace!, C, pC, A, pA, :C, α, β, StridedBLAS(); atol, rtol)
28+
test_rrule(tensortrace!, C, pC, A, pA, :N, α, β, StridedNative(); atol, rtol)
2329
end
2430

2531
@testset "tensoradd! ($T₁, $T₂)" for (T₁, T₂) in ((Float64, Float64), (Float32, Float64),
@@ -35,6 +41,9 @@ end
3541
β = rand(T)
3642
test_rrule(tensoradd!, C, pC, A, :N, α, β; atol, rtol)
3743
test_rrule(tensoradd!, C, pC, A, :C, α, β; atol, rtol)
44+
45+
test_rrule(tensoradd!, C, pC, A, :N, α, β, StridedBLAS(); atol, rtol)
46+
test_rrule(tensoradd!, C, pC, A, :C, α, β, StridedNative(); atol, rtol)
3847
end
3948

4049
@testset "tensorcontract! ($T₁, $T₂)" for (T₁, T₂) in
@@ -58,6 +67,11 @@ end
5867
test_rrule(tensorcontract!, C, pC, A, pA, :C, B, pB, :N, α, β; atol, rtol)
5968
test_rrule(tensorcontract!, C, pC, A, pA, :N, B, pB, :C, α, β; atol, rtol)
6069
test_rrule(tensorcontract!, C, pC, A, pA, :C, B, pB, :C, α, β; atol, rtol)
70+
71+
test_rrule(tensorcontract!, C, pC, A, pA, :N, B, pB, :N, α, β, StridedBLAS();
72+
atol, rtol)
73+
test_rrule(tensorcontract!, C, pC, A, pA, :C, B, pB, :N, α, β, StridedNative();
74+
atol, rtol)
6175
end
6276

6377
@testset "tensorscalar ($T)" for T in (Float32, Float64, ComplexF64)

0 commit comments

Comments
 (0)