diff --git a/candle-examples/examples/qwen/main.rs b/candle-examples/examples/qwen/main.rs index 53f2f70dd1..d0e179e0ca 100644 --- a/candle-examples/examples/qwen/main.rs +++ b/candle-examples/examples/qwen/main.rs @@ -9,6 +9,7 @@ use clap::Parser; use candle_transformers::models::qwen2::{Config as ConfigBase, ModelForCausalLM as ModelBase}; use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe}; +use candle_transformers::models::qwen3::{Config as Config3, ModelForCausalLM as Model3}; use candle::{DType, Device, Tensor}; use candle_examples::token_output_stream::TokenOutputStream; @@ -20,6 +21,7 @@ use tokenizers::Tokenizer; enum Model { Base(ModelBase), Moe(ModelMoe), + Base3(Model3), } impl Model { @@ -27,6 +29,7 @@ impl Model { match self { Self::Moe(ref mut m) => m.forward(xs, s), Self::Base(ref mut m) => m.forward(xs, s), + Self::Base3(ref mut m) => m.forward(xs, s), } } } @@ -85,6 +88,10 @@ impl TextGeneration { Some(token) => token, None => anyhow::bail!("cannot find the <|endoftext|> token"), }; + let eos_token2 = match self.tokenizer.get_token("<|im_end|>") { + Some(token) => token, + None => anyhow::bail!("cannot find the <|im_end|> token"), + }; let start_gen = std::time::Instant::now(); for index in 0..sample_len { let context_size = if index > 0 { 1 } else { tokens.len() }; @@ -107,7 +114,7 @@ impl TextGeneration { let next_token = self.logits_processor.sample(&logits)?; tokens.push(next_token); generated_tokens += 1; - if next_token == eos_token { + if next_token == eos_token || next_token == eos_token2 { break; } if let Some(t) = self.tokenizer.next_token(next_token)? { @@ -152,6 +159,14 @@ enum WhichModel { W2_7b, #[value(name = "2-72b")] W2_72b, + #[value(name = "3-0.6b")] + W3_0_6b, + #[value(name = "3-1.7b")] + W3_1_7b, + #[value(name = "3-4b")] + W3_4b, + #[value(name = "3-8b")] + W3_8b, } #[derive(Parser, Debug)] @@ -254,6 +269,10 @@ fn main() -> Result<()> { WhichModel::W14b => ("1.5", "14B"), WhichModel::W72b => ("1.5", "72B"), WhichModel::MoeA27b => ("1.5", "MoE-A2.7B"), + WhichModel::W3_0_6b => ("3", "0.6B"), + WhichModel::W3_1_7b => ("3", "1.7B"), + WhichModel::W3_4b => ("3", "4B"), + WhichModel::W3_8b => ("3", "8B"), }; format!("Qwen/Qwen{version}-{size}") } @@ -273,7 +292,11 @@ fn main() -> Result<()> { .map(std::path::PathBuf::from) .collect::>(), None => match args.model { - WhichModel::W0_5b | WhichModel::W2_0_5b | WhichModel::W2_1_5b | WhichModel::W1_8b => { + WhichModel::W0_5b + | WhichModel::W2_0_5b + | WhichModel::W2_1_5b + | WhichModel::W1_8b + | WhichModel::W3_0_6b => { vec![repo.get("model.safetensors")?] } WhichModel::W4b @@ -282,7 +305,10 @@ fn main() -> Result<()> { | WhichModel::W14b | WhichModel::W72b | WhichModel::W2_72b - | WhichModel::MoeA27b => { + | WhichModel::MoeA27b + | WhichModel::W3_1_7b + | WhichModel::W3_4b + | WhichModel::W3_8b => { candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")? } }, @@ -304,6 +330,10 @@ fn main() -> Result<()> { let config: ConfigMoe = serde_json::from_slice(&std::fs::read(config_file)?)?; Model::Moe(ModelMoe::new(&config, vb)?) } + WhichModel::W3_0_6b | WhichModel::W3_1_7b | WhichModel::W3_4b | WhichModel::W3_8b => { + let config: Config3 = serde_json::from_slice(&std::fs::read(config_file)?)?; + Model::Base3(Model3::new(&config, vb)?) + } _ => { let config: ConfigBase = serde_json::from_slice(&std::fs::read(config_file)?)?; Model::Base(ModelBase::new(&config, vb)?) diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index bdb8d267b5..9e198f63a7 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -96,6 +96,7 @@ pub mod quantized_stable_lm; pub mod quantized_t5; pub mod qwen2; pub mod qwen2_moe; +pub mod qwen3; pub mod recurrent_gemma; pub mod repvgg; pub mod resnet; diff --git a/candle-transformers/src/models/qwen3.rs b/candle-transformers/src/models/qwen3.rs new file mode 100644 index 0000000000..30ea3c1561 --- /dev/null +++ b/candle-transformers/src/models/qwen3.rs @@ -0,0 +1,387 @@ +use crate::{ + models::with_tracing::{linear_b, linear_no_bias, Linear, RmsNorm}, + utils::repeat_kv, +}; +use candle::{DType, Device, Module, Result, Tensor}; +use candle_nn::{kv_cache::KvCache, Activation, VarBuilder}; +use std::sync::Arc; + +#[derive(Debug, Clone, PartialEq, serde::Deserialize)] +pub struct Config { + pub vocab_size: usize, + pub hidden_size: usize, + pub intermediate_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub head_dim: usize, + pub attention_bias: bool, + pub num_key_value_heads: usize, + pub max_position_embeddings: usize, + pub sliding_window: Option, + pub max_window_layers: usize, + pub tie_word_embeddings: bool, + pub rope_theta: f64, + pub rms_norm_eps: f64, + pub use_sliding_window: bool, + pub hidden_act: Activation, +} + +#[derive(Debug, Clone)] +pub(crate) struct Qwen3RotaryEmbedding { + sin: Tensor, + cos: Tensor, +} + +impl Qwen3RotaryEmbedding { + pub(crate) fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { + let dim = cfg.head_dim; + let max_seq_len = cfg.max_position_embeddings; + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; + let t = Tensor::arange(0u32, max_seq_len as u32, dev)? + .to_dtype(dtype)? + .reshape((max_seq_len, 1))?; + let freqs = t.matmul(&inv_freq)?; + Ok(Self { + sin: freqs.sin()?, + cos: freqs.cos()?, + }) + } + + /// Apply RoPE (q, k shape: B x H x L x D) + fn apply(&self, q: &Tensor, k: &Tensor, offset: usize) -> Result<(Tensor, Tensor)> { + let (_, _, seq_len, _) = q.dims4()?; + let cos = self.cos.narrow(0, offset, seq_len)?; + let sin = self.sin.narrow(0, offset, seq_len)?; + let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?; + let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?; + Ok((q_embed, k_embed)) + } +} + +#[derive(Debug, Clone)] +pub(crate) struct Qwen3MLP { + gate_proj: Linear, + up_proj: Linear, + down_proj: Linear, + act_fn: Activation, +} + +impl Qwen3MLP { + pub(crate) fn new(cfg: &Config, vb: VarBuilder) -> Result { + Ok(Self { + gate_proj: linear_no_bias(cfg.hidden_size, cfg.intermediate_size, vb.pp("gate_proj"))?, + up_proj: linear_no_bias(cfg.hidden_size, cfg.intermediate_size, vb.pp("up_proj"))?, + down_proj: linear_no_bias(cfg.intermediate_size, cfg.hidden_size, vb.pp("down_proj"))?, + act_fn: cfg.hidden_act, + }) + } +} + +impl Module for Qwen3MLP { + fn forward(&self, x: &Tensor) -> Result { + let lhs = x.apply(&self.gate_proj)?.apply(&self.act_fn)?; + let rhs = x.apply(&self.up_proj)?; + (lhs * rhs)?.apply(&self.down_proj) + } +} + +#[derive(Debug, Clone)] +pub(crate) struct Qwen3Attention { + // projections + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + o_proj: Linear, + // norms + q_norm: RmsNorm, + k_norm: RmsNorm, + // hyper params + num_heads: usize, + num_kv_heads: usize, + num_kv_groups: usize, + head_dim: usize, + hidden_size: usize, + // utils + rotary_emb: Arc, + kv_cache: KvCache, +} + +impl Qwen3Attention { + pub(crate) fn new( + cfg: &Config, + rotary_emb: Arc, + vb: VarBuilder, + ) -> Result { + if cfg.use_sliding_window { + candle::bail!("sliding window is not suppored") + } + + let head_dim = cfg.head_dim; + let num_heads = cfg.num_attention_heads; + let num_kv_heads = cfg.num_key_value_heads; + let num_kv_groups = num_heads / num_kv_heads; + + let q_proj = linear_b( + cfg.hidden_size, + num_heads * head_dim, + cfg.attention_bias, + vb.pp("q_proj"), + )?; + let k_proj = linear_b( + cfg.hidden_size, + num_kv_heads * head_dim, + cfg.attention_bias, + vb.pp("k_proj"), + )?; + let v_proj = linear_b( + cfg.hidden_size, + num_kv_heads * head_dim, + cfg.attention_bias, + vb.pp("v_proj"), + )?; + let o_proj = linear_b( + num_heads * head_dim, + cfg.hidden_size, + cfg.attention_bias, + vb.pp("o_proj"), + )?; + + let q_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("q_norm"))?; + let k_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("k_norm"))?; + + // Necessary because the hidden_size in the config isn't always accurate + let hidden_size = head_dim * cfg.num_attention_heads; + + let kv_cache = KvCache::new(2, cfg.max_position_embeddings); + + Ok(Self { + q_proj, + k_proj, + v_proj, + o_proj, + q_norm, + k_norm, + num_heads, + num_kv_heads, + num_kv_groups, + head_dim, + hidden_size, + rotary_emb, + kv_cache, + }) + } + + pub(crate) fn forward( + &mut self, + x: &Tensor, + attn_mask: Option<&Tensor>, + offset: usize, + ) -> Result { + let (b, l, _) = x.dims3()?; + + // 1. Proj + let q = self.q_proj.forward(x)?; + let k = self.k_proj.forward(x)?; + let v = self.v_proj.forward(x)?; + + // 2. Reshape: (B, L, H, D) -> (B, H, L, D) + let q = q + .reshape((b, l, self.num_heads, self.head_dim))? + .transpose(1, 2)?; + let k = k + .reshape((b, l, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + let v = v + .reshape((b, l, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + + // 3. Per‑head RMSNorm + let q_flat = q.flatten(0, 2)?; // (B*H, L, D) -> (BHL, D) after transpose later + let k_flat = k.flatten(0, 2)?; + let q_flat = self.q_norm.forward(&q_flat)?; + let k_flat = self.k_norm.forward(&k_flat)?; + let q = q_flat.reshape((b, self.num_heads, l, self.head_dim))?; + let k = k_flat.reshape((b, self.num_kv_heads, l, self.head_dim))?; + + // 4. RoPE + let (q, k) = self.rotary_emb.apply(&q, &k, offset)?; + + // 5. Accumulate KV cache + let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?; + + // 6. GQA repeat_kv + let k = repeat_kv(k, self.num_kv_groups)?; + let v = repeat_kv(v, self.num_kv_groups)?; + + // 7. Attention score + let scale = 1.0 / (self.head_dim as f64).sqrt(); + let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?; + if let Some(m) = attn_mask { + scores = scores.broadcast_add(m)?; + } + let probs = candle_nn::ops::softmax_last_dim(&scores)?; + let ctx = probs.matmul(&v)?; // (B, H, L, D) + + // 8. Output proj + ctx.transpose(1, 2)? + .reshape((b, l, self.hidden_size))? + .apply(&self.o_proj) + } + + pub(crate) fn clear_kv_cache(&mut self) { + self.kv_cache.reset(); + } +} + +#[derive(Debug, Clone)] +struct DecoderLayer { + self_attn: Qwen3Attention, + mlp: Qwen3MLP, + ln1: RmsNorm, + ln2: RmsNorm, +} + +impl DecoderLayer { + fn new(cfg: &Config, rotary: Arc, vb: VarBuilder) -> Result { + let self_attn = Qwen3Attention::new(cfg, rotary, vb.pp("self_attn"))?; + let mlp = Qwen3MLP::new(cfg, vb.pp("mlp"))?; + let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; + let ln2 = RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + vb.pp("post_attention_layernorm"), + )?; + Ok(Self { + self_attn, + mlp, + ln1, + ln2, + }) + } + + fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result { + let h = self.ln1.forward(x)?; + let h = self.self_attn.forward(&h, mask, offset)?; + let x = (x + h)?; + let h2 = self.ln2.forward(&x)?; + let h2 = h2.apply(&self.mlp)?; + x + h2 + } + + fn clear_kv_cache(&mut self) { + self.self_attn.clear_kv_cache(); + } +} + +#[derive(Debug, Clone)] +pub struct Model { + embed_tokens: candle_nn::Embedding, + layers: Vec, + norm: RmsNorm, + device: Device, + dtype: DType, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let embed_tokens = + candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?; + let rotary = Arc::new(Qwen3RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?); + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + let vb_l = vb.pp("model.layers"); + for i in 0..cfg.num_hidden_layers { + layers.push(DecoderLayer::new(cfg, rotary.clone(), vb_l.pp(i))?); + } + Ok(Self { + embed_tokens, + layers, + norm: RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?, + device: vb.device().clone(), + dtype: vb.dtype(), + }) + } + + fn clear_kv_cache(&mut self) { + for l in &mut self.layers { + l.clear_kv_cache(); + } + } + + fn causal_mask( + &self, + b: usize, + tgt: usize, + offset: usize, + sw: Option, + ) -> Result { + let minf = f32::NEG_INFINITY; + let mask: Vec<_> = (0..tgt) + .flat_map(|i| { + (0..(tgt + offset)).map(move |j| { + let past_ok = j <= i + offset; + let sw_ok = match sw { + Some(w) => (i + offset) as i64 - j as i64 <= w as i64, + None => true, + }; + if past_ok && sw_ok { + 0. + } else { + minf + } + }) + }) + .collect(); + Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype) + } + + pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result { + let (b, l) = input.dims2()?; + let mut h = self.embed_tokens.forward(input)?; + + let causal = if l == 1 { + None + } else { + Some(self.causal_mask(b, l, offset, None)?) + }; + + for layer in &mut self.layers { + h = layer.forward(&h, causal.as_ref(), offset)?; + } + self.norm.forward(&h) + } +} + +#[derive(Debug, Clone)] +pub struct ModelForCausalLM { + base: Model, + lm_head: Linear, +} + +impl ModelForCausalLM { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let base = Model::new(cfg, vb.clone())?; + let lm_head = if cfg.tie_word_embeddings { + Linear::from_weights(base.embed_tokens.embeddings().clone(), None) + } else { + linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))? + }; + Ok(Self { base, lm_head }) + } + + pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result { + let (_, l) = input.dims2()?; + self.base + .forward(input, offset)? + .narrow(1, l - 1, 1)? + .apply(&self.lm_head) + } + + pub fn clear_kv_cache(&mut self) { + self.base.clear_kv_cache(); + } +}