@@ -6,14 +6,14 @@ use std::collections::HashMap;
6
6
use std:: ffi:: c_void;
7
7
use std:: sync:: RwLock ;
8
8
9
+ const CANDLE : & [ u8 ] = include_bytes ! ( concat!( env!( "OUT_DIR" ) , "/candle.metallib" ) ) ;
9
10
const AFFINE : & str = include_str ! ( "affine.metal" ) ;
10
11
const INDEXING : & str = include_str ! ( "indexing.metal" ) ;
11
12
const UNARY : & str = include_str ! ( "unary.metal" ) ;
12
13
const BINARY : & str = include_str ! ( "binary.metal" ) ;
13
14
const TERNARY : & str = include_str ! ( "ternary.metal" ) ;
14
15
const CAST : & str = include_str ! ( "cast.metal" ) ;
15
16
const CONV : & str = include_str ! ( "conv.metal" ) ;
16
- const REDUCE : & str = include_str ! ( "reduce.metal" ) ;
17
17
const RANDOM : & str = include_str ! ( "random.metal" ) ;
18
18
const MFA : & [ u8 ] = include_bytes ! ( "libMetalFlashAttention.metallib" ) ;
19
19
const QUANTIZED : & str = include_str ! ( "quantized.metal" ) ;
@@ -114,13 +114,13 @@ macro_rules! set_params {
114
114
115
115
#[ derive( Debug , Clone , Copy , PartialEq , Eq , Hash ) ]
116
116
pub enum Source {
117
+ Candle ,
117
118
Affine ,
118
119
Indexing ,
119
120
Unary ,
120
121
Binary ,
121
122
Ternary ,
122
123
Cast ,
123
- Reduce ,
124
124
Mfa ,
125
125
Conv ,
126
126
Random ,
@@ -243,11 +243,10 @@ impl Kernels {
243
243
Source :: Ternary => TERNARY ,
244
244
Source :: Indexing => INDEXING ,
245
245
Source :: Cast => CAST ,
246
- Source :: Reduce => REDUCE ,
247
246
Source :: Conv => CONV ,
248
247
Source :: Random => RANDOM ,
249
248
Source :: Quantized => QUANTIZED ,
250
- Source :: Mfa => panic ! ( "Invalid lib" ) ,
249
+ _ => panic ! ( "Invalid lib" ) ,
251
250
}
252
251
}
253
252
@@ -263,6 +262,14 @@ impl Kernels {
263
262
Ok ( lib. clone ( ) )
264
263
} else {
265
264
let lib = match source {
265
+ Source :: Candle => {
266
+ let source_data = CANDLE ;
267
+ device. new_library_with_data ( source_data) . map_err ( |e| {
268
+ MetalKernelError :: LoadLibraryError ( format ! (
269
+ "Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
270
+ ) )
271
+ } ) ?
272
+ } ,
266
273
Source :: Mfa => {
267
274
let source_data = MFA ;
268
275
device. new_library_with_data ( source_data) . map_err ( |e| {
@@ -569,7 +576,7 @@ pub fn call_reduce_contiguous(
569
576
} else {
570
577
( format ! ( "{kernel_name}" ) . leak ( ) , 1 )
571
578
} ;
572
- let pipeline = kernels. load_pipeline ( device, Source :: Reduce , name) ?;
579
+ let pipeline = kernels. load_pipeline ( device, Source :: Candle , name) ?;
573
580
574
581
let encoder = command_buffer. new_compute_command_encoder ( ) ;
575
582
encoder. set_compute_pipeline_state ( & pipeline) ;
@@ -628,7 +635,7 @@ pub fn call_reduce_strided(
628
635
) -> Result < ( ) , MetalKernelError > {
629
636
let length: usize = shape. iter ( ) . product ( ) ;
630
637
let work_per_threadgroup = length / out_length;
631
- let pipeline = kernels. load_pipeline ( device, Source :: Reduce , kernel_name) ?;
638
+ let pipeline = kernels. load_pipeline ( device, Source :: Candle , kernel_name) ?;
632
639
633
640
let encoder = command_buffer. new_compute_command_encoder ( ) ;
634
641
encoder. set_compute_pipeline_state ( & pipeline) ;
@@ -697,7 +704,7 @@ pub fn call_last_softmax(
697
704
( format ! ( "{kernel_name}" ) . leak ( ) , 1 )
698
705
} ;
699
706
700
- let pipeline = kernels. load_pipeline ( device, Source :: Reduce , name) ?;
707
+ let pipeline = kernels. load_pipeline ( device, Source :: Candle , name) ?;
701
708
let encoder = command_buffer. new_compute_command_encoder ( ) ;
702
709
encoder. set_compute_pipeline_state ( & pipeline) ;
703
710
0 commit comments