Skip to content

Commit 5af28d2

Browse files
Integer & atomic Intrinsics improvements (#544)
1 parent 97121c8 commit 5af28d2

File tree

2 files changed

+202
-113
lines changed

2 files changed

+202
-113
lines changed

src/device/intrinsics/math.jl

+87-77
Original file line numberDiff line numberDiff line change
@@ -343,82 +343,92 @@ end
343343
@device_override Base.abs(x::Int8) = ccall("extern air.abs.s.i8", llvmcall, Int8, (Int8,), x)
344344
@device_override Base.abs(x::UInt8) = ccall("extern air.abs.u.i8", llvmcall, UInt8, (UInt8,), x)
345345

346-
@device_override Base.min(x::Int64) = ccall("extern air.min.s.i64", llvmcall, Int64, (Int64,), x)
347-
@device_override Base.min(x::UInt64) = ccall("extern air.min.u.i64", llvmcall, UInt64, (UInt64,), x)
348-
@device_override Base.min(x::Int32) = ccall("extern air.min.s.i32", llvmcall, Int32, (Int32,), x)
349-
@device_override Base.min(x::UInt32) = ccall("extern air.min.u.i32", llvmcall, UInt32, (UInt32,), x)
350-
@device_override Base.min(x::Int16) = ccall("extern air.min.s.i16", llvmcall, Int16, (Int16,), x)
351-
@device_override Base.min(x::UInt16) = ccall("extern air.min.u.i16", llvmcall, UInt16, (UInt16,), x)
352-
@device_override Base.min(x::Int8) = ccall("extern air.min.s.i8", llvmcall, Int8, (Int8,), x)
353-
@device_override Base.min(x::UInt8) = ccall("extern air.min.u.i8", llvmcall, UInt8, (UInt8,), x)
354-
355-
@device_override Base.max(x::Int64) = ccall("extern air.max.s.i64", llvmcall, Int64, (Int64,), x)
356-
@device_override Base.max(x::UInt64) = ccall("extern air.max.u.i64", llvmcall, UInt64, (UInt64,), x)
357-
@device_override Base.max(x::Int32) = ccall("extern air.max.s.i32", llvmcall, Int32, (Int32,), x)
358-
@device_override Base.max(x::UInt32) = ccall("extern air.max.u.i32", llvmcall, UInt32, (UInt32,), x)
359-
@device_override Base.max(x::Int16) = ccall("extern air.max.s.i16", llvmcall, Int16, (Int16,), x)
360-
@device_override Base.max(x::UInt16) = ccall("extern air.max.u.i16", llvmcall, UInt16, (UInt16,), x)
361-
@device_override Base.max(x::Int8) = ccall("extern air.max.s.i8", llvmcall, Int8, (Int8,), x)
362-
@device_override Base.max(x::UInt8) = ccall("extern air.max.u.i8", llvmcall, UInt8, (UInt8,), x)
363-
364-
@device_function clz(x::Int64) = ccall("extern air.clz.i64", llvmcall, Int64, (Int64,), x)
365-
@device_function clz(x::UInt64) = ccall("extern air.clz.i64", llvmcall, UInt64, (UInt64,), x)
366-
@device_function clz(x::Int32) = ccall("extern air.clz.i32", llvmcall, Int32, (Int32,), x)
367-
@device_function clz(x::UInt32) = ccall("extern air.clz.i32", llvmcall, UInt32, (UInt32,), x)
368-
@device_function clz(x::Int16) = ccall("extern air.clz.i16", llvmcall, Int16, (Int16,), x)
369-
@device_function clz(x::UInt16) = ccall("extern air.clz.i16", llvmcall, UInt16, (UInt16,), x)
370-
@device_function clz(x::Int8) = ccall("extern air.clz.i8", llvmcall, Int8, (Int8,), x)
371-
@device_function clz(x::UInt8) = ccall("extern air.clz.i8", llvmcall, UInt8, (UInt8,), x)
372-
373-
@device_function ctz(x::Int64) = ccall("extern air.ctz.i64", llvmcall, Int64, (Int64,), x)
374-
@device_function ctz(x::UInt64) = ccall("extern air.ctz.i64", llvmcall, UInt64, (UInt64,), x)
375-
@device_function ctz(x::Int32) = ccall("extern air.ctz.i32", llvmcall, Int32, (Int32,), x)
376-
@device_function ctz(x::UInt32) = ccall("extern air.ctz.i32", llvmcall, UInt32, (UInt32,), x)
377-
@device_function ctz(x::Int16) = ccall("extern air.ctz.i16", llvmcall, Int16, (Int16,), x)
378-
@device_function ctz(x::UInt16) = ccall("extern air.ctz.i16", llvmcall, UInt16, (UInt16,), x)
379-
@device_function ctz(x::Int8) = ccall("extern air.ctz.i8", llvmcall, Int8, (Int8,), x)
380-
@device_function ctz(x::UInt8) = ccall("extern air.ctz.i8", llvmcall, UInt8, (UInt8,), x)
381-
382-
@device_function popcount(x::Int64) = ccall("extern air.popcount.i64", llvmcall, Int64, (Int64,), x)
383-
@device_function popcount(x::UInt64) = ccall("extern air.popcount.i64", llvmcall, UInt64, (UInt64,), x)
384-
@device_function popcount(x::Int32) = ccall("extern air.popcount.i32", llvmcall, Int32, (Int32,), x)
385-
@device_function popcount(x::UInt32) = ccall("extern air.popcount.i32", llvmcall, UInt32, (UInt32,), x)
386-
@device_function popcount(x::Int16) = ccall("extern air.popcount.i16", llvmcall, Int16, (Int16,), x)
387-
@device_function popcount(x::UInt16) = ccall("extern air.popcount.i16", llvmcall, UInt16, (UInt16,), x)
388-
@device_function popcount(x::Int8) = ccall("extern air.popcount.i8", llvmcall, Int8, (Int8,), x)
389-
@device_function popcount(x::UInt8) = ccall("extern air.popcount.i8", llvmcall, UInt8, (UInt8,), x)
390-
391-
@device_function reverse_bits(x::Int64) = ccall("extern air.reverse_bits.i64", llvmcall, Int64, (Int64,), x)
392-
@device_function reverse_bits(x::UInt64) = ccall("extern air.reverse_bits.i64", llvmcall, UInt64, (UInt64,), x)
393-
@device_function reverse_bits(x::Int32) = ccall("extern air.reverse_bits.i32", llvmcall, Int32, (Int32,), x)
394-
@device_function reverse_bits(x::UInt32) = ccall("extern air.reverse_bits.i32", llvmcall, UInt32, (UInt32,), x)
395-
@device_function reverse_bits(x::Int16) = ccall("extern air.reverse_bits.i16", llvmcall, Int16, (Int16,), x)
396-
@device_function reverse_bits(x::UInt16) = ccall("extern air.reverse_bits.i16", llvmcall, UInt16, (UInt16,), x)
397-
@device_function reverse_bits(x::Int8) = ccall("extern air.reverse_bits.i8", llvmcall, Int8, (Int8,), x)
398-
@device_function reverse_bits(x::UInt8) = ccall("extern air.reverse_bits.i8", llvmcall, UInt8, (UInt8,), x)
399-
400-
401-
function _mulhi(a::Int64, b::Int64)
402-
shift = sizeof(a) * 4
403-
mask = typemax(UInt32)
404-
a1, a2 = (a >> shift), a & mask
405-
b1, b2 = (b >> shift), b & mask
406-
a1b1, a1b2, a2b1 = a1*b1, a1*b2, a2*b1
407-
t1 = a1b2 + _mulhi(a2 % UInt32, b2 % UInt32)
408-
t2 = a2b1 + (t1 & mask)
409-
a1b1 + (t1 >> shift) + (t2 >> shift)
410-
end
411-
@static if isdefined(Base.MultiplicativeInverses, :_mul_high)
412-
_mulhi(a::T, b::T) where {T<:Union{Signed, Unsigned}} = Base.MultiplicativeInverses._mul_high(a, b)
413-
@device_override Base.MultiplicativeInverses._mul_high(a::Int64, b::Int64) = _mulhi(a, b)
414-
else
415-
_mulhi(a::T, b::T) where {T<:Union{Signed, Unsigned}} = ((widen(a)*b) >>> (sizeof(a)*8)) % T
416-
@device_override function Base.div(a::Int64, b::Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64})
417-
x = _mulhi(a, b.multiplier)
418-
x += (a*b.addmul) % Int64
419-
ifelse(abs(b.divisor) == 1, a*b.divisor, (signbit(x) + (x >> b.shift)) % Int64)
420-
end
421-
end
346+
@device_override Base.min(x::Int64, y::Int64) = ccall("extern air.min.s.i64", llvmcall, Int64, (Int64, Int64), x, y)
347+
@device_override Base.min(x::UInt64, y::UInt64) = ccall("extern air.min.u.i64", llvmcall, UInt64, (UInt64, UInt64), x, y)
348+
@device_override Base.min(x::Int32, y::Int32) = ccall("extern air.min.s.i32", llvmcall, Int32, (Int32, Int32), x, y)
349+
@device_override Base.min(x::UInt32, y::UInt32) = ccall("extern air.min.u.i32", llvmcall, UInt32, (UInt32, UInt32), x, y)
350+
@device_override Base.min(x::Int16, y::Int16) = ccall("extern air.min.s.i16", llvmcall, Int16, (Int16, Int16), x, y)
351+
@device_override Base.min(x::UInt16, y::UInt16) = ccall("extern air.min.u.i16", llvmcall, UInt16, (UInt16, UInt16), x, y)
352+
@device_override Base.min(x::Int8, y::Int8) = ccall("extern air.min.s.i8", llvmcall, Int8, (Int8, Int8), x, y)
353+
@device_override Base.min(x::UInt8, y::UInt8) = ccall("extern air.min.u.i8", llvmcall, UInt8, (UInt8, UInt8), x, y)
354+
355+
# XXX: Breaks mul! when uncommented. MWE: using Revise, Metal;A, x = mtl(rand(Int32, 4, 4)), mtl(rand(Int32, 4)); A*x
356+
# @device_override Base.max(x::Int64, y::Int64) = ccall("extern air.max.s.i64", llvmcall, Int64, (Int64, Int64), x, y)
357+
@device_override Base.max(x::UInt64, y::UInt64) = ccall("extern air.max.u.i64", llvmcall, UInt64, (UInt64, UInt64), x, y)
358+
@device_override Base.max(x::Int32, y::Int32) = ccall("extern air.max.s.i32", llvmcall, Int32, (Int32, Int32), x, y)
359+
@device_override Base.max(x::UInt32, y::UInt32) = ccall("extern air.max.u.i32", llvmcall, UInt32, (UInt32, UInt32), x, y)
360+
@device_override Base.max(x::Int16, y::Int16) = ccall("extern air.max.s.i16", llvmcall, Int16, (Int16, Int16), x, y)
361+
@device_override Base.max(x::UInt16, y::UInt16) = ccall("extern air.max.u.i16", llvmcall, UInt16, (UInt16, UInt16), x, y)
362+
@device_override Base.max(x::Int8, y::Int8) = ccall("extern air.max.s.i8", llvmcall, Int8, (Int8, Int8), x, y)
363+
@device_override Base.max(x::UInt8, y::UInt8) = ccall("extern air.max.u.i8", llvmcall, UInt8, (UInt8, UInt8), x, y)
364+
365+
@device_override Base.min(x::Int64, y::Int64, z::Int64) = ccall("extern air.min3.s.i64", llvmcall, Int64, (Int64, Int64, Int64), x, y, z)
366+
@device_override Base.min(x::UInt64, y::UInt64, z::UInt64) = ccall("extern air.min3.u.i64", llvmcall, UInt64, (UInt64, UInt64, UInt64), x, y, z)
367+
@device_override Base.min(x::Int32, y::Int32, z::Int32) = ccall("extern air.min3.s.i32", llvmcall, Int32, (Int32, Int32, Int32), x, y, z)
368+
@device_override Base.min(x::UInt32, y::UInt32, z::UInt32) = ccall("extern air.min3.u.i32", llvmcall, UInt32, (UInt32, UInt32, UInt32), x, y, z)
369+
@device_override Base.min(x::Int16, y::Int16, z::Int16) = ccall("extern air.min3.s.i16", llvmcall, Int16, (Int16, Int16, Int16), x, y, z)
370+
@device_override Base.min(x::UInt16, y::UInt16, z::UInt16) = ccall("extern air.min3.u.i16", llvmcall, UInt16, (UInt16, UInt16, UInt16), x, y, z)
371+
@device_override Base.min(x::Int8, y::Int8, z::Int8) = ccall("extern air.min3.s.i8", llvmcall, Int8, (Int8, Int8, Int8), x, y, z)
372+
@device_override Base.min(x::UInt8, y::UInt8, z::UInt8) = ccall("extern air.min3.u.i8", llvmcall, UInt8, (UInt8, UInt8, UInt8), x, y, z)
373+
374+
@device_override Base.max(x::Int64, y::Int64, z::Int64) = ccall("extern air.max3.s.i64", llvmcall, Int64, (Int64, Int64, Int64), x, y, z)
375+
@device_override Base.max(x::UInt64, y::UInt64, z::UInt64) = ccall("extern air.max3.u.i64", llvmcall, UInt64, (UInt64, UInt64, UInt64), x, y, z)
376+
@device_override Base.max(x::Int32, y::Int32, z::Int32) = ccall("extern air.max3.s.i32", llvmcall, Int32, (Int32, Int32, Int32), x, y, z)
377+
@device_override Base.max(x::UInt32, y::UInt32, z::UInt32) = ccall("extern air.max3.u.i32", llvmcall, UInt32, (UInt32, UInt32, UInt32), x, y, z)
378+
@device_override Base.max(x::Int16, y::Int16, z::Int16) = ccall("extern air.max3.s.i16", llvmcall, Int16, (Int16, Int16, Int16), x, y, z)
379+
@device_override Base.max(x::UInt16, y::UInt16, z::UInt16) = ccall("extern air.max3.u.i16", llvmcall, UInt16, (UInt16, UInt16, UInt16), x, y, z)
380+
@device_override Base.max(x::Int8, y::Int8, z::Int8) = ccall("extern air.max3.s.i8", llvmcall, Int8, (Int8, Int8, Int8), x, y, z)
381+
@device_override Base.max(x::UInt8, y::UInt8, z::UInt8) = ccall("extern air.max3.u.i8", llvmcall, UInt8, (UInt8, UInt8, UInt8), x, y, z)
382+
383+
@device_override Base.leading_zeros(x::Int64) = ccall("extern air.clz.i64", llvmcall, Int64, (Int64,), x)
384+
@device_override Base.leading_zeros(x::UInt64) = ccall("extern air.clz.i64", llvmcall, UInt64, (UInt64,), x)
385+
@device_override Base.leading_zeros(x::Int32) = ccall("extern air.clz.i32", llvmcall, Int32, (Int32,), x)
386+
@device_override Base.leading_zeros(x::UInt32) = ccall("extern air.clz.i32", llvmcall, UInt32, (UInt32,), x)
387+
@device_override Base.leading_zeros(x::Int16) = ccall("extern air.clz.i16", llvmcall, Int16, (Int16,), x)
388+
@device_override Base.leading_zeros(x::UInt16) = ccall("extern air.clz.i16", llvmcall, UInt16, (UInt16,), x)
389+
@device_override Base.leading_zeros(x::Int8) = ccall("extern air.clz.i8", llvmcall, Int8, (Int8,), x)
390+
@device_override Base.leading_zeros(x::UInt8) = ccall("extern air.clz.i8", llvmcall, UInt8, (UInt8,), x)
391+
const clz = leading_zeros
392+
393+
@device_override Base.trailing_zeros(x::Int64) = ccall("extern air.ctz.i64", llvmcall, Int64, (Int64,), x)
394+
@device_override Base.trailing_zeros(x::UInt64) = ccall("extern air.ctz.i64", llvmcall, UInt64, (UInt64,), x)
395+
@device_override Base.trailing_zeros(x::Int32) = ccall("extern air.ctz.i32", llvmcall, Int32, (Int32,), x)
396+
@device_override Base.trailing_zeros(x::UInt32) = ccall("extern air.ctz.i32", llvmcall, UInt32, (UInt32,), x)
397+
@device_override Base.trailing_zeros(x::Int16) = ccall("extern air.ctz.i16", llvmcall, Int16, (Int16,), x)
398+
@device_override Base.trailing_zeros(x::UInt16) = ccall("extern air.ctz.i16", llvmcall, UInt16, (UInt16,), x)
399+
@device_override Base.trailing_zeros(x::Int8) = ccall("extern air.ctz.i8", llvmcall, Int8, (Int8,), x)
400+
@device_override Base.trailing_zeros(x::UInt8) = ccall("extern air.ctz.i8", llvmcall, UInt8, (UInt8,), x)
401+
const ctz = trailing_zeros
402+
403+
@device_override Base.count_ones(x::Int64) = ccall("extern air.popcount.i64", llvmcall, Int64, (Int64,), x)
404+
@device_override Base.count_ones(x::UInt64) = ccall("extern air.popcount.i64", llvmcall, UInt64, (UInt64,), x)
405+
@device_override Base.count_ones(x::Int32) = ccall("extern air.popcount.i32", llvmcall, Int32, (Int32,), x)
406+
@device_override Base.count_ones(x::UInt32) = ccall("extern air.popcount.i32", llvmcall, UInt32, (UInt32,), x)
407+
@device_override Base.count_ones(x::Int16) = ccall("extern air.popcount.i16", llvmcall, Int16, (Int16,), x)
408+
@device_override Base.count_ones(x::UInt16) = ccall("extern air.popcount.i16", llvmcall, UInt16, (UInt16,), x)
409+
@device_override Base.count_ones(x::Int8) = ccall("extern air.popcount.i8", llvmcall, Int8, (Int8,), x)
410+
@device_override Base.count_ones(x::UInt8) = ccall("extern air.popcount.i8", llvmcall, UInt8, (UInt8,), x)
411+
const popcount = count_ones
412+
413+
@device_override Base.bitreverse(x::Int64) = ccall("extern air.reverse_bits.i64", llvmcall, Int64, (Int64,), x)
414+
@device_override Base.bitreverse(x::UInt64) = ccall("extern air.reverse_bits.i64", llvmcall, UInt64, (UInt64,), x)
415+
@device_override Base.bitreverse(x::Int32) = ccall("extern air.reverse_bits.i32", llvmcall, Int32, (Int32,), x)
416+
@device_override Base.bitreverse(x::UInt32) = ccall("extern air.reverse_bits.i32", llvmcall, UInt32, (UInt32,), x)
417+
@device_override Base.bitreverse(x::Int16) = ccall("extern air.reverse_bits.i16", llvmcall, Int16, (Int16,), x)
418+
@device_override Base.bitreverse(x::UInt16) = ccall("extern air.reverse_bits.i16", llvmcall, UInt16, (UInt16,), x)
419+
@device_override Base.bitreverse(x::Int8) = ccall("extern air.reverse_bits.i8", llvmcall, Int8, (Int8,), x)
420+
@device_override Base.bitreverse(x::UInt8) = ccall("extern air.reverse_bits.i8", llvmcall, UInt8, (UInt8,), x)
421+
const reverse_bits = bitreverse
422+
423+
@device_override Base.MultiplicativeInverses._mul_high(x::Int64, y::Int64) = ccall("extern air.mul_hi.s.i64", llvmcall, Int64, (Int64, Int64), x, y)
424+
@device_override Base.MultiplicativeInverses._mul_high(x::UInt64, y::UInt64) = ccall("extern air.mul_hi.u.i64", llvmcall, UInt64, (UInt64, UInt64), x, y)
425+
@device_override Base.MultiplicativeInverses._mul_high(x::Int32, y::Int32) = ccall("extern air.mul_hi.s.i32", llvmcall, Int32, (Int32, Int32), x, y)
426+
@device_override Base.MultiplicativeInverses._mul_high(x::UInt32, y::UInt32) = ccall("extern air.mul_hi.u.i32", llvmcall, UInt32, (UInt32, UInt32), x, y)
427+
@device_override Base.MultiplicativeInverses._mul_high(x::Int16, y::Int16) = ccall("extern air.mul_hi.s.i16", llvmcall, Int16, (Int16, Int16), x, y)
428+
@device_override Base.MultiplicativeInverses._mul_high(x::UInt16, y::UInt16) = ccall("extern air.mul_hi.u.i16", llvmcall, UInt16, (UInt16, UInt16), x, y)
429+
@device_override Base.MultiplicativeInverses._mul_high(x::Int8, y::Int8) = ccall("extern air.mul_hi.s.i8", llvmcall, Int8, (Int8, Int8), x, y)
430+
@device_override Base.MultiplicativeInverses._mul_high(x::UInt8, y::UInt8) = ccall("extern air.mul_hi.u.i8", llvmcall, UInt8, (UInt8, UInt8), x, y)
431+
const mulhi = Base.MultiplicativeInverses._mul_high
422432

423433
# From: https://forums.developer.nvidia.com/t/a-faster-and-more-accurate-implementation-of-expm1f/48085/2
424434
# Original license copied below:
@@ -495,4 +505,4 @@ end
495505
end
496506

497507
return r
498-
end
508+
end

0 commit comments

Comments
 (0)