1
1
using TensorOperations
2
+ using TensorOperations: StridedBLAS, StridedNative
2
3
using Test
3
4
using ChainRulesTestUtils
4
5
@@ -19,7 +20,12 @@ precision(::Type{<:Union{Float64,Complex{Float64}}}) = 1e-8
19
20
β = rand (T)
20
21
A = rand (T₁, (2 , 3 , 4 , 2 , 5 ))
21
22
C = rand (T₂, size .(Ref (A), pC[1 ]))
23
+
22
24
test_rrule (tensortrace!, C, pC, A, pA, :N , α, β; atol, rtol)
25
+ test_rrule (tensortrace!, C, pC, A, pA, :C , α, β; atol, rtol)
26
+
27
+ test_rrule (tensortrace!, C, pC, A, pA, :C , α, β, StridedBLAS (); atol, rtol)
28
+ test_rrule (tensortrace!, C, pC, A, pA, :N , α, β, StridedNative (); atol, rtol)
23
29
end
24
30
25
31
@testset " tensoradd! ($T₁ , $T₂ )" for (T₁, T₂) in ((Float64, Float64), (Float32, Float64),
35
41
β = rand (T)
36
42
test_rrule (tensoradd!, C, pC, A, :N , α, β; atol, rtol)
37
43
test_rrule (tensoradd!, C, pC, A, :C , α, β; atol, rtol)
44
+
45
+ test_rrule (tensoradd!, C, pC, A, :N , α, β, StridedBLAS (); atol, rtol)
46
+ test_rrule (tensoradd!, C, pC, A, :C , α, β, StridedNative (); atol, rtol)
38
47
end
39
48
40
49
@testset " tensorcontract! ($T₁ , $T₂ )" for (T₁, T₂) in
58
67
test_rrule (tensorcontract!, C, pC, A, pA, :C , B, pB, :N , α, β; atol, rtol)
59
68
test_rrule (tensorcontract!, C, pC, A, pA, :N , B, pB, :C , α, β; atol, rtol)
60
69
test_rrule (tensorcontract!, C, pC, A, pA, :C , B, pB, :C , α, β; atol, rtol)
70
+
71
+ test_rrule (tensorcontract!, C, pC, A, pA, :N , B, pB, :N , α, β, StridedBLAS ();
72
+ atol, rtol)
73
+ test_rrule (tensorcontract!, C, pC, A, pA, :C , B, pB, :N , α, β, StridedNative ();
74
+ atol, rtol)
61
75
end
62
76
63
77
@testset " tensorscalar ($T )" for T in (Float32, Float64, ComplexF64)
0 commit comments