Skip to content

Commit c2afd6e

Browse files
committed
Initial BFloat16 support
1 parent b020d21 commit c2afd6e

File tree

5 files changed

+17
-10
lines changed

5 files changed

+17
-10
lines changed

lib/mps/MPS.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ using ObjectiveC, .Foundation
1616

1717
import GPUArrays
1818

19-
using BFloat16s
19+
using BFloat16s: BFloat16
2020

21-
const MtlFloat = Union{Float32, Float16}
21+
const MtlFloat = Union{Float32, Float16, BFloat16}
2222

2323
const MPSShape = NSArray#{NSNumber}
2424
Base.convert(::Type{MPSShape}, tuple::Union{Vector{N},NTuple{N, <:Integer}}) where N = NSArray(NSNumber.(collect(tuple)))

src/Metal.jl

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using Preferences: @load_preference, load_preference
1010
using ExprTools: splitdef, combinedef
1111
using ObjectiveC, .CoreFoundation, .Foundation, .Dispatch, .OS
1212
import KernelAbstractions
13+
using BFloat16s
1314

1415
include("version.jl")
1516

src/compiler/compilation.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ function GPUCompiler.finish_ir!(@nospecialize(job::MetalCompilerJob),
1818
# pointer type information for typed intrinsics
1919
# (this is consumed by the LLVM IR downgrader)
2020
for (jltyp, llvmtyp) in (Int32 => :i32, Int64 => :i64,
21-
Float16 => :f16, Float32 => :f32),
21+
Float16 => :f16, Float32 => :f32,
22+
BFloat16 => :bf16),
2223
(as, asname) in (AS.Device => "global", AS.ThreadGroup => "local")
2324

2425
# map of intrinsics to pointer operand indices and eltypes

src/device/intrinsics/simd.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ function convert_origin(origin::NTuple{2, Int64})
77
return (VecElement{Int64}(origin[1]-1), VecElement{Int64}(origin[2]-1))
88
end
99

10-
for (jltype, suffix) in ((:Float16, "f16"), (:Float32, "f32"))
10+
for (jltype, suffix) in ((:Float16, "f16"), (:Float32, "f32"), (:BFloat16, "bf18"))
1111
for as in (AS.Device, AS.ThreadGroup)
1212
@eval begin
1313
@device_function simdgroup_load(
@@ -88,6 +88,7 @@ Returns `a * b + c`.
8888

8989
simd_shuffle_map = ((Float32, "f32"),
9090
(Float16, "f16"),
91+
(BFloat16, "bf16"),
9192
(Int32, "s.i32"),
9293
(UInt32, "u.i32"),
9394
(Int16, "s.i16"),

test/device/intrinsics.jl

+10-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using SpecialFunctions
2+
using BFloat16s
23
using Metal: metal_support
34

45
@testset "arguments" begin
@@ -308,8 +309,9 @@ end
308309
############################################################################################
309310

310311
@testset "simd intrinsics" begin
311-
312-
@testset "shuffle($typ)" for typ in [Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, UInt8]
312+
types = [Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, UInt8]
313+
metal_support() >= v"3.1" && push!(types, BFloat16)
314+
@testset "shuffle($typ)" for typ in types
313315
function kernel(a::MtlDeviceVector{T}, b::MtlDeviceVector{T}) where T
314316
idx = thread_position_in_grid_1d()
315317
idx_in_simd = thread_index_in_simdgroup()
@@ -344,7 +346,9 @@ end
344346
end
345347

346348
@testset "matrix functions" begin
347-
@testset "load_store($typ)" for typ in [Float16, Float32]
349+
simdgroup_types = [Float16, Float32]
350+
metal_support() >= v"3.1" && push!(simdgroup_types, BFloat16)
351+
@testset "load_store($typ)" for typ in simdgroup_types
348352
function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T},
349353
origin_a=(1, 1), origin_b=(1, 1)) where {T}
350354
sg_a = simdgroup_load(a, origin_a)
@@ -367,7 +371,7 @@ end
367371
end
368372
end
369373

370-
@testset "load_store_tg($typ)" for typ in [Float16, Float32]
374+
@testset "load_store_tg($typ)" for typ in simdgroup_types
371375
function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T}) where {T}
372376
pos = thread_position_in_threadgroup_2d()
373377

@@ -391,7 +395,7 @@ end
391395
@test Array(a) == Array(b)
392396
end
393397

394-
@testset "mul($typ)" for typ in [Float16, Float32]
398+
@testset "mul($typ)" for typ in simdgroup_types
395399
function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T}, c::MtlDeviceArray{T}) where {T}
396400
sg_a = simdgroup_load(a)
397401
sg_b = simdgroup_load(b)
@@ -407,7 +411,7 @@ end
407411
@test Array(a) * Array(b) Array(c)
408412
end
409413

410-
@testset "mad($typ)" for typ in [Float16, Float32]
414+
@testset "mad($typ)" for typ in simdgroup_types
411415
function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T}, c::MtlDeviceArray{T},
412416
d::MtlDeviceArray{T}) where {T}
413417
sg_a = simdgroup_load(a)

0 commit comments

Comments
 (0)