Skip to content

Commit ac76ca8

Browse files
committed
Fix alignment computation
1 parent c82eac7 commit ac76ca8

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

ext/TensorOperationscuTENSORExt.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -232,14 +232,14 @@ function cuTENSOR.CuTensorDescriptor(a::CuStridedView; size=size(a), strides=str
232232
eltype=eltype(a))
233233
sz = collect(Int64, size)
234234
st = collect(Int64, strides)
235-
# compute largest possible alignment
236-
alignment::UInt32 = 256
237-
while (alignment > Base.aligned_sizeof(eltype)) && (alignment % 2 != 0)
238-
alignment >>= 1
239-
end
235+
alignment = find_alignment(a)
240236
return cuTENSOR.CuTensorDescriptor(sz, st, eltype, alignment)
241237
end
242238

239+
const MAX_ALIGNMENT = UInt32(256) # This is the largest alignment of CUDA memory
240+
"find the alignment of the first element of the view"
241+
find_alignment(A::CuStridedView) = gcd(MAX_ALIGNMENT, UInt32(pointer(A)))
242+
243243
# trace!
244244
# ------
245245
# not actually part of cuTENSOR, just a special case of reduce

0 commit comments

Comments
 (0)