@@ -38,7 +38,6 @@ const CuStridedView = StridedViewsCUDAExt.CuStridedView
38
38
const SUPPORTED_CUARRAYS = Union{AnyCuArray,CuStridedView}
39
39
const cuTENSORBackend = TO. Backend{:cuTENSOR }
40
40
41
-
42
41
function TO. tensorscalar (C:: SUPPORTED_CUARRAYS )
43
42
return ndims (C) == 0 ? tensorscalar (collect (C)) : throw (DimensionMismatch ())
44
43
end
60
59
61
60
# making sure that if no backend is specified, the cuTENSOR backend is used:
62
61
63
- function TO. tensoradd! (C:: SUPPORTED_CUARRAYS , pC:: Index2Tuple , A:: SUPPORTED_CUARRAYS , conjA:: Symbol ,
62
+ function TO. tensoradd! (C:: SUPPORTED_CUARRAYS , pC:: Index2Tuple , A:: SUPPORTED_CUARRAYS ,
63
+ conjA:: Symbol ,
64
64
α:: Number , β:: Number )
65
65
return tensoradd! (C, pC, A, conjA, α, β, cuTENSORBackend ())
66
66
end
@@ -171,7 +171,7 @@ function TO.tensorcontract!(C::CuArray, pC::Index2Tuple,
171
171
Ainds, Binds, Cinds = collect .(TO. contract_labels (pA, pB, pC))
172
172
opA = tensorop (A, conjA)
173
173
opB = tensorop (B, conjB)
174
-
174
+
175
175
# dispatch to cuTENSOR
176
176
return cuTENSOR. contract! (α,
177
177
A, Ainds, opA,
@@ -206,38 +206,38 @@ function plan_trace(@nospecialize(A::AbstractArray), Ainds::cuTENSOR.ModeType,
206
206
! cuTENSOR. is_unary (opA) && throw (ArgumentError (" opA must be a unary op!" ))
207
207
! cuTENSOR. is_unary (opC) && throw (ArgumentError (" opC must be a unary op!" ))
208
208
! cuTENSOR. is_binary (opReduce) && throw (ArgumentError (" opReduce must be a binary op!" ))
209
-
209
+
210
210
# TODO : check if this can be avoided, available in caller
211
211
# TODO : cuTENSOR will allocate sizes and strides anyways, could use that here
212
212
_, cindA1, cindA2 = TO. trace_indices (tuple (Ainds... ), tuple (Cinds... ))
213
-
213
+
214
214
# add strides of cindA2 to strides of cindA1 -> selects diagonal
215
215
stA = strides (A)
216
216
for (i, j) in zip (cindA1, cindA2)
217
217
stA = Base. setindex (stA, stA[i] + stA[j], i)
218
218
end
219
219
szA = TT. deleteat (size (A), cindA2)
220
220
stA′ = TT. deleteat (stA, cindA2)
221
-
221
+
222
222
descA = cuTENSOR. CuTensorDescriptor (A; size= szA, strides= stA′)
223
223
descC = cuTENSOR. CuTensorDescriptor (C)
224
-
224
+
225
225
modeA = collect (Cint, deleteat! (Ainds, cindA2))
226
226
modeC = collect (Cint, Cinds)
227
-
227
+
228
228
actual_compute_type = if compute_type === nothing
229
229
cuTENSOR. reduction_compute_types[(eltype (A), eltype (C))]
230
230
else
231
231
compute_type
232
232
end
233
-
233
+
234
234
desc = Ref {cuTENSOR.cutensorOperationDescriptor_t} ()
235
235
cuTENSOR. cutensorCreateReduction (cuTENSOR. handle (),
236
- desc,
237
- descA, modeA, opA,
238
- descC, modeC, opC,
239
- descC, modeC, opReduce,
240
- actual_compute_type)
236
+ desc,
237
+ descA, modeA, opA,
238
+ descC, modeC, opC,
239
+ descC, modeC, opReduce,
240
+ actual_compute_type)
241
241
242
242
plan_pref = Ref {cuTENSOR.cutensorPlanPreference_t} ()
243
243
cuTENSOR. cutensorCreatePlanPreference (cuTENSOR. handle (), plan_pref, algo, jit)
0 commit comments