From 7a13e5846c01a436327057124a864db094ab9796 Mon Sep 17 00:00:00 2001 From: Max Date: Thu, 17 Apr 2025 12:43:00 +0900 Subject: [PATCH 01/22] add Qwen3.rs --- candle-transformers/src/models/mod.rs | 1 + candle-transformers/src/models/qwen3.rs | 436 ++++++++++++++++++++++++ 2 files changed, 437 insertions(+) create mode 100644 candle-transformers/src/models/qwen3.rs diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index bdb8d267b5..9e198f63a7 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -96,6 +96,7 @@ pub mod quantized_stable_lm; pub mod quantized_t5; pub mod qwen2; pub mod qwen2_moe; +pub mod qwen3; pub mod recurrent_gemma; pub mod repvgg; pub mod resnet; diff --git a/candle-transformers/src/models/qwen3.rs b/candle-transformers/src/models/qwen3.rs new file mode 100644 index 0000000000..a2db35d5ef --- /dev/null +++ b/candle-transformers/src/models/qwen3.rs @@ -0,0 +1,436 @@ +use crate::models::with_tracing::{linear, linear_no_bias, Linear}; +use candle::{DType, Device, Module, Result, Tensor, D}; +use candle_nn::{Activation, VarBuilder}; +use std::sync::Arc; + +#[derive(Debug, Clone, PartialEq, serde::Deserialize)] +pub struct Config { + pub vocab_size: usize, + pub hidden_size: usize, + pub intermediate_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub head_dim: usize, + pub attention_bias: bool, + pub num_key_value_heads: usize, + pub max_position_embeddings: usize, + pub sliding_window: Option, // ⬅️ Option 으로 수정 + pub max_window_layers: usize, + pub tie_word_embeddings: bool, + pub rope_theta: f64, + pub rms_norm_eps: f64, + pub use_sliding_window: bool, + pub hidden_act: Activation, +} + +#[derive(Debug, Clone)] +struct Qwen3RotaryEmbedding { + sin: Tensor, + cos: Tensor, +} + +impl Qwen3RotaryEmbedding { + fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { + let dim = cfg.head_dim; + let max_seq_len = cfg.max_position_embeddings; + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; + let t = Tensor::arange(0u32, max_seq_len as u32, dev)? + .to_dtype(dtype)? + .reshape((max_seq_len, 1))?; + let freqs = t.matmul(&inv_freq)?; + Ok(Self { + sin: freqs.sin()?, + cos: freqs.cos()?, + }) + } + + /// RoPE 적용 (q, k shape: B x H x L x D) + fn apply(&self, q: &Tensor, k: &Tensor, offset: usize) -> Result<(Tensor, Tensor)> { + let (_, _, seq_len, _) = q.dims4()?; + let cos = self.cos.narrow(0, offset, seq_len)?; + let sin = self.sin.narrow(0, offset, seq_len)?; + let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?; + let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?; + Ok((q_embed, k_embed)) + } +} + +#[derive(Debug, Clone)] +struct Qwen3RmsNorm { + weight: Tensor, + eps: f64, +} + +impl Qwen3RmsNorm { + fn new(dim: usize, eps: f64, vb: VarBuilder) -> Result { + let weight = vb.get((dim,), "weight")?; + Ok(Self { weight, eps }) + } + + fn forward(&self, xs: &Tensor) -> Result { + let orig_dtype = xs.dtype(); + let xs = xs.to_dtype(DType::F32)?; + let var = (xs.clone() * &xs)?.mean_keepdim(D::Minus1)?; + let rms = (var + self.eps)?.powf(-0.5)?; + let xs = xs.broadcast_mul(&rms)?; + let ws = self.weight.reshape((1, -1))?; + let out = xs * &ws; + out?.to_dtype(orig_dtype) + } +} + +#[derive(Debug, Clone)] +struct Qwen3HeadRmsNorm { + weight: Tensor, + eps: f64, +} + +impl Qwen3HeadRmsNorm { + fn new(dim: usize, eps: f64, vb: VarBuilder) -> Result { + let weight = vb.get((dim,), DType::F32)?; + Ok(Self { weight, eps }) + } + + fn forward(&self, xs: &Tensor) -> Result { + // xs: (B*L*H, D) + let orig_dtype = xs.dtype(); + let xs = xs.to_dtype(DType::F32)?; + let var = (xs.clone() * &xs)?.mean_keepdim(D::Minus1)?; + let rms = var.add(self.eps)?.powf(-0.5)?; + let xs = xs.broadcast_mul(&rms)?; + let ws = self.weight.reshape((1, -1))?; + let out = xs * &ws; + out?.to_dtype(orig_dtype) + } +} + +fn repeat_kv(kv: &Tensor, n_rep: usize) -> Result { + if n_rep == 1 { + return Ok(kv.clone()); + } + let (b, h_kv, l, d) = kv.dims4()?; + kv.unsqueeze(2)? + .expand((b, h_kv, n_rep, l, d))? + .reshape((b, h_kv * n_rep, l, d)) +} + +#[derive(Debug, Clone)] +struct Qwen3MLP { + gate_proj: Linear, + up_proj: Linear, + down_proj: Linear, + act_fn: Activation, +} + +impl Qwen3MLP { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + Ok(Self { + gate_proj: linear_no_bias(cfg.hidden_size, cfg.intermediate_size, vb.pp("gate_proj"))?, + up_proj: linear_no_bias(cfg.hidden_size, cfg.intermediate_size, vb.pp("up_proj"))?, + down_proj: linear_no_bias(cfg.intermediate_size, cfg.hidden_size, vb.pp("down_proj"))?, + act_fn: cfg.hidden_act, + }) + } +} + +impl Module for Qwen3MLP { + fn forward(&self, x: &Tensor) -> Result { + let lhs = x.apply(&self.gate_proj)?.apply(&self.act_fn)?; + let rhs = x.apply(&self.up_proj)?; + (lhs * rhs)?.apply(&self.down_proj) + } +} + +#[derive(Debug, Clone)] +struct Qwen3Attention { + // projections + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + o_proj: Linear, + // norms + q_norm: Qwen3HeadRmsNorm, + k_norm: Qwen3HeadRmsNorm, + // hyper params + num_heads: usize, + num_kv_heads: usize, + num_kv_groups: usize, + head_dim: usize, + hidden_size: usize, + // sliding window + sliding_window: Option, + // utils + rotary_emb: Arc, + kv_cache: Option<(Tensor, Tensor)>, +} + +impl Qwen3Attention { + fn new( + cfg: &Config, + rotary_emb: Arc, + layer_idx: usize, + vb: VarBuilder, + ) -> Result { + let head_dim = cfg.head_dim; + let num_heads = cfg.num_attention_heads; + let num_kv_heads = cfg.num_key_value_heads; + let num_kv_groups = num_heads / num_kv_heads; + + let (q_proj, k_proj, v_proj, o_proj) = if cfg.attention_bias { + ( + linear(cfg.hidden_size, num_heads * head_dim, vb.pp("q_proj"))?, + linear(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("k_proj"))?, + linear(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("v_proj"))?, + linear(num_heads * head_dim, cfg.hidden_size, vb.pp("o_proj"))?, + ) + } else { + ( + linear_no_bias(cfg.hidden_size, num_heads * head_dim, vb.pp("q_proj"))?, + linear_no_bias(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("k_proj"))?, + linear_no_bias(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("v_proj"))?, + linear_no_bias(num_heads * head_dim, cfg.hidden_size, vb.pp("o_proj"))?, + ) + }; + + let q_norm = Qwen3HeadRmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("q_norm"))?; + let k_norm = Qwen3HeadRmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("k_norm"))?; + + let sliding_window = if cfg.use_sliding_window && layer_idx >= cfg.max_window_layers { + cfg.sliding_window + } else { + None + }; + + Ok(Self { + q_proj, + k_proj, + v_proj, + o_proj, + q_norm, + k_norm, + num_heads, + num_kv_heads, + num_kv_groups, + head_dim, + hidden_size: cfg.hidden_size, + sliding_window, + rotary_emb, + kv_cache: None, + }) + } + + fn forward(&mut self, x: &Tensor, attn_mask: Option<&Tensor>, offset: usize) -> Result { + let (b, l, _) = x.dims3()?; + + // 1. Proj + let q = self.q_proj.forward(x)?; + let k = self.k_proj.forward(x)?; + let v = self.v_proj.forward(x)?; + + // 2. Reshape: (B, L, H, D) -> (B, H, L, D) + let q = q + .reshape((b, l, self.num_heads, self.head_dim))? + .transpose(1, 2)?; + let k = k + .reshape((b, l, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + let v = v + .reshape((b, l, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + + // 3. Per‑head RMSNorm + let q_flat = q.flatten(0, 2)?; // (B*H, L, D) -> (BHL, D) after transpose later + let k_flat = k.flatten(0, 2)?; + let q_flat = self.q_norm.forward(&q_flat)?; + let k_flat = self.k_norm.forward(&k_flat)?; + let q = q_flat.reshape((b, self.num_heads, l, self.head_dim))?; + let k = k_flat.reshape((b, self.num_kv_heads, l, self.head_dim))?; + + // 4. RoPE + let (q, k) = self.rotary_emb.apply(&q, &k, offset)?; + + // 5. KV 캐시 누적 + let (k, v) = match &self.kv_cache { + None => (k, v), + Some((prev_k, prev_v)) => ( + Tensor::cat(&[prev_k, &k], 2)?, + Tensor::cat(&[prev_v, &v], 2)?, + ), + }; + self.kv_cache = Some((k.clone(), v.clone())); + + // 6. GQA repeat_kv + let k = repeat_kv(&k, self.num_kv_groups)?; + let v = repeat_kv(&v, self.num_kv_groups)?; + + // 7. Attention score + let scale = 1.0 / (self.head_dim as f64).sqrt(); + let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?; + if let Some(m) = attn_mask { + scores = scores.broadcast_add(m)?; + } + let probs = candle_nn::ops::softmax_last_dim(&scores)?; + let ctx = probs.matmul(&v)?; // (B, H, L, D) + + // 8. Output proj + ctx.transpose(1, 2)? + .reshape((b, l, self.hidden_size))? + .apply(&self.o_proj) + } + + fn clear_kv_cache(&mut self) { + self.kv_cache = None; + } +} + +#[derive(Debug, Clone)] +struct DecoderLayer { + self_attn: Qwen3Attention, + mlp: Qwen3MLP, + ln1: Qwen3RmsNorm, + ln2: Qwen3RmsNorm, +} + +impl DecoderLayer { + fn new( + cfg: &Config, + rotary: Arc, + idx: usize, + vb: VarBuilder, + ) -> Result { + Ok(Self { + self_attn: Qwen3Attention::new(cfg, rotary, idx, vb.pp("self_attn"))?, + mlp: Qwen3MLP::new(cfg, vb.pp("mlp"))?, + ln1: Qwen3RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("ln1"))?, + ln2: Qwen3RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("ln2"))?, + }) + } + + fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result { + let h = self.ln1.forward(x)?; + let h = self.self_attn.forward(&h, mask, offset)?; + let x = x + h; + let h2 = self.ln2.forward(&x)?; + let h2 = h2.apply(&self.mlp)?; + Ok(x + h2?) + } + + fn clear_kv_cache(&mut self) { + self.self_attn.clear_kv_cache(); + } +} + +#[derive(Debug, Clone)] +pub struct Model { + embed_tokens: candle_nn::Embedding, + layers: Vec, + norm: Qwen3RmsNorm, + rotary: Arc, + device: Device, + dtype: DType, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let embed_tokens = + candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("embed_tokens"))?; + let rotary = Arc::new(Qwen3RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?); + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + let vb_l = vb.pp("layers"); + for i in 0..cfg.num_hidden_layers { + layers.push(DecoderLayer::new(cfg, rotary.clone(), i, vb_l.pp(i))?); + } + Ok(Self { + embed_tokens, + layers, + norm: Qwen3RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("norm"))?, + rotary, + device: vb.device().clone(), + dtype: vb.dtype(), + }) + } + + fn clear_kv_cache(&mut self) { + for l in &mut self.layers { + l.clear_kv_cache(); + } + } + + fn causal_mask( + &self, + b: usize, + tgt: usize, + offset: usize, + sw: Option, + ) -> Result { + let minf = f32::NEG_INFINITY; + let mask: Vec<_> = (0..tgt) + .flat_map(|i| { + (0..(tgt + offset)).map(move |j| { + let past_ok = j <= i + offset; + let sw_ok = match sw { + Some(w) => (i + offset) as i64 - j as i64 <= w as i64, + None => true, + }; + if past_ok && sw_ok { + 0. + } else { + minf + } + }) + }) + .collect(); + Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype) + } + + pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result { + let (b, l) = input.dims2()?; + let mut h = self.embed_tokens.forward(input)?; + + let causal = if l == 1 { + None + } else { + Some(self.causal_mask(b, l, offset, None)?) + }; + + for layer in &mut self.layers { + h = layer.forward(&h, causal.as_ref(), offset)?; + } + self.norm.forward(&h) + } +} + +#[derive(Debug, Clone)] +pub struct ModelForCausalLM { + base: Model, + lm_head: Linear, +} + +impl ModelForCausalLM { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let base = Model::new(cfg, vb.clone())?; + let lm_head = if vb.contains_tensor("lm_head.weight") { + linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))? + } else { + Linear::from_weights(base.embed_tokens.embeddings().clone(), None) + }; + Ok(Self { base, lm_head }) + } + + pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result { + let (_, l) = input.dims2()?; + self.base + .forward(input, offset)? + .narrow(1, l - 1, 1)? + .apply(&self.lm_head) + } + + pub fn clear_kv_cache(&mut self) { + self.base.clear_kv_cache(); + } +} From e2da61909dfbd60ce4de7a7bd465f7e8b46d409f Mon Sep 17 00:00:00 2001 From: Max Date: Thu, 17 Apr 2025 13:23:04 +0900 Subject: [PATCH 02/22] fixed compile error --- candle-transformers/src/models/qwen3.rs | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/candle-transformers/src/models/qwen3.rs b/candle-transformers/src/models/qwen3.rs index a2db35d5ef..d5073f969a 100644 --- a/candle-transformers/src/models/qwen3.rs +++ b/candle-transformers/src/models/qwen3.rs @@ -14,7 +14,7 @@ pub struct Config { pub attention_bias: bool, pub num_key_value_heads: usize, pub max_position_embeddings: usize, - pub sliding_window: Option, // ⬅️ Option 으로 수정 + pub sliding_window: Option, pub max_window_layers: usize, pub tie_word_embeddings: bool, pub rope_theta: f64, @@ -79,8 +79,7 @@ impl Qwen3RmsNorm { let rms = (var + self.eps)?.powf(-0.5)?; let xs = xs.broadcast_mul(&rms)?; let ws = self.weight.reshape((1, -1))?; - let out = xs * &ws; - out?.to_dtype(orig_dtype) + Ok((xs * &ws)?.to_dtype(orig_dtype)) } } @@ -104,8 +103,7 @@ impl Qwen3HeadRmsNorm { let rms = var.add(self.eps)?.powf(-0.5)?; let xs = xs.broadcast_mul(&rms)?; let ws = self.weight.reshape((1, -1))?; - let out = xs * &ws; - out?.to_dtype(orig_dtype) + Ok((xs * &ws)?.to_dtype(orig_dtype)) } } @@ -317,7 +315,7 @@ impl DecoderLayer { let x = x + h; let h2 = self.ln2.forward(&x)?; let h2 = h2.apply(&self.mlp)?; - Ok(x + h2?) + Ok(x + h2) } fn clear_kv_cache(&mut self) { From 9c39581402031a68d70fe3cf3644f4aac6ea5ec9 Mon Sep 17 00:00:00 2001 From: keighbee Date: Mon, 28 Apr 2025 16:58:56 -0700 Subject: [PATCH 03/22] attempting to gett pr 2903 working with qwen weights --- candle-examples/examples/qwen/main.rs | 13 ++++++++++- candle-transformers/src/models/qwen3.rs | 30 +++++++++++++------------ 2 files changed, 28 insertions(+), 15 deletions(-) diff --git a/candle-examples/examples/qwen/main.rs b/candle-examples/examples/qwen/main.rs index 53f2f70dd1..e448cd9c56 100644 --- a/candle-examples/examples/qwen/main.rs +++ b/candle-examples/examples/qwen/main.rs @@ -9,6 +9,7 @@ use clap::Parser; use candle_transformers::models::qwen2::{Config as ConfigBase, ModelForCausalLM as ModelBase}; use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe}; +use candle_transformers::models::qwen3::{Config as Config3, Model as Model3}; use candle::{DType, Device, Tensor}; use candle_examples::token_output_stream::TokenOutputStream; @@ -20,6 +21,7 @@ use tokenizers::Tokenizer; enum Model { Base(ModelBase), Moe(ModelMoe), + Base3(Model3) } impl Model { @@ -27,6 +29,7 @@ impl Model { match self { Self::Moe(ref mut m) => m.forward(xs, s), Self::Base(ref mut m) => m.forward(xs, s), + Self::Base3(ref mut m) => m.forward(xs, s), } } } @@ -152,6 +155,8 @@ enum WhichModel { W2_7b, #[value(name = "2-72b")] W2_72b, + #[value(name = "3-8B")] + W3_8b, } #[derive(Parser, Debug)] @@ -254,6 +259,7 @@ fn main() -> Result<()> { WhichModel::W14b => ("1.5", "14B"), WhichModel::W72b => ("1.5", "72B"), WhichModel::MoeA27b => ("1.5", "MoE-A2.7B"), + WhichModel::W3_8b => ("3", "8B"), }; format!("Qwen/Qwen{version}-{size}") } @@ -282,7 +288,8 @@ fn main() -> Result<()> { | WhichModel::W14b | WhichModel::W72b | WhichModel::W2_72b - | WhichModel::MoeA27b => { + | WhichModel::MoeA27b + | WhichModel::W3_8b => { candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")? } }, @@ -303,6 +310,10 @@ fn main() -> Result<()> { WhichModel::MoeA27b => { let config: ConfigMoe = serde_json::from_slice(&std::fs::read(config_file)?)?; Model::Moe(ModelMoe::new(&config, vb)?) + }, + WhichModel::W3_8b => { + let config: Config3 = serde_json::from_slice(&std::fs::read(config_file)?)?; + Model::Base3(Model3::new(&config, vb)?) } _ => { let config: ConfigBase = serde_json::from_slice(&std::fs::read(config_file)?)?; diff --git a/candle-transformers/src/models/qwen3.rs b/candle-transformers/src/models/qwen3.rs index d5073f969a..d56dc80406 100644 --- a/candle-transformers/src/models/qwen3.rs +++ b/candle-transformers/src/models/qwen3.rs @@ -78,8 +78,8 @@ impl Qwen3RmsNorm { let var = (xs.clone() * &xs)?.mean_keepdim(D::Minus1)?; let rms = (var + self.eps)?.powf(-0.5)?; let xs = xs.broadcast_mul(&rms)?; - let ws = self.weight.reshape((1, -1))?; - Ok((xs * &ws)?.to_dtype(orig_dtype)) + let ws = self.weight.unsqueeze(0)?.unsqueeze(1)?; + Ok((xs.broadcast_mul(&ws))?.to_dtype(orig_dtype)?) } } @@ -91,7 +91,7 @@ struct Qwen3HeadRmsNorm { impl Qwen3HeadRmsNorm { fn new(dim: usize, eps: f64, vb: VarBuilder) -> Result { - let weight = vb.get((dim,), DType::F32)?; + let weight = vb.get((dim,), "weight")?; Ok(Self { weight, eps }) } @@ -100,10 +100,11 @@ impl Qwen3HeadRmsNorm { let orig_dtype = xs.dtype(); let xs = xs.to_dtype(DType::F32)?; let var = (xs.clone() * &xs)?.mean_keepdim(D::Minus1)?; - let rms = var.add(self.eps)?.powf(-0.5)?; + let eps_tensor = Tensor::new::(self.eps as f32, xs.device())?; + let rms = var.broadcast_add(&eps_tensor)?.powf(-0.5)?; let xs = xs.broadcast_mul(&rms)?; - let ws = self.weight.reshape((1, -1))?; - Ok((xs * &ws)?.to_dtype(orig_dtype)) + let ws = self.weight.unsqueeze(0)?; + Ok((xs.broadcast_mul(&ws))?.to_dtype(orig_dtype)?) } } @@ -304,18 +305,18 @@ impl DecoderLayer { Ok(Self { self_attn: Qwen3Attention::new(cfg, rotary, idx, vb.pp("self_attn"))?, mlp: Qwen3MLP::new(cfg, vb.pp("mlp"))?, - ln1: Qwen3RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("ln1"))?, - ln2: Qwen3RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("ln2"))?, + ln1: Qwen3RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?, + ln2: Qwen3RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("post_attention_layernorm"))?, }) } fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result { let h = self.ln1.forward(x)?; let h = self.self_attn.forward(&h, mask, offset)?; - let x = x + h; + let x = x.broadcast_add(&h)?; let h2 = self.ln2.forward(&x)?; let h2 = h2.apply(&self.mlp)?; - Ok(x + h2) + Ok(x.broadcast_add(&h2)?) } fn clear_kv_cache(&mut self) { @@ -336,17 +337,17 @@ pub struct Model { impl Model { pub fn new(cfg: &Config, vb: VarBuilder) -> Result { let embed_tokens = - candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("embed_tokens"))?; + candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?; let rotary = Arc::new(Qwen3RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?); let mut layers = Vec::with_capacity(cfg.num_hidden_layers); - let vb_l = vb.pp("layers"); + let vb_l = vb.pp("model.layers"); for i in 0..cfg.num_hidden_layers { layers.push(DecoderLayer::new(cfg, rotary.clone(), i, vb_l.pp(i))?); } Ok(Self { embed_tokens, layers, - norm: Qwen3RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("norm"))?, + norm: Qwen3RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?, rotary, device: vb.device().clone(), dtype: vb.dtype(), @@ -399,7 +400,8 @@ impl Model { for layer in &mut self.layers { h = layer.forward(&h, causal.as_ref(), offset)?; } - self.norm.forward(&h) + let b = self.norm.forward(&h)?; + Ok(b) } } From a014a67740dce67dc2fd1df328a588731e063421 Mon Sep 17 00:00:00 2001 From: keighbee Date: Tue, 29 Apr 2025 16:09:03 -0700 Subject: [PATCH 04/22] different qwen variants working --- candle-examples/examples/qwen/main.rs | 27 +++++++++++++++++++------ candle-transformers/src/models/qwen3.rs | 18 +++++++++++------ 2 files changed, 33 insertions(+), 12 deletions(-) diff --git a/candle-examples/examples/qwen/main.rs b/candle-examples/examples/qwen/main.rs index e448cd9c56..b9ea876c09 100644 --- a/candle-examples/examples/qwen/main.rs +++ b/candle-examples/examples/qwen/main.rs @@ -9,7 +9,7 @@ use clap::Parser; use candle_transformers::models::qwen2::{Config as ConfigBase, ModelForCausalLM as ModelBase}; use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe}; -use candle_transformers::models::qwen3::{Config as Config3, Model as Model3}; +use candle_transformers::models::qwen3::{Config as Config3, ModelForCausalLM as Model3}; use candle::{DType, Device, Tensor}; use candle_examples::token_output_stream::TokenOutputStream; @@ -21,7 +21,7 @@ use tokenizers::Tokenizer; enum Model { Base(ModelBase), Moe(ModelMoe), - Base3(Model3) + Base3(Model3), } impl Model { @@ -155,6 +155,12 @@ enum WhichModel { W2_7b, #[value(name = "2-72b")] W2_72b, + #[value(name = "3-0.6B")] + W3_0_6b, + #[value(name = "3-1.7B")] + W3_1_7b, + #[value(name = "3-4B")] + W3_4b, #[value(name = "3-8B")] W3_8b, } @@ -259,6 +265,9 @@ fn main() -> Result<()> { WhichModel::W14b => ("1.5", "14B"), WhichModel::W72b => ("1.5", "72B"), WhichModel::MoeA27b => ("1.5", "MoE-A2.7B"), + WhichModel::W3_0_6b => ("3", "0.6B"), + WhichModel::W3_1_7b => ("3", "1.7B"), + WhichModel::W3_4b => ("3", "4B"), WhichModel::W3_8b => ("3", "8B"), }; format!("Qwen/Qwen{version}-{size}") @@ -279,7 +288,11 @@ fn main() -> Result<()> { .map(std::path::PathBuf::from) .collect::>(), None => match args.model { - WhichModel::W0_5b | WhichModel::W2_0_5b | WhichModel::W2_1_5b | WhichModel::W1_8b => { + WhichModel::W0_5b + | WhichModel::W2_0_5b + | WhichModel::W2_1_5b + | WhichModel::W1_8b + | WhichModel::W3_0_6b => { vec![repo.get("model.safetensors")?] } WhichModel::W4b @@ -288,7 +301,9 @@ fn main() -> Result<()> { | WhichModel::W14b | WhichModel::W72b | WhichModel::W2_72b - | WhichModel::MoeA27b + | WhichModel::MoeA27b + | WhichModel::W3_1_7b + | WhichModel::W3_4b | WhichModel::W3_8b => { candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")? } @@ -310,8 +325,8 @@ fn main() -> Result<()> { WhichModel::MoeA27b => { let config: ConfigMoe = serde_json::from_slice(&std::fs::read(config_file)?)?; Model::Moe(ModelMoe::new(&config, vb)?) - }, - WhichModel::W3_8b => { + } + WhichModel::W3_0_6b | WhichModel::W3_1_7b | WhichModel::W3_4b | WhichModel::W3_8b => { let config: Config3 = serde_json::from_slice(&std::fs::read(config_file)?)?; Model::Base3(Model3::new(&config, vb)?) } diff --git a/candle-transformers/src/models/qwen3.rs b/candle-transformers/src/models/qwen3.rs index d56dc80406..4d180d3b8f 100644 --- a/candle-transformers/src/models/qwen3.rs +++ b/candle-transformers/src/models/qwen3.rs @@ -79,7 +79,7 @@ impl Qwen3RmsNorm { let rms = (var + self.eps)?.powf(-0.5)?; let xs = xs.broadcast_mul(&rms)?; let ws = self.weight.unsqueeze(0)?.unsqueeze(1)?; - Ok((xs.broadcast_mul(&ws))?.to_dtype(orig_dtype)?) + xs.broadcast_mul(&ws)?.to_dtype(orig_dtype) } } @@ -100,11 +100,10 @@ impl Qwen3HeadRmsNorm { let orig_dtype = xs.dtype(); let xs = xs.to_dtype(DType::F32)?; let var = (xs.clone() * &xs)?.mean_keepdim(D::Minus1)?; - let eps_tensor = Tensor::new::(self.eps as f32, xs.device())?; - let rms = var.broadcast_add(&eps_tensor)?.powf(-0.5)?; + let rms = (var + self.eps)?.powf(-0.5)?; let xs = xs.broadcast_mul(&rms)?; let ws = self.weight.unsqueeze(0)?; - Ok((xs.broadcast_mul(&ws))?.to_dtype(orig_dtype)?) + xs.broadcast_mul(&ws)?.to_dtype(orig_dtype) } } @@ -205,6 +204,9 @@ impl Qwen3Attention { None }; + // Necessary because the hidden_size in the cofig isn't always accurate + let hidden_size = cfg.head_dim * cfg.num_attention_heads; + Ok(Self { q_proj, k_proj, @@ -216,7 +218,7 @@ impl Qwen3Attention { num_kv_heads, num_kv_groups, head_dim, - hidden_size: cfg.hidden_size, + hidden_size, sliding_window, rotary_emb, kv_cache: None, @@ -306,7 +308,11 @@ impl DecoderLayer { self_attn: Qwen3Attention::new(cfg, rotary, idx, vb.pp("self_attn"))?, mlp: Qwen3MLP::new(cfg, vb.pp("mlp"))?, ln1: Qwen3RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?, - ln2: Qwen3RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("post_attention_layernorm"))?, + ln2: Qwen3RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + vb.pp("post_attention_layernorm"), + )?, }) } From 8727cdf32beb572d4b13c2bc409000f3e141af02 Mon Sep 17 00:00:00 2001 From: keighbee Date: Tue, 29 Apr 2025 17:04:03 -0700 Subject: [PATCH 05/22] added moe model --- candle-examples/examples/qwen/main.rs | 13 +- candle-transformers/src/models/mod.rs | 1 + candle-transformers/src/models/qwen3_moe.rs | 559 ++++++++++++++++++++ 3 files changed, 572 insertions(+), 1 deletion(-) create mode 100644 candle-transformers/src/models/qwen3_moe.rs diff --git a/candle-examples/examples/qwen/main.rs b/candle-examples/examples/qwen/main.rs index b9ea876c09..7eab07db8e 100644 --- a/candle-examples/examples/qwen/main.rs +++ b/candle-examples/examples/qwen/main.rs @@ -10,6 +10,7 @@ use clap::Parser; use candle_transformers::models::qwen2::{Config as ConfigBase, ModelForCausalLM as ModelBase}; use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe}; use candle_transformers::models::qwen3::{Config as Config3, ModelForCausalLM as Model3}; +use candle_transformers::models::qwen3_moe::{Config as ConfigMoe3, Model as ModelMoe3}; use candle::{DType, Device, Tensor}; use candle_examples::token_output_stream::TokenOutputStream; @@ -22,6 +23,7 @@ enum Model { Base(ModelBase), Moe(ModelMoe), Base3(Model3), + Moe3(ModelMoe3), } impl Model { @@ -30,6 +32,7 @@ impl Model { Self::Moe(ref mut m) => m.forward(xs, s), Self::Base(ref mut m) => m.forward(xs, s), Self::Base3(ref mut m) => m.forward(xs, s), + Self::Moe3(ref mut m) => m.forward(xs, s), } } } @@ -163,6 +166,8 @@ enum WhichModel { W3_4b, #[value(name = "3-8B")] W3_8b, + #[value(name = "3-moe-a3b")] + W3MoeA3b, } #[derive(Parser, Debug)] @@ -269,6 +274,7 @@ fn main() -> Result<()> { WhichModel::W3_1_7b => ("3", "1.7B"), WhichModel::W3_4b => ("3", "4B"), WhichModel::W3_8b => ("3", "8B"), + WhichModel::W3MoeA3b => ("3", "30B-A3B"), }; format!("Qwen/Qwen{version}-{size}") } @@ -304,7 +310,8 @@ fn main() -> Result<()> { | WhichModel::MoeA27b | WhichModel::W3_1_7b | WhichModel::W3_4b - | WhichModel::W3_8b => { + | WhichModel::W3_8b + | WhichModel::W3MoeA3b => { candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")? } }, @@ -330,6 +337,10 @@ fn main() -> Result<()> { let config: Config3 = serde_json::from_slice(&std::fs::read(config_file)?)?; Model::Base3(Model3::new(&config, vb)?) } + WhichModel::W3MoeA3b => { + let config: ConfigMoe3 = serde_json::from_slice(&std::fs::read(config_file)?)?; + Model::Moe3(ModelMoe3::new(&config, vb)?) + } _ => { let config: ConfigBase = serde_json::from_slice(&std::fs::read(config_file)?)?; Model::Base(ModelBase::new(&config, vb)?) diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 9e198f63a7..1cb25a133d 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -97,6 +97,7 @@ pub mod quantized_t5; pub mod qwen2; pub mod qwen2_moe; pub mod qwen3; +pub mod qwen3_moe; pub mod recurrent_gemma; pub mod repvgg; pub mod resnet; diff --git a/candle-transformers/src/models/qwen3_moe.rs b/candle-transformers/src/models/qwen3_moe.rs new file mode 100644 index 0000000000..ddf1738bf9 --- /dev/null +++ b/candle-transformers/src/models/qwen3_moe.rs @@ -0,0 +1,559 @@ +use crate::models::with_tracing::{linear, linear_no_bias, Linear}; +use candle::{DType, Device, Module, Result, Tensor, D}; +use candle_nn::{Activation, VarBuilder}; +use std::sync::Arc; + +#[derive(Debug, Clone, PartialEq, serde::Deserialize)] +pub struct Config { + pub vocab_size: usize, + pub hidden_size: usize, + pub intermediate_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub head_dim: usize, + pub attention_bias: bool, + pub num_key_value_heads: usize, + pub max_position_embeddings: usize, + pub sliding_window: Option, + pub max_window_layers: usize, + pub tie_word_embeddings: bool, + pub rope_theta: f64, + pub rms_norm_eps: f64, + pub use_sliding_window: bool, + pub hidden_act: Activation, + // MoE specific configuration + pub decoder_sparse_step: usize, + pub moe_intermediate_size: usize, + pub num_experts_per_tok: usize, + pub num_experts: usize, + pub norm_topk_prob: bool, +} + +#[derive(Debug, Clone)] +struct Qwen3RotaryEmbedding { + sin: Tensor, + cos: Tensor, +} + +impl Qwen3RotaryEmbedding { + fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { + let dim = cfg.head_dim; + let max_seq_len = cfg.max_position_embeddings; + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; + let t = Tensor::arange(0u32, max_seq_len as u32, dev)? + .to_dtype(dtype)? + .reshape((max_seq_len, 1))?; + let freqs = t.matmul(&inv_freq)?; + Ok(Self { + sin: freqs.sin()?, + cos: freqs.cos()?, + }) + } + + fn apply(&self, q: &Tensor, k: &Tensor, offset: usize) -> Result<(Tensor, Tensor)> { + let (_, _, seq_len, _) = q.dims4()?; + let cos = self.cos.narrow(0, offset, seq_len)?; + let sin = self.sin.narrow(0, offset, seq_len)?; + let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?; + let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?; + Ok((q_embed, k_embed)) + } +} + +#[derive(Debug, Clone)] +struct Qwen3RmsNorm { + weight: Tensor, + eps: f64, +} + +impl Qwen3RmsNorm { + fn new(dim: usize, eps: f64, vb: VarBuilder) -> Result { + let weight = vb.get((dim,), "weight")?; + Ok(Self { weight, eps }) + } + + fn forward(&self, xs: &Tensor) -> Result { + let orig_dtype = xs.dtype(); + let xs = xs.to_dtype(DType::F32)?; + let var = (xs.clone() * &xs)?.mean_keepdim(D::Minus1)?; + let rms = (var + self.eps)?.powf(-0.5)?; + let xs = xs.broadcast_mul(&rms)?; + let ws = self.weight.unsqueeze(0)?.unsqueeze(1)?; + xs.broadcast_mul(&ws)?.to_dtype(orig_dtype) + } +} + +#[derive(Debug, Clone)] +struct Qwen3HeadRmsNorm { + weight: Tensor, + eps: f64, +} + +impl Qwen3HeadRmsNorm { + fn new(dim: usize, eps: f64, vb: VarBuilder) -> Result { + let weight = vb.get((dim,), "weight")?; + Ok(Self { weight, eps }) + } + + fn forward(&self, xs: &Tensor) -> Result { + // xs: (B*L*H, D) + let orig_dtype = xs.dtype(); + let xs = xs.to_dtype(DType::F32)?; + let var = (xs.clone() * &xs)?.mean_keepdim(D::Minus1)?; + let rms = (var + self.eps)?.powf(-0.5)?; + let xs = xs.broadcast_mul(&rms)?; + let ws = self.weight.unsqueeze(0)?; + xs.broadcast_mul(&ws)?.to_dtype(orig_dtype) + } +} + +fn repeat_kv(kv: &Tensor, n_rep: usize) -> Result { + if n_rep == 1 { + return Ok(kv.clone()); + } + let (b, h_kv, l, d) = kv.dims4()?; + kv.unsqueeze(2)? + .expand((b, h_kv, n_rep, l, d))? + .reshape((b, h_kv * n_rep, l, d)) +} + +#[derive(Debug, Clone)] +struct Qwen3MLP { + gate_proj: Linear, + up_proj: Linear, + down_proj: Linear, + act_fn: Activation, +} + +impl Qwen3MLP { + fn new(cfg: &Config, intermediate_size: usize, vb: VarBuilder) -> Result { + Ok(Self { + gate_proj: linear_no_bias(cfg.hidden_size, intermediate_size, vb.pp("gate_proj"))?, + up_proj: linear_no_bias(cfg.hidden_size, intermediate_size, vb.pp("up_proj"))?, + down_proj: linear_no_bias(intermediate_size, cfg.hidden_size, vb.pp("down_proj"))?, + act_fn: cfg.hidden_act, + }) + } +} + +impl Module for Qwen3MLP { + fn forward(&self, x: &Tensor) -> Result { + let lhs = x.apply(&self.gate_proj)?.apply(&self.act_fn)?; + let rhs = x.apply(&self.up_proj)?; + (lhs * rhs)?.apply(&self.down_proj) + } +} + +// Qwen3 Sparse MoE Block implementation +#[derive(Debug, Clone)] +struct Qwen3SparseMoeBlock { + gate: Linear, + experts: Vec, + norm_topk_prob: bool, + num_experts_per_tok: usize, +} + +impl Qwen3SparseMoeBlock { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let gate = linear_no_bias(cfg.hidden_size, cfg.num_experts, vb.pp("gate"))?; + let mut experts = Vec::with_capacity(cfg.num_experts); + let vb_e = vb.pp("experts"); + for idx in 0..cfg.num_experts { + let expert = Qwen3MLP::new(cfg, cfg.moe_intermediate_size, vb_e.pp(idx))?; + experts.push(expert) + } + Ok(Self { + gate, + experts, + norm_topk_prob: cfg.norm_topk_prob, + num_experts_per_tok: cfg.num_experts_per_tok, + }) + } +} + +impl Module for Qwen3SparseMoeBlock { + fn forward(&self, xs: &Tensor) -> Result { + let (b_size, seq_len, hidden_dim) = xs.dims3()?; + let xs = xs.reshape(((), hidden_dim))?; + let router_logits = xs.apply(&self.gate)?; + let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?; + + // Extract topk experts per token + let experts_per_tok = routing_weights + .arg_sort_last_dim(false)? + .narrow(D::Minus1, 0, self.num_experts_per_tok)? + .contiguous()?; + let routing_weights = routing_weights.gather(&experts_per_tok, D::Minus1)?; + + // Extract needed data + let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::()?; + let experts_per_tok = experts_per_tok.to_vec2::()?; + let mut top_x = vec![vec![]; self.experts.len()]; + let mut selected_experts = vec![vec![]; self.experts.len()]; + for (row_idx, (rw, expert_idxs)) in routing_weights + .iter() + .zip(experts_per_tok.iter()) + .enumerate() + { + let sum_rw = rw.iter().sum::(); + for (&rw, &expert_idx) in rw.iter().zip(expert_idxs.iter()) { + top_x[expert_idx as usize].push(row_idx as u32); + let rw = if self.norm_topk_prob { rw / sum_rw } else { rw }; + selected_experts[expert_idx as usize].push(rw) + } + } + + // Process through experts + let mut ys = xs.zeros_like()?; + for (expert_idx, expert_layer) in self.experts.iter().enumerate() { + let top_x = &top_x[expert_idx]; + if top_x.is_empty() { + continue; + } + let top_x = Tensor::new(top_x.as_slice(), xs.device())?; + let selected_experts = + Tensor::new(selected_experts[expert_idx].as_slice(), xs.device())? + .reshape(((), 1))? + .to_dtype(xs.dtype())?; + + let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?; + let current_hidden_states = expert_layer.forward(¤t_state)?; + let current_hidden_states = current_hidden_states.broadcast_mul(&selected_experts)?; + ys = ys.index_add(&top_x, ¤t_hidden_states, 0)?; + } + + let ys = ys.reshape((b_size, seq_len, hidden_dim))?; + Ok(ys) + } +} + +// MLP or MoE decision enum +#[derive(Debug, Clone)] +enum Qwen3FeedForward { + MLP(Qwen3MLP), + MoE(Qwen3SparseMoeBlock), +} + +impl Module for Qwen3FeedForward { + fn forward(&self, xs: &Tensor) -> Result { + match self { + Self::MLP(m) => m.forward(xs), + Self::MoE(m) => m.forward(xs), + } + } +} + +#[derive(Debug, Clone)] +struct Qwen3Attention { + // projections + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + o_proj: Linear, + // norms + q_norm: Qwen3HeadRmsNorm, + k_norm: Qwen3HeadRmsNorm, + // hyper params + num_heads: usize, + num_kv_heads: usize, + num_kv_groups: usize, + head_dim: usize, + hidden_size: usize, + // sliding window + sliding_window: Option, + // utils + rotary_emb: Arc, + kv_cache: Option<(Tensor, Tensor)>, +} + +impl Qwen3Attention { + fn new( + cfg: &Config, + rotary_emb: Arc, + layer_idx: usize, + vb: VarBuilder, + ) -> Result { + let head_dim = cfg.head_dim; + let num_heads = cfg.num_attention_heads; + let num_kv_heads = cfg.num_key_value_heads; + let num_kv_groups = num_heads / num_kv_heads; + + let (q_proj, k_proj, v_proj, o_proj) = if cfg.attention_bias { + ( + linear(cfg.hidden_size, num_heads * head_dim, vb.pp("q_proj"))?, + linear(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("k_proj"))?, + linear(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("v_proj"))?, + linear(num_heads * head_dim, cfg.hidden_size, vb.pp("o_proj"))?, + ) + } else { + ( + linear_no_bias(cfg.hidden_size, num_heads * head_dim, vb.pp("q_proj"))?, + linear_no_bias(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("k_proj"))?, + linear_no_bias(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("v_proj"))?, + linear_no_bias(num_heads * head_dim, cfg.hidden_size, vb.pp("o_proj"))?, + ) + }; + + let q_norm = Qwen3HeadRmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("q_norm"))?; + let k_norm = Qwen3HeadRmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("k_norm"))?; + + let sliding_window = if cfg.use_sliding_window && layer_idx >= cfg.max_window_layers { + cfg.sliding_window + } else { + None + }; + + // Necessary because the hidden_size in the config isn't always accurate + let hidden_size = cfg.head_dim * cfg.num_attention_heads; + + Ok(Self { + q_proj, + k_proj, + v_proj, + o_proj, + q_norm, + k_norm, + num_heads, + num_kv_heads, + num_kv_groups, + head_dim, + hidden_size, + sliding_window, + rotary_emb, + kv_cache: None, + }) + } + + fn forward(&mut self, x: &Tensor, attn_mask: Option<&Tensor>, offset: usize) -> Result { + let (b, l, _) = x.dims3()?; + + // 1. Proj + let q = self.q_proj.forward(x)?; + let k = self.k_proj.forward(x)?; + let v = self.v_proj.forward(x)?; + + // 2. Reshape: (B, L, H, D) -> (B, H, L, D) + let q = q + .reshape((b, l, self.num_heads, self.head_dim))? + .transpose(1, 2)?; + let k = k + .reshape((b, l, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + let v = v + .reshape((b, l, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + + // 3. Per‑head RMSNorm + let q_flat = q.flatten(0, 2)?; + let k_flat = k.flatten(0, 2)?; + let q_flat = self.q_norm.forward(&q_flat)?; + let k_flat = self.k_norm.forward(&k_flat)?; + let q = q_flat.reshape((b, self.num_heads, l, self.head_dim))?; + let k = k_flat.reshape((b, self.num_kv_heads, l, self.head_dim))?; + + // 4. RoPE + let (q, k) = self.rotary_emb.apply(&q, &k, offset)?; + + // 5. KV cache + let (k, v) = match &self.kv_cache { + None => (k, v), + Some((prev_k, prev_v)) => ( + Tensor::cat(&[prev_k, &k], 2)?, + Tensor::cat(&[prev_v, &v], 2)?, + ), + }; + self.kv_cache = Some((k.clone(), v.clone())); + + // 6. GQA repeat_kv + let k = repeat_kv(&k, self.num_kv_groups)?; + let v = repeat_kv(&v, self.num_kv_groups)?; + + // 7. Attention score + let scale = 1.0 / (self.head_dim as f64).sqrt(); + let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?; + if let Some(m) = attn_mask { + scores = scores.broadcast_add(m)?; + } + let probs = candle_nn::ops::softmax_last_dim(&scores)?; + let ctx = probs.matmul(&v)?; // (B, H, L, D) + + // 8. Output proj + ctx.transpose(1, 2)? + .reshape((b, l, self.hidden_size))? + .apply(&self.o_proj) + } + + fn clear_kv_cache(&mut self) { + self.kv_cache = None; + } +} + +#[derive(Debug, Clone)] +struct DecoderLayer { + self_attn: Qwen3Attention, + feed_forward: Qwen3FeedForward, + ln1: Qwen3RmsNorm, + ln2: Qwen3RmsNorm, +} + +impl DecoderLayer { + fn new( + layer_idx: usize, + cfg: &Config, + rotary: Arc, + vb: VarBuilder, + ) -> Result { + let self_attn = Qwen3Attention::new(cfg, rotary, layer_idx, vb.pp("self_attn"))?; + + // Decide whether to use MoE or regular MLP based on layer_idx and decoder_sparse_step + let feed_forward = if cfg.num_experts > 0 && (layer_idx + 1) % cfg.decoder_sparse_step == 0 + { + Qwen3FeedForward::MoE(Qwen3SparseMoeBlock::new(cfg, vb.pp("mlp"))?) + } else { + Qwen3FeedForward::MLP(Qwen3MLP::new(cfg, cfg.intermediate_size, vb.pp("mlp"))?) + }; + + let ln1 = Qwen3RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; + let ln2 = Qwen3RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + vb.pp("post_attention_layernorm"), + )?; + + Ok(Self { + self_attn, + feed_forward, + ln1, + ln2, + }) + } + + fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result { + let h = self.ln1.forward(x)?; + let h = self.self_attn.forward(&h, mask, offset)?; + let x = x.broadcast_add(&h)?; + let h2 = self.ln2.forward(&x)?; + let h2 = h2.apply(&self.feed_forward)?; + Ok(x.broadcast_add(&h2)?) + } + + fn clear_kv_cache(&mut self) { + self.self_attn.clear_kv_cache(); + } +} + +#[derive(Debug, Clone)] +pub struct Model { + embed_tokens: candle_nn::Embedding, + layers: Vec, + norm: Qwen3RmsNorm, + rotary: Arc, + device: Device, + dtype: DType, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let embed_tokens = + candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?; + let rotary = Arc::new(Qwen3RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?); + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + let vb_l = vb.pp("model.layers"); + for i in 0..cfg.num_hidden_layers { + layers.push(DecoderLayer::new(i, cfg, rotary.clone(), vb_l.pp(i))?); + } + Ok(Self { + embed_tokens, + layers, + norm: Qwen3RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?, + rotary, + device: vb.device().clone(), + dtype: vb.dtype(), + }) + } + + fn clear_kv_cache(&mut self) { + for l in &mut self.layers { + l.clear_kv_cache(); + } + } + + fn causal_mask( + &self, + b: usize, + tgt: usize, + offset: usize, + sw: Option, + ) -> Result { + let minf = f32::NEG_INFINITY; + let mask: Vec<_> = (0..tgt) + .flat_map(|i| { + (0..(tgt + offset)).map(move |j| { + let past_ok = j <= i + offset; + let sw_ok = match sw { + Some(w) => (i + offset) as i64 - j as i64 <= w as i64, + None => true, + }; + if past_ok && sw_ok { + 0. + } else { + minf + } + }) + }) + .collect(); + Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype) + } + + pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result { + let (b, l) = input.dims2()?; + let mut h = self.embed_tokens.forward(input)?; + + let causal = if l == 1 { + None + } else { + Some(self.causal_mask(b, l, offset, None)?) + }; + + for layer in &mut self.layers { + h = layer.forward(&h, causal.as_ref(), offset)?; + } + let b = self.norm.forward(&h)?; + Ok(b) + } +} + +#[derive(Debug, Clone)] +pub struct ModelForCausalLM { + base: Model, + lm_head: Linear, +} + +impl ModelForCausalLM { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let base = Model::new(cfg, vb.clone())?; + let lm_head = if vb.contains_tensor("lm_head.weight") { + linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))? + } else { + Linear::from_weights(base.embed_tokens.embeddings().clone(), None) + }; + Ok(Self { base, lm_head }) + } + + pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result { + let (_, l) = input.dims2()?; + self.base + .forward(input, offset)? + .narrow(1, l - 1, 1)? + .apply(&self.lm_head) + } + + pub fn clear_kv_cache(&mut self) { + self.base.clear_kv_cache(); + } +} From 7b5aaddebea36b8abec3de345ef0d8b7540c920f Mon Sep 17 00:00:00 2001 From: keighbee Date: Tue, 29 Apr 2025 17:08:45 -0700 Subject: [PATCH 06/22] clippy --- candle-transformers/src/models/qwen3.rs | 13 ++++++------- candle-transformers/src/models/qwen3_moe.rs | 18 ++++++++---------- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/candle-transformers/src/models/qwen3.rs b/candle-transformers/src/models/qwen3.rs index 4d180d3b8f..73018d1f0f 100644 --- a/candle-transformers/src/models/qwen3.rs +++ b/candle-transformers/src/models/qwen3.rs @@ -161,7 +161,7 @@ struct Qwen3Attention { head_dim: usize, hidden_size: usize, // sliding window - sliding_window: Option, + _sliding_window: Option, // utils rotary_emb: Arc, kv_cache: Option<(Tensor, Tensor)>, @@ -219,7 +219,7 @@ impl Qwen3Attention { num_kv_groups, head_dim, hidden_size, - sliding_window, + _sliding_window: sliding_window, rotary_emb, kv_cache: None, }) @@ -322,7 +322,7 @@ impl DecoderLayer { let x = x.broadcast_add(&h)?; let h2 = self.ln2.forward(&x)?; let h2 = h2.apply(&self.mlp)?; - Ok(x.broadcast_add(&h2)?) + x.broadcast_add(&h2) } fn clear_kv_cache(&mut self) { @@ -335,7 +335,7 @@ pub struct Model { embed_tokens: candle_nn::Embedding, layers: Vec, norm: Qwen3RmsNorm, - rotary: Arc, + _rotary: Arc, device: Device, dtype: DType, } @@ -354,7 +354,7 @@ impl Model { embed_tokens, layers, norm: Qwen3RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?, - rotary, + _rotary: rotary, device: vb.device().clone(), dtype: vb.dtype(), }) @@ -406,8 +406,7 @@ impl Model { for layer in &mut self.layers { h = layer.forward(&h, causal.as_ref(), offset)?; } - let b = self.norm.forward(&h)?; - Ok(b) + self.norm.forward(&h) } } diff --git a/candle-transformers/src/models/qwen3_moe.rs b/candle-transformers/src/models/qwen3_moe.rs index ddf1738bf9..52ea0108db 100644 --- a/candle-transformers/src/models/qwen3_moe.rs +++ b/candle-transformers/src/models/qwen3_moe.rs @@ -227,8 +227,7 @@ impl Module for Qwen3SparseMoeBlock { ys = ys.index_add(&top_x, ¤t_hidden_states, 0)?; } - let ys = ys.reshape((b_size, seq_len, hidden_dim))?; - Ok(ys) + ys.reshape((b_size, seq_len, hidden_dim)) } } @@ -265,7 +264,7 @@ struct Qwen3Attention { head_dim: usize, hidden_size: usize, // sliding window - sliding_window: Option, + _sliding_window: Option, // utils rotary_emb: Arc, kv_cache: Option<(Tensor, Tensor)>, @@ -302,7 +301,7 @@ impl Qwen3Attention { let q_norm = Qwen3HeadRmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("q_norm"))?; let k_norm = Qwen3HeadRmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("k_norm"))?; - let sliding_window = if cfg.use_sliding_window && layer_idx >= cfg.max_window_layers { + let _sliding_window = if cfg.use_sliding_window && layer_idx >= cfg.max_window_layers { cfg.sliding_window } else { None @@ -323,7 +322,7 @@ impl Qwen3Attention { num_kv_groups, head_dim, hidden_size, - sliding_window, + _sliding_window, rotary_emb, kv_cache: None, }) @@ -439,7 +438,7 @@ impl DecoderLayer { let x = x.broadcast_add(&h)?; let h2 = self.ln2.forward(&x)?; let h2 = h2.apply(&self.feed_forward)?; - Ok(x.broadcast_add(&h2)?) + x.broadcast_add(&h2) } fn clear_kv_cache(&mut self) { @@ -452,7 +451,7 @@ pub struct Model { embed_tokens: candle_nn::Embedding, layers: Vec, norm: Qwen3RmsNorm, - rotary: Arc, + _rotary: Arc, device: Device, dtype: DType, } @@ -471,7 +470,7 @@ impl Model { embed_tokens, layers, norm: Qwen3RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?, - rotary, + _rotary: rotary, device: vb.device().clone(), dtype: vb.dtype(), }) @@ -523,8 +522,7 @@ impl Model { for layer in &mut self.layers { h = layer.forward(&h, causal.as_ref(), offset)?; } - let b = self.norm.forward(&h)?; - Ok(b) + self.norm.forward(&h) } } From e71c4780248cab45e9fa85bc2e4564d641f8a0cf Mon Sep 17 00:00:00 2001 From: keighbee Date: Tue, 29 Apr 2025 17:16:41 -0700 Subject: [PATCH 07/22] added additional eos token --- candle-examples/examples/qwen/main.rs | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/candle-examples/examples/qwen/main.rs b/candle-examples/examples/qwen/main.rs index 7eab07db8e..eb2d7b2cc7 100644 --- a/candle-examples/examples/qwen/main.rs +++ b/candle-examples/examples/qwen/main.rs @@ -91,6 +91,10 @@ impl TextGeneration { Some(token) => token, None => anyhow::bail!("cannot find the <|endoftext|> token"), }; + let eos_token2 = match self.tokenizer.get_token("<|im_end|>") { + Some(token) => token, + None => anyhow::bail!("cannot find the <|im_end|> token"), + }; let start_gen = std::time::Instant::now(); for index in 0..sample_len { let context_size = if index > 0 { 1 } else { tokens.len() }; @@ -113,7 +117,7 @@ impl TextGeneration { let next_token = self.logits_processor.sample(&logits)?; tokens.push(next_token); generated_tokens += 1; - if next_token == eos_token { + if next_token == eos_token || next_token == eos_token2 { break; } if let Some(t) = self.tokenizer.next_token(next_token)? { @@ -158,13 +162,13 @@ enum WhichModel { W2_7b, #[value(name = "2-72b")] W2_72b, - #[value(name = "3-0.6B")] + #[value(name = "3-0.6b")] W3_0_6b, - #[value(name = "3-1.7B")] + #[value(name = "3-1.7b")] W3_1_7b, - #[value(name = "3-4B")] + #[value(name = "3-4b")] W3_4b, - #[value(name = "3-8B")] + #[value(name = "3-8b")] W3_8b, #[value(name = "3-moe-a3b")] W3MoeA3b, From bbb490fb394bee22d72c168aec652a50f54f570d Mon Sep 17 00:00:00 2001 From: keighbee Date: Wed, 30 Apr 2025 12:20:12 -0700 Subject: [PATCH 08/22] translated Korean comments to English as well as I can --- candle-transformers/src/models/qwen3.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/candle-transformers/src/models/qwen3.rs b/candle-transformers/src/models/qwen3.rs index 73018d1f0f..a06dfd958e 100644 --- a/candle-transformers/src/models/qwen3.rs +++ b/candle-transformers/src/models/qwen3.rs @@ -49,7 +49,7 @@ impl Qwen3RotaryEmbedding { }) } - /// RoPE 적용 (q, k shape: B x H x L x D) + /// Apply RoPE (q, k shape: B x H x L x D) fn apply(&self, q: &Tensor, k: &Tensor, offset: usize) -> Result<(Tensor, Tensor)> { let (_, _, seq_len, _) = q.dims4()?; let cos = self.cos.narrow(0, offset, seq_len)?; @@ -255,7 +255,7 @@ impl Qwen3Attention { // 4. RoPE let (q, k) = self.rotary_emb.apply(&q, &k, offset)?; - // 5. KV 캐시 누적 + // 5. Accumulate KV cache let (k, v) = match &self.kv_cache { None => (k, v), Some((prev_k, prev_v)) => ( From eada460c41979e3ce5d71ad5c7adcce0b8fedfc3 Mon Sep 17 00:00:00 2001 From: keighbee Date: Wed, 30 Apr 2025 12:26:52 -0700 Subject: [PATCH 09/22] removed specialized Qwen3RmsNorm and replaced with generic Candle RmsNorm --- candle-transformers/src/models/qwen3.rs | 71 ++++----------------- candle-transformers/src/models/qwen3_moe.rs | 69 ++++---------------- 2 files changed, 23 insertions(+), 117 deletions(-) diff --git a/candle-transformers/src/models/qwen3.rs b/candle-transformers/src/models/qwen3.rs index a06dfd958e..e4708d7373 100644 --- a/candle-transformers/src/models/qwen3.rs +++ b/candle-transformers/src/models/qwen3.rs @@ -1,5 +1,5 @@ -use crate::models::with_tracing::{linear, linear_no_bias, Linear}; -use candle::{DType, Device, Module, Result, Tensor, D}; +use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm}; +use candle::{DType, Device, Module, Result, Tensor}; use candle_nn::{Activation, VarBuilder}; use std::sync::Arc; @@ -60,53 +60,6 @@ impl Qwen3RotaryEmbedding { } } -#[derive(Debug, Clone)] -struct Qwen3RmsNorm { - weight: Tensor, - eps: f64, -} - -impl Qwen3RmsNorm { - fn new(dim: usize, eps: f64, vb: VarBuilder) -> Result { - let weight = vb.get((dim,), "weight")?; - Ok(Self { weight, eps }) - } - - fn forward(&self, xs: &Tensor) -> Result { - let orig_dtype = xs.dtype(); - let xs = xs.to_dtype(DType::F32)?; - let var = (xs.clone() * &xs)?.mean_keepdim(D::Minus1)?; - let rms = (var + self.eps)?.powf(-0.5)?; - let xs = xs.broadcast_mul(&rms)?; - let ws = self.weight.unsqueeze(0)?.unsqueeze(1)?; - xs.broadcast_mul(&ws)?.to_dtype(orig_dtype) - } -} - -#[derive(Debug, Clone)] -struct Qwen3HeadRmsNorm { - weight: Tensor, - eps: f64, -} - -impl Qwen3HeadRmsNorm { - fn new(dim: usize, eps: f64, vb: VarBuilder) -> Result { - let weight = vb.get((dim,), "weight")?; - Ok(Self { weight, eps }) - } - - fn forward(&self, xs: &Tensor) -> Result { - // xs: (B*L*H, D) - let orig_dtype = xs.dtype(); - let xs = xs.to_dtype(DType::F32)?; - let var = (xs.clone() * &xs)?.mean_keepdim(D::Minus1)?; - let rms = (var + self.eps)?.powf(-0.5)?; - let xs = xs.broadcast_mul(&rms)?; - let ws = self.weight.unsqueeze(0)?; - xs.broadcast_mul(&ws)?.to_dtype(orig_dtype) - } -} - fn repeat_kv(kv: &Tensor, n_rep: usize) -> Result { if n_rep == 1 { return Ok(kv.clone()); @@ -152,8 +105,8 @@ struct Qwen3Attention { v_proj: Linear, o_proj: Linear, // norms - q_norm: Qwen3HeadRmsNorm, - k_norm: Qwen3HeadRmsNorm, + q_norm: RmsNorm, + k_norm: RmsNorm, // hyper params num_heads: usize, num_kv_heads: usize, @@ -195,8 +148,8 @@ impl Qwen3Attention { ) }; - let q_norm = Qwen3HeadRmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("q_norm"))?; - let k_norm = Qwen3HeadRmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("k_norm"))?; + let q_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("q_norm"))?; + let k_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("k_norm"))?; let sliding_window = if cfg.use_sliding_window && layer_idx >= cfg.max_window_layers { cfg.sliding_window @@ -293,8 +246,8 @@ impl Qwen3Attention { struct DecoderLayer { self_attn: Qwen3Attention, mlp: Qwen3MLP, - ln1: Qwen3RmsNorm, - ln2: Qwen3RmsNorm, + ln1: RmsNorm, + ln2: RmsNorm, } impl DecoderLayer { @@ -307,8 +260,8 @@ impl DecoderLayer { Ok(Self { self_attn: Qwen3Attention::new(cfg, rotary, idx, vb.pp("self_attn"))?, mlp: Qwen3MLP::new(cfg, vb.pp("mlp"))?, - ln1: Qwen3RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?, - ln2: Qwen3RmsNorm::new( + ln1: RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?, + ln2: RmsNorm::new( cfg.hidden_size, cfg.rms_norm_eps, vb.pp("post_attention_layernorm"), @@ -334,7 +287,7 @@ impl DecoderLayer { pub struct Model { embed_tokens: candle_nn::Embedding, layers: Vec, - norm: Qwen3RmsNorm, + norm: RmsNorm, _rotary: Arc, device: Device, dtype: DType, @@ -353,7 +306,7 @@ impl Model { Ok(Self { embed_tokens, layers, - norm: Qwen3RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?, + norm: RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?, _rotary: rotary, device: vb.device().clone(), dtype: vb.dtype(), diff --git a/candle-transformers/src/models/qwen3_moe.rs b/candle-transformers/src/models/qwen3_moe.rs index 52ea0108db..7a03a97f59 100644 --- a/candle-transformers/src/models/qwen3_moe.rs +++ b/candle-transformers/src/models/qwen3_moe.rs @@ -1,4 +1,4 @@ -use crate::models::with_tracing::{linear, linear_no_bias, Linear}; +use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm}; use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{Activation, VarBuilder}; use std::sync::Arc; @@ -65,53 +65,6 @@ impl Qwen3RotaryEmbedding { } } -#[derive(Debug, Clone)] -struct Qwen3RmsNorm { - weight: Tensor, - eps: f64, -} - -impl Qwen3RmsNorm { - fn new(dim: usize, eps: f64, vb: VarBuilder) -> Result { - let weight = vb.get((dim,), "weight")?; - Ok(Self { weight, eps }) - } - - fn forward(&self, xs: &Tensor) -> Result { - let orig_dtype = xs.dtype(); - let xs = xs.to_dtype(DType::F32)?; - let var = (xs.clone() * &xs)?.mean_keepdim(D::Minus1)?; - let rms = (var + self.eps)?.powf(-0.5)?; - let xs = xs.broadcast_mul(&rms)?; - let ws = self.weight.unsqueeze(0)?.unsqueeze(1)?; - xs.broadcast_mul(&ws)?.to_dtype(orig_dtype) - } -} - -#[derive(Debug, Clone)] -struct Qwen3HeadRmsNorm { - weight: Tensor, - eps: f64, -} - -impl Qwen3HeadRmsNorm { - fn new(dim: usize, eps: f64, vb: VarBuilder) -> Result { - let weight = vb.get((dim,), "weight")?; - Ok(Self { weight, eps }) - } - - fn forward(&self, xs: &Tensor) -> Result { - // xs: (B*L*H, D) - let orig_dtype = xs.dtype(); - let xs = xs.to_dtype(DType::F32)?; - let var = (xs.clone() * &xs)?.mean_keepdim(D::Minus1)?; - let rms = (var + self.eps)?.powf(-0.5)?; - let xs = xs.broadcast_mul(&rms)?; - let ws = self.weight.unsqueeze(0)?; - xs.broadcast_mul(&ws)?.to_dtype(orig_dtype) - } -} - fn repeat_kv(kv: &Tensor, n_rep: usize) -> Result { if n_rep == 1 { return Ok(kv.clone()); @@ -255,8 +208,8 @@ struct Qwen3Attention { v_proj: Linear, o_proj: Linear, // norms - q_norm: Qwen3HeadRmsNorm, - k_norm: Qwen3HeadRmsNorm, + q_norm: RmsNorm, + k_norm: RmsNorm, // hyper params num_heads: usize, num_kv_heads: usize, @@ -298,8 +251,8 @@ impl Qwen3Attention { ) }; - let q_norm = Qwen3HeadRmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("q_norm"))?; - let k_norm = Qwen3HeadRmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("k_norm"))?; + let q_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("q_norm"))?; + let k_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("k_norm"))?; let _sliding_window = if cfg.use_sliding_window && layer_idx >= cfg.max_window_layers { cfg.sliding_window @@ -396,8 +349,8 @@ impl Qwen3Attention { struct DecoderLayer { self_attn: Qwen3Attention, feed_forward: Qwen3FeedForward, - ln1: Qwen3RmsNorm, - ln2: Qwen3RmsNorm, + ln1: RmsNorm, + ln2: RmsNorm, } impl DecoderLayer { @@ -417,8 +370,8 @@ impl DecoderLayer { Qwen3FeedForward::MLP(Qwen3MLP::new(cfg, cfg.intermediate_size, vb.pp("mlp"))?) }; - let ln1 = Qwen3RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; - let ln2 = Qwen3RmsNorm::new( + let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; + let ln2 = RmsNorm::new( cfg.hidden_size, cfg.rms_norm_eps, vb.pp("post_attention_layernorm"), @@ -450,7 +403,7 @@ impl DecoderLayer { pub struct Model { embed_tokens: candle_nn::Embedding, layers: Vec, - norm: Qwen3RmsNorm, + norm: RmsNorm, _rotary: Arc, device: Device, dtype: DType, @@ -469,7 +422,7 @@ impl Model { Ok(Self { embed_tokens, layers, - norm: Qwen3RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?, + norm: RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?, _rotary: rotary, device: vb.device().clone(), dtype: vb.dtype(), From 80170fd39fc5025d15fe0cd1bb066b8036b5730d Mon Sep 17 00:00:00 2001 From: keighbee Date: Wed, 30 Apr 2025 12:34:15 -0700 Subject: [PATCH 10/22] replaced custom repeat_kv implementation with candle's repeat_kv implementation --- candle-transformers/src/models/qwen3.rs | 19 ++++++------------- candle-transformers/src/models/qwen3_moe.rs | 19 ++++++------------- 2 files changed, 12 insertions(+), 26 deletions(-) diff --git a/candle-transformers/src/models/qwen3.rs b/candle-transformers/src/models/qwen3.rs index e4708d7373..9ea252422a 100644 --- a/candle-transformers/src/models/qwen3.rs +++ b/candle-transformers/src/models/qwen3.rs @@ -1,4 +1,7 @@ -use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm}; +use crate::{ + models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm}, + utils::repeat_kv, +}; use candle::{DType, Device, Module, Result, Tensor}; use candle_nn::{Activation, VarBuilder}; use std::sync::Arc; @@ -60,16 +63,6 @@ impl Qwen3RotaryEmbedding { } } -fn repeat_kv(kv: &Tensor, n_rep: usize) -> Result { - if n_rep == 1 { - return Ok(kv.clone()); - } - let (b, h_kv, l, d) = kv.dims4()?; - kv.unsqueeze(2)? - .expand((b, h_kv, n_rep, l, d))? - .reshape((b, h_kv * n_rep, l, d)) -} - #[derive(Debug, Clone)] struct Qwen3MLP { gate_proj: Linear, @@ -219,8 +212,8 @@ impl Qwen3Attention { self.kv_cache = Some((k.clone(), v.clone())); // 6. GQA repeat_kv - let k = repeat_kv(&k, self.num_kv_groups)?; - let v = repeat_kv(&v, self.num_kv_groups)?; + let k = repeat_kv(k, self.num_kv_groups)?; + let v = repeat_kv(v, self.num_kv_groups)?; // 7. Attention score let scale = 1.0 / (self.head_dim as f64).sqrt(); diff --git a/candle-transformers/src/models/qwen3_moe.rs b/candle-transformers/src/models/qwen3_moe.rs index 7a03a97f59..ac4bae9137 100644 --- a/candle-transformers/src/models/qwen3_moe.rs +++ b/candle-transformers/src/models/qwen3_moe.rs @@ -1,4 +1,7 @@ -use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm}; +use crate::{ + models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm}, + utils::repeat_kv, +}; use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{Activation, VarBuilder}; use std::sync::Arc; @@ -65,16 +68,6 @@ impl Qwen3RotaryEmbedding { } } -fn repeat_kv(kv: &Tensor, n_rep: usize) -> Result { - if n_rep == 1 { - return Ok(kv.clone()); - } - let (b, h_kv, l, d) = kv.dims4()?; - kv.unsqueeze(2)? - .expand((b, h_kv, n_rep, l, d))? - .reshape((b, h_kv * n_rep, l, d)) -} - #[derive(Debug, Clone)] struct Qwen3MLP { gate_proj: Linear, @@ -322,8 +315,8 @@ impl Qwen3Attention { self.kv_cache = Some((k.clone(), v.clone())); // 6. GQA repeat_kv - let k = repeat_kv(&k, self.num_kv_groups)?; - let v = repeat_kv(&v, self.num_kv_groups)?; + let k = repeat_kv(k, self.num_kv_groups)?; + let v = repeat_kv(v, self.num_kv_groups)?; // 7. Attention score let scale = 1.0 / (self.head_dim as f64).sqrt(); From 181f2ceae417fc08595f8e58e023d254be5181c1 Mon Sep 17 00:00:00 2001 From: keighbee Date: Wed, 30 Apr 2025 12:40:13 -0700 Subject: [PATCH 11/22] replace linear with linear_b in attention initalization --- candle-transformers/src/models/qwen3.rs | 41 +++++++++++++-------- candle-transformers/src/models/qwen3_moe.rs | 41 +++++++++++++-------- 2 files changed, 50 insertions(+), 32 deletions(-) diff --git a/candle-transformers/src/models/qwen3.rs b/candle-transformers/src/models/qwen3.rs index 9ea252422a..70bf84ebbf 100644 --- a/candle-transformers/src/models/qwen3.rs +++ b/candle-transformers/src/models/qwen3.rs @@ -1,5 +1,5 @@ use crate::{ - models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm}, + models::with_tracing::{linear_b, linear_no_bias, Linear, RmsNorm}, utils::repeat_kv, }; use candle::{DType, Device, Module, Result, Tensor}; @@ -125,21 +125,30 @@ impl Qwen3Attention { let num_kv_heads = cfg.num_key_value_heads; let num_kv_groups = num_heads / num_kv_heads; - let (q_proj, k_proj, v_proj, o_proj) = if cfg.attention_bias { - ( - linear(cfg.hidden_size, num_heads * head_dim, vb.pp("q_proj"))?, - linear(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("k_proj"))?, - linear(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("v_proj"))?, - linear(num_heads * head_dim, cfg.hidden_size, vb.pp("o_proj"))?, - ) - } else { - ( - linear_no_bias(cfg.hidden_size, num_heads * head_dim, vb.pp("q_proj"))?, - linear_no_bias(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("k_proj"))?, - linear_no_bias(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("v_proj"))?, - linear_no_bias(num_heads * head_dim, cfg.hidden_size, vb.pp("o_proj"))?, - ) - }; + let q_proj = linear_b( + cfg.hidden_size, + num_heads * head_dim, + cfg.attention_bias, + vb.pp("q_proj"), + )?; + let k_proj = linear_b( + cfg.hidden_size, + num_kv_heads * head_dim, + cfg.attention_bias, + vb.pp("k_proj"), + )?; + let v_proj = linear_b( + cfg.hidden_size, + num_kv_heads * head_dim, + cfg.attention_bias, + vb.pp("v_proj"), + )?; + let o_proj = linear_b( + num_heads * head_dim, + cfg.hidden_size, + cfg.attention_bias, + vb.pp("o_proj"), + )?; let q_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("q_norm"))?; let k_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("k_norm"))?; diff --git a/candle-transformers/src/models/qwen3_moe.rs b/candle-transformers/src/models/qwen3_moe.rs index ac4bae9137..36164e82ab 100644 --- a/candle-transformers/src/models/qwen3_moe.rs +++ b/candle-transformers/src/models/qwen3_moe.rs @@ -1,5 +1,5 @@ use crate::{ - models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm}, + models::with_tracing::{linear_b, linear_no_bias, Linear, RmsNorm}, utils::repeat_kv, }; use candle::{DType, Device, Module, Result, Tensor, D}; @@ -228,21 +228,30 @@ impl Qwen3Attention { let num_kv_heads = cfg.num_key_value_heads; let num_kv_groups = num_heads / num_kv_heads; - let (q_proj, k_proj, v_proj, o_proj) = if cfg.attention_bias { - ( - linear(cfg.hidden_size, num_heads * head_dim, vb.pp("q_proj"))?, - linear(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("k_proj"))?, - linear(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("v_proj"))?, - linear(num_heads * head_dim, cfg.hidden_size, vb.pp("o_proj"))?, - ) - } else { - ( - linear_no_bias(cfg.hidden_size, num_heads * head_dim, vb.pp("q_proj"))?, - linear_no_bias(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("k_proj"))?, - linear_no_bias(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("v_proj"))?, - linear_no_bias(num_heads * head_dim, cfg.hidden_size, vb.pp("o_proj"))?, - ) - }; + let q_proj = linear_b( + cfg.hidden_size, + num_heads * head_dim, + cfg.attention_bias, + vb.pp("q_proj"), + )?; + let k_proj = linear_b( + cfg.hidden_size, + num_kv_heads * head_dim, + cfg.attention_bias, + vb.pp("k_proj"), + )?; + let v_proj = linear_b( + cfg.hidden_size, + num_kv_heads * head_dim, + cfg.attention_bias, + vb.pp("v_proj"), + )?; + let o_proj = linear_b( + num_heads * head_dim, + cfg.hidden_size, + cfg.attention_bias, + vb.pp("o_proj"), + )?; let q_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("q_norm"))?; let k_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("k_norm"))?; From 02f0247108e31cc2a76a8ed7ad53e9ec08ad9747 Mon Sep 17 00:00:00 2001 From: keighbee Date: Wed, 30 Apr 2025 13:01:49 -0700 Subject: [PATCH 12/22] replaced custom custom kv_cache implementation with candle kv_cache --- candle-transformers/src/models/qwen3.rs | 21 ++++++++------------- candle-transformers/src/models/qwen3_moe.rs | 21 ++++++++------------- 2 files changed, 16 insertions(+), 26 deletions(-) diff --git a/candle-transformers/src/models/qwen3.rs b/candle-transformers/src/models/qwen3.rs index 70bf84ebbf..630b45fdf9 100644 --- a/candle-transformers/src/models/qwen3.rs +++ b/candle-transformers/src/models/qwen3.rs @@ -3,7 +3,7 @@ use crate::{ utils::repeat_kv, }; use candle::{DType, Device, Module, Result, Tensor}; -use candle_nn::{Activation, VarBuilder}; +use candle_nn::{kv_cache::KvCache, Activation, VarBuilder}; use std::sync::Arc; #[derive(Debug, Clone, PartialEq, serde::Deserialize)] @@ -110,7 +110,7 @@ struct Qwen3Attention { _sliding_window: Option, // utils rotary_emb: Arc, - kv_cache: Option<(Tensor, Tensor)>, + kv_cache: KvCache, } impl Qwen3Attention { @@ -160,7 +160,9 @@ impl Qwen3Attention { }; // Necessary because the hidden_size in the cofig isn't always accurate - let hidden_size = cfg.head_dim * cfg.num_attention_heads; + let hidden_size = head_dim * cfg.num_attention_heads; + + let kv_cache = KvCache::new(2, cfg.max_position_embeddings); Ok(Self { q_proj, @@ -176,7 +178,7 @@ impl Qwen3Attention { hidden_size, _sliding_window: sliding_window, rotary_emb, - kv_cache: None, + kv_cache, }) } @@ -211,14 +213,7 @@ impl Qwen3Attention { let (q, k) = self.rotary_emb.apply(&q, &k, offset)?; // 5. Accumulate KV cache - let (k, v) = match &self.kv_cache { - None => (k, v), - Some((prev_k, prev_v)) => ( - Tensor::cat(&[prev_k, &k], 2)?, - Tensor::cat(&[prev_v, &v], 2)?, - ), - }; - self.kv_cache = Some((k.clone(), v.clone())); + let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?; // 6. GQA repeat_kv let k = repeat_kv(k, self.num_kv_groups)?; @@ -240,7 +235,7 @@ impl Qwen3Attention { } fn clear_kv_cache(&mut self) { - self.kv_cache = None; + self.kv_cache.reset(); } } diff --git a/candle-transformers/src/models/qwen3_moe.rs b/candle-transformers/src/models/qwen3_moe.rs index 36164e82ab..5716f215a4 100644 --- a/candle-transformers/src/models/qwen3_moe.rs +++ b/candle-transformers/src/models/qwen3_moe.rs @@ -3,7 +3,7 @@ use crate::{ utils::repeat_kv, }; use candle::{DType, Device, Module, Result, Tensor, D}; -use candle_nn::{Activation, VarBuilder}; +use candle_nn::{kv_cache::KvCache, Activation, VarBuilder}; use std::sync::Arc; #[derive(Debug, Clone, PartialEq, serde::Deserialize)] @@ -213,7 +213,7 @@ struct Qwen3Attention { _sliding_window: Option, // utils rotary_emb: Arc, - kv_cache: Option<(Tensor, Tensor)>, + kv_cache: KvCache, } impl Qwen3Attention { @@ -263,7 +263,9 @@ impl Qwen3Attention { }; // Necessary because the hidden_size in the config isn't always accurate - let hidden_size = cfg.head_dim * cfg.num_attention_heads; + let hidden_size = head_dim * cfg.num_attention_heads; + + let kv_cache = KvCache::new(2, cfg.max_position_embeddings); Ok(Self { q_proj, @@ -279,7 +281,7 @@ impl Qwen3Attention { hidden_size, _sliding_window, rotary_emb, - kv_cache: None, + kv_cache, }) } @@ -314,14 +316,7 @@ impl Qwen3Attention { let (q, k) = self.rotary_emb.apply(&q, &k, offset)?; // 5. KV cache - let (k, v) = match &self.kv_cache { - None => (k, v), - Some((prev_k, prev_v)) => ( - Tensor::cat(&[prev_k, &k], 2)?, - Tensor::cat(&[prev_v, &v], 2)?, - ), - }; - self.kv_cache = Some((k.clone(), v.clone())); + let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?; // 6. GQA repeat_kv let k = repeat_kv(k, self.num_kv_groups)?; @@ -343,7 +338,7 @@ impl Qwen3Attention { } fn clear_kv_cache(&mut self) { - self.kv_cache = None; + self.kv_cache.reset(); } } From f2962f7d9256d62d290820336034187a7a92f235 Mon Sep 17 00:00:00 2001 From: keighbee Date: Wed, 30 Apr 2025 13:05:48 -0700 Subject: [PATCH 13/22] style --- candle-transformers/src/models/qwen3.rs | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/candle-transformers/src/models/qwen3.rs b/candle-transformers/src/models/qwen3.rs index 630b45fdf9..e43e700d7f 100644 --- a/candle-transformers/src/models/qwen3.rs +++ b/candle-transformers/src/models/qwen3.rs @@ -254,15 +254,19 @@ impl DecoderLayer { idx: usize, vb: VarBuilder, ) -> Result { + let self_attn = Qwen3Attention::new(cfg, rotary, idx, vb.pp("self_attn"))?; + let mlp = Qwen3MLP::new(cfg, vb.pp("mlp"))?; + let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; + let ln2 = RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + vb.pp("post_attention_layernorm"), + )?; Ok(Self { - self_attn: Qwen3Attention::new(cfg, rotary, idx, vb.pp("self_attn"))?, - mlp: Qwen3MLP::new(cfg, vb.pp("mlp"))?, - ln1: RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?, - ln2: RmsNorm::new( - cfg.hidden_size, - cfg.rms_norm_eps, - vb.pp("post_attention_layernorm"), - )?, + self_attn, + mlp, + ln1, + ln2, }) } From 410e11e1134e6f68ee4b399ce26da2ba646df2b4 Mon Sep 17 00:00:00 2001 From: keighbee Date: Wed, 30 Apr 2025 13:08:34 -0700 Subject: [PATCH 14/22] replaced explicit broadcast add with normal add in decoder layer --- candle-transformers/src/models/qwen3.rs | 4 ++-- candle-transformers/src/models/qwen3_moe.rs | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/candle-transformers/src/models/qwen3.rs b/candle-transformers/src/models/qwen3.rs index e43e700d7f..fb6023f8d6 100644 --- a/candle-transformers/src/models/qwen3.rs +++ b/candle-transformers/src/models/qwen3.rs @@ -273,10 +273,10 @@ impl DecoderLayer { fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result { let h = self.ln1.forward(x)?; let h = self.self_attn.forward(&h, mask, offset)?; - let x = x.broadcast_add(&h)?; + let x = (x + h)?; let h2 = self.ln2.forward(&x)?; let h2 = h2.apply(&self.mlp)?; - x.broadcast_add(&h2) + x + h2 } fn clear_kv_cache(&mut self) { diff --git a/candle-transformers/src/models/qwen3_moe.rs b/candle-transformers/src/models/qwen3_moe.rs index 5716f215a4..fe09d5e531 100644 --- a/candle-transformers/src/models/qwen3_moe.rs +++ b/candle-transformers/src/models/qwen3_moe.rs @@ -385,10 +385,10 @@ impl DecoderLayer { fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result { let h = self.ln1.forward(x)?; let h = self.self_attn.forward(&h, mask, offset)?; - let x = x.broadcast_add(&h)?; + let x = (x + h)?; let h2 = self.ln2.forward(&x)?; let h2 = h2.apply(&self.feed_forward)?; - x.broadcast_add(&h2) + x + h2 } fn clear_kv_cache(&mut self) { From d99d10421f5bc5bba98918f7f0074155886380a6 Mon Sep 17 00:00:00 2001 From: keighbee Date: Wed, 30 Apr 2025 13:12:14 -0700 Subject: [PATCH 15/22] removed keeping the Rotary embedding layer in the model struct --- candle-transformers/src/models/qwen3.rs | 2 -- candle-transformers/src/models/qwen3_moe.rs | 2 -- 2 files changed, 4 deletions(-) diff --git a/candle-transformers/src/models/qwen3.rs b/candle-transformers/src/models/qwen3.rs index fb6023f8d6..e34dd40a27 100644 --- a/candle-transformers/src/models/qwen3.rs +++ b/candle-transformers/src/models/qwen3.rs @@ -289,7 +289,6 @@ pub struct Model { embed_tokens: candle_nn::Embedding, layers: Vec, norm: RmsNorm, - _rotary: Arc, device: Device, dtype: DType, } @@ -308,7 +307,6 @@ impl Model { embed_tokens, layers, norm: RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?, - _rotary: rotary, device: vb.device().clone(), dtype: vb.dtype(), }) diff --git a/candle-transformers/src/models/qwen3_moe.rs b/candle-transformers/src/models/qwen3_moe.rs index fe09d5e531..61605d8b3a 100644 --- a/candle-transformers/src/models/qwen3_moe.rs +++ b/candle-transformers/src/models/qwen3_moe.rs @@ -401,7 +401,6 @@ pub struct Model { embed_tokens: candle_nn::Embedding, layers: Vec, norm: RmsNorm, - _rotary: Arc, device: Device, dtype: DType, } @@ -420,7 +419,6 @@ impl Model { embed_tokens, layers, norm: RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?, - _rotary: rotary, device: vb.device().clone(), dtype: vb.dtype(), }) From a57c5ab46e3fcd511058b5709e70c6fc14f2fa6f Mon Sep 17 00:00:00 2001 From: keighbee Date: Wed, 30 Apr 2025 13:17:08 -0700 Subject: [PATCH 16/22] used tie_word_embeddings bool from config instead of relying on existence of weights for lm head in CasualLM --- candle-transformers/src/models/qwen3.rs | 2 +- candle-transformers/src/models/qwen3_moe.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/candle-transformers/src/models/qwen3.rs b/candle-transformers/src/models/qwen3.rs index e34dd40a27..d797e8b171 100644 --- a/candle-transformers/src/models/qwen3.rs +++ b/candle-transformers/src/models/qwen3.rs @@ -371,7 +371,7 @@ pub struct ModelForCausalLM { impl ModelForCausalLM { pub fn new(cfg: &Config, vb: VarBuilder) -> Result { let base = Model::new(cfg, vb.clone())?; - let lm_head = if vb.contains_tensor("lm_head.weight") { + let lm_head = if cfg.tie_word_embeddings { linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))? } else { Linear::from_weights(base.embed_tokens.embeddings().clone(), None) diff --git a/candle-transformers/src/models/qwen3_moe.rs b/candle-transformers/src/models/qwen3_moe.rs index 61605d8b3a..232f5636e5 100644 --- a/candle-transformers/src/models/qwen3_moe.rs +++ b/candle-transformers/src/models/qwen3_moe.rs @@ -483,7 +483,7 @@ pub struct ModelForCausalLM { impl ModelForCausalLM { pub fn new(cfg: &Config, vb: VarBuilder) -> Result { let base = Model::new(cfg, vb.clone())?; - let lm_head = if vb.contains_tensor("lm_head.weight") { + let lm_head = if cfg.tie_word_embeddings { linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))? } else { Linear::from_weights(base.embed_tokens.embeddings().clone(), None) From 595007717f2de48f2fd6e6dc4b239058cd9ff92b Mon Sep 17 00:00:00 2001 From: keighbee Date: Wed, 30 Apr 2025 13:44:19 -0700 Subject: [PATCH 17/22] removed duplicate code from qwen3_moe --- candle-transformers/src/models/qwen3.rs | 16 +- candle-transformers/src/models/qwen3_moe.rs | 250 ++++---------------- 2 files changed, 58 insertions(+), 208 deletions(-) diff --git a/candle-transformers/src/models/qwen3.rs b/candle-transformers/src/models/qwen3.rs index d797e8b171..23766635de 100644 --- a/candle-transformers/src/models/qwen3.rs +++ b/candle-transformers/src/models/qwen3.rs @@ -27,13 +27,13 @@ pub struct Config { } #[derive(Debug, Clone)] -struct Qwen3RotaryEmbedding { +pub(crate) struct Qwen3RotaryEmbedding { sin: Tensor, cos: Tensor, } impl Qwen3RotaryEmbedding { - fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { + pub(crate) fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { let dim = cfg.head_dim; let max_seq_len = cfg.max_position_embeddings; let inv_freq: Vec<_> = (0..dim) @@ -64,7 +64,7 @@ impl Qwen3RotaryEmbedding { } #[derive(Debug, Clone)] -struct Qwen3MLP { +pub(crate) struct Qwen3MLP { gate_proj: Linear, up_proj: Linear, down_proj: Linear, @@ -72,7 +72,7 @@ struct Qwen3MLP { } impl Qwen3MLP { - fn new(cfg: &Config, vb: VarBuilder) -> Result { + pub(crate) fn new(cfg: &Config, vb: VarBuilder) -> Result { Ok(Self { gate_proj: linear_no_bias(cfg.hidden_size, cfg.intermediate_size, vb.pp("gate_proj"))?, up_proj: linear_no_bias(cfg.hidden_size, cfg.intermediate_size, vb.pp("up_proj"))?, @@ -91,7 +91,7 @@ impl Module for Qwen3MLP { } #[derive(Debug, Clone)] -struct Qwen3Attention { +pub(crate) struct Qwen3Attention { // projections q_proj: Linear, k_proj: Linear, @@ -114,7 +114,7 @@ struct Qwen3Attention { } impl Qwen3Attention { - fn new( + pub(crate) fn new( cfg: &Config, rotary_emb: Arc, layer_idx: usize, @@ -182,7 +182,7 @@ impl Qwen3Attention { }) } - fn forward(&mut self, x: &Tensor, attn_mask: Option<&Tensor>, offset: usize) -> Result { + pub(crate) fn forward(&mut self, x: &Tensor, attn_mask: Option<&Tensor>, offset: usize) -> Result { let (b, l, _) = x.dims3()?; // 1. Proj @@ -234,7 +234,7 @@ impl Qwen3Attention { .apply(&self.o_proj) } - fn clear_kv_cache(&mut self) { + pub(crate) fn clear_kv_cache(&mut self) { self.kv_cache.reset(); } } diff --git a/candle-transformers/src/models/qwen3_moe.rs b/candle-transformers/src/models/qwen3_moe.rs index 232f5636e5..4b400ca606 100644 --- a/candle-transformers/src/models/qwen3_moe.rs +++ b/candle-transformers/src/models/qwen3_moe.rs @@ -1,9 +1,9 @@ -use crate::{ - models::with_tracing::{linear_b, linear_no_bias, Linear, RmsNorm}, - utils::repeat_kv, +use crate::models::{ + qwen3::{Config as Qwen3Config, Qwen3Attention, Qwen3MLP, Qwen3RotaryEmbedding}, + with_tracing::{linear_no_bias, Linear, RmsNorm}, }; use candle::{DType, Device, Module, Result, Tensor, D}; -use candle_nn::{kv_cache::KvCache, Activation, VarBuilder}; +use candle_nn::{Activation, VarBuilder}; use std::sync::Arc; #[derive(Debug, Clone, PartialEq, serde::Deserialize)] @@ -32,62 +32,57 @@ pub struct Config { pub norm_topk_prob: bool, } -#[derive(Debug, Clone)] -struct Qwen3RotaryEmbedding { - sin: Tensor, - cos: Tensor, -} - -impl Qwen3RotaryEmbedding { - fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { - let dim = cfg.head_dim; - let max_seq_len = cfg.max_position_embeddings; - let inv_freq: Vec<_> = (0..dim) - .step_by(2) - .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32) - .collect(); - let inv_freq_len = inv_freq.len(); - let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; - let t = Tensor::arange(0u32, max_seq_len as u32, dev)? - .to_dtype(dtype)? - .reshape((max_seq_len, 1))?; - let freqs = t.matmul(&inv_freq)?; - Ok(Self { - sin: freqs.sin()?, - cos: freqs.cos()?, - }) - } - - fn apply(&self, q: &Tensor, k: &Tensor, offset: usize) -> Result<(Tensor, Tensor)> { - let (_, _, seq_len, _) = q.dims4()?; - let cos = self.cos.narrow(0, offset, seq_len)?; - let sin = self.sin.narrow(0, offset, seq_len)?; - let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?; - let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?; - Ok((q_embed, k_embed)) +impl From<&Config> for Qwen3Config { + fn from(val: &Config) -> Self { + Qwen3Config { + vocab_size: val.vocab_size, + hidden_size: val.hidden_size, + intermediate_size: val.intermediate_size, + num_hidden_layers: val.num_hidden_layers, + num_attention_heads: val.num_attention_heads, + head_dim: val.head_dim, + attention_bias: val.attention_bias, + num_key_value_heads: val.num_key_value_heads, + max_position_embeddings: val.max_position_embeddings, + sliding_window: val.sliding_window, + max_window_layers: val.max_window_layers, + tie_word_embeddings: val.tie_word_embeddings, + rope_theta: val.rope_theta, + rms_norm_eps: val.rms_norm_eps, + use_sliding_window: val.use_sliding_window, + hidden_act: val.hidden_act, + } } } #[derive(Debug, Clone)] -struct Qwen3MLP { +struct Qwen3MLPExpert { gate_proj: Linear, up_proj: Linear, down_proj: Linear, act_fn: Activation, } -impl Qwen3MLP { - fn new(cfg: &Config, intermediate_size: usize, vb: VarBuilder) -> Result { +impl Qwen3MLPExpert { + fn new(cfg: &Config, vb: VarBuilder) -> Result { Ok(Self { - gate_proj: linear_no_bias(cfg.hidden_size, intermediate_size, vb.pp("gate_proj"))?, - up_proj: linear_no_bias(cfg.hidden_size, intermediate_size, vb.pp("up_proj"))?, - down_proj: linear_no_bias(intermediate_size, cfg.hidden_size, vb.pp("down_proj"))?, + gate_proj: linear_no_bias( + cfg.hidden_size, + cfg.moe_intermediate_size, + vb.pp("gate_proj"), + )?, + up_proj: linear_no_bias(cfg.hidden_size, cfg.moe_intermediate_size, vb.pp("up_proj"))?, + down_proj: linear_no_bias( + cfg.moe_intermediate_size, + cfg.hidden_size, + vb.pp("down_proj"), + )?, act_fn: cfg.hidden_act, }) } } -impl Module for Qwen3MLP { +impl Module for Qwen3MLPExpert { fn forward(&self, x: &Tensor) -> Result { let lhs = x.apply(&self.gate_proj)?.apply(&self.act_fn)?; let rhs = x.apply(&self.up_proj)?; @@ -99,7 +94,7 @@ impl Module for Qwen3MLP { #[derive(Debug, Clone)] struct Qwen3SparseMoeBlock { gate: Linear, - experts: Vec, + experts: Vec, norm_topk_prob: bool, num_experts_per_tok: usize, } @@ -110,7 +105,7 @@ impl Qwen3SparseMoeBlock { let mut experts = Vec::with_capacity(cfg.num_experts); let vb_e = vb.pp("experts"); for idx in 0..cfg.num_experts { - let expert = Qwen3MLP::new(cfg, cfg.moe_intermediate_size, vb_e.pp(idx))?; + let expert = Qwen3MLPExpert::new(cfg, vb_e.pp(idx))?; experts.push(expert) } Ok(Self { @@ -180,168 +175,19 @@ impl Module for Qwen3SparseMoeBlock { // MLP or MoE decision enum #[derive(Debug, Clone)] enum Qwen3FeedForward { - MLP(Qwen3MLP), + Mlp(Qwen3MLP), MoE(Qwen3SparseMoeBlock), } impl Module for Qwen3FeedForward { fn forward(&self, xs: &Tensor) -> Result { match self { - Self::MLP(m) => m.forward(xs), + Self::Mlp(m) => m.forward(xs), Self::MoE(m) => m.forward(xs), } } } -#[derive(Debug, Clone)] -struct Qwen3Attention { - // projections - q_proj: Linear, - k_proj: Linear, - v_proj: Linear, - o_proj: Linear, - // norms - q_norm: RmsNorm, - k_norm: RmsNorm, - // hyper params - num_heads: usize, - num_kv_heads: usize, - num_kv_groups: usize, - head_dim: usize, - hidden_size: usize, - // sliding window - _sliding_window: Option, - // utils - rotary_emb: Arc, - kv_cache: KvCache, -} - -impl Qwen3Attention { - fn new( - cfg: &Config, - rotary_emb: Arc, - layer_idx: usize, - vb: VarBuilder, - ) -> Result { - let head_dim = cfg.head_dim; - let num_heads = cfg.num_attention_heads; - let num_kv_heads = cfg.num_key_value_heads; - let num_kv_groups = num_heads / num_kv_heads; - - let q_proj = linear_b( - cfg.hidden_size, - num_heads * head_dim, - cfg.attention_bias, - vb.pp("q_proj"), - )?; - let k_proj = linear_b( - cfg.hidden_size, - num_kv_heads * head_dim, - cfg.attention_bias, - vb.pp("k_proj"), - )?; - let v_proj = linear_b( - cfg.hidden_size, - num_kv_heads * head_dim, - cfg.attention_bias, - vb.pp("v_proj"), - )?; - let o_proj = linear_b( - num_heads * head_dim, - cfg.hidden_size, - cfg.attention_bias, - vb.pp("o_proj"), - )?; - - let q_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("q_norm"))?; - let k_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("k_norm"))?; - - let _sliding_window = if cfg.use_sliding_window && layer_idx >= cfg.max_window_layers { - cfg.sliding_window - } else { - None - }; - - // Necessary because the hidden_size in the config isn't always accurate - let hidden_size = head_dim * cfg.num_attention_heads; - - let kv_cache = KvCache::new(2, cfg.max_position_embeddings); - - Ok(Self { - q_proj, - k_proj, - v_proj, - o_proj, - q_norm, - k_norm, - num_heads, - num_kv_heads, - num_kv_groups, - head_dim, - hidden_size, - _sliding_window, - rotary_emb, - kv_cache, - }) - } - - fn forward(&mut self, x: &Tensor, attn_mask: Option<&Tensor>, offset: usize) -> Result { - let (b, l, _) = x.dims3()?; - - // 1. Proj - let q = self.q_proj.forward(x)?; - let k = self.k_proj.forward(x)?; - let v = self.v_proj.forward(x)?; - - // 2. Reshape: (B, L, H, D) -> (B, H, L, D) - let q = q - .reshape((b, l, self.num_heads, self.head_dim))? - .transpose(1, 2)?; - let k = k - .reshape((b, l, self.num_kv_heads, self.head_dim))? - .transpose(1, 2)?; - let v = v - .reshape((b, l, self.num_kv_heads, self.head_dim))? - .transpose(1, 2)?; - - // 3. Per‑head RMSNorm - let q_flat = q.flatten(0, 2)?; - let k_flat = k.flatten(0, 2)?; - let q_flat = self.q_norm.forward(&q_flat)?; - let k_flat = self.k_norm.forward(&k_flat)?; - let q = q_flat.reshape((b, self.num_heads, l, self.head_dim))?; - let k = k_flat.reshape((b, self.num_kv_heads, l, self.head_dim))?; - - // 4. RoPE - let (q, k) = self.rotary_emb.apply(&q, &k, offset)?; - - // 5. KV cache - let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?; - - // 6. GQA repeat_kv - let k = repeat_kv(k, self.num_kv_groups)?; - let v = repeat_kv(v, self.num_kv_groups)?; - - // 7. Attention score - let scale = 1.0 / (self.head_dim as f64).sqrt(); - let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?; - if let Some(m) = attn_mask { - scores = scores.broadcast_add(m)?; - } - let probs = candle_nn::ops::softmax_last_dim(&scores)?; - let ctx = probs.matmul(&v)?; // (B, H, L, D) - - // 8. Output proj - ctx.transpose(1, 2)? - .reshape((b, l, self.hidden_size))? - .apply(&self.o_proj) - } - - fn clear_kv_cache(&mut self) { - self.kv_cache.reset(); - } -} - #[derive(Debug, Clone)] struct DecoderLayer { self_attn: Qwen3Attention, @@ -357,14 +203,14 @@ impl DecoderLayer { rotary: Arc, vb: VarBuilder, ) -> Result { - let self_attn = Qwen3Attention::new(cfg, rotary, layer_idx, vb.pp("self_attn"))?; + let self_attn = Qwen3Attention::new(&cfg.into(), rotary, layer_idx, vb.pp("self_attn"))?; // Decide whether to use MoE or regular MLP based on layer_idx and decoder_sparse_step let feed_forward = if cfg.num_experts > 0 && (layer_idx + 1) % cfg.decoder_sparse_step == 0 { Qwen3FeedForward::MoE(Qwen3SparseMoeBlock::new(cfg, vb.pp("mlp"))?) } else { - Qwen3FeedForward::MLP(Qwen3MLP::new(cfg, cfg.intermediate_size, vb.pp("mlp"))?) + Qwen3FeedForward::Mlp(Qwen3MLP::new(&cfg.into(), vb.pp("mlp"))?) }; let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; @@ -409,7 +255,11 @@ impl Model { pub fn new(cfg: &Config, vb: VarBuilder) -> Result { let embed_tokens = candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?; - let rotary = Arc::new(Qwen3RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?); + let rotary = Arc::new(Qwen3RotaryEmbedding::new( + vb.dtype(), + &cfg.into(), + vb.device(), + )?); let mut layers = Vec::with_capacity(cfg.num_hidden_layers); let vb_l = vb.pp("model.layers"); for i in 0..cfg.num_hidden_layers { From c635621e401e022752d9225b7dc65d4340d2f07f Mon Sep 17 00:00:00 2001 From: keighbee Date: Wed, 30 Apr 2025 13:52:10 -0700 Subject: [PATCH 18/22] removed sliding window from qwen3 attention --- candle-transformers/src/models/qwen3.rs | 32 +++++++++------------ candle-transformers/src/models/qwen3_moe.rs | 2 +- 2 files changed, 14 insertions(+), 20 deletions(-) diff --git a/candle-transformers/src/models/qwen3.rs b/candle-transformers/src/models/qwen3.rs index 23766635de..fad3061cd3 100644 --- a/candle-transformers/src/models/qwen3.rs +++ b/candle-transformers/src/models/qwen3.rs @@ -106,8 +106,6 @@ pub(crate) struct Qwen3Attention { num_kv_groups: usize, head_dim: usize, hidden_size: usize, - // sliding window - _sliding_window: Option, // utils rotary_emb: Arc, kv_cache: KvCache, @@ -117,9 +115,12 @@ impl Qwen3Attention { pub(crate) fn new( cfg: &Config, rotary_emb: Arc, - layer_idx: usize, vb: VarBuilder, ) -> Result { + if cfg.use_sliding_window { + candle::bail!("sliding window is not suppored") + } + let head_dim = cfg.head_dim; let num_heads = cfg.num_attention_heads; let num_kv_heads = cfg.num_key_value_heads; @@ -153,12 +154,6 @@ impl Qwen3Attention { let q_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("q_norm"))?; let k_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("k_norm"))?; - let sliding_window = if cfg.use_sliding_window && layer_idx >= cfg.max_window_layers { - cfg.sliding_window - } else { - None - }; - // Necessary because the hidden_size in the cofig isn't always accurate let hidden_size = head_dim * cfg.num_attention_heads; @@ -176,13 +171,17 @@ impl Qwen3Attention { num_kv_groups, head_dim, hidden_size, - _sliding_window: sliding_window, rotary_emb, kv_cache, }) } - pub(crate) fn forward(&mut self, x: &Tensor, attn_mask: Option<&Tensor>, offset: usize) -> Result { + pub(crate) fn forward( + &mut self, + x: &Tensor, + attn_mask: Option<&Tensor>, + offset: usize, + ) -> Result { let (b, l, _) = x.dims3()?; // 1. Proj @@ -248,13 +247,8 @@ struct DecoderLayer { } impl DecoderLayer { - fn new( - cfg: &Config, - rotary: Arc, - idx: usize, - vb: VarBuilder, - ) -> Result { - let self_attn = Qwen3Attention::new(cfg, rotary, idx, vb.pp("self_attn"))?; + fn new(cfg: &Config, rotary: Arc, vb: VarBuilder) -> Result { + let self_attn = Qwen3Attention::new(cfg, rotary, vb.pp("self_attn"))?; let mlp = Qwen3MLP::new(cfg, vb.pp("mlp"))?; let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; let ln2 = RmsNorm::new( @@ -301,7 +295,7 @@ impl Model { let mut layers = Vec::with_capacity(cfg.num_hidden_layers); let vb_l = vb.pp("model.layers"); for i in 0..cfg.num_hidden_layers { - layers.push(DecoderLayer::new(cfg, rotary.clone(), i, vb_l.pp(i))?); + layers.push(DecoderLayer::new(cfg, rotary.clone(), vb_l.pp(i))?); } Ok(Self { embed_tokens, diff --git a/candle-transformers/src/models/qwen3_moe.rs b/candle-transformers/src/models/qwen3_moe.rs index 4b400ca606..ba605a71d6 100644 --- a/candle-transformers/src/models/qwen3_moe.rs +++ b/candle-transformers/src/models/qwen3_moe.rs @@ -203,7 +203,7 @@ impl DecoderLayer { rotary: Arc, vb: VarBuilder, ) -> Result { - let self_attn = Qwen3Attention::new(&cfg.into(), rotary, layer_idx, vb.pp("self_attn"))?; + let self_attn = Qwen3Attention::new(&cfg.into(), rotary, vb.pp("self_attn"))?; // Decide whether to use MoE or regular MLP based on layer_idx and decoder_sparse_step let feed_forward = if cfg.num_experts > 0 && (layer_idx + 1) % cfg.decoder_sparse_step == 0 From 86e69ddea328aacb9a10d2855ddfe6b829596a4e Mon Sep 17 00:00:00 2001 From: keighbee Date: Wed, 30 Apr 2025 14:03:37 -0700 Subject: [PATCH 19/22] removed MoE code --- candle-examples/examples/qwen/main.rs | 7 - candle-transformers/src/models/mod.rs | 1 - candle-transformers/src/models/qwen3_moe.rs | 355 -------------------- 3 files changed, 363 deletions(-) delete mode 100644 candle-transformers/src/models/qwen3_moe.rs diff --git a/candle-examples/examples/qwen/main.rs b/candle-examples/examples/qwen/main.rs index eb2d7b2cc7..5e72ab0234 100644 --- a/candle-examples/examples/qwen/main.rs +++ b/candle-examples/examples/qwen/main.rs @@ -10,7 +10,6 @@ use clap::Parser; use candle_transformers::models::qwen2::{Config as ConfigBase, ModelForCausalLM as ModelBase}; use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe}; use candle_transformers::models::qwen3::{Config as Config3, ModelForCausalLM as Model3}; -use candle_transformers::models::qwen3_moe::{Config as ConfigMoe3, Model as ModelMoe3}; use candle::{DType, Device, Tensor}; use candle_examples::token_output_stream::TokenOutputStream; @@ -23,7 +22,6 @@ enum Model { Base(ModelBase), Moe(ModelMoe), Base3(Model3), - Moe3(ModelMoe3), } impl Model { @@ -32,7 +30,6 @@ impl Model { Self::Moe(ref mut m) => m.forward(xs, s), Self::Base(ref mut m) => m.forward(xs, s), Self::Base3(ref mut m) => m.forward(xs, s), - Self::Moe3(ref mut m) => m.forward(xs, s), } } } @@ -341,10 +338,6 @@ fn main() -> Result<()> { let config: Config3 = serde_json::from_slice(&std::fs::read(config_file)?)?; Model::Base3(Model3::new(&config, vb)?) } - WhichModel::W3MoeA3b => { - let config: ConfigMoe3 = serde_json::from_slice(&std::fs::read(config_file)?)?; - Model::Moe3(ModelMoe3::new(&config, vb)?) - } _ => { let config: ConfigBase = serde_json::from_slice(&std::fs::read(config_file)?)?; Model::Base(ModelBase::new(&config, vb)?) diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 1cb25a133d..9e198f63a7 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -97,7 +97,6 @@ pub mod quantized_t5; pub mod qwen2; pub mod qwen2_moe; pub mod qwen3; -pub mod qwen3_moe; pub mod recurrent_gemma; pub mod repvgg; pub mod resnet; diff --git a/candle-transformers/src/models/qwen3_moe.rs b/candle-transformers/src/models/qwen3_moe.rs deleted file mode 100644 index ba605a71d6..0000000000 --- a/candle-transformers/src/models/qwen3_moe.rs +++ /dev/null @@ -1,355 +0,0 @@ -use crate::models::{ - qwen3::{Config as Qwen3Config, Qwen3Attention, Qwen3MLP, Qwen3RotaryEmbedding}, - with_tracing::{linear_no_bias, Linear, RmsNorm}, -}; -use candle::{DType, Device, Module, Result, Tensor, D}; -use candle_nn::{Activation, VarBuilder}; -use std::sync::Arc; - -#[derive(Debug, Clone, PartialEq, serde::Deserialize)] -pub struct Config { - pub vocab_size: usize, - pub hidden_size: usize, - pub intermediate_size: usize, - pub num_hidden_layers: usize, - pub num_attention_heads: usize, - pub head_dim: usize, - pub attention_bias: bool, - pub num_key_value_heads: usize, - pub max_position_embeddings: usize, - pub sliding_window: Option, - pub max_window_layers: usize, - pub tie_word_embeddings: bool, - pub rope_theta: f64, - pub rms_norm_eps: f64, - pub use_sliding_window: bool, - pub hidden_act: Activation, - // MoE specific configuration - pub decoder_sparse_step: usize, - pub moe_intermediate_size: usize, - pub num_experts_per_tok: usize, - pub num_experts: usize, - pub norm_topk_prob: bool, -} - -impl From<&Config> for Qwen3Config { - fn from(val: &Config) -> Self { - Qwen3Config { - vocab_size: val.vocab_size, - hidden_size: val.hidden_size, - intermediate_size: val.intermediate_size, - num_hidden_layers: val.num_hidden_layers, - num_attention_heads: val.num_attention_heads, - head_dim: val.head_dim, - attention_bias: val.attention_bias, - num_key_value_heads: val.num_key_value_heads, - max_position_embeddings: val.max_position_embeddings, - sliding_window: val.sliding_window, - max_window_layers: val.max_window_layers, - tie_word_embeddings: val.tie_word_embeddings, - rope_theta: val.rope_theta, - rms_norm_eps: val.rms_norm_eps, - use_sliding_window: val.use_sliding_window, - hidden_act: val.hidden_act, - } - } -} - -#[derive(Debug, Clone)] -struct Qwen3MLPExpert { - gate_proj: Linear, - up_proj: Linear, - down_proj: Linear, - act_fn: Activation, -} - -impl Qwen3MLPExpert { - fn new(cfg: &Config, vb: VarBuilder) -> Result { - Ok(Self { - gate_proj: linear_no_bias( - cfg.hidden_size, - cfg.moe_intermediate_size, - vb.pp("gate_proj"), - )?, - up_proj: linear_no_bias(cfg.hidden_size, cfg.moe_intermediate_size, vb.pp("up_proj"))?, - down_proj: linear_no_bias( - cfg.moe_intermediate_size, - cfg.hidden_size, - vb.pp("down_proj"), - )?, - act_fn: cfg.hidden_act, - }) - } -} - -impl Module for Qwen3MLPExpert { - fn forward(&self, x: &Tensor) -> Result { - let lhs = x.apply(&self.gate_proj)?.apply(&self.act_fn)?; - let rhs = x.apply(&self.up_proj)?; - (lhs * rhs)?.apply(&self.down_proj) - } -} - -// Qwen3 Sparse MoE Block implementation -#[derive(Debug, Clone)] -struct Qwen3SparseMoeBlock { - gate: Linear, - experts: Vec, - norm_topk_prob: bool, - num_experts_per_tok: usize, -} - -impl Qwen3SparseMoeBlock { - fn new(cfg: &Config, vb: VarBuilder) -> Result { - let gate = linear_no_bias(cfg.hidden_size, cfg.num_experts, vb.pp("gate"))?; - let mut experts = Vec::with_capacity(cfg.num_experts); - let vb_e = vb.pp("experts"); - for idx in 0..cfg.num_experts { - let expert = Qwen3MLPExpert::new(cfg, vb_e.pp(idx))?; - experts.push(expert) - } - Ok(Self { - gate, - experts, - norm_topk_prob: cfg.norm_topk_prob, - num_experts_per_tok: cfg.num_experts_per_tok, - }) - } -} - -impl Module for Qwen3SparseMoeBlock { - fn forward(&self, xs: &Tensor) -> Result { - let (b_size, seq_len, hidden_dim) = xs.dims3()?; - let xs = xs.reshape(((), hidden_dim))?; - let router_logits = xs.apply(&self.gate)?; - let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?; - - // Extract topk experts per token - let experts_per_tok = routing_weights - .arg_sort_last_dim(false)? - .narrow(D::Minus1, 0, self.num_experts_per_tok)? - .contiguous()?; - let routing_weights = routing_weights.gather(&experts_per_tok, D::Minus1)?; - - // Extract needed data - let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::()?; - let experts_per_tok = experts_per_tok.to_vec2::()?; - let mut top_x = vec![vec![]; self.experts.len()]; - let mut selected_experts = vec![vec![]; self.experts.len()]; - for (row_idx, (rw, expert_idxs)) in routing_weights - .iter() - .zip(experts_per_tok.iter()) - .enumerate() - { - let sum_rw = rw.iter().sum::(); - for (&rw, &expert_idx) in rw.iter().zip(expert_idxs.iter()) { - top_x[expert_idx as usize].push(row_idx as u32); - let rw = if self.norm_topk_prob { rw / sum_rw } else { rw }; - selected_experts[expert_idx as usize].push(rw) - } - } - - // Process through experts - let mut ys = xs.zeros_like()?; - for (expert_idx, expert_layer) in self.experts.iter().enumerate() { - let top_x = &top_x[expert_idx]; - if top_x.is_empty() { - continue; - } - let top_x = Tensor::new(top_x.as_slice(), xs.device())?; - let selected_experts = - Tensor::new(selected_experts[expert_idx].as_slice(), xs.device())? - .reshape(((), 1))? - .to_dtype(xs.dtype())?; - - let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?; - let current_hidden_states = expert_layer.forward(¤t_state)?; - let current_hidden_states = current_hidden_states.broadcast_mul(&selected_experts)?; - ys = ys.index_add(&top_x, ¤t_hidden_states, 0)?; - } - - ys.reshape((b_size, seq_len, hidden_dim)) - } -} - -// MLP or MoE decision enum -#[derive(Debug, Clone)] -enum Qwen3FeedForward { - Mlp(Qwen3MLP), - MoE(Qwen3SparseMoeBlock), -} - -impl Module for Qwen3FeedForward { - fn forward(&self, xs: &Tensor) -> Result { - match self { - Self::Mlp(m) => m.forward(xs), - Self::MoE(m) => m.forward(xs), - } - } -} - -#[derive(Debug, Clone)] -struct DecoderLayer { - self_attn: Qwen3Attention, - feed_forward: Qwen3FeedForward, - ln1: RmsNorm, - ln2: RmsNorm, -} - -impl DecoderLayer { - fn new( - layer_idx: usize, - cfg: &Config, - rotary: Arc, - vb: VarBuilder, - ) -> Result { - let self_attn = Qwen3Attention::new(&cfg.into(), rotary, vb.pp("self_attn"))?; - - // Decide whether to use MoE or regular MLP based on layer_idx and decoder_sparse_step - let feed_forward = if cfg.num_experts > 0 && (layer_idx + 1) % cfg.decoder_sparse_step == 0 - { - Qwen3FeedForward::MoE(Qwen3SparseMoeBlock::new(cfg, vb.pp("mlp"))?) - } else { - Qwen3FeedForward::Mlp(Qwen3MLP::new(&cfg.into(), vb.pp("mlp"))?) - }; - - let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; - let ln2 = RmsNorm::new( - cfg.hidden_size, - cfg.rms_norm_eps, - vb.pp("post_attention_layernorm"), - )?; - - Ok(Self { - self_attn, - feed_forward, - ln1, - ln2, - }) - } - - fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result { - let h = self.ln1.forward(x)?; - let h = self.self_attn.forward(&h, mask, offset)?; - let x = (x + h)?; - let h2 = self.ln2.forward(&x)?; - let h2 = h2.apply(&self.feed_forward)?; - x + h2 - } - - fn clear_kv_cache(&mut self) { - self.self_attn.clear_kv_cache(); - } -} - -#[derive(Debug, Clone)] -pub struct Model { - embed_tokens: candle_nn::Embedding, - layers: Vec, - norm: RmsNorm, - device: Device, - dtype: DType, -} - -impl Model { - pub fn new(cfg: &Config, vb: VarBuilder) -> Result { - let embed_tokens = - candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?; - let rotary = Arc::new(Qwen3RotaryEmbedding::new( - vb.dtype(), - &cfg.into(), - vb.device(), - )?); - let mut layers = Vec::with_capacity(cfg.num_hidden_layers); - let vb_l = vb.pp("model.layers"); - for i in 0..cfg.num_hidden_layers { - layers.push(DecoderLayer::new(i, cfg, rotary.clone(), vb_l.pp(i))?); - } - Ok(Self { - embed_tokens, - layers, - norm: RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?, - device: vb.device().clone(), - dtype: vb.dtype(), - }) - } - - fn clear_kv_cache(&mut self) { - for l in &mut self.layers { - l.clear_kv_cache(); - } - } - - fn causal_mask( - &self, - b: usize, - tgt: usize, - offset: usize, - sw: Option, - ) -> Result { - let minf = f32::NEG_INFINITY; - let mask: Vec<_> = (0..tgt) - .flat_map(|i| { - (0..(tgt + offset)).map(move |j| { - let past_ok = j <= i + offset; - let sw_ok = match sw { - Some(w) => (i + offset) as i64 - j as i64 <= w as i64, - None => true, - }; - if past_ok && sw_ok { - 0. - } else { - minf - } - }) - }) - .collect(); - Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype) - } - - pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result { - let (b, l) = input.dims2()?; - let mut h = self.embed_tokens.forward(input)?; - - let causal = if l == 1 { - None - } else { - Some(self.causal_mask(b, l, offset, None)?) - }; - - for layer in &mut self.layers { - h = layer.forward(&h, causal.as_ref(), offset)?; - } - self.norm.forward(&h) - } -} - -#[derive(Debug, Clone)] -pub struct ModelForCausalLM { - base: Model, - lm_head: Linear, -} - -impl ModelForCausalLM { - pub fn new(cfg: &Config, vb: VarBuilder) -> Result { - let base = Model::new(cfg, vb.clone())?; - let lm_head = if cfg.tie_word_embeddings { - linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))? - } else { - Linear::from_weights(base.embed_tokens.embeddings().clone(), None) - }; - Ok(Self { base, lm_head }) - } - - pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result { - let (_, l) = input.dims2()?; - self.base - .forward(input, offset)? - .narrow(1, l - 1, 1)? - .apply(&self.lm_head) - } - - pub fn clear_kv_cache(&mut self) { - self.base.clear_kv_cache(); - } -} From d943cbe41970c2cdcb2c8f4a3dcc1f5ba95b7ffb Mon Sep 17 00:00:00 2001 From: keighbee Date: Wed, 30 Apr 2025 14:04:57 -0700 Subject: [PATCH 20/22] removed unused option --- candle-examples/examples/qwen/main.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/candle-examples/examples/qwen/main.rs b/candle-examples/examples/qwen/main.rs index 5e72ab0234..d0e179e0ca 100644 --- a/candle-examples/examples/qwen/main.rs +++ b/candle-examples/examples/qwen/main.rs @@ -167,8 +167,6 @@ enum WhichModel { W3_4b, #[value(name = "3-8b")] W3_8b, - #[value(name = "3-moe-a3b")] - W3MoeA3b, } #[derive(Parser, Debug)] @@ -275,7 +273,6 @@ fn main() -> Result<()> { WhichModel::W3_1_7b => ("3", "1.7B"), WhichModel::W3_4b => ("3", "4B"), WhichModel::W3_8b => ("3", "8B"), - WhichModel::W3MoeA3b => ("3", "30B-A3B"), }; format!("Qwen/Qwen{version}-{size}") } @@ -311,8 +308,7 @@ fn main() -> Result<()> { | WhichModel::MoeA27b | WhichModel::W3_1_7b | WhichModel::W3_4b - | WhichModel::W3_8b - | WhichModel::W3MoeA3b => { + | WhichModel::W3_8b => { candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")? } }, From a3d7f6e1879a7a7645eedaae9b0a2d21a568c316 Mon Sep 17 00:00:00 2001 From: Kyle Birnbaum Date: Thu, 1 May 2025 13:30:27 -0700 Subject: [PATCH 21/22] Fixed Typo Co-authored-by: Laurent Mazare --- candle-transformers/src/models/qwen3.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candle-transformers/src/models/qwen3.rs b/candle-transformers/src/models/qwen3.rs index fad3061cd3..8411345929 100644 --- a/candle-transformers/src/models/qwen3.rs +++ b/candle-transformers/src/models/qwen3.rs @@ -154,7 +154,7 @@ impl Qwen3Attention { let q_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("q_norm"))?; let k_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("k_norm"))?; - // Necessary because the hidden_size in the cofig isn't always accurate + // Necessary because the hidden_size in the config isn't always accurate let hidden_size = head_dim * cfg.num_attention_heads; let kv_cache = KvCache::new(2, cfg.max_position_embeddings); From bdbefa56650b8d6382c7c784db271709c1b33741 Mon Sep 17 00:00:00 2001 From: keighbee Date: Thu, 1 May 2025 13:33:46 -0700 Subject: [PATCH 22/22] fixed tie word embeddings to use the correct embedding weights instead of the opposite --- candle-transformers/src/models/qwen3.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/candle-transformers/src/models/qwen3.rs b/candle-transformers/src/models/qwen3.rs index 8411345929..30ea3c1561 100644 --- a/candle-transformers/src/models/qwen3.rs +++ b/candle-transformers/src/models/qwen3.rs @@ -366,9 +366,9 @@ impl ModelForCausalLM { pub fn new(cfg: &Config, vb: VarBuilder) -> Result { let base = Model::new(cfg, vb.clone())?; let lm_head = if cfg.tie_word_embeddings { - linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))? - } else { Linear::from_weights(base.embed_tokens.embeddings().clone(), None) + } else { + linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))? }; Ok(Self { base, lm_head }) }