Skip to content

Commit 5a94e11

Browse files
committed
Add BFloat16 intrinsics
1 parent 145dc33 commit 5a94e11

File tree

5 files changed

+19
-12
lines changed

5 files changed

+19
-12
lines changed

lib/mps/MPS.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ using ObjectiveC, .Foundation
1616

1717
import GPUArrays
1818

19-
using BFloat16s
19+
using BFloat16s: BFloat16
2020

21-
const MtlFloat = Union{Float32, Float16}
21+
const MtlFloat = Union{Float32, Float16, BFloat16}
2222

2323
const MPSShape = NSArray#{NSNumber}
2424
Base.convert(::Type{MPSShape}, tuple::Union{Vector{T},NTuple{T, <:Integer}}) where T = NSArray(NSNumber.(collect(tuple)))

src/Metal.jl

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ using ExprTools: splitdef, combinedef
1212
using ObjectiveC, .CoreFoundation, .Foundation, .Dispatch, .OS
1313
import ObjectiveC: is_macos, darwin_version, macos_version
1414
import KernelAbstractions
15+
using BFloat16s: BFloat16
1516
using ScopedValues
1617

1718
include("version.jl")

src/compiler/compilation.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ function GPUCompiler.finish_ir!(@nospecialize(job::MetalCompilerJob),
1818
# pointer type information for typed intrinsics
1919
# (this is consumed by the LLVM IR downgrader)
2020
for (jltyp, llvmtyp) in (Int32 => :i32, Int64 => :i64,
21-
Float16 => :f16, Float32 => :f32),
21+
Float16 => :f16, Float32 => :f32,
22+
BFloat16 => :bf16),
2223
(as, asname) in (AS.Device => "global", AS.ThreadGroup => "local")
2324

2425
# map of intrinsics to pointer operand indices and eltypes

src/device/intrinsics/simd.jl

+6-5
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ function convert_origin(origin::NTuple{2, Int64})
77
return (VecElement{Int64}(origin[1]-1), VecElement{Int64}(origin[2]-1))
88
end
99

10-
for (jltype, suffix) in ((:Float16, "f16"), (:Float32, "f32"))
10+
for (jltype, suffix) in ((:Float16, "f16"), (:Float32, "f32"), (:BFloat16, "bf16"))
1111
for as in (AS.Device, AS.ThreadGroup)
1212
@eval begin
1313
@device_function simdgroup_load(
@@ -55,7 +55,7 @@ end
5555
simdgroup_load(data::MtlDeviceArray{T}, matrix_origin=(1, 1))
5656
5757
Loads data from device or threadgroup memory into an 8x8 SIMD-group matrix
58-
and returns it. `T` must be either `Float16` or `Float32`.
58+
and returns it. `T` must be either `Float16`, `Float32`, or `BFloat16`.
5959
6060
# Arguments
6161
- `matrix_origin::NTuple{2, Int64}=(1, 1)`: origin in the source memory to load from.
@@ -65,7 +65,7 @@ and returns it. `T` must be either `Float16` or `Float32`.
6565
simdgroup_store(src, dest::MtlDeviceArray{T}, matrix_origin=(1, 1))
6666
6767
Stores data from an 8x8 SIMD-group matrix into device or threadgroup memory.
68-
`T` must be either `Float16` or `Float32`.
68+
`T` must be either `Float16`, `Float32`, or `BFloat16`.
6969
7070
# Arguments
7171
- `matrix_origin::NTuple{2, Int64}=(1, 1)`: origin in the destination memory to store to.
@@ -88,6 +88,7 @@ Returns `a * b + c`.
8888

8989
simd_shuffle_map = ((Float32, "f32"),
9090
(Float16, "f16"),
91+
(BFloat16,"bf16"),
9192
(Int32, "s.i32"),
9293
(UInt32, "u.i32"),
9394
(Int16, "s.i16"),
@@ -118,7 +119,7 @@ The value for `delta` must be the same for all threads in the SIMD-group. This f
118119
doesn't modify the upper `delta` lanes of `data` because it doesn't wrap values around
119120
the SIMD-group.
120121
121-
T must be one of the following: Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, or UInt8
122+
T must be one of the following: Float32, Float16, BFloat16, Int32, UInt32, Int16, UInt16, Int8, or UInt8
122123
"""
123124
simd_shuffle_down
124125

@@ -131,6 +132,6 @@ lane ID minus `delta`.
131132
The value of `delta` must be the same for all threads in a SIMD-group. This function doesn't
132133
modify the lower `delta` lanes of `data` because it doesn't wrap values around the SIMD-group.
133134
134-
T must be one of the following: Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, or UInt8
135+
T must be one of the following: Float32, Float16, BFloat16, Int32, UInt32, Int16, UInt16, Int8, or UInt8
135136
"""
136137
simd_shuffle_up

test/device/intrinsics/simd.jl

+8-4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using Metal: metal_support
2+
13
@testset "simd intrinsics" begin
24

35
@testset "$f($typ)" for typ in [Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, UInt8], (f,res_idx) in [(simd_shuffle_down, 1), (simd_shuffle_up, 32)]
@@ -36,7 +38,9 @@
3638
end
3739

3840
@testset "matrix functions" begin
39-
@testset "load_store($typ)" for typ in [Float16, Float32]
41+
simdgroup_types = [Float16, Float32]
42+
metal_support() >= v"3.1" && push!(simdgroup_types, BFloat16)
43+
@testset "load_store($typ)" for typ in simdgroup_types
4044
function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T},
4145
origin_a=(1, 1), origin_b=(1, 1)) where {T}
4246
sg_a = simdgroup_load(a, origin_a)
@@ -59,7 +63,7 @@ end
5963
end
6064
end
6165

62-
@testset "load_store_tg($typ)" for typ in [Float16, Float32]
66+
@testset "load_store_tg($typ)" for typ in simdgroup_types
6367
function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T}) where {T}
6468
pos = thread_position_in_threadgroup_2d()
6569

@@ -83,7 +87,7 @@ end
8387
@test Array(a) == Array(b)
8488
end
8589

86-
@testset "mul($typ)" for typ in [Float16, Float32]
90+
@testset "mul($typ)" for typ in simdgroup_types
8791
function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T}, c::MtlDeviceArray{T}) where {T}
8892
sg_a = simdgroup_load(a)
8993
sg_b = simdgroup_load(b)
@@ -99,7 +103,7 @@ end
99103
@test Array(a) * Array(b) Array(c)
100104
end
101105

102-
@testset "mad($typ)" for typ in [Float16, Float32]
106+
@testset "mad($typ)" for typ in simdgroup_types
103107
function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T}, c::MtlDeviceArray{T},
104108
d::MtlDeviceArray{T}) where {T}
105109
sg_a = simdgroup_load(a)

0 commit comments

Comments
 (0)