@@ -7,8 +7,8 @@ extern crate accelerate_src;
7
7
8
8
use candle:: quantized:: GgmlType ;
9
9
use candle:: { CpuStorage , Device , Layout , Module , Result , Shape , Tensor , D } ;
10
+ use candle_nn:: Activation ;
10
11
use clap:: { Parser , Subcommand } ;
11
-
12
12
const CHECK_CONV2D : bool = false ;
13
13
14
14
trait Benchmark {
@@ -21,6 +21,54 @@ trait Benchmark {
21
21
const ITERS : usize ;
22
22
}
23
23
24
+ struct GluActivation ;
25
+ impl Benchmark for GluActivation {
26
+ type PreProcessData = Tensor ;
27
+ type RunResult = Tensor ;
28
+
29
+ fn preprocess ( ) -> Result < Self :: PreProcessData > {
30
+ Tensor :: randn ( 0f32 , 1. , ( 1024 , 2048 ) , & Device :: Cpu )
31
+ }
32
+
33
+ fn run_one ( data : & Self :: PreProcessData ) -> Result < Self :: RunResult > {
34
+ Activation :: Glu . forward ( data)
35
+ }
36
+
37
+ const ITERS : usize = 100 ;
38
+ }
39
+
40
+ struct GeGluActivation ;
41
+ impl Benchmark for GeGluActivation {
42
+ type PreProcessData = Tensor ;
43
+ type RunResult = Tensor ;
44
+
45
+ fn preprocess ( ) -> Result < Self :: PreProcessData > {
46
+ Tensor :: randn ( 0f32 , 1. , ( 1024 , 2048 ) , & Device :: Cpu )
47
+ }
48
+
49
+ fn run_one ( data : & Self :: PreProcessData ) -> Result < Self :: RunResult > {
50
+ Activation :: GeGlu . forward ( data)
51
+ }
52
+
53
+ const ITERS : usize = 100 ;
54
+ }
55
+
56
+ struct ReGluActivation ;
57
+ impl Benchmark for ReGluActivation {
58
+ type PreProcessData = Tensor ;
59
+ type RunResult = Tensor ;
60
+
61
+ fn preprocess ( ) -> Result < Self :: PreProcessData > {
62
+ Tensor :: randn ( 0f32 , 1. , ( 1024 , 2048 ) , & Device :: Cpu )
63
+ }
64
+
65
+ fn run_one ( data : & Self :: PreProcessData ) -> Result < Self :: RunResult > {
66
+ Activation :: ReGlu . forward ( data)
67
+ }
68
+
69
+ const ITERS : usize = 100 ;
70
+ }
71
+
24
72
struct Im2Col {
25
73
h_k : usize ,
26
74
w_k : usize ,
@@ -313,6 +361,9 @@ enum Task {
313
361
Softmax ,
314
362
SoftmaxLastDim ,
315
363
Cat ,
364
+ GluActivation ,
365
+ GeGluActivation ,
366
+ ReGluActivation ,
316
367
}
317
368
318
369
#[ derive( Parser , Debug ) ]
@@ -338,6 +389,9 @@ fn main() -> Result<()> {
338
389
Task :: SoftmaxLastDim => run :: < SoftmaxLastDim > ( args. iters ) ?,
339
390
Task :: Qmatmul => run :: < QMatMul > ( args. iters ) ?,
340
391
Task :: Cat => run :: < Cat > ( args. iters ) ?,
392
+ Task :: GluActivation => run :: < GluActivation > ( args. iters ) ?,
393
+ Task :: GeGluActivation => run :: < GeGluActivation > ( args. iters ) ?,
394
+ Task :: ReGluActivation => run :: < ReGluActivation > ( args. iters ) ?,
341
395
}
342
396
Ok ( ( ) )
343
397
}
0 commit comments