Skip to content

Support new arch of GLM4 models #2991

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ ug-metal = "0.4.0"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "1.1.1", default-features = false }
metal = { version = "0.27.0", features = ["mps"]}
either = { version = "1.13.0", features = ["serde"] }

[profile.release-with-debug]
inherits = "release"
Expand Down
1 change: 1 addition & 0 deletions candle-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ symphonia = { version = "0.5.3", features = ["all"], optional = true }
tokenizers = { workspace = true, features = ["onig"] }
cpal = { version = "0.15.2", optional = true }
pdf2image = { version = "0.1.2" , optional = true}
either = { workspace = true }

[dev-dependencies]
anyhow = { workspace = true }
Expand Down
52 changes: 52 additions & 0 deletions candle-examples/examples/glm4/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
## GLM4
GLM-4-9B-0414 is a new architecture in the GLM-4 series developed by Zhipu AI. This model is not compatible with previous versions of GLM-4, such as THUDM/glm-4-9b, due to differences in model architecture and internal implementation. Users must explicitly specify the correct model type when loading it, as using the wrong configuration may lead to initialization errors or runtime failures.

### GLM4-0414 Arch:

- [GLM4-0414 Collection](https://huggingface.co/collections/THUDM/glm-4-0414-67f3cbcb34dd9d252707cb2e)
- [GLM-4-9B-0414 Weight](https://huggingface.co/THUDM/GLM-4-9B-0414)

### Old GLM4 Arch:

- [GitHub](https://github.com/THUDM/GLM4)
- [GLM-4-9B Weight](https://huggingface.co/THUDM/glm-4-9b)

### Running with CUDA
Use `--which` to distinguish two archs

```bash
cargo run --example glm4 --release --features cuda -- --which "glm4-new" --model-id THUDM/GLM-4-9B-0414 --prompt "How are you today?"
cargo run --example glm4 --release --features cuda -- --which "glm4-old" --model-id THUDM/glm-4-9b --prompt "How are you today?"
```

### Running with local file (CUDA)

```bash
cargo run --example glm4 --release --features cuda -- --which "glm4-new" --weight-path /path/GLM-4-9B-0414 --prompt "How are you today?"
cargo run --example glm4 --release --features cuda -- --which "glm4-old" --weight-path /path/glm-4-9b --prompt "How are you today?"
```

### Running with local file (Metal)

```bash
cargo run --example glm4 --release --features metal -- --which "glm4-new" --weight-path /path/GLM-4-9B-0414 --prompt "How are you today?"
cargo run --example glm4 --release --features metal -- --which "glm4-old" --weight-path /path/glm-4-9b --prompt "How are you today?"
```

### Running with CPU
```bash
cargo run --example glm4 --release -- --cpu --which "glm4-new" --model-id THUDM/GLM-4-9B-0414 --prompt "How are you today?"
```

### Output Example (GLM-4-9B-0414)
```
avx: true, neon: false, simd128: false, f16c: true
temp: 0.80 repeat-penalty: 1.20 repeat-last-n: 64
retrieved the files in 158.728989ms
loaded the model in 3.714556129s
starting the inference loop
How are you today?
I'm just a computer program, so I don't have feelings or emotions. But thank you for asking! How can I assist you today?

31 tokens generated (28.77 token/s)
```
54 changes: 0 additions & 54 deletions candle-examples/examples/glm4/README.org

This file was deleted.

130 changes: 104 additions & 26 deletions candle-examples/examples/glm4/main.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,54 @@
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use candle_transformers::models::glm4::*;
use candle_transformers::models::glm4::{Config as ConfigOld, Model as ModelOld, TokenID};
use candle_transformers::models::glm4_new::{Config as ConfigNew, ModelForCausalLM as ModelNew};

use clap::Parser;
use either::Either;
use hf_hub::{Repo, RepoType};
use tokenizers::Tokenizer;

enum Model {
Old(ModelOld),
New(ModelNew),
}

impl Model {
fn forward(&mut self, input_ids: &Tensor, pos: usize) -> candle::Result<Tensor> {
match self {
Self::Old(m) => m.forward(input_ids),
Self::New(m) => m.forward(input_ids, pos),
}
}
}

#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "glm4-old")]
GLM4Old,
#[value(name = "glm4-new")]
GLM4New,
}

struct TextGeneration {
model: Model,
device: Device,
tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
args: Args,
dtype: DType,
eos_tokens: Vec<u32>,
}

impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(model: Model, tokenizer: Tokenizer, args: Args, device: &Device, dtype: DType) -> Self {
fn new(
model: Model,
tokenizer: Tokenizer,
args: Args,
device: &Device,
eos_tokens: Vec<u32>,
) -> Self {
let logits_processor =
LogitsProcessor::new(args.seed, Some(args.temperature), Some(args.top_p));
Self {
Expand All @@ -25,7 +57,7 @@ impl TextGeneration {
logits_processor,
args,
device: device.clone(),
dtype,
eos_tokens,
}
}

Expand All @@ -34,10 +66,12 @@ impl TextGeneration {
let args = &self.args;
println!("starting the inference loop");

let tokens = self
.tokenizer
.encode(args.prompt.to_string(), true)
.expect("tokens error");
let prompt = format!(
"[gMASK]<sop><|user|>\n{}<|assistant|>",
args.prompt.to_string()
);

let tokens = self.tokenizer.encode(prompt, true).expect("tokens error");
if tokens.is_empty() {
panic!("Empty prompts are not supported in the chatglm model.")
}
Expand All @@ -50,10 +84,7 @@ impl TextGeneration {
print!("{}", &args.prompt);
std::io::stdout().flush()?;
}
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
Some(token) => *token,
None => panic!("cannot find the endoftext token"),
};

let mut tokens = tokens.get_ids().to_vec();
let mut generated_tokens = 0usize;

Expand All @@ -62,10 +93,15 @@ impl TextGeneration {

for index in 0..args.sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input)?;
let logits = logits.squeeze(0)?.to_dtype(self.dtype)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = match self.model {
Model::Old(_) => logits.squeeze(0)?.to_dtype(DType::F32)?,
Model::New(_) => logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?,
};

let logits = if args.repeat_penalty == 1. {
logits
} else {
Expand All @@ -80,7 +116,7 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
if self.eos_tokens.contains(&next_token) {
break;
}
let token = self
Expand Down Expand Up @@ -158,6 +194,13 @@ struct Args {
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,

/// Specifies the model type (e.g., GLM4-Old or GLM4-New, such as GLM4-0414).
/// This argument is required because the two architectures are incompatible.
/// For example, if the user does not explicitly specify the model type (defaulting to "glm4-old"),
/// but provides a GLM4-New model ID, it can cause a runtime panic during model execution!
#[arg(long)]
which: Which,
}

fn main() -> anyhow::Result<()> {
Expand Down Expand Up @@ -186,19 +229,23 @@ fn main() -> anyhow::Result<()> {

let model_id = match args.model_id.as_ref() {
Some(model_id) => model_id.to_string(),
None => "THUDM/glm-4-9b".to_string(),
None => match args.which {
Which::GLM4Old => "THUDM/glm-4-9b".to_string(),
Which::GLM4New => "THUDM/GLM-4-9B-0414".to_string(),
},
};
let revision = match args.revision.as_ref() {
Some(rev) => rev.to_string(),
None => "main".to_string(),
};
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let tokenizer_filename = match args.tokenizer.as_ref() {
Some(file) => std::path::PathBuf::from(file),
None => api
.model("THUDM/codegeex4-all-9b".to_string())
.get("tokenizer.json")
.map_err(anyhow::Error::msg)?,
let tokenizer_filename = match (args.weight_path.as_ref(), args.tokenizer.as_ref()) {
(Some(_), Some(file)) => std::path::PathBuf::from(file),
(None, Some(file)) => std::path::PathBuf::from(file),
(Some(path), None) => {
std::path::PathBuf::from(std::path::Path::new(path).join("tokenizer.json"))
}
(None, None) => repo.get("tokenizer.json")?,
};
let config_filename = match &args.weight_path {
Some(path) => std::path::Path::new(path).join("config.json"),
Expand All @@ -216,19 +263,50 @@ fn main() -> anyhow::Result<()> {
let tokenizer = Tokenizer::from_file(tokenizer_filename).expect("Tokenizer Error");

let start = std::time::Instant::now();
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;

let (model, eos_token_id) = match args.which {
Which::GLM4Old => {
let config: ConfigOld = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let model = ModelOld::new(&config, vb)?;
(Model::Old(model), config.eos_token_id)
}
Which::GLM4New => {
let config: ConfigNew = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let model = ModelNew::new(&config, vb)?;
(Model::New(model), config.eos_token_id)
}
};

let mut eos_tokens = Vec::new();
match eos_token_id {
TokenID(Either::Left(Some(eos))) => {
eos_tokens.push(eos);
}
TokenID(Either::Right(Some(eos_vec))) => {
eos_tokens.extend(eos_vec);
}
_ => {
let eos_token = match args.which {
Which::GLM4Old => "<|endoftext|>",
Which::GLM4New => "<|user|>",
};
match tokenizer.get_vocab(true).get(eos_token) {
Some(token) => eos_tokens.push(*token),
None => panic!("cannot find the endoftext token"),
};
}
}

println!("loaded the model in {:?}", start.elapsed());

let mut pipeline = TextGeneration::new(model, tokenizer, args, &device, dtype);
let mut pipeline = TextGeneration::new(model, tokenizer, args, &device, eos_tokens);
pipeline.run()?;
Ok(())
}
1 change: 1 addition & 0 deletions candle-transformers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ serde = { workspace = true }
serde_json = { workspace = true }
serde_plain = { workspace = true }
tracing = { workspace = true }
either = { workspace = true }

[features]
default = []
Expand Down
Loading