Skip to content

Commit a17f420

Browse files
committed
make compatible with TensorOperations v5
1 parent f348d12 commit a17f420

File tree

7 files changed

+86
-59
lines changed

7 files changed

+86
-59
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ jobs:
2121
fail-fast: false
2222
matrix:
2323
version:
24-
- '1.6' # LTS version
24+
- '1.8' # lowest TensorOperations version
25+
- '1.10' # LTS version
2526
- '1' # automatically expands to the latest stable 1.x release of Julia
2627
os:
2728
- ubuntu-latest

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SparseArrayKit"
22
uuid = "a9a3c162-d163-4c15-8926-b8794fbefed2"
33
authors = ["Lukas Devos <[email protected]>", "Maarten Van Damme <[email protected]>", "Jutho Haegeman <[email protected]>"]
4-
version = "0.3.1"
4+
version = "0.4"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -20,10 +20,10 @@ SparseArrayKitTensorOperations = "TensorOperations"
2020

2121
[compat]
2222
PackageExtensionCompat = "1"
23-
TensorOperations = "4"
23+
TensorOperations = "5"
2424
TupleTools = "1.1"
2525
VectorInterface = "0.4.1"
26-
julia = "1.6"
26+
julia = "1.8"
2727

2828
[extras]
2929
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"

ext/SparseArrayKitTensorOperations.jl

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,33 +4,58 @@ import TensorOperations as TO
44
using TensorOperations: Index2Tuple, linearize, numind
55
using SparseArrayKit: tensoradd!, tensortrace!, tensorcontract!, SparseArray
66

7-
function TO.tensoradd!(C::SparseArray, pC::Index2Tuple,
8-
A::SparseArray, conjA::Symbol,
9-
α::Number, β::Number)
10-
return tensoradd!(C, linearize(pC), A, conjA, α, β)
7+
struct SparseArrayBackend <: TO.AbstractBackend
118
end
129

13-
function TO.tensortrace!(C::SparseArray, pC::Index2Tuple,
14-
A::SparseArray, pA::Index2Tuple, conjA::Symbol,
15-
α::Number, β::Number)
16-
return tensortrace!(C, linearize(pC), A, conjA, pA[1], pA[2], α, β)
10+
# ------------------------------------------------------------------------------------------
11+
# Default backend selection mechanism for AbstractArray instances
12+
# ------------------------------------------------------------------------------------------
13+
function TO.select_backend(::typeof(TO.tensoradd!), C::SparseArray, A::SparseArray)
14+
return SparseArrayBackend()
1715
end
1816

19-
function TO.tensorcontract!(C::SparseArray, pC::Index2Tuple,
20-
A::SparseArray, pA::Index2Tuple, conjA::Symbol,
21-
B::SparseArray, pB::Index2Tuple, conjB::Symbol,
22-
α::Number, β::Number)
23-
return tensorcontract!(C, linearize(pC), A, conjA, pA[1], pA[2], B, conjB, pB[2], pB[1],
24-
α, β)
17+
function TO.select_backend(::typeof(TO.tensortrace!), C::SparseArray, A::SparseArray)
18+
return SparseArrayBackend()
2519
end
2620

27-
function TO.tensoradd_type(TC, pA::Index2Tuple, ::SparseArray, ::Symbol)
21+
function TO.select_backend(::typeof(TO.tensorcontract!), C::SparseArray, A::SparseArray,
22+
B::SparseArray)
23+
return SparseArrayBackend()
24+
end
25+
26+
function TO.tensoradd!(C::SparseArray,
27+
A::SparseArray, pA::Index2Tuple, conjA::Bool,
28+
α::Number, β::Number,
29+
::SparseArrayBackend, allocator)
30+
return tensoradd!(C, A, conjA, linearize(pA), α, β)
31+
end
32+
33+
function TO.tensortrace!(C::SparseArray,
34+
A::SparseArray, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
35+
α::Number, β::Number,
36+
::SparseArrayBackend, allocator)
37+
return tensortrace!(C, A, conjA, linearize(p), q[1], q[2], α, β)
38+
end
39+
40+
function TO.tensorcontract!(C::SparseArray,
41+
A::SparseArray, pA::Index2Tuple, conjA::Bool,
42+
B::SparseArray, pB::Index2Tuple, conjB::Bool,
43+
pAB::Index2Tuple,
44+
α::Number, β::Number,
45+
::SparseArrayBackend, allocator)
46+
return tensorcontract!(C, A, conjA, pA[1], pA[2], B, conjB, pB[2], pB[1],
47+
linearize(pAB), α, β)
48+
end
49+
50+
function TO.tensoradd_type(TC, ::SparseArray, pA::Index2Tuple, ::Bool)
2851
return SparseArray{TC,numind(pA)}
2952
end
3053

31-
function TO.tensorcontract_type(TC, pC, ::SparseArray, pA, conjA,
32-
::SparseArray, pB, conjB)
33-
return SparseArray{TC,numind(pC)}
54+
function TO.tensorcontract_type(TC,
55+
::SparseArray, pA::Index2Tuple, conjA::Bool,
56+
::SparseArray, pB::Index2Tuple, conjB::Bool,
57+
pAB::Index2Tuple)
58+
return SparseArray{TC,numind(pAB)}
3459
end
3560

3661
end

src/base.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ end
5454

5555
# array manipulation
5656
function Base.permutedims!(dst::SparseArray, src::SparseArray, p)
57-
return tensoradd!(dst, tuple(p...), src, :N, true, false)
57+
return tensoradd!(dst, src, false, tuple(p...), true, false)
5858
end
5959

6060
function Base.reshape(parent::SparseArray{T}, dims::Dims) where {T}

src/linearalgebra.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ const ASM{T} = Union{SparseArray{T,2},
7777

7878
LinearAlgebra.mul!(C::SM, A::ASM, B::ASM) = mul!(C, A, B, one(eltype(C)), zero(eltype(C)))
7979
function LinearAlgebra.mul!(C::SM, A::ASM, B::ASM, α::Number, β::Number)
80-
CA = A isa Adjoint ? :C : :N
81-
CB = B isa Adjoint ? :C : :N
80+
conjA = A isa Adjoint
81+
conjB = B isa Adjoint
8282
oindA = A isa Union{Adjoint,Transpose} ? (2,) : (1,)
8383
cindA = A isa Union{Adjoint,Transpose} ? (1,) : (2,)
8484
oindB = B isa Union{Adjoint,Transpose} ? (1,) : (2,)
@@ -87,8 +87,9 @@ function LinearAlgebra.mul!(C::SM, A::ASM, B::ASM, α::Number, β::Number)
8787
AA = A isa Union{Adjoint,Transpose} ? parent(A) : A
8888
BB = B isa Union{Adjoint,Transpose} ? parent(B) : B
8989

90-
return tensorcontract!(C, (1, 2), AA, CA, oindA, cindA, BB, CB, oindB, cindB, α, β)
90+
return tensorcontract!(C, AA, conjA, oindA, cindA, BB, conjB, oindB, cindB, (1, 2), α,
91+
β)
9192
end
9293

93-
LinearAlgebra.adjoint!(C::SM, A::SM) = tensoradd!(C, (2, 1), A, :C, One(), Zero())
94-
LinearAlgebra.transpose!(C::SM, A::SM) = tensoradd!(C, (2, 1), A, :N, One(), Zero())
94+
LinearAlgebra.adjoint!(C::SM, A::SM) = tensoradd!(C, A, true, (2, 1), One(), Zero())
95+
LinearAlgebra.transpose!(C::SM, A::SM) = tensoradd!(C, A, false, (2, 1), One(), Zero())

src/tensoroperations.jl

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,55 @@
11
# TensorOperations compatiblity
22
#-------------------------------
3-
function tensoradd!(C::SparseArray{<:Any,N}, indCinA,
4-
A::SparseArray{<:Any,N}, CA::Symbol,
3+
function tensoradd!(C::SparseArray{<:Any,N},
4+
A::SparseArray{<:Any,N}, conjA::Bool, pA,
55
α::Number=One(), β::Number=One()) where {N}
6-
(N == length(indCinA) && TupleTools.isperm(indCinA)) ||
7-
throw(ArgumentError("Invalid permutation of length $N: $indCinA"))
8-
size(C) == TupleTools.getindices(size(A), indCinA) ||
6+
(N == length(pA) && TupleTools.isperm(pA)) ||
7+
throw(ArgumentError("Invalid permutation of length $N: $pA"))
8+
size(C) == TupleTools.getindices(size(A), pA) ||
99
throw(DimensionMismatch("non-matching sizes while adding arrays"))
1010
scale!(C, β)
1111
for (IA, vA) in A.data
12-
IC = CartesianIndex(TupleTools.getindices(IA.I, indCinA))
13-
C[IC] += α * (CA == :C ? conj(vA) : vA)
12+
IC = CartesianIndex(TupleTools.getindices(IA.I, pA))
13+
C[IC] += α * (conjA ? conj(vA) : vA)
1414
end
1515
return C
1616
end
1717

18-
function tensortrace!(C::SparseArray{<:Any,NC}, indCinA,
19-
A::SparseArray{<:Any,NA}, CA::Symbol, cindA1, cindA2,
18+
function tensortrace!(C::SparseArray{<:Any,NC},
19+
A::SparseArray{<:Any,NA}, conjA::Bool, p, q1, q2,
2020
α::Number=One(), β::Number=Zero()) where {NA,NC}
21-
NC == length(indCinA) ||
22-
throw(ArgumentError("Invalid selection of $NC out of $NA: $indCinA"))
23-
NA - NC == 2 * length(cindA1) == 2 * length(cindA2) ||
21+
NC == length(p) ||
22+
throw(ArgumentError("Invalid selection of $NC out of $NA: $p"))
23+
NA - NC == 2 * length(q1) == 2 * length(q2) ||
2424
throw(ArgumentError("invalid number of trace dimension"))
25-
pA = (indCinA..., cindA1..., cindA2...)
25+
pA = (p..., q1..., q2...)
2626
TupleTools.isperm(pA) ||
2727
throw(ArgumentError("invalid permutation of length $(ndims(A)): $pA"))
2828

2929
sizeA = size(A)
3030
sizeC = size(C)
3131

32-
TupleTools.getindices(sizeA, cindA1) == TupleTools.getindices(sizeA, cindA2) ||
32+
TupleTools.getindices(sizeA, q1) == TupleTools.getindices(sizeA, q2) ||
3333
throw(DimensionMismatch("non-matching trace sizes"))
34-
sizeC == TupleTools.getindices(sizeA, indCinA) ||
34+
sizeC == TupleTools.getindices(sizeA, p) ||
3535
throw(DimensionMismatch("non-matching sizes"))
3636

3737
scale!(C, β)
3838
for (IA, v) in A.data
39-
IAc1 = CartesianIndex(TupleTools.getindices(IA.I, cindA1))
40-
IAc2 = CartesianIndex(TupleTools.getindices(IA.I, cindA2))
39+
IAc1 = CartesianIndex(TupleTools.getindices(IA.I, q1))
40+
IAc2 = CartesianIndex(TupleTools.getindices(IA.I, q2))
4141
IAc1 == IAc2 || continue
4242

43-
IC = CartesianIndex(TupleTools.getindices(IA.I, indCinA))
44-
C[IC] += α * (CA == :C ? conj(v) : v)
43+
IC = CartesianIndex(TupleTools.getindices(IA.I, p))
44+
C[IC] += α * (conjA ? conj(v) : v)
4545
end
4646
return C
4747
end
4848

49-
function tensorcontract!(C::SparseArray, indCinoAB,
50-
A::SparseArray, CA::Symbol, oindA, cindA,
51-
B::SparseArray, CB::Symbol, oindB, cindB,
49+
function tensorcontract!(C::SparseArray,
50+
A::SparseArray, conjA::Bool, oindA, cindA,
51+
B::SparseArray, conjB::Bool, oindB, cindB,
52+
indCinoAB,
5253
α::Number=One(), β::Number=Zero())
5354
pA = (oindA..., cindA...)
5455
(length(pA) == ndims(A) && TupleTools.isperm(pA)) ||
@@ -119,9 +120,8 @@ function tensorcontract!(C::SparseArray, indCinoAB,
119120
IABo = CartesianIndex(IAo, IBo)
120121
IC = CartesianIndex(TupleTools.getindices(IABo.I, indCinoAB))
121122
vA = A[IA]
122-
increaseindex!(C,
123-
α * (CA == :C ? conj(vA) : vA) *
124-
(CB == :C ? conj(vB) : vB), IC)
123+
v = α * (conjA ? conj(vA) : vA) * (conjB ? conj(vB) : vB)
124+
increaseindex!(C, v, IC)
125125
end
126126
end
127127
else
@@ -135,9 +135,8 @@ function tensorcontract!(C::SparseArray, indCinoAB,
135135
vB = B[IB]
136136
IABo = CartesianIndex(IAo, IBo)
137137
IC = CartesianIndex(TupleTools.getindices(IABo.I, indCinoAB))
138-
increaseindex!(C,
139-
α * (CA == :C ? conj(vA) : vA) *
140-
(CB == :C ? conj(vB) : vB), IC)
138+
v = α * (conjA ? conj(vA) : vA) * (conjB ? conj(vB) : vB)
139+
increaseindex!(C, v, IC)
141140
end
142141
end
143142
end

test/contractions.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,10 @@ end
6262
end
6363

6464
sparse_result = TensorOperations.ncon(tensors, indices, conjlist)
65-
dense_result = TensorOperations.ncon(Array.(tensors), indices, conjlist)
66-
67-
@test Array(sparse_result) dense_result
65+
if SparseArrayKit.nonzero_length(sparse_result) > 0
66+
dense_result = TensorOperations.ncon(Array.(tensors), indices, conjlist)
67+
@test Array(sparse_result) dense_result
68+
end
6869
end
6970
end
7071

0 commit comments

Comments
 (0)