Skip to content

Updating Add qwen3 (PR 2903) to use HF weights #2930

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
May 2, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
7a13e58
add Qwen3.rs
maximizemaxwell Apr 17, 2025
e2da619
fixed compile error
maximizemaxwell Apr 17, 2025
9c39581
attempting to gett pr 2903 working with qwen weights
greenrazer Apr 28, 2025
a014a67
different qwen variants working
greenrazer Apr 29, 2025
8727cdf
added moe model
greenrazer Apr 30, 2025
7b5aadd
clippy
greenrazer Apr 30, 2025
e71c478
added additional eos token
greenrazer Apr 30, 2025
bbb490f
translated Korean comments to English as well as I can
greenrazer Apr 30, 2025
eada460
removed specialized Qwen3RmsNorm and replaced with generic Candle Rms…
greenrazer Apr 30, 2025
80170fd
replaced custom repeat_kv implementation with candle's repeat_kv impl…
greenrazer Apr 30, 2025
181f2ce
replace linear with linear_b in attention initalization
greenrazer Apr 30, 2025
02f0247
replaced custom custom kv_cache implementation with candle kv_cache
greenrazer Apr 30, 2025
f2962f7
style
greenrazer Apr 30, 2025
410e11e
replaced explicit broadcast add with normal add in decoder layer
greenrazer Apr 30, 2025
d99d104
removed keeping the Rotary embedding layer in the model struct
greenrazer Apr 30, 2025
a57c5ab
used tie_word_embeddings bool from config instead of relying on exist…
greenrazer Apr 30, 2025
5950077
removed duplicate code from qwen3_moe
greenrazer Apr 30, 2025
c635621
removed sliding window from qwen3 attention
greenrazer Apr 30, 2025
86e69dd
removed MoE code
greenrazer Apr 30, 2025
d943cbe
removed unused option
greenrazer Apr 30, 2025
a3d7f6e
Fixed Typo
greenrazer May 1, 2025
bdbefa5
fixed tie word embeddings to use the correct embedding weights instea…
greenrazer May 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions candle-examples/examples/qwen/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use clap::Parser;

use candle_transformers::models::qwen2::{Config as ConfigBase, ModelForCausalLM as ModelBase};
use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe};
use candle_transformers::models::qwen3::{Config as Config3, ModelForCausalLM as Model3};
use candle_transformers::models::qwen3_moe::{Config as ConfigMoe3, Model as ModelMoe3};

use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
Expand All @@ -20,13 +22,17 @@ use tokenizers::Tokenizer;
enum Model {
Base(ModelBase),
Moe(ModelMoe),
Base3(Model3),
Moe3(ModelMoe3),
}

impl Model {
fn forward(&mut self, xs: &Tensor, s: usize) -> candle::Result<Tensor> {
match self {
Self::Moe(ref mut m) => m.forward(xs, s),
Self::Base(ref mut m) => m.forward(xs, s),
Self::Base3(ref mut m) => m.forward(xs, s),
Self::Moe3(ref mut m) => m.forward(xs, s),
}
}
}
Expand Down Expand Up @@ -85,6 +91,10 @@ impl TextGeneration {
Some(token) => token,
None => anyhow::bail!("cannot find the <|endoftext|> token"),
};
let eos_token2 = match self.tokenizer.get_token("<|im_end|>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <|im_end|> token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
Expand All @@ -107,7 +117,7 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
if next_token == eos_token || next_token == eos_token2 {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
Expand Down Expand Up @@ -152,6 +162,16 @@ enum WhichModel {
W2_7b,
#[value(name = "2-72b")]
W2_72b,
#[value(name = "3-0.6b")]
W3_0_6b,
#[value(name = "3-1.7b")]
W3_1_7b,
#[value(name = "3-4b")]
W3_4b,
#[value(name = "3-8b")]
W3_8b,
#[value(name = "3-moe-a3b")]
W3MoeA3b,
}

#[derive(Parser, Debug)]
Expand Down Expand Up @@ -254,6 +274,11 @@ fn main() -> Result<()> {
WhichModel::W14b => ("1.5", "14B"),
WhichModel::W72b => ("1.5", "72B"),
WhichModel::MoeA27b => ("1.5", "MoE-A2.7B"),
WhichModel::W3_0_6b => ("3", "0.6B"),
WhichModel::W3_1_7b => ("3", "1.7B"),
WhichModel::W3_4b => ("3", "4B"),
WhichModel::W3_8b => ("3", "8B"),
WhichModel::W3MoeA3b => ("3", "30B-A3B"),
};
format!("Qwen/Qwen{version}-{size}")
}
Expand All @@ -273,7 +298,11 @@ fn main() -> Result<()> {
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => match args.model {
WhichModel::W0_5b | WhichModel::W2_0_5b | WhichModel::W2_1_5b | WhichModel::W1_8b => {
WhichModel::W0_5b
| WhichModel::W2_0_5b
| WhichModel::W2_1_5b
| WhichModel::W1_8b
| WhichModel::W3_0_6b => {
vec![repo.get("model.safetensors")?]
}
WhichModel::W4b
Expand All @@ -282,7 +311,11 @@ fn main() -> Result<()> {
| WhichModel::W14b
| WhichModel::W72b
| WhichModel::W2_72b
| WhichModel::MoeA27b => {
| WhichModel::MoeA27b
| WhichModel::W3_1_7b
| WhichModel::W3_4b
| WhichModel::W3_8b
| WhichModel::W3MoeA3b => {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
}
},
Expand All @@ -304,6 +337,14 @@ fn main() -> Result<()> {
let config: ConfigMoe = serde_json::from_slice(&std::fs::read(config_file)?)?;
Model::Moe(ModelMoe::new(&config, vb)?)
}
WhichModel::W3_0_6b | WhichModel::W3_1_7b | WhichModel::W3_4b | WhichModel::W3_8b => {
let config: Config3 = serde_json::from_slice(&std::fs::read(config_file)?)?;
Model::Base3(Model3::new(&config, vb)?)
}
WhichModel::W3MoeA3b => {
let config: ConfigMoe3 = serde_json::from_slice(&std::fs::read(config_file)?)?;
Model::Moe3(ModelMoe3::new(&config, vb)?)
}
_ => {
let config: ConfigBase = serde_json::from_slice(&std::fs::read(config_file)?)?;
Model::Base(ModelBase::new(&config, vb)?)
Expand Down
2 changes: 2 additions & 0 deletions candle-transformers/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ pub mod quantized_stable_lm;
pub mod quantized_t5;
pub mod qwen2;
pub mod qwen2_moe;
pub mod qwen3;
pub mod qwen3_moe;
pub mod recurrent_gemma;
pub mod repvgg;
pub mod resnet;
Expand Down
Loading
Loading