Skip to content

Commit a9e26b5

Browse files
committed
Add build.rs to avoid metal kernel jit compile overhead
1 parent 3633135 commit a9e26b5

File tree

3 files changed

+77
-7
lines changed

3 files changed

+77
-7
lines changed

candle-metal-kernels/build.rs

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
use std::process::Command;
2+
use std::{env, str};
3+
use std::path::PathBuf;
4+
5+
const METAL_SOURCES: [&str; 1] = ["reduce"];
6+
7+
fn main() -> Result<(), String> {
8+
println!("cargo:rerun-if-changed=build.rs");
9+
println!("cargo:rerun-if-changed=*.metal");
10+
println!("cargo:rerun-if-changed=*.m");
11+
12+
let xcrun_output = Command::new("xcrun")
13+
.args(["--sdk", "macosx", "--show-sdk-path"])
14+
.output()
15+
.expect("xcrun command failed to start");
16+
17+
let sdk_path = str::from_utf8(&xcrun_output.stdout)
18+
.expect("Invalid UTF-8 from xcrun")
19+
.replace('\n', "");
20+
21+
println!("cargo:rerun-if-changed={sdk_path}");
22+
let current_dir = env::current_dir().expect("Failed to get current directory");
23+
let out_dir = PathBuf::from(std::env::var("OUT_DIR").map_err(|_|"OUT_DIR not set")?);
24+
25+
let sources = current_dir
26+
.join("src")
27+
.to_str()
28+
.unwrap()
29+
.to_string();
30+
31+
// Compile metal to air
32+
let mut compile_air_cmd = Command::new("xcrun");
33+
compile_air_cmd
34+
.arg("metal")
35+
.arg(format!("-working-directory={}", out_dir.to_str().ok_or("")?))
36+
.arg("-c")
37+
.arg("-frecord-sources")
38+
.arg("-w");
39+
for metal_file in METAL_SOURCES {
40+
compile_air_cmd.arg(format!("{sources}/{metal_file}.metal"));
41+
}
42+
compile_air_cmd.spawn().expect("Failed to compile air");
43+
44+
// Compile air to metallib
45+
let metallib = out_dir.join("candle.metallib");
46+
let mut compile_metallib_cmd = Command::new("xcrun");
47+
compile_metallib_cmd
48+
.arg("metal")
49+
.arg("-o")
50+
.arg(&metallib);
51+
52+
for metal_file in METAL_SOURCES {
53+
compile_metallib_cmd.arg(out_dir.join(format!("{metal_file}.air")));
54+
}
55+
56+
compile_metallib_cmd
57+
.spawn()
58+
.expect("Failed to compile metallib");
59+
60+
Ok(())
61+
}

candle-metal-kernels/src/lib.rs

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@ use std::collections::HashMap;
66
use std::ffi::c_void;
77
use std::sync::RwLock;
88

9+
const CANDLE: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/candle.metallib"));
910
const AFFINE: &str = include_str!("affine.metal");
1011
const INDEXING: &str = include_str!("indexing.metal");
1112
const UNARY: &str = include_str!("unary.metal");
1213
const BINARY: &str = include_str!("binary.metal");
1314
const TERNARY: &str = include_str!("ternary.metal");
1415
const CAST: &str = include_str!("cast.metal");
1516
const CONV: &str = include_str!("conv.metal");
16-
const REDUCE: &str = include_str!("reduce.metal");
1717
const RANDOM: &str = include_str!("random.metal");
1818
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
1919
const QUANTIZED: &str = include_str!("quantized.metal");
@@ -114,13 +114,13 @@ macro_rules! set_params {
114114

115115
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
116116
pub enum Source {
117+
Candle,
117118
Affine,
118119
Indexing,
119120
Unary,
120121
Binary,
121122
Ternary,
122123
Cast,
123-
Reduce,
124124
Mfa,
125125
Conv,
126126
Random,
@@ -243,11 +243,10 @@ impl Kernels {
243243
Source::Ternary => TERNARY,
244244
Source::Indexing => INDEXING,
245245
Source::Cast => CAST,
246-
Source::Reduce => REDUCE,
247246
Source::Conv => CONV,
248247
Source::Random => RANDOM,
249248
Source::Quantized => QUANTIZED,
250-
Source::Mfa => panic!("Invalid lib"),
249+
_ => panic!("Invalid lib"),
251250
}
252251
}
253252

@@ -263,6 +262,14 @@ impl Kernels {
263262
Ok(lib.clone())
264263
} else {
265264
let lib = match source {
265+
Source::Candle => {
266+
let source_data = CANDLE;
267+
device.new_library_with_data(source_data).map_err(|e| {
268+
MetalKernelError::LoadLibraryError(format!(
269+
"Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
270+
))
271+
})?
272+
},
266273
Source::Mfa => {
267274
let source_data = MFA;
268275
device.new_library_with_data(source_data).map_err(|e| {
@@ -569,7 +576,7 @@ pub fn call_reduce_contiguous(
569576
} else {
570577
(format!("{kernel_name}").leak(), 1)
571578
};
572-
let pipeline = kernels.load_pipeline(device, Source::Reduce, name)?;
579+
let pipeline = kernels.load_pipeline(device, Source::Candle, name)?;
573580

574581
let encoder = command_buffer.new_compute_command_encoder();
575582
encoder.set_compute_pipeline_state(&pipeline);
@@ -628,7 +635,7 @@ pub fn call_reduce_strided(
628635
) -> Result<(), MetalKernelError> {
629636
let length: usize = shape.iter().product();
630637
let work_per_threadgroup = length / out_length;
631-
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
638+
let pipeline = kernels.load_pipeline(device, Source::Candle, kernel_name)?;
632639

633640
let encoder = command_buffer.new_compute_command_encoder();
634641
encoder.set_compute_pipeline_state(&pipeline);
@@ -697,7 +704,7 @@ pub fn call_last_softmax(
697704
(format!("{kernel_name}").leak(), 1)
698705
};
699706

700-
let pipeline = kernels.load_pipeline(device, Source::Reduce, name)?;
707+
let pipeline = kernels.load_pipeline(device, Source::Candle, name)?;
701708
let encoder = command_buffer.new_compute_command_encoder();
702709
encoder.set_compute_pipeline_state(&pipeline);
703710

candle-metal-kernels/src/reduce.metal

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ constexpr ushort granularity() {
1919
METAL_FUNC uint next_p2(uint x) {
2020
return 1 << (32 - clz(x - 1));
2121
}
22+
23+
2224
METAL_FUNC uint prev_p2(uint x) {
2325
return 1 << (31 - clz(x));
2426
}

0 commit comments

Comments
 (0)