Skip to content

Commit 34fa83d

Browse files
committed
Add BFloat16 intrinsics
1 parent 5af28d2 commit 34fa83d

File tree

5 files changed

+21
-14
lines changed

5 files changed

+21
-14
lines changed

lib/mps/MPS.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ using ObjectiveC, .Foundation
1616

1717
import GPUArrays
1818

19-
using BFloat16s
19+
using BFloat16s: BFloat16
2020

21-
const MtlFloat = Union{Float32, Float16}
21+
const MtlFloat = Union{Float32, Float16, BFloat16}
2222

2323
const MPSShape = NSArray#{NSNumber}
2424
Base.convert(::Type{MPSShape}, tuple::Union{Vector{N},NTuple{N, <:Integer}}) where N = NSArray(NSNumber.(collect(tuple)))

src/Metal.jl

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ using ExprTools: splitdef, combinedef
1212
using ObjectiveC, .CoreFoundation, .Foundation, .Dispatch, .OS
1313
import ObjectiveC: is_macos, darwin_version, macos_version
1414
import KernelAbstractions
15+
using BFloat16s: BFloat16
1516

1617
include("version.jl")
1718

src/compiler/compilation.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ function GPUCompiler.finish_ir!(@nospecialize(job::MetalCompilerJob),
1818
# pointer type information for typed intrinsics
1919
# (this is consumed by the LLVM IR downgrader)
2020
for (jltyp, llvmtyp) in (Int32 => :i32, Int64 => :i64,
21-
Float16 => :f16, Float32 => :f32),
21+
Float16 => :f16, Float32 => :f32,
22+
BFloat16 => :bf16),
2223
(as, asname) in (AS.Device => "global", AS.ThreadGroup => "local")
2324

2425
# map of intrinsics to pointer operand indices and eltypes

src/device/intrinsics/simd.jl

+6-5
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ function convert_origin(origin::NTuple{2, Int64})
77
return (VecElement{Int64}(origin[1]-1), VecElement{Int64}(origin[2]-1))
88
end
99

10-
for (jltype, suffix) in ((:Float16, "f16"), (:Float32, "f32"))
10+
for (jltype, suffix) in ((:Float16, "f16"), (:Float32, "f32"), (:BFloat16, "bf16"))
1111
for as in (AS.Device, AS.ThreadGroup)
1212
@eval begin
1313
@device_function simdgroup_load(
@@ -55,7 +55,7 @@ end
5555
simdgroup_load(data::MtlDeviceArray{T}, matrix_origin=(1, 1))
5656
5757
Loads data from device or threadgroup memory into an 8x8 SIMD-group matrix
58-
and returns it. `T` must be either `Float16` or `Float32`.
58+
and returns it. `T` must be either `Float16`, `Float32`, or `BFloat16`.
5959
6060
# Arguments
6161
- `matrix_origin::NTuple{2, Int64}=(1, 1)`: origin in the source memory to load from.
@@ -65,7 +65,7 @@ and returns it. `T` must be either `Float16` or `Float32`.
6565
simdgroup_store(src, dest::MtlDeviceArray{T}, matrix_origin=(1, 1))
6666
6767
Stores data from an 8x8 SIMD-group matrix into device or threadgroup memory.
68-
`T` must be either `Float16` or `Float32`.
68+
`T` must be either `Float16`, `Float32`, `BFloat16`.
6969
7070
# Arguments
7171
- `matrix_origin::NTuple{2, Int64}=(1, 1)`: origin in the destination memory to store to.
@@ -88,6 +88,7 @@ Returns `a * b + c`.
8888

8989
simd_shuffle_map = ((Float32, "f32"),
9090
(Float16, "f16"),
91+
(BFloat16, "bf16"),
9192
(Int32, "s.i32"),
9293
(UInt32, "u.i32"),
9394
(Int16, "s.i16"),
@@ -118,7 +119,7 @@ The value for delta must be the same for all threads in the SIMD-group. This fun
118119
doesn’t modify the upper delta lanes of data because it doesn’t wrap values around
119120
the SIMD-group.
120121
121-
T must be one of the following: Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, or UInt8
122+
T must be one of the following: Float32, Float16, BFloat16, Int32, UInt32, Int16, UInt16, Int8, or UInt8
122123
"""
123124
simd_shuffle_down
124125

@@ -131,6 +132,6 @@ lane ID minus delta.
131132
The value of delta must be the same for all threads in a SIMD-group. This function doesn’t
132133
modify the lower delta lanes of data because it doesn’t wrap values around the SIMD-group.
133134
134-
T must be one of the following: Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, or UInt8
135+
T must be one of the following: Float32, Float16, BFloat16, Int32, UInt32, Int16, UInt16, Int8, or UInt8
135136
"""
136137
simd_shuffle_up

test/device/intrinsics.jl

+10-6
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using BFloat16s
12
using Metal: metal_support
23
using Random
34
using SpecialFunctions
@@ -624,8 +625,9 @@ end
624625
############################################################################################
625626

626627
@testset "simd intrinsics" begin
627-
628-
@testset "shuffle($typ)" for typ in [Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, UInt8]
628+
types = [Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, UInt8]
629+
metal_support() >= v"3.1" && push!(types, BFloat16)
630+
@testset "shuffle($typ)" for typ in types
629631
function kernel(a::MtlDeviceVector{T}, b::MtlDeviceVector{T}) where T
630632
idx = thread_position_in_grid_1d()
631633
idx_in_simd = thread_index_in_simdgroup()
@@ -660,7 +662,9 @@ end
660662
end
661663

662664
@testset "matrix functions" begin
663-
@testset "load_store($typ)" for typ in [Float16, Float32]
665+
simdgroup_types = [Float16, Float32]
666+
metal_support() >= v"3.1" && push!(simdgroup_types, BFloat16)
667+
@testset "load_store($typ)" for typ in simdgroup_types
664668
function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T},
665669
origin_a=(1, 1), origin_b=(1, 1)) where {T}
666670
sg_a = simdgroup_load(a, origin_a)
@@ -683,7 +687,7 @@ end
683687
end
684688
end
685689

686-
@testset "load_store_tg($typ)" for typ in [Float16, Float32]
690+
@testset "load_store_tg($typ)" for typ in simdgroup_types
687691
function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T}) where {T}
688692
pos = thread_position_in_threadgroup_2d()
689693

@@ -707,7 +711,7 @@ end
707711
@test Array(a) == Array(b)
708712
end
709713

710-
@testset "mul($typ)" for typ in [Float16, Float32]
714+
@testset "mul($typ)" for typ in simdgroup_types
711715
function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T}, c::MtlDeviceArray{T}) where {T}
712716
sg_a = simdgroup_load(a)
713717
sg_b = simdgroup_load(b)
@@ -723,7 +727,7 @@ end
723727
@test Array(a) * Array(b) Array(c)
724728
end
725729

726-
@testset "mad($typ)" for typ in [Float16, Float32]
730+
@testset "mad($typ)" for typ in simdgroup_types
727731
function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T}, c::MtlDeviceArray{T},
728732
d::MtlDeviceArray{T}) where {T}
729733
sg_a = simdgroup_load(a)

0 commit comments

Comments
 (0)