@@ -105,18 +105,26 @@ pub type Vulkan<F = f32, I = i32, B = u8> = Wgpu<F, I, B>;
105
105
/// Tensor backend that uses the wgpu crate to execute GPU compute shaders written in WGSL.
106
106
pub type WebGpu < F = f32 , I = i32 , B = u32 > = Wgpu < F , I , B > ;
107
107
108
+ #[ cfg( feature = "metal" ) ]
109
+ /// Tensor backend that leverages the Metal graphics API to execute GPU compute shaders compiled to MSL.
110
+ pub type Metal < F = f32 , I = i32 , B = u8 > = Wgpu < F , I , B > ;
111
+
108
112
#[ cfg( test) ]
109
113
mod tests {
110
114
use burn_cubecl:: CubeBackend ;
111
115
#[ cfg( feature = "vulkan" ) ]
112
116
pub use half:: f16;
117
+ #[ cfg( feature = "metal" ) ]
118
+ pub use half:: f16;
113
119
114
120
pub type TestRuntime = cubecl:: wgpu:: WgpuRuntime ;
115
121
116
122
// Don't test `flex32` for now, burn sees it as `f32` but is actually `f16` precision, so it
117
123
// breaks a lot of tests from precision issues
118
124
#[ cfg( feature = "vulkan" ) ]
119
125
burn_cubecl:: testgen_all!( [ f16, f32 ] , [ i8 , i16 , i32 , i64 ] , [ u8 , u32 ] ) ;
120
- #[ cfg( not( feature = "vulkan" ) ) ]
126
+ #[ cfg( feature = "metal" ) ]
127
+ burn_cubecl:: testgen_all!( [ f16, f32 ] , [ i16 , i32 ] , [ u32 ] ) ;
128
+ #[ cfg( all( not( feature = "vulkan" ) , not( feature = "metal" ) ) ) ]
121
129
burn_cubecl:: testgen_all!( [ f32 ] , [ i32 ] , [ u32 ] ) ;
122
130
}
0 commit comments