Skip to content

Commit aaf3179

Browse files
authored
[metal] Add CubeCL metal compiler support (#2993)
1 parent d28557c commit aaf3179

File tree

8 files changed

+32
-4
lines changed

8 files changed

+32
-4
lines changed

crates/burn-wgpu/Cargo.toml

+2
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@ template = ["burn-cubecl/template", "cubecl/template"]
2323
# Backends
2424
vulkan = ["cubecl-spirv"]
2525
webgpu = ["cubecl-wgsl"]
26+
metal = ["cubecl-msl"]
2627

2728
# Compilers
2829
cubecl-spirv = ["cubecl/wgpu-spirv"]
2930
cubecl-wgsl = []
31+
cubecl-msl = ["cubecl/wgpu-cpp-msl"]
3032

3133
[dependencies]
3234
cubecl = { workspace = true, features = ["wgpu"] }

crates/burn-wgpu/src/lib.rs

+9-1
Original file line numberDiff line numberDiff line change
@@ -105,18 +105,26 @@ pub type Vulkan<F = f32, I = i32, B = u8> = Wgpu<F, I, B>;
105105
/// Tensor backend that uses the wgpu crate to execute GPU compute shaders written in WGSL.
106106
pub type WebGpu<F = f32, I = i32, B = u32> = Wgpu<F, I, B>;
107107

108+
#[cfg(feature = "metal")]
109+
/// Tensor backend that leverages the Metal graphics API to execute GPU compute shaders compiled to MSL.
110+
pub type Metal<F = f32, I = i32, B = u8> = Wgpu<F, I, B>;
111+
108112
#[cfg(test)]
109113
mod tests {
110114
use burn_cubecl::CubeBackend;
111115
#[cfg(feature = "vulkan")]
112116
pub use half::f16;
117+
#[cfg(feature = "metal")]
118+
pub use half::f16;
113119

114120
pub type TestRuntime = cubecl::wgpu::WgpuRuntime;
115121

116122
// Don't test `flex32` for now, burn sees it as `f32` but is actually `f16` precision, so it
117123
// breaks a lot of tests from precision issues
118124
#[cfg(feature = "vulkan")]
119125
burn_cubecl::testgen_all!([f16, f32], [i8, i16, i32, i64], [u8, u32]);
120-
#[cfg(not(feature = "vulkan"))]
126+
#[cfg(feature = "metal")]
127+
burn_cubecl::testgen_all!([f16, f32], [i16, i32], [u32]);
128+
#[cfg(all(not(feature = "vulkan"), not(feature = "metal")))]
121129
burn_cubecl::testgen_all!([f32], [i32], [u32]);
122130
}

crates/burn/Cargo.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ fusion = ["burn-wgpu?/fusion", "burn-cuda?/fusion"]
8989
accelerate = ["burn-candle?/accelerate", "burn-ndarray?/blas-accelerate"]
9090
autotune = ["burn-wgpu?/autotune", "burn-cuda?/autotune", "burn-hip?/autotune"]
9191
blas-netlib = ["burn-ndarray?/blas-netlib"]
92-
metal = ["burn-candle?/metal"]
9392
openblas = ["burn-ndarray?/blas-openblas"]
9493
openblas-system = ["burn-ndarray?/blas-openblas-system"]
9594
remote = ["burn-remote/client"]
@@ -100,12 +99,14 @@ template = ["burn-wgpu?/template"]
10099

101100
candle = ["burn-candle"]
102101
candle-cuda = ["candle", "burn-candle/cuda"]
102+
candle-metal = ["burn-candle?/metal"]
103103
cuda = ["burn-cuda"]
104104
hip = ["burn-hip"]
105105
ndarray = ["burn-ndarray"]
106106
tch = ["burn-tch"]
107107
vulkan = ["wgpu", "burn-wgpu/vulkan"]
108108
webgpu = ["wgpu", "burn-wgpu/webgpu"]
109+
metal = ["wgpu", "burn-wgpu/metal"]
109110
wgpu = ["burn-wgpu"]
110111

111112
[dependencies]

crates/burn/src/backend.rs

+3
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ pub use burn_wgpu::WebGpu;
2727
#[cfg(feature = "vulkan")]
2828
pub use burn_wgpu::Vulkan;
2929

30+
#[cfg(feature = "metal")]
31+
pub use burn_wgpu::Metal;
32+
3033
#[cfg(feature = "cuda")]
3134
pub use burn_cuda as cuda;
3235

examples/mnist/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"]
1515
tch-cpu = ["burn/tch"]
1616
tch-gpu = ["burn/tch"]
1717
wgpu = ["burn/wgpu"]
18+
metal = ["burn/metal"]
1819

1920
[dependencies]
2021
burn = { path = "../../crates/burn", features = ["train"] }

examples/mnist/examples/mnist.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ mod tch_gpu {
3535
}
3636
}
3737

38-
#[cfg(feature = "wgpu")]
38+
#[cfg(any(feature = "wgpu", feature = "metal",))]
3939
mod wgpu {
4040
use burn::backend::{
4141
Autodiff,
@@ -75,6 +75,6 @@ fn main() {
7575
tch_gpu::run();
7676
#[cfg(feature = "tch-cpu")]
7777
tch_cpu::run();
78-
#[cfg(feature = "wgpu")]
78+
#[cfg(any(feature = "wgpu", feature = "metal"))]
7979
wgpu::run();
8080
}

examples/text-classification/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ vulkan = ["burn/vulkan", "burn/default"]
2020
remote = ["burn/remote"]
2121
cuda = ["burn/cuda"]
2222
hip = ["burn/hip"]
23+
metal = ["burn/metal"]
2324

2425
[dependencies]
2526
# Burn

examples/text-classification/examples/ag-news-train.rs

+12
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,16 @@ mod vulkan {
103103
}
104104
}
105105

106+
#[cfg(feature = "metal")]
107+
mod metal {
108+
use crate::{ElemType, launch};
109+
use burn::backend::{Autodiff, Metal};
110+
111+
pub fn run() {
112+
launch::<Autodiff<Metal<ElemType, i32>>>(vec![Default::default()]);
113+
}
114+
}
115+
106116
#[cfg(feature = "remote")]
107117
mod remote {
108118
use crate::{ElemType, launch};
@@ -155,4 +165,6 @@ fn main() {
155165
remote::run();
156166
#[cfg(feature = "vulkan")]
157167
vulkan::run();
168+
#[cfg(feature = "metal")]
169+
metal::run();
158170
}

0 commit comments

Comments
 (0)