Skip to content

Commit 8c0429a

Browse files
author
Artem Ryzhov
committed
Add GLU integration to CPU benchmarks and Phi-3 model
1 parent 82dec69 commit 8c0429a

File tree

2 files changed

+86
-3
lines changed

2 files changed

+86
-3
lines changed

candle-nn/examples/cpu_benchmarks.rs

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ extern crate accelerate_src;
77

88
use candle::quantized::GgmlType;
99
use candle::{CpuStorage, Device, Layout, Module, Result, Shape, Tensor, D};
10+
use candle_nn::Activation;
1011
use clap::{Parser, Subcommand};
11-
1212
const CHECK_CONV2D: bool = false;
1313

1414
trait Benchmark {
@@ -21,6 +21,54 @@ trait Benchmark {
2121
const ITERS: usize;
2222
}
2323

24+
struct GluActivation;
25+
impl Benchmark for GluActivation {
26+
type PreProcessData = Tensor;
27+
type RunResult = Tensor;
28+
29+
fn preprocess() -> Result<Self::PreProcessData> {
30+
Tensor::randn(0f32, 1., (1024, 2048), &Device::Cpu)
31+
}
32+
33+
fn run_one(data: &Self::PreProcessData) -> Result<Self::RunResult> {
34+
Activation::Glu.forward(data)
35+
}
36+
37+
const ITERS: usize = 100;
38+
}
39+
40+
struct GeGluActivation;
41+
impl Benchmark for GeGluActivation {
42+
type PreProcessData = Tensor;
43+
type RunResult = Tensor;
44+
45+
fn preprocess() -> Result<Self::PreProcessData> {
46+
Tensor::randn(0f32, 1., (1024, 2048), &Device::Cpu)
47+
}
48+
49+
fn run_one(data: &Self::PreProcessData) -> Result<Self::RunResult> {
50+
Activation::GeGlu.forward(data)
51+
}
52+
53+
const ITERS: usize = 100;
54+
}
55+
56+
struct ReGluActivation;
57+
impl Benchmark for ReGluActivation {
58+
type PreProcessData = Tensor;
59+
type RunResult = Tensor;
60+
61+
fn preprocess() -> Result<Self::PreProcessData> {
62+
Tensor::randn(0f32, 1., (1024, 2048), &Device::Cpu)
63+
}
64+
65+
fn run_one(data: &Self::PreProcessData) -> Result<Self::RunResult> {
66+
Activation::ReGlu.forward(data)
67+
}
68+
69+
const ITERS: usize = 100;
70+
}
71+
2472
struct Im2Col {
2573
h_k: usize,
2674
w_k: usize,
@@ -313,6 +361,9 @@ enum Task {
313361
Softmax,
314362
SoftmaxLastDim,
315363
Cat,
364+
GluActivation,
365+
GeGluActivation,
366+
ReGluActivation,
316367
}
317368

318369
#[derive(Parser, Debug)]
@@ -338,6 +389,9 @@ fn main() -> Result<()> {
338389
Task::SoftmaxLastDim => run::<SoftmaxLastDim>(args.iters)?,
339390
Task::Qmatmul => run::<QMatMul>(args.iters)?,
340391
Task::Cat => run::<Cat>(args.iters)?,
392+
Task::GluActivation => run::<GluActivation>(args.iters)?,
393+
Task::GeGluActivation => run::<GeGluActivation>(args.iters)?,
394+
Task::ReGluActivation => run::<ReGluActivation>(args.iters)?,
341395
}
342396
Ok(())
343397
}

candle-transformers/src/models/phi3.rs

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py
2222
use crate::models::with_tracing::{linear_no_bias as linear, Linear, RmsNorm};
2323
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
24-
use candle_nn::VarBuilder;
24+
use candle_nn::{Activation, VarBuilder};
2525
use std::sync::Arc;
2626

2727
#[derive(Debug, Clone, serde::Deserialize)]
@@ -59,12 +59,41 @@ pub struct Config {
5959
#[serde(default)]
6060
pub tie_word_embeddings: bool,
6161
}
62-
6362
impl Config {
63+
pub fn mini_4k_instruct() -> Self {
64+
Self {
65+
vocab_size: 32064,
66+
hidden_act: Activation::GeGlu,
67+
hidden_size: 3072,
68+
intermediate_size: 8192,
69+
num_hidden_layers: 32,
70+
num_attention_heads: 32,
71+
num_key_value_heads: 32,
72+
rms_norm_eps: 1e-5,
73+
rope_theta: 10000.0,
74+
bos_token_id: Some(1),
75+
eos_token_id: Some(2),
76+
rope_scaling: None,
77+
max_position_embeddings: 4096,
78+
original_max_position_embeddings: None,
79+
partial_rotary_factor: None,
80+
tie_word_embeddings: false,
81+
}
82+
}
83+
84+
pub fn with_activation(mut self, activation: Activation) -> Self {
85+
self.hidden_act = activation;
86+
self
87+
}
6488
pub fn head_dim(&self) -> usize {
6589
self.hidden_size / self.num_attention_heads
6690
}
6791
}
92+
// impl Config {
93+
// pub fn head_dim(&self) -> usize {
94+
// self.hidden_size / self.num_attention_heads
95+
// }
96+
// }
6897

6998
#[derive(Debug, Clone)]
7099
pub struct RotaryEmbedding {

0 commit comments

Comments
 (0)