Skip to content

Commit 33c0fde

Browse files
committed
Add tests for rrules with backend specified
1 parent a873e20 commit 33c0fde

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

test/ad.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using TensorOperations
2+
using TensorOperations: StridedBLAS, StridedNative
23
using Test
34
using ChainRulesTestUtils
45

@@ -19,7 +20,12 @@ precision(::Type{<:Union{Float64,Complex{Float64}}}) = 1e-8
1920
β = rand(T)
2021
A = rand(T₁, (2, 3, 4, 2, 5))
2122
C = rand(T₂, size.(Ref(A), pC[1]))
23+
2224
test_rrule(tensortrace!, C, pC, A, pA, :N, α, β; atol, rtol)
25+
test_rrule(tensortrace!, C, pC, A, pA, :C, α, β; atol, rtol)
26+
27+
test_rrule(tensortrace!, C, pC, A, pA, :C, α, β, StridedBLAS(); atol, rtol)
28+
test_rrule(tensortrace!, C, pC, A, pA, :N, α, β, StridedNative(); atol, rtol)
2329
end
2430

2531
@testset "tensoradd! ($T₁, $T₂)" for (T₁, T₂) in ((Float64, Float64), (Float32, Float64),
@@ -35,6 +41,9 @@ end
3541
β = rand(T)
3642
test_rrule(tensoradd!, C, pC, A, :N, α, β; atol, rtol)
3743
test_rrule(tensoradd!, C, pC, A, :C, α, β; atol, rtol)
44+
45+
test_rrule(tensoradd!, C, pC, A, :N, α, β, StridedBLAS(); atol, rtol)
46+
test_rrule(tensoradd!, C, pC, A, :C, α, β, StridedNative(); atol, rtol)
3847
end
3948

4049
@testset "tensorcontract! ($T₁, $T₂)" for (T₁, T₂) in
@@ -58,6 +67,11 @@ end
5867
test_rrule(tensorcontract!, C, pC, A, pA, :C, B, pB, :N, α, β; atol, rtol)
5968
test_rrule(tensorcontract!, C, pC, A, pA, :N, B, pB, :C, α, β; atol, rtol)
6069
test_rrule(tensorcontract!, C, pC, A, pA, :C, B, pB, :C, α, β; atol, rtol)
70+
71+
test_rrule(tensorcontract!, C, pC, A, pA, :N, B, pB, :N, α, β, StridedBLAS();
72+
atol, rtol)
73+
test_rrule(tensorcontract!, C, pC, A, pA, :C, B, pB, :N, α, β, StridedNative();
74+
atol, rtol)
6175
end
6276

6377
@testset "tensorscalar ($T)" for T in (Float32, Float64, ComplexF64)

0 commit comments

Comments
 (0)