Skip to content

Commit 16e865e

Browse files
committed
Formatter
1 parent f066505 commit 16e865e

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

ext/TensorOperationscuTENSORExt.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ const CuStridedView = StridedViewsCUDAExt.CuStridedView
3838
const SUPPORTED_CUARRAYS = Union{AnyCuArray,CuStridedView}
3939
const cuTENSORBackend = TO.Backend{:cuTENSOR}
4040

41-
4241
function TO.tensorscalar(C::SUPPORTED_CUARRAYS)
4342
return ndims(C) == 0 ? tensorscalar(collect(C)) : throw(DimensionMismatch())
4443
end
@@ -60,7 +59,8 @@ end
6059

6160
# making sure that if no backend is specified, the cuTENSOR backend is used:
6261

63-
function TO.tensoradd!(C::SUPPORTED_CUARRAYS, pC::Index2Tuple, A::SUPPORTED_CUARRAYS, conjA::Symbol,
62+
function TO.tensoradd!(C::SUPPORTED_CUARRAYS, pC::Index2Tuple, A::SUPPORTED_CUARRAYS,
63+
conjA::Symbol,
6464
α::Number, β::Number)
6565
return tensoradd!(C, pC, A, conjA, α, β, cuTENSORBackend())
6666
end
@@ -171,7 +171,7 @@ function TO.tensorcontract!(C::CuArray, pC::Index2Tuple,
171171
Ainds, Binds, Cinds = collect.(TO.contract_labels(pA, pB, pC))
172172
opA = tensorop(A, conjA)
173173
opB = tensorop(B, conjB)
174-
174+
175175
# dispatch to cuTENSOR
176176
return cuTENSOR.contract!(α,
177177
A, Ainds, opA,
@@ -206,38 +206,38 @@ function plan_trace(@nospecialize(A::AbstractArray), Ainds::cuTENSOR.ModeType,
206206
!cuTENSOR.is_unary(opA) && throw(ArgumentError("opA must be a unary op!"))
207207
!cuTENSOR.is_unary(opC) && throw(ArgumentError("opC must be a unary op!"))
208208
!cuTENSOR.is_binary(opReduce) && throw(ArgumentError("opReduce must be a binary op!"))
209-
209+
210210
# TODO: check if this can be avoided, available in caller
211211
# TODO: cuTENSOR will allocate sizes and strides anyways, could use that here
212212
_, cindA1, cindA2 = TO.trace_indices(tuple(Ainds...), tuple(Cinds...))
213-
213+
214214
# add strides of cindA2 to strides of cindA1 -> selects diagonal
215215
stA = strides(A)
216216
for (i, j) in zip(cindA1, cindA2)
217217
stA = Base.setindex(stA, stA[i] + stA[j], i)
218218
end
219219
szA = TT.deleteat(size(A), cindA2)
220220
stA′ = TT.deleteat(stA, cindA2)
221-
221+
222222
descA = cuTENSOR.CuTensorDescriptor(A; size=szA, strides=stA′)
223223
descC = cuTENSOR.CuTensorDescriptor(C)
224-
224+
225225
modeA = collect(Cint, deleteat!(Ainds, cindA2))
226226
modeC = collect(Cint, Cinds)
227-
227+
228228
actual_compute_type = if compute_type === nothing
229229
cuTENSOR.reduction_compute_types[(eltype(A), eltype(C))]
230230
else
231231
compute_type
232232
end
233-
233+
234234
desc = Ref{cuTENSOR.cutensorOperationDescriptor_t}()
235235
cuTENSOR.cutensorCreateReduction(cuTENSOR.handle(),
236-
desc,
237-
descA, modeA, opA,
238-
descC, modeC, opC,
239-
descC, modeC, opReduce,
240-
actual_compute_type)
236+
desc,
237+
descA, modeA, opA,
238+
descC, modeC, opC,
239+
descC, modeC, opReduce,
240+
actual_compute_type)
241241

242242
plan_pref = Ref{cuTENSOR.cutensorPlanPreference_t}()
243243
cuTENSOR.cutensorCreatePlanPreference(cuTENSOR.handle(), plan_pref, algo, jit)

test/cutensor.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,4 +303,4 @@ if cuTENSOR.has_cutensor()
303303
@test copy(C) Ccopy
304304
end
305305
end
306-
end
306+
end

0 commit comments

Comments
 (0)