Skip to content

Commit 991f6df

Browse files
Add nextafter intrinsic (#529)
1 parent 6654291 commit 991f6df

File tree

4 files changed

+59
-4
lines changed

4 files changed

+59
-4
lines changed

src/Metal.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ include("device/utils.jl")
2727
include("device/pointer.jl")
2828
include("device/array.jl")
2929
include("device/runtime.jl")
30+
include("device/intrinsics/version.jl")
3031
include("device/intrinsics/arguments.jl")
3132
include("device/intrinsics/math.jl")
3233
include("device/intrinsics/synchronization.jl")
3334
include("device/intrinsics/memory.jl")
3435
include("device/intrinsics/simd.jl")
35-
include("device/intrinsics/version.jl")
3636
include("device/intrinsics/atomics.jl")
3737
include("device/quirks.jl")
3838

src/compiler/execution.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ certain extent arguments will be converted and managed automatically using `mtlc
1818
Finally, a call to `mtlcall` is performed, creating a command buffer in the current global
1919
command queue then committing it.
2020
21-
There is one supported keyword argument that influences the behavior of `@metal`:
21+
There are a few keyword arguments that influence the behavior of `@metal`:
2222
23-
- `launch`: whether to launch this kernel, defaults to `true`. If `false` the returned
23+
- `launch`: whether to launch this kernel, defaults to `true`. If `false`, the returned
2424
kernel object should be launched by calling it and passing arguments again.
2525
- `name`: the name of the kernel in the generated code. Defaults to an automatically-
2626
generated name.

src/device/intrinsics/math.jl

+15
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,21 @@ end
294294
@device_override Base.trunc(x::Float32) = ccall("extern air.trunc.f32", llvmcall, Cfloat, (Cfloat,), x)
295295
@device_override Base.trunc(x::Float16) = ccall("extern air.trunc.f16", llvmcall, Float16, (Float16,), x)
296296

297+
@device_function function nextafter(x::Float32, y::Float32)
298+
if metal_version() >= sv"3.1" # macOS 14+
299+
ccall("extern air.nextafter.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
300+
else
301+
nextfloat(x, unsafe_trunc(Int32, sign(y - x)))
302+
end
303+
end
304+
@device_function function nextafter(x::Float16, y::Float16)
305+
if metal_version() >= sv"3.1" # macOS 14+
306+
ccall("extern air.nextafter.f16", llvmcall, Float16, (Float16, Float16), x, y)
307+
else
308+
nextfloat(x, unsafe_trunc(Int16, sign(y - x)))
309+
end
310+
end
311+
297312
# hypot without use of double
298313
#
299314
# taken from Cosmopolitan Libc

test/device/intrinsics.jl

+41-1
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,6 @@ MATH_INTR_FUNCS_2_ARG = [
159159
# frexp, # T frexp(T x, Ti &exponent)
160160
# ldexp, # T ldexp(T x, Ti k)
161161
# modf, # T modf(T x, T &intval)
162-
# nextafter, # T nextafter(T x, T y) # Metal 3.1+
163162
hypot, # NOT MSL but tested the same
164163
]
165164

@@ -353,6 +352,47 @@ end
353352
vec = Array(expm1.(buffer))
354353
@test vec expm1.(arr)
355354
end
355+
356+
357+
let # nextafter
358+
function nextafter_test(X, y)
359+
idx = thread_position_in_grid_1d()
360+
X[idx] = Metal.nextafter(X[idx], y)
361+
return nothing
362+
end
363+
364+
# Check the code is generated as expected
365+
outval = T(0)
366+
function nextafter_out_test()
367+
Metal.nextafter(outval, outval)
368+
return
369+
end
370+
371+
N = 4
372+
arr = rand(T, N)
373+
374+
# test the intrinsic (macOS >= v14)
375+
if metal_support() >= v"3.1"
376+
buffer1 = MtlArray(arr)
377+
Metal.@sync @metal threads = N nextafter_test(buffer1, typemax(T))
378+
@test Array(buffer1) == nextfloat.(arr)
379+
Metal.@sync @metal threads = N nextafter_test(buffer1, typemin(T))
380+
@test Array(buffer1) == arr
381+
382+
ir = sprint(io->(@device_code_llvm io=io dump_module=true @metal nextafter_out_test()))
383+
@test occursin(Regex("@air\\.nextafter\\.f$(8*sizeof(T))"), ir)
384+
end
385+
386+
# test for metal < 3.1
387+
buffer2 = MtlArray(arr)
388+
Metal.@sync @metal threads = N metal = v"3.0" nextafter_test(buffer2, typemax(T))
389+
@test Array(buffer2) == nextfloat.(arr)
390+
Metal.@sync @metal threads = N metal = v"3.0" nextafter_test(buffer2, typemin(T))
391+
@test Array(buffer2) == arr
392+
393+
ir = sprint(io->(@device_code_llvm io=io dump_module=true @metal metal = v"3.0" nextafter_out_test()))
394+
@test occursin(Regex("@air\\.sign\\.f$(8*sizeof(T))"), ir)
395+
end
356396
end
357397
end
358398

0 commit comments

Comments
 (0)