Skip to content

Commit 76f635f

Browse files
committed
Add BFloat16 intrinsics
1 parent ca1a3fa commit 76f635f

File tree

5 files changed

+21
-14
lines changed

5 files changed

+21
-14
lines changed

lib/mps/MPS.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ using ObjectiveC, .Foundation
1616

1717
import GPUArrays
1818

19-
using BFloat16s
19+
using BFloat16s: BFloat16
2020

21-
const MtlFloat = Union{Float32, Float16}
21+
const MtlFloat = Union{Float32, Float16, BFloat16}
2222

2323
const MPSShape = NSArray#{NSNumber}
2424
Base.convert(::Type{MPSShape}, tuple::Union{Vector{N},NTuple{N, <:Integer}}) where N = NSArray(NSNumber.(collect(tuple)))

src/Metal.jl

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using Preferences: @load_preference, load_preference
1010
using ExprTools: splitdef, combinedef
1111
using ObjectiveC, .CoreFoundation, .Foundation, .Dispatch, .OS
1212
import KernelAbstractions
13+
using BFloat16s
1314

1415
include("version.jl")
1516

src/compiler/compilation.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ function GPUCompiler.finish_ir!(@nospecialize(job::MetalCompilerJob),
1818
# pointer type information for typed intrinsics
1919
# (this is consumed by the LLVM IR downgrader)
2020
for (jltyp, llvmtyp) in (Int32 => :i32, Int64 => :i64,
21-
Float16 => :f16, Float32 => :f32),
21+
Float16 => :f16, Float32 => :f32,
22+
BFloat16 => :bf16),
2223
(as, asname) in (AS.Device => "global", AS.ThreadGroup => "local")
2324

2425
# map of intrinsics to pointer operand indices and eltypes

src/device/intrinsics/simd.jl

+6-5
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ function convert_origin(origin::NTuple{2, Int64})
77
return (VecElement{Int64}(origin[1]-1), VecElement{Int64}(origin[2]-1))
88
end
99

10-
for (jltype, suffix) in ((:Float16, "f16"), (:Float32, "f32"))
10+
for (jltype, suffix) in ((:Float16, "f16"), (:Float32, "f32"), (:BFloat16, "bf16"))
1111
for as in (AS.Device, AS.ThreadGroup)
1212
@eval begin
1313
@device_function simdgroup_load(
@@ -55,7 +55,7 @@ end
5555
simdgroup_load(data::MtlDeviceArray{T}, matrix_origin=(1, 1))
5656
5757
Loads data from device or threadgroup memory into an 8x8 SIMD-group matrix
58-
and returns it. `T` must be either `Float16` or `Float32`.
58+
and returns it. `T` must be either `Float16`, `Float32`, or `BFloat16`.
5959
6060
# Arguments
6161
- `matrix_origin::NTuple{2, Int64}=(1, 1)`: origin in the source memory to load from.
@@ -65,7 +65,7 @@ and returns it. `T` must be either `Float16` or `Float32`.
6565
simdgroup_store(src, dest::MtlDeviceArray{T}, matrix_origin=(1, 1))
6666
6767
Stores data from an 8x8 SIMD-group matrix into device or threadgroup memory.
68-
`T` must be either `Float16` or `Float32`.
68+
`T` must be either `Float16`, `Float32`, `BFloat16`.
6969
7070
# Arguments
7171
- `matrix_origin::NTuple{2, Int64}=(1, 1)`: origin in the destination memory to store to.
@@ -88,6 +88,7 @@ Returns `a * b + c`.
8888

8989
simd_shuffle_map = ((Float32, "f32"),
9090
(Float16, "f16"),
91+
(BFloat16, "bf16"),
9192
(Int32, "s.i32"),
9293
(UInt32, "u.i32"),
9394
(Int16, "s.i16"),
@@ -118,7 +119,7 @@ The value for delta must be the same for all threads in the SIMD-group. This fun
118119
doesn’t modify the upper delta lanes of data because it doesn’t wrap values around
119120
the SIMD-group.
120121
121-
T must be one of the following: Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, or UInt8
122+
T must be one of the following: Float32, Float16, BFloat16, Int32, UInt32, Int16, UInt16, Int8, or UInt8
122123
"""
123124
simd_shuffle_down
124125

@@ -131,6 +132,6 @@ lane ID minus delta.
131132
The value of delta must be the same for all threads in a SIMD-group. This function doesn’t
132133
modify the lower delta lanes of data because it doesn’t wrap values around the SIMD-group.
133134
134-
T must be one of the following: Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, or UInt8
135+
T must be one of the following: Float32, Float16, BFloat16, Int32, UInt32, Int16, UInt16, Int8, or UInt8
135136
"""
136137
simd_shuffle_up

test/device/intrinsics.jl

+10-6
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using BFloat16s
12
using Metal: metal_support
23
using Random
34
using SpecialFunctions
@@ -499,8 +500,9 @@ end
499500
############################################################################################
500501

501502
@testset "simd intrinsics" begin
502-
503-
@testset "shuffle($typ)" for typ in [Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, UInt8]
503+
types = [Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, UInt8]
504+
metal_support() >= v"3.1" && push!(types, BFloat16)
505+
@testset "shuffle($typ)" for typ in types
504506
function kernel(a::MtlDeviceVector{T}, b::MtlDeviceVector{T}) where T
505507
idx = thread_position_in_grid_1d()
506508
idx_in_simd = thread_index_in_simdgroup()
@@ -535,7 +537,9 @@ end
535537
end
536538

537539
@testset "matrix functions" begin
538-
@testset "load_store($typ)" for typ in [Float16, Float32]
540+
simdgroup_types = [Float16, Float32]
541+
metal_support() >= v"3.1" && push!(simdgroup_types, BFloat16)
542+
@testset "load_store($typ)" for typ in simdgroup_types
539543
function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T},
540544
origin_a=(1, 1), origin_b=(1, 1)) where {T}
541545
sg_a = simdgroup_load(a, origin_a)
@@ -558,7 +562,7 @@ end
558562
end
559563
end
560564

561-
@testset "load_store_tg($typ)" for typ in [Float16, Float32]
565+
@testset "load_store_tg($typ)" for typ in simdgroup_types
562566
function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T}) where {T}
563567
pos = thread_position_in_threadgroup_2d()
564568

@@ -582,7 +586,7 @@ end
582586
@test Array(a) == Array(b)
583587
end
584588

585-
@testset "mul($typ)" for typ in [Float16, Float32]
589+
@testset "mul($typ)" for typ in simdgroup_types
586590
function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T}, c::MtlDeviceArray{T}) where {T}
587591
sg_a = simdgroup_load(a)
588592
sg_b = simdgroup_load(b)
@@ -598,7 +602,7 @@ end
598602
@test Array(a) * Array(b) Array(c)
599603
end
600604

601-
@testset "mad($typ)" for typ in [Float16, Float32]
605+
@testset "mad($typ)" for typ in simdgroup_types
602606
function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T}, c::MtlDeviceArray{T},
603607
d::MtlDeviceArray{T}) where {T}
604608
sg_a = simdgroup_load(a)

0 commit comments

Comments
 (0)