Skip to content

Commit cf56ea3

Browse files
committed
Add BFloat16 intrinsics
1 parent bec8c71 commit cf56ea3

File tree

5 files changed

+21
-14
lines changed

5 files changed

+21
-14
lines changed

lib/mps/MPS.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ using ObjectiveC, .Foundation
1616

1717
import GPUArrays
1818

19-
using BFloat16s
19+
using BFloat16s: BFloat16
2020

21-
const MtlFloat = Union{Float32, Float16}
21+
const MtlFloat = Union{Float32, Float16, BFloat16}
2222

2323
const MPSShape = NSArray#{NSNumber}
2424
Base.convert(::Type{MPSShape}, tuple::Union{Vector{N},NTuple{N, <:Integer}}) where N = NSArray(NSNumber.(collect(tuple)))

src/Metal.jl

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ using ExprTools: splitdef, combinedef
1111
using ObjectiveC, .CoreFoundation, .Foundation, .Dispatch, .OS
1212
import ObjectiveC: is_macos, darwin_version, macos_version
1313
import KernelAbstractions
14+
using BFloat16s: BFloat16
1415

1516
include("version.jl")
1617

src/compiler/compilation.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ function GPUCompiler.finish_ir!(@nospecialize(job::MetalCompilerJob),
1818
# pointer type information for typed intrinsics
1919
# (this is consumed by the LLVM IR downgrader)
2020
for (jltyp, llvmtyp) in (Int32 => :i32, Int64 => :i64,
21-
Float16 => :f16, Float32 => :f32),
21+
Float16 => :f16, Float32 => :f32,
22+
BFloat16 => :bf16),
2223
(as, asname) in (AS.Device => "global", AS.ThreadGroup => "local")
2324

2425
# map of intrinsics to pointer operand indices and eltypes

src/device/intrinsics/simd.jl

+6-5
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ function convert_origin(origin::NTuple{2, Int64})
77
return (VecElement{Int64}(origin[1]-1), VecElement{Int64}(origin[2]-1))
88
end
99

10-
for (jltype, suffix) in ((:Float16, "f16"), (:Float32, "f32"))
10+
for (jltype, suffix) in ((:Float16, "f16"), (:Float32, "f32"), (:BFloat16, "bf16"))
1111
for as in (AS.Device, AS.ThreadGroup)
1212
@eval begin
1313
@device_function simdgroup_load(
@@ -55,7 +55,7 @@ end
5555
simdgroup_load(data::MtlDeviceArray{T}, matrix_origin=(1, 1))
5656
5757
Loads data from device or threadgroup memory into an 8x8 SIMD-group matrix
58-
and returns it. `T` must be either `Float16` or `Float32`.
58+
and returns it. `T` must be either `Float16`, `Float32`, or `BFloat16`.
5959
6060
# Arguments
6161
- `matrix_origin::NTuple{2, Int64}=(1, 1)`: origin in the source memory to load from.
@@ -65,7 +65,7 @@ and returns it. `T` must be either `Float16` or `Float32`.
6565
simdgroup_store(src, dest::MtlDeviceArray{T}, matrix_origin=(1, 1))
6666
6767
Stores data from an 8x8 SIMD-group matrix into device or threadgroup memory.
68-
`T` must be either `Float16` or `Float32`.
68+
`T` must be either `Float16`, `Float32`, `BFloat16`.
6969
7070
# Arguments
7171
- `matrix_origin::NTuple{2, Int64}=(1, 1)`: origin in the destination memory to store to.
@@ -88,6 +88,7 @@ Returns `a * b + c`.
8888

8989
simd_shuffle_map = ((Float32, "f32"),
9090
(Float16, "f16"),
91+
(BFloat16, "bf16"),
9192
(Int32, "s.i32"),
9293
(UInt32, "u.i32"),
9394
(Int16, "s.i16"),
@@ -118,7 +119,7 @@ The value for delta must be the same for all threads in the SIMD-group. This fun
118119
doesn’t modify the upper delta lanes of data because it doesn’t wrap values around
119120
the SIMD-group.
120121
121-
T must be one of the following: Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, or UInt8
122+
T must be one of the following: Float32, Float16, BFloat16, Int32, UInt32, Int16, UInt16, Int8, or UInt8
122123
"""
123124
simd_shuffle_down
124125

@@ -131,6 +132,6 @@ lane ID minus delta.
131132
The value of delta must be the same for all threads in a SIMD-group. This function doesn’t
132133
modify the lower delta lanes of data because it doesn’t wrap values around the SIMD-group.
133134
134-
T must be one of the following: Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, or UInt8
135+
T must be one of the following: Float32, Float16, BFloat16, Int32, UInt32, Int16, UInt16, Int8, or UInt8
135136
"""
136137
simd_shuffle_up

test/device/intrinsics.jl

+10-6
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using BFloat16s
12
using Metal: metal_support
23
using Random
34
using SpecialFunctions
@@ -534,8 +535,9 @@ end
534535
############################################################################################
535536

536537
@testset "simd intrinsics" begin
537-
538-
@testset "shuffle($typ)" for typ in [Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, UInt8]
538+
types = [Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, UInt8]
539+
metal_support() >= v"3.1" && push!(types, BFloat16)
540+
@testset "shuffle($typ)" for typ in types
539541
function kernel(a::MtlDeviceVector{T}, b::MtlDeviceVector{T}) where T
540542
idx = thread_position_in_grid_1d()
541543
idx_in_simd = thread_index_in_simdgroup()
@@ -570,7 +572,9 @@ end
570572
end
571573

572574
@testset "matrix functions" begin
573-
@testset "load_store($typ)" for typ in [Float16, Float32]
575+
simdgroup_types = [Float16, Float32]
576+
metal_support() >= v"3.1" && push!(simdgroup_types, BFloat16)
577+
@testset "load_store($typ)" for typ in simdgroup_types
574578
function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T},
575579
origin_a=(1, 1), origin_b=(1, 1)) where {T}
576580
sg_a = simdgroup_load(a, origin_a)
@@ -593,7 +597,7 @@ end
593597
end
594598
end
595599

596-
@testset "load_store_tg($typ)" for typ in [Float16, Float32]
600+
@testset "load_store_tg($typ)" for typ in simdgroup_types
597601
function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T}) where {T}
598602
pos = thread_position_in_threadgroup_2d()
599603

@@ -617,7 +621,7 @@ end
617621
@test Array(a) == Array(b)
618622
end
619623

620-
@testset "mul($typ)" for typ in [Float16, Float32]
624+
@testset "mul($typ)" for typ in simdgroup_types
621625
function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T}, c::MtlDeviceArray{T}) where {T}
622626
sg_a = simdgroup_load(a)
623627
sg_b = simdgroup_load(b)
@@ -633,7 +637,7 @@ end
633637
@test Array(a) * Array(b) Array(c)
634638
end
635639

636-
@testset "mad($typ)" for typ in [Float16, Float32]
640+
@testset "mad($typ)" for typ in simdgroup_types
637641
function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T}, c::MtlDeviceArray{T},
638642
d::MtlDeviceArray{T}) where {T}
639643
sg_a = simdgroup_load(a)

0 commit comments

Comments
 (0)