@@ -7,7 +7,7 @@ function convert_origin(origin::NTuple{2, Int64})
7
7
return (VecElement {Int64} (origin[1 ]- 1 ), VecElement {Int64} (origin[2 ]- 1 ))
8
8
end
9
9
10
- for (jltype, suffix) in ((:Float16 , " f16" ), (:Float32 , " f32" ))
10
+ for (jltype, suffix) in ((:Float16 , " f16" ), (:Float32 , " f32" ), ( :BFloat16 , " bf16 " ) )
11
11
for as in (AS. Device, AS. ThreadGroup)
12
12
@eval begin
13
13
@device_function simdgroup_load (
55
55
simdgroup_load(data::MtlDeviceArray{T}, matrix_origin=(1, 1))
56
56
57
57
Loads data from device or threadgroup memory into an 8x8 SIMD-group matrix
58
- and returns it. `T` must be either `Float16` or `Float32 `.
58
+ and returns it. `T` must be either `Float16`, `Float32`, or `BFloat16 `.
59
59
60
60
# Arguments
61
61
- `matrix_origin::NTuple{2, Int64}=(1, 1)`: origin in the source memory to load from.
@@ -65,7 +65,7 @@ and returns it. `T` must be either `Float16` or `Float32`.
65
65
simdgroup_store(src, dest::MtlDeviceArray{T}, matrix_origin=(1, 1))
66
66
67
67
Stores data from an 8x8 SIMD-group matrix into device or threadgroup memory.
68
- `T` must be either `Float16` or `Float32 `.
68
+ `T` must be either `Float16`, `Float32`, or `BFloat16 `.
69
69
70
70
# Arguments
71
71
- `matrix_origin::NTuple{2, Int64}=(1, 1)`: origin in the destination memory to store to.
@@ -88,6 +88,7 @@ Returns `a * b + c`.
88
88
89
89
simd_shuffle_map = ((Float32, " f32" ),
90
90
(Float16, " f16" ),
91
+ (BFloat16," bf16" ),
91
92
(Int32, " s.i32" ),
92
93
(UInt32, " u.i32" ),
93
94
(Int16, " s.i16" ),
@@ -118,7 +119,7 @@ The value for `delta` must be the same for all threads in the SIMD-group. This f
118
119
doesn't modify the upper `delta` lanes of `data` because it doesn't wrap values around
119
120
the SIMD-group.
120
121
121
- T must be one of the following: Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, or UInt8
122
+ T must be one of the following: Float32, Float16, BFloat16, Int32, UInt32, Int16, UInt16, Int8, or UInt8
122
123
"""
123
124
simd_shuffle_down
124
125
@@ -131,6 +132,6 @@ lane ID minus `delta`.
131
132
The value of `delta` must be the same for all threads in a SIMD-group. This function doesn't
132
133
modify the lower `delta` lanes of `data` because it doesn't wrap values around the SIMD-group.
133
134
134
- T must be one of the following: Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, or UInt8
135
+ T must be one of the following: Float32, Float16, BFloat16, Int32, UInt32, Int16, UInt16, Int8, or UInt8
135
136
"""
136
137
simd_shuffle_up
0 commit comments