@@ -343,82 +343,92 @@ end
343
343
@device_override Base. abs (x:: Int8 ) = ccall (" extern air.abs.s.i8" , llvmcall, Int8, (Int8,), x)
344
344
@device_override Base. abs (x:: UInt8 ) = ccall (" extern air.abs.u.i8" , llvmcall, UInt8, (UInt8,), x)
345
345
346
- @device_override Base. min (x:: Int64 ) = ccall (" extern air.min.s.i64" , llvmcall, Int64, (Int64,), x)
347
- @device_override Base. min (x:: UInt64 ) = ccall (" extern air.min.u.i64" , llvmcall, UInt64, (UInt64,), x)
348
- @device_override Base. min (x:: Int32 ) = ccall (" extern air.min.s.i32" , llvmcall, Int32, (Int32,), x)
349
- @device_override Base. min (x:: UInt32 ) = ccall (" extern air.min.u.i32" , llvmcall, UInt32, (UInt32,), x)
350
- @device_override Base. min (x:: Int16 ) = ccall (" extern air.min.s.i16" , llvmcall, Int16, (Int16,), x)
351
- @device_override Base. min (x:: UInt16 ) = ccall (" extern air.min.u.i16" , llvmcall, UInt16, (UInt16,), x)
352
- @device_override Base. min (x:: Int8 ) = ccall (" extern air.min.s.i8" , llvmcall, Int8, (Int8,), x)
353
- @device_override Base. min (x:: UInt8 ) = ccall (" extern air.min.u.i8" , llvmcall, UInt8, (UInt8,), x)
354
-
355
- @device_override Base. max (x:: Int64 ) = ccall (" extern air.max.s.i64" , llvmcall, Int64, (Int64,), x)
356
- @device_override Base. max (x:: UInt64 ) = ccall (" extern air.max.u.i64" , llvmcall, UInt64, (UInt64,), x)
357
- @device_override Base. max (x:: Int32 ) = ccall (" extern air.max.s.i32" , llvmcall, Int32, (Int32,), x)
358
- @device_override Base. max (x:: UInt32 ) = ccall (" extern air.max.u.i32" , llvmcall, UInt32, (UInt32,), x)
359
- @device_override Base. max (x:: Int16 ) = ccall (" extern air.max.s.i16" , llvmcall, Int16, (Int16,), x)
360
- @device_override Base. max (x:: UInt16 ) = ccall (" extern air.max.u.i16" , llvmcall, UInt16, (UInt16,), x)
361
- @device_override Base. max (x:: Int8 ) = ccall (" extern air.max.s.i8" , llvmcall, Int8, (Int8,), x)
362
- @device_override Base. max (x:: UInt8 ) = ccall (" extern air.max.u.i8" , llvmcall, UInt8, (UInt8,), x)
363
-
364
- @device_function clz (x:: Int64 ) = ccall (" extern air.clz.i64" , llvmcall, Int64, (Int64,), x)
365
- @device_function clz (x:: UInt64 ) = ccall (" extern air.clz.i64" , llvmcall, UInt64, (UInt64,), x)
366
- @device_function clz (x:: Int32 ) = ccall (" extern air.clz.i32" , llvmcall, Int32, (Int32,), x)
367
- @device_function clz (x:: UInt32 ) = ccall (" extern air.clz.i32" , llvmcall, UInt32, (UInt32,), x)
368
- @device_function clz (x:: Int16 ) = ccall (" extern air.clz.i16" , llvmcall, Int16, (Int16,), x)
369
- @device_function clz (x:: UInt16 ) = ccall (" extern air.clz.i16" , llvmcall, UInt16, (UInt16,), x)
370
- @device_function clz (x:: Int8 ) = ccall (" extern air.clz.i8" , llvmcall, Int8, (Int8,), x)
371
- @device_function clz (x:: UInt8 ) = ccall (" extern air.clz.i8" , llvmcall, UInt8, (UInt8,), x)
372
-
373
- @device_function ctz (x:: Int64 ) = ccall (" extern air.ctz.i64" , llvmcall, Int64, (Int64,), x)
374
- @device_function ctz (x:: UInt64 ) = ccall (" extern air.ctz.i64" , llvmcall, UInt64, (UInt64,), x)
375
- @device_function ctz (x:: Int32 ) = ccall (" extern air.ctz.i32" , llvmcall, Int32, (Int32,), x)
376
- @device_function ctz (x:: UInt32 ) = ccall (" extern air.ctz.i32" , llvmcall, UInt32, (UInt32,), x)
377
- @device_function ctz (x:: Int16 ) = ccall (" extern air.ctz.i16" , llvmcall, Int16, (Int16,), x)
378
- @device_function ctz (x:: UInt16 ) = ccall (" extern air.ctz.i16" , llvmcall, UInt16, (UInt16,), x)
379
- @device_function ctz (x:: Int8 ) = ccall (" extern air.ctz.i8" , llvmcall, Int8, (Int8,), x)
380
- @device_function ctz (x:: UInt8 ) = ccall (" extern air.ctz.i8" , llvmcall, UInt8, (UInt8,), x)
381
-
382
- @device_function popcount (x:: Int64 ) = ccall (" extern air.popcount.i64" , llvmcall, Int64, (Int64,), x)
383
- @device_function popcount (x:: UInt64 ) = ccall (" extern air.popcount.i64" , llvmcall, UInt64, (UInt64,), x)
384
- @device_function popcount (x:: Int32 ) = ccall (" extern air.popcount.i32" , llvmcall, Int32, (Int32,), x)
385
- @device_function popcount (x:: UInt32 ) = ccall (" extern air.popcount.i32" , llvmcall, UInt32, (UInt32,), x)
386
- @device_function popcount (x:: Int16 ) = ccall (" extern air.popcount.i16" , llvmcall, Int16, (Int16,), x)
387
- @device_function popcount (x:: UInt16 ) = ccall (" extern air.popcount.i16" , llvmcall, UInt16, (UInt16,), x)
388
- @device_function popcount (x:: Int8 ) = ccall (" extern air.popcount.i8" , llvmcall, Int8, (Int8,), x)
389
- @device_function popcount (x:: UInt8 ) = ccall (" extern air.popcount.i8" , llvmcall, UInt8, (UInt8,), x)
390
-
391
- @device_function reverse_bits (x:: Int64 ) = ccall (" extern air.reverse_bits.i64" , llvmcall, Int64, (Int64,), x)
392
- @device_function reverse_bits (x:: UInt64 ) = ccall (" extern air.reverse_bits.i64" , llvmcall, UInt64, (UInt64,), x)
393
- @device_function reverse_bits (x:: Int32 ) = ccall (" extern air.reverse_bits.i32" , llvmcall, Int32, (Int32,), x)
394
- @device_function reverse_bits (x:: UInt32 ) = ccall (" extern air.reverse_bits.i32" , llvmcall, UInt32, (UInt32,), x)
395
- @device_function reverse_bits (x:: Int16 ) = ccall (" extern air.reverse_bits.i16" , llvmcall, Int16, (Int16,), x)
396
- @device_function reverse_bits (x:: UInt16 ) = ccall (" extern air.reverse_bits.i16" , llvmcall, UInt16, (UInt16,), x)
397
- @device_function reverse_bits (x:: Int8 ) = ccall (" extern air.reverse_bits.i8" , llvmcall, Int8, (Int8,), x)
398
- @device_function reverse_bits (x:: UInt8 ) = ccall (" extern air.reverse_bits.i8" , llvmcall, UInt8, (UInt8,), x)
399
-
400
-
401
- function _mulhi (a:: Int64 , b:: Int64 )
402
- shift = sizeof (a) * 4
403
- mask = typemax (UInt32)
404
- a1, a2 = (a >> shift), a & mask
405
- b1, b2 = (b >> shift), b & mask
406
- a1b1, a1b2, a2b1 = a1* b1, a1* b2, a2* b1
407
- t1 = a1b2 + _mulhi (a2 % UInt32, b2 % UInt32)
408
- t2 = a2b1 + (t1 & mask)
409
- a1b1 + (t1 >> shift) + (t2 >> shift)
410
- end
411
- @static if isdefined (Base. MultiplicativeInverses, :_mul_high )
412
- _mulhi (a:: T , b:: T ) where {T<: Union{Signed, Unsigned} } = Base. MultiplicativeInverses. _mul_high (a, b)
413
- @device_override Base. MultiplicativeInverses. _mul_high (a:: Int64 , b:: Int64 ) = _mulhi (a, b)
414
- else
415
- _mulhi (a:: T , b:: T ) where {T<: Union{Signed, Unsigned} } = ((widen (a)* b) >>> (sizeof (a)* 8 )) % T
416
- @device_override function Base. div (a:: Int64 , b:: Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64} )
417
- x = _mulhi (a, b. multiplier)
418
- x += (a* b. addmul) % Int64
419
- ifelse (abs (b. divisor) == 1 , a* b. divisor, (signbit (x) + (x >> b. shift)) % Int64)
420
- end
421
- end
346
+ @device_override Base. min (x:: Int64 , y:: Int64 ) = ccall (" extern air.min.s.i64" , llvmcall, Int64, (Int64, Int64), x, y)
347
+ @device_override Base. min (x:: UInt64 , y:: UInt64 ) = ccall (" extern air.min.u.i64" , llvmcall, UInt64, (UInt64, UInt64), x, y)
348
+ @device_override Base. min (x:: Int32 , y:: Int32 ) = ccall (" extern air.min.s.i32" , llvmcall, Int32, (Int32, Int32), x, y)
349
+ @device_override Base. min (x:: UInt32 , y:: UInt32 ) = ccall (" extern air.min.u.i32" , llvmcall, UInt32, (UInt32, UInt32), x, y)
350
+ @device_override Base. min (x:: Int16 , y:: Int16 ) = ccall (" extern air.min.s.i16" , llvmcall, Int16, (Int16, Int16), x, y)
351
+ @device_override Base. min (x:: UInt16 , y:: UInt16 ) = ccall (" extern air.min.u.i16" , llvmcall, UInt16, (UInt16, UInt16), x, y)
352
+ @device_override Base. min (x:: Int8 , y:: Int8 ) = ccall (" extern air.min.s.i8" , llvmcall, Int8, (Int8, Int8), x, y)
353
+ @device_override Base. min (x:: UInt8 , y:: UInt8 ) = ccall (" extern air.min.u.i8" , llvmcall, UInt8, (UInt8, UInt8), x, y)
354
+
355
+ # XXX : Breaks mul! when uncommented. MWE: using Revise, Metal;A, x = mtl(rand(Int32, 4, 4)), mtl(rand(Int32, 4)); A*x
356
+ # @device_override Base.max(x::Int64, y::Int64) = ccall("extern air.max.s.i64", llvmcall, Int64, (Int64, Int64), x, y)
357
+ @device_override Base. max (x:: UInt64 , y:: UInt64 ) = ccall (" extern air.max.u.i64" , llvmcall, UInt64, (UInt64, UInt64), x, y)
358
+ @device_override Base. max (x:: Int32 , y:: Int32 ) = ccall (" extern air.max.s.i32" , llvmcall, Int32, (Int32, Int32), x, y)
359
+ @device_override Base. max (x:: UInt32 , y:: UInt32 ) = ccall (" extern air.max.u.i32" , llvmcall, UInt32, (UInt32, UInt32), x, y)
360
+ @device_override Base. max (x:: Int16 , y:: Int16 ) = ccall (" extern air.max.s.i16" , llvmcall, Int16, (Int16, Int16), x, y)
361
+ @device_override Base. max (x:: UInt16 , y:: UInt16 ) = ccall (" extern air.max.u.i16" , llvmcall, UInt16, (UInt16, UInt16), x, y)
362
+ @device_override Base. max (x:: Int8 , y:: Int8 ) = ccall (" extern air.max.s.i8" , llvmcall, Int8, (Int8, Int8), x, y)
363
+ @device_override Base. max (x:: UInt8 , y:: UInt8 ) = ccall (" extern air.max.u.i8" , llvmcall, UInt8, (UInt8, UInt8), x, y)
364
+
365
+ @device_override Base. min (x:: Int64 , y:: Int64 , z:: Int64 ) = ccall (" extern air.min3.s.i64" , llvmcall, Int64, (Int64, Int64, Int64), x, y, z)
366
+ @device_override Base. min (x:: UInt64 , y:: UInt64 , z:: UInt64 ) = ccall (" extern air.min3.u.i64" , llvmcall, UInt64, (UInt64, UInt64, UInt64), x, y, z)
367
+ @device_override Base. min (x:: Int32 , y:: Int32 , z:: Int32 ) = ccall (" extern air.min3.s.i32" , llvmcall, Int32, (Int32, Int32, Int32), x, y, z)
368
+ @device_override Base. min (x:: UInt32 , y:: UInt32 , z:: UInt32 ) = ccall (" extern air.min3.u.i32" , llvmcall, UInt32, (UInt32, UInt32, UInt32), x, y, z)
369
+ @device_override Base. min (x:: Int16 , y:: Int16 , z:: Int16 ) = ccall (" extern air.min3.s.i16" , llvmcall, Int16, (Int16, Int16, Int16), x, y, z)
370
+ @device_override Base. min (x:: UInt16 , y:: UInt16 , z:: UInt16 ) = ccall (" extern air.min3.u.i16" , llvmcall, UInt16, (UInt16, UInt16, UInt16), x, y, z)
371
+ @device_override Base. min (x:: Int8 , y:: Int8 , z:: Int8 ) = ccall (" extern air.min3.s.i8" , llvmcall, Int8, (Int8, Int8, Int8), x, y, z)
372
+ @device_override Base. min (x:: UInt8 , y:: UInt8 , z:: UInt8 ) = ccall (" extern air.min3.u.i8" , llvmcall, UInt8, (UInt8, UInt8, UInt8), x, y, z)
373
+
374
+ @device_override Base. max (x:: Int64 , y:: Int64 , z:: Int64 ) = ccall (" extern air.max3.s.i64" , llvmcall, Int64, (Int64, Int64, Int64), x, y, z)
375
+ @device_override Base. max (x:: UInt64 , y:: UInt64 , z:: UInt64 ) = ccall (" extern air.max3.u.i64" , llvmcall, UInt64, (UInt64, UInt64, UInt64), x, y, z)
376
+ @device_override Base. max (x:: Int32 , y:: Int32 , z:: Int32 ) = ccall (" extern air.max3.s.i32" , llvmcall, Int32, (Int32, Int32, Int32), x, y, z)
377
+ @device_override Base. max (x:: UInt32 , y:: UInt32 , z:: UInt32 ) = ccall (" extern air.max3.u.i32" , llvmcall, UInt32, (UInt32, UInt32, UInt32), x, y, z)
378
+ @device_override Base. max (x:: Int16 , y:: Int16 , z:: Int16 ) = ccall (" extern air.max3.s.i16" , llvmcall, Int16, (Int16, Int16, Int16), x, y, z)
379
+ @device_override Base. max (x:: UInt16 , y:: UInt16 , z:: UInt16 ) = ccall (" extern air.max3.u.i16" , llvmcall, UInt16, (UInt16, UInt16, UInt16), x, y, z)
380
+ @device_override Base. max (x:: Int8 , y:: Int8 , z:: Int8 ) = ccall (" extern air.max3.s.i8" , llvmcall, Int8, (Int8, Int8, Int8), x, y, z)
381
+ @device_override Base. max (x:: UInt8 , y:: UInt8 , z:: UInt8 ) = ccall (" extern air.max3.u.i8" , llvmcall, UInt8, (UInt8, UInt8, UInt8), x, y, z)
382
+
383
+ @device_override Base. leading_zeros (x:: Int64 ) = ccall (" extern air.clz.i64" , llvmcall, Int64, (Int64,), x)
384
+ @device_override Base. leading_zeros (x:: UInt64 ) = ccall (" extern air.clz.i64" , llvmcall, UInt64, (UInt64,), x)
385
+ @device_override Base. leading_zeros (x:: Int32 ) = ccall (" extern air.clz.i32" , llvmcall, Int32, (Int32,), x)
386
+ @device_override Base. leading_zeros (x:: UInt32 ) = ccall (" extern air.clz.i32" , llvmcall, UInt32, (UInt32,), x)
387
+ @device_override Base. leading_zeros (x:: Int16 ) = ccall (" extern air.clz.i16" , llvmcall, Int16, (Int16,), x)
388
+ @device_override Base. leading_zeros (x:: UInt16 ) = ccall (" extern air.clz.i16" , llvmcall, UInt16, (UInt16,), x)
389
+ @device_override Base. leading_zeros (x:: Int8 ) = ccall (" extern air.clz.i8" , llvmcall, Int8, (Int8,), x)
390
+ @device_override Base. leading_zeros (x:: UInt8 ) = ccall (" extern air.clz.i8" , llvmcall, UInt8, (UInt8,), x)
391
+ const clz = leading_zeros
392
+
393
+ @device_override Base. trailing_zeros (x:: Int64 ) = ccall (" extern air.ctz.i64" , llvmcall, Int64, (Int64,), x)
394
+ @device_override Base. trailing_zeros (x:: UInt64 ) = ccall (" extern air.ctz.i64" , llvmcall, UInt64, (UInt64,), x)
395
+ @device_override Base. trailing_zeros (x:: Int32 ) = ccall (" extern air.ctz.i32" , llvmcall, Int32, (Int32,), x)
396
+ @device_override Base. trailing_zeros (x:: UInt32 ) = ccall (" extern air.ctz.i32" , llvmcall, UInt32, (UInt32,), x)
397
+ @device_override Base. trailing_zeros (x:: Int16 ) = ccall (" extern air.ctz.i16" , llvmcall, Int16, (Int16,), x)
398
+ @device_override Base. trailing_zeros (x:: UInt16 ) = ccall (" extern air.ctz.i16" , llvmcall, UInt16, (UInt16,), x)
399
+ @device_override Base. trailing_zeros (x:: Int8 ) = ccall (" extern air.ctz.i8" , llvmcall, Int8, (Int8,), x)
400
+ @device_override Base. trailing_zeros (x:: UInt8 ) = ccall (" extern air.ctz.i8" , llvmcall, UInt8, (UInt8,), x)
401
+ const ctz = trailing_zeros
402
+
403
+ @device_override Base. count_ones (x:: Int64 ) = ccall (" extern air.popcount.i64" , llvmcall, Int64, (Int64,), x)
404
+ @device_override Base. count_ones (x:: UInt64 ) = ccall (" extern air.popcount.i64" , llvmcall, UInt64, (UInt64,), x)
405
+ @device_override Base. count_ones (x:: Int32 ) = ccall (" extern air.popcount.i32" , llvmcall, Int32, (Int32,), x)
406
+ @device_override Base. count_ones (x:: UInt32 ) = ccall (" extern air.popcount.i32" , llvmcall, UInt32, (UInt32,), x)
407
+ @device_override Base. count_ones (x:: Int16 ) = ccall (" extern air.popcount.i16" , llvmcall, Int16, (Int16,), x)
408
+ @device_override Base. count_ones (x:: UInt16 ) = ccall (" extern air.popcount.i16" , llvmcall, UInt16, (UInt16,), x)
409
+ @device_override Base. count_ones (x:: Int8 ) = ccall (" extern air.popcount.i8" , llvmcall, Int8, (Int8,), x)
410
+ @device_override Base. count_ones (x:: UInt8 ) = ccall (" extern air.popcount.i8" , llvmcall, UInt8, (UInt8,), x)
411
+ const popcount = count_ones
412
+
413
+ @device_override Base. bitreverse (x:: Int64 ) = ccall (" extern air.reverse_bits.i64" , llvmcall, Int64, (Int64,), x)
414
+ @device_override Base. bitreverse (x:: UInt64 ) = ccall (" extern air.reverse_bits.i64" , llvmcall, UInt64, (UInt64,), x)
415
+ @device_override Base. bitreverse (x:: Int32 ) = ccall (" extern air.reverse_bits.i32" , llvmcall, Int32, (Int32,), x)
416
+ @device_override Base. bitreverse (x:: UInt32 ) = ccall (" extern air.reverse_bits.i32" , llvmcall, UInt32, (UInt32,), x)
417
+ @device_override Base. bitreverse (x:: Int16 ) = ccall (" extern air.reverse_bits.i16" , llvmcall, Int16, (Int16,), x)
418
+ @device_override Base. bitreverse (x:: UInt16 ) = ccall (" extern air.reverse_bits.i16" , llvmcall, UInt16, (UInt16,), x)
419
+ @device_override Base. bitreverse (x:: Int8 ) = ccall (" extern air.reverse_bits.i8" , llvmcall, Int8, (Int8,), x)
420
+ @device_override Base. bitreverse (x:: UInt8 ) = ccall (" extern air.reverse_bits.i8" , llvmcall, UInt8, (UInt8,), x)
421
+ const reverse_bits = bitreverse
422
+
423
+ @device_override Base. MultiplicativeInverses. _mul_high (x:: Int64 , y:: Int64 ) = ccall (" extern air.mul_hi.s.i64" , llvmcall, Int64, (Int64, Int64), x, y)
424
+ @device_override Base. MultiplicativeInverses. _mul_high (x:: UInt64 , y:: UInt64 ) = ccall (" extern air.mul_hi.u.i64" , llvmcall, UInt64, (UInt64, UInt64), x, y)
425
+ @device_override Base. MultiplicativeInverses. _mul_high (x:: Int32 , y:: Int32 ) = ccall (" extern air.mul_hi.s.i32" , llvmcall, Int32, (Int32, Int32), x, y)
426
+ @device_override Base. MultiplicativeInverses. _mul_high (x:: UInt32 , y:: UInt32 ) = ccall (" extern air.mul_hi.u.i32" , llvmcall, UInt32, (UInt32, UInt32), x, y)
427
+ @device_override Base. MultiplicativeInverses. _mul_high (x:: Int16 , y:: Int16 ) = ccall (" extern air.mul_hi.s.i16" , llvmcall, Int16, (Int16, Int16), x, y)
428
+ @device_override Base. MultiplicativeInverses. _mul_high (x:: UInt16 , y:: UInt16 ) = ccall (" extern air.mul_hi.u.i16" , llvmcall, UInt16, (UInt16, UInt16), x, y)
429
+ @device_override Base. MultiplicativeInverses. _mul_high (x:: Int8 , y:: Int8 ) = ccall (" extern air.mul_hi.s.i8" , llvmcall, Int8, (Int8, Int8), x, y)
430
+ @device_override Base. MultiplicativeInverses. _mul_high (x:: UInt8 , y:: UInt8 ) = ccall (" extern air.mul_hi.u.i8" , llvmcall, UInt8, (UInt8, UInt8), x, y)
431
+ const mulhi = Base. MultiplicativeInverses. _mul_high
422
432
423
433
# From: https://forums.developer.nvidia.com/t/a-faster-and-more-accurate-implementation-of-expm1f/48085/2
424
434
# Original license copied below:
495
505
end
496
506
497
507
return r
498
- end
508
+ end
0 commit comments