Skip to content

Commit 68f7ce1

Browse files
committed
More bfloat support and test fixes
1 parent 48bda9f commit 68f7ce1

File tree

4 files changed

+15
-8
lines changed

4 files changed

+15
-8
lines changed

src/device/intrinsics/math.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ end
418418
j = fma(1.442695f0, a, 12582912.0f0)
419419
j = j - 12582912.0f0
420420
i = unsafe_trunc(Int32, j)
421-
f = fma(j, -6.93145752f-1, a) # log_2_hi
421+
f = fma(j, -6.93145752f-1, a) # log_2_hi
422422
f = fma(j, -1.42860677f-6, f) # log_2_lo
423423

424424
# approximate r = exp(f)-1 on interval [-log(2)/2, +log(2)/2]
@@ -460,4 +460,4 @@ end
460460
end
461461

462462
return r
463-
end
463+
end

src/device/intrinsics/simd.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ function convert_origin(origin::NTuple{2, Int64})
77
return (VecElement{Int64}(origin[1]-1), VecElement{Int64}(origin[2]-1))
88
end
99

10-
for (jltype, suffix) in ((:Float16, "f16"), (:Float32, "f32"), (:BFloat16, "bf18"))
10+
for (jltype, suffix) in ((:Float16, "f16"), (:Float32, "f32"), (:BFloat16, "bf16"))
1111
for as in (AS.Device, AS.ThreadGroup)
1212
@eval begin
1313
@device_function simdgroup_load(

test/device/intrinsics.jl

+5-4
Original file line numberDiff line numberDiff line change
@@ -275,9 +275,9 @@ end
275275
end
276276

277277
@testset "parametrically typed" begin
278-
typs = [Int32, Int64, Float32]
278+
types = [Int32, Int64, Float32]
279279
metal_support() >= v"3.1" && push!(types, BFloat16)
280-
@testset for typ in typs
280+
@testset for typ in types
281281
function kernel(d::MtlDeviceArray{T}, n) where {T}
282282
t = thread_position_in_threadgroup_1d()
283283
tr = n-t+1
@@ -405,8 +405,9 @@ end
405405
return
406406
end
407407

408-
a = MtlArray(rand(typ, 8, 8))
409-
b = MtlArray(rand(typ, 8, 8))
408+
#Use `ones` for figuring out issues
409+
a = MtlArray(ones(typ, 8, 8))
410+
b = MtlArray(ones(typ, 8, 8))
410411
c = MtlArray(zeros(typ, 8, 8))
411412
@metal threads=(8, 8) kernel(a, b, c)
412413
@test Array(a) * Array(b) Array(c)

test/runtests.jl

+7-1
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,12 @@ const gpuarr_eltypes = [Int16, Int32, Int64,
8181
ComplexF16, ComplexF32]
8282
const gpuarr_eltypes_nobf16 = copy(gpuarr_eltypes)
8383

84+
# don't test BFloat16 for unsupported operations
85+
nobf16_tests = ["random", "reductions/reducedim!",
86+
"reductions/mapreducedim!_large", "reductions/mapreduce",
87+
"reductions/== isequal", "reductions/minimum maximum extrema",
88+
"reductions/sum prod", "reductions/mapreducedim!", "reductions/reduce"]
89+
8490
# Add BFloat16 for tests that use it
8591
Metal.metal_support() >= v"3.1" && push!(gpuarr_eltypes, BFloat16)
8692

@@ -90,7 +96,7 @@ for name in keys(TestSuite.tests)
9096
continue
9197
end
9298

93-
tmp_eltypes = name in ["random"] ? gpuarr_eltypes_nobf16 : gpuarr_eltypes
99+
tmp_eltypes = name in nobf16_tests ? gpuarr_eltypes_nobf16 : gpuarr_eltypes
94100

95101
push!(tests, "gpuarrays$(Base.Filesystem.path_separator)$name")
96102
test_runners["gpuarrays$(Base.Filesystem.path_separator)$name"] = ()->TestSuite.tests[name](MtlArray;eltypes=tmp_eltypes)

0 commit comments

Comments
 (0)