Skip to content

Commit 1fdfb58

Browse files
greenrazermaximizemaxwellLaurentMazare
authored
Updating Add qwen3 (PR 2903) to use HF weights (#2930)
* add Qwen3.rs * fixed compile error * attempting to gett pr 2903 working with qwen weights * different qwen variants working * added moe model * clippy * added additional eos token * translated Korean comments to English as well as I can * removed specialized Qwen3RmsNorm and replaced with generic Candle RmsNorm * replaced custom repeat_kv implementation with candle's repeat_kv implementation * replace linear with linear_b in attention initalization * replaced custom custom kv_cache implementation with candle kv_cache * style * replaced explicit broadcast add with normal add in decoder layer * removed keeping the Rotary embedding layer in the model struct * used tie_word_embeddings bool from config instead of relying on existence of weights for lm head in CasualLM * removed duplicate code from qwen3_moe * removed sliding window from qwen3 attention * removed MoE code * removed unused option * Fixed Typo Co-authored-by: Laurent Mazare <[email protected]> * fixed tie word embeddings to use the correct embedding weights instead of the opposite --------- Co-authored-by: Max <[email protected]> Co-authored-by: Laurent Mazare <[email protected]>
1 parent cd96fa8 commit 1fdfb58

File tree

3 files changed

+421
-3
lines changed

3 files changed

+421
-3
lines changed

candle-examples/examples/qwen/main.rs

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use clap::Parser;
99

1010
use candle_transformers::models::qwen2::{Config as ConfigBase, ModelForCausalLM as ModelBase};
1111
use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe};
12+
use candle_transformers::models::qwen3::{Config as Config3, ModelForCausalLM as Model3};
1213

1314
use candle::{DType, Device, Tensor};
1415
use candle_examples::token_output_stream::TokenOutputStream;
@@ -20,13 +21,15 @@ use tokenizers::Tokenizer;
2021
enum Model {
2122
Base(ModelBase),
2223
Moe(ModelMoe),
24+
Base3(Model3),
2325
}
2426

2527
impl Model {
2628
fn forward(&mut self, xs: &Tensor, s: usize) -> candle::Result<Tensor> {
2729
match self {
2830
Self::Moe(ref mut m) => m.forward(xs, s),
2931
Self::Base(ref mut m) => m.forward(xs, s),
32+
Self::Base3(ref mut m) => m.forward(xs, s),
3033
}
3134
}
3235
}
@@ -85,6 +88,10 @@ impl TextGeneration {
8588
Some(token) => token,
8689
None => anyhow::bail!("cannot find the <|endoftext|> token"),
8790
};
91+
let eos_token2 = match self.tokenizer.get_token("<|im_end|>") {
92+
Some(token) => token,
93+
None => anyhow::bail!("cannot find the <|im_end|> token"),
94+
};
8895
let start_gen = std::time::Instant::now();
8996
for index in 0..sample_len {
9097
let context_size = if index > 0 { 1 } else { tokens.len() };
@@ -107,7 +114,7 @@ impl TextGeneration {
107114
let next_token = self.logits_processor.sample(&logits)?;
108115
tokens.push(next_token);
109116
generated_tokens += 1;
110-
if next_token == eos_token {
117+
if next_token == eos_token || next_token == eos_token2 {
111118
break;
112119
}
113120
if let Some(t) = self.tokenizer.next_token(next_token)? {
@@ -152,6 +159,14 @@ enum WhichModel {
152159
W2_7b,
153160
#[value(name = "2-72b")]
154161
W2_72b,
162+
#[value(name = "3-0.6b")]
163+
W3_0_6b,
164+
#[value(name = "3-1.7b")]
165+
W3_1_7b,
166+
#[value(name = "3-4b")]
167+
W3_4b,
168+
#[value(name = "3-8b")]
169+
W3_8b,
155170
}
156171

157172
#[derive(Parser, Debug)]
@@ -254,6 +269,10 @@ fn main() -> Result<()> {
254269
WhichModel::W14b => ("1.5", "14B"),
255270
WhichModel::W72b => ("1.5", "72B"),
256271
WhichModel::MoeA27b => ("1.5", "MoE-A2.7B"),
272+
WhichModel::W3_0_6b => ("3", "0.6B"),
273+
WhichModel::W3_1_7b => ("3", "1.7B"),
274+
WhichModel::W3_4b => ("3", "4B"),
275+
WhichModel::W3_8b => ("3", "8B"),
257276
};
258277
format!("Qwen/Qwen{version}-{size}")
259278
}
@@ -273,7 +292,11 @@ fn main() -> Result<()> {
273292
.map(std::path::PathBuf::from)
274293
.collect::<Vec<_>>(),
275294
None => match args.model {
276-
WhichModel::W0_5b | WhichModel::W2_0_5b | WhichModel::W2_1_5b | WhichModel::W1_8b => {
295+
WhichModel::W0_5b
296+
| WhichModel::W2_0_5b
297+
| WhichModel::W2_1_5b
298+
| WhichModel::W1_8b
299+
| WhichModel::W3_0_6b => {
277300
vec![repo.get("model.safetensors")?]
278301
}
279302
WhichModel::W4b
@@ -282,7 +305,10 @@ fn main() -> Result<()> {
282305
| WhichModel::W14b
283306
| WhichModel::W72b
284307
| WhichModel::W2_72b
285-
| WhichModel::MoeA27b => {
308+
| WhichModel::MoeA27b
309+
| WhichModel::W3_1_7b
310+
| WhichModel::W3_4b
311+
| WhichModel::W3_8b => {
286312
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
287313
}
288314
},
@@ -304,6 +330,10 @@ fn main() -> Result<()> {
304330
let config: ConfigMoe = serde_json::from_slice(&std::fs::read(config_file)?)?;
305331
Model::Moe(ModelMoe::new(&config, vb)?)
306332
}
333+
WhichModel::W3_0_6b | WhichModel::W3_1_7b | WhichModel::W3_4b | WhichModel::W3_8b => {
334+
let config: Config3 = serde_json::from_slice(&std::fs::read(config_file)?)?;
335+
Model::Base3(Model3::new(&config, vb)?)
336+
}
307337
_ => {
308338
let config: ConfigBase = serde_json::from_slice(&std::fs::read(config_file)?)?;
309339
Model::Base(ModelBase::new(&config, vb)?)

candle-transformers/src/models/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ pub mod quantized_stable_lm;
9797
pub mod quantized_t5;
9898
pub mod qwen2;
9999
pub mod qwen2_moe;
100+
pub mod qwen3;
100101
pub mod recurrent_gemma;
101102
pub mod repvgg;
102103
pub mod resnet;

0 commit comments

Comments
 (0)