Skip to content

Commit f047345

Browse files
github-actions[bot]CompatHelper Julialkdvos
authored
Update to cuTENSOR 2.0 (#160)
This PR updates to cuTENSOR v2. This necessarily also bumps the minimal required Julia version to 1.8. Importantly, this hoists a lot of the implementation to cuTENSOR itself, and attempts to be more of a minimal wrapper, in order to be maintenance-low in the future. --------- Co-authored-by: CompatHelper Julia <[email protected]> Co-authored-by: lkdvos <[email protected]>
1 parent 0a4ea36 commit f047345

File tree

7 files changed

+577
-594
lines changed

7 files changed

+577
-594
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ jobs:
2121
fail-fast: false
2222
matrix:
2323
version:
24-
- '1.6' # LTS version
24+
- '1.8' # lowest supported version
2525
- '1' # automatically expands to the latest stable 1.x release of Julia
2626
os:
2727
- ubuntu-latest

Project.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorOperations"
22
uuid = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
33
authors = ["Lukas Devos <[email protected]>", "Maarten Van Damme <[email protected]>", "Jutho Haegeman <[email protected]>"]
4-
version = "4.1.1"
4+
version = "5.0.0-DEV"
55

66
[deps]
77
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
@@ -26,7 +26,7 @@ TensorOperationscuTENSORExt = ["cuTENSOR", "CUDA"]
2626

2727
[compat]
2828
Aqua = "0.6, 0.7, 0.8"
29-
CUDA = "4,5"
29+
CUDA = "5.4.0"
3030
ChainRulesCore = "1"
3131
ChainRulesTestUtils = "1"
3232
DynamicPolynomials = "0.5"
@@ -40,8 +40,8 @@ StridedViews = "0.2"
4040
Test = "1"
4141
TupleTools = "1.1"
4242
VectorInterface = "0.4.1"
43-
cuTENSOR = "1"
44-
julia = "1.6"
43+
cuTENSOR = "2.1.1"
44+
julia = "1.8"
4545

4646
[extras]
4747
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"

ext/TensorOperationscuTENSORExt.jl

Lines changed: 234 additions & 285 deletions
Large diffs are not rendered by default.

src/implementation/strided.jl

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,32 @@
11
#-------------------------------------------------------------------------------------------
22
# StridedView implementation
33
#-------------------------------------------------------------------------------------------
4+
5+
# default backends
6+
function tensoradd!(C::StridedView,
7+
A::StridedView, pA::Index2Tuple, conjA::Symbol,
8+
α::Number, β::Number)
9+
backend = eltype(C) isa BlasFloat ? StridedBLAS() : StridedNative()
10+
return tensoradd!(C, A, pA, conjA, α, β, backend)
11+
end
12+
function tensortrace!(C::StridedView,
13+
A::StridedView, p::Index2Tuple, q::Index2Tuple, conjA::Symbol,
14+
α::Number, β::Number)
15+
backend = eltype(C) isa BlasFloat ? StridedBLAS() : StridedNative()
16+
return tensortrace!(C, A, p, q, conjA, α, β, backend)
17+
end
18+
function tensorcontract!(C::StridedView,
19+
A::StridedView, pA::Index2Tuple, conjA::Symbol,
20+
B::StridedView, pB::Index2Tuple, conjB::Symbol,
21+
pAB::Index2Tuple, α::Number, β::Number)
22+
backend = eltype(C) isa BlasFloat ? StridedBLAS() : StridedNative()
23+
return tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, backend)
24+
end
25+
426
function tensoradd!(C::StridedView,
527
A::StridedView, pA::Index2Tuple, conjA::Symbol,
628
α::Number, β::Number,
7-
backend::Union{StridedNative,StridedBLAS}=StridedNative())
29+
::Union{StridedNative,StridedBLAS})
830
argcheck_tensoradd(C, A, pA)
931
dimcheck_tensoradd(C, A, pA)
1032
if !istrivialpermutation(pA) && Base.mightalias(C, A)
@@ -21,7 +43,7 @@ end
2143
function tensortrace!(C::StridedView,
2244
A::StridedView, p::Index2Tuple, q::Index2Tuple, conjA::Symbol,
2345
α::Number, β::Number,
24-
backend::Union{StridedNative,StridedBLAS}=StridedNative())
46+
::Union{StridedNative,StridedBLAS})
2547
argcheck_tensortrace(C, A, p, q)
2648
dimcheck_tensortrace(C, A, p, q)
2749

@@ -41,12 +63,11 @@ function tensortrace!(C::StridedView,
4163
return C
4264
end
4365

44-
function tensorcontract!(C::StridedView{T},
66+
function tensorcontract!(C::StridedView,
4567
A::StridedView, pA::Index2Tuple, conjA::Symbol,
4668
B::StridedView, pB::Index2Tuple, conjB::Symbol,
4769
pAB::Index2Tuple,
48-
α::Number, β::Number,
49-
backend::StridedBLAS=StridedBLAS()) where {T<:LinearAlgebra.BlasFloat}
70+
α::Number, β::Number, ::StridedBLAS)
5071
argcheck_tensorcontract(C, A, pA, B, pB, pAB)
5172
dimcheck_tensorcontract(C, A, pA, B, pB, pAB)
5273

@@ -74,7 +95,7 @@ function tensorcontract!(C::StridedView{T,2},
7495
A::StridedView{T,2}, pA::Index2Tuple{1,1}, conjA::Symbol,
7596
B::StridedView{T,2}, pB::Index2Tuple{1,1}, conjB::Symbol,
7697
pAB::Index2Tuple{1,1}, α::Number, β::Number,
77-
backend::StridedBLAS=StridedBLAS()) where {T<:LinearAlgebra.BlasFloat}
98+
::StridedBLAS) where {T}
7899
argcheck_tensorcontract(C, A, pA, B, pB, pAB)
79100
dimcheck_tensorcontract(C, A, pA, B, pB, pAB)
80101

@@ -97,7 +118,7 @@ function tensorcontract!(C::StridedView,
97118
A::StridedView, pA::Index2Tuple, conjA::Symbol,
98119
B::StridedView, pB::Index2Tuple, conjB::Symbol,
99120
pAB::Index2Tuple, α::Number, β::Number,
100-
backend::StridedNative)
121+
::StridedNative)
101122
argcheck_tensorcontract(C, A, pA, B, pB, pAB)
102123
dimcheck_tensorcontract(C, A, pA, B, pB, pAB)
103124

0 commit comments

Comments
 (0)