@@ -70,13 +70,9 @@ def harness(self, name, in_fn, ins, dins, douts):
70
70
71
71
print (
72
72
name + " JaX Primal: " ,
73
- timeit .Timer (
74
- primalstr ,
75
- globals = {
76
- "fn" : rfn_jax ,
77
- }
78
- | primalins ,
79
- ).timeit (number )
73
+ timeit .Timer (primalstr , globals = {"fn" : rfn_jax ,} | primalins ,).timeit (
74
+ number
75
+ )
80
76
/ number ,
81
77
)
82
78
@@ -97,13 +93,7 @@ def harness(self, name, in_fn, ins, dins, douts):
97
93
fwdins = primalins | {("din" + str (i )): dins [0 ] for i in range (len (dins ))}
98
94
print (
99
95
name + " JaX Fwd: " ,
100
- timeit .Timer (
101
- fwdstr ,
102
- globals = {
103
- "fwd" : fwd_jax ,
104
- }
105
- | fwdins ,
106
- ).timeit (number )
96
+ timeit .Timer (fwdstr , globals = {"fwd" : fwd_jax ,} | fwdins ,).timeit (number )
107
97
/ number ,
108
98
)
109
99
@@ -124,13 +114,7 @@ def harness(self, name, in_fn, ins, dins, douts):
124
114
125
115
print (
126
116
name + " JaX Rev: " ,
127
- timeit .Timer (
128
- revstr ,
129
- globals = {
130
- "rev" : rev_jax ,
131
- }
132
- | revins ,
133
- ).timeit (number )
117
+ timeit .Timer (revstr , globals = {"rev" : rev_jax ,} | revins ,).timeit (number )
134
118
/ number ,
135
119
)
136
120
@@ -148,11 +132,7 @@ def harness(self, name, in_fn, ins, dins, douts):
148
132
name ,
149
133
") Primal: " ,
150
134
timeit .Timer (
151
- primalstr ,
152
- globals = {
153
- "fn" : rfn_enzyme ,
154
- }
155
- | primalins ,
135
+ primalstr , globals = {"fn" : rfn_enzyme ,} | primalins ,
156
136
).timeit (number )
157
137
/ number ,
158
138
)
@@ -174,13 +154,9 @@ def harness(self, name, in_fn, ins, dins, douts):
174
154
name + " EnzymeMLIR(" ,
175
155
name ,
176
156
") Fwd: " ,
177
- timeit .Timer (
178
- fwdstr ,
179
- globals = {
180
- "fwd" : fwd_enzyme ,
181
- }
182
- | fwdins ,
183
- ).timeit (number )
157
+ timeit .Timer (fwdstr , globals = {"fwd" : fwd_enzyme ,} | fwdins ,).timeit (
158
+ number
159
+ )
184
160
/ number ,
185
161
)
186
162
@@ -198,13 +174,9 @@ def harness(self, name, in_fn, ins, dins, douts):
198
174
name + " EnzymeMLIR(" ,
199
175
name ,
200
176
") Rev: " ,
201
- timeit .Timer (
202
- revstr ,
203
- globals = {
204
- "rev" : rev_enzyme ,
205
- }
206
- | revins ,
207
- ).timeit (number )
177
+ timeit .Timer (revstr , globals = {"rev" : rev_enzyme ,} | revins ,).timeit (
178
+ number
179
+ )
208
180
/ number ,
209
181
)
210
182
@@ -291,7 +263,9 @@ class Slicing(EnzymeJaxTest):
291
263
def setUp (self ):
292
264
dim = 3
293
265
self .ins = [jnp .array (range (dim ), dtype = jnp .float32 ).reshape (1 , dim , 1 )]
294
- self .dins = [jnp .array ([i * i for i in range (dim )], dtype = jnp .float32 ).reshape (1 , dim , 1 )]
266
+ self .dins = [
267
+ jnp .array ([i * i for i in range (dim )], dtype = jnp .float32 ).reshape (1 , dim , 1 )
268
+ ]
295
269
self .douts = [jnp .array ([i * i for i in range (dim )], dtype = jnp .float32 )]
296
270
297
271
def nomlir (x ):
@@ -311,16 +285,24 @@ def setUp(self):
311
285
dim = 12
312
286
self .ins = [jnp .array (range (dim ), dtype = jnp .float32 )]
313
287
self .dins = [jnp .array ([i * i for i in range (dim )], dtype = jnp .float32 )]
314
- self .douts = [jnp .array ([i * i for i in range (2 * dim )], dtype = jnp .float32 ).reshape ((2 , dim ))]
288
+ self .douts = [
289
+ jnp .array ([i * i for i in range (2 * dim )], dtype = jnp .float32 ).reshape (
290
+ (2 , dim )
291
+ )
292
+ ]
315
293
316
294
def nomlir (x ):
317
- return [(name , a ) for (name , a ) in x if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA" ]
295
+ return [
296
+ (name , a )
297
+ for (name , a ) in x
298
+ if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"
299
+ ]
318
300
319
301
self .revfilter = nomlir
320
302
321
303
def f (x ):
322
304
toconv2 = jnp .ones ((dim , dim ))
323
- k = jnp .einsum (' jk,k->j' , toconv2 , x )
305
+ k = jnp .einsum (" jk,k->j" , toconv2 , x )
324
306
kcl = jnp .zeros ((1 , dim ))
325
307
h = jnp .reshape (k , (1 , dim ))
326
308
kcl = jnp .append (kcl , h , axis = 0 )
@@ -329,15 +311,24 @@ def f(x):
329
311
self .fn = f
330
312
self .name = "activitymismatch"
331
313
314
+
332
315
class GenDot (EnzymeJaxTest ):
333
316
def setUp (self ):
334
317
dim = 12
335
318
self .ins = [jnp .array (range (dim ), dtype = jnp .float32 )]
336
319
self .dins = [jnp .array ([i * i for i in range (dim )], dtype = jnp .float32 )]
337
- self .douts = [jnp .array ([i * i for i in range (2 * dim )], dtype = jnp .float32 ).reshape ((2 , dim ))]
320
+ self .douts = [
321
+ jnp .array ([i * i for i in range (2 * dim )], dtype = jnp .float32 ).reshape (
322
+ (2 , dim )
323
+ )
324
+ ]
338
325
339
326
def nomlir (x ):
340
- return [(name , a ) for (name , a ) in x if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA" ]
327
+ return [
328
+ (name , a )
329
+ for (name , a ) in x
330
+ if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"
331
+ ]
341
332
342
333
self .revfilter = nomlir
343
334
@@ -349,7 +340,7 @@ def f(x):
349
340
k = jnp .reshape (jnp .einsum ("ijk,ik -> ij" , toconv2 , k_tmp ), (dim ,))
350
341
351
342
kcl = jnp .zeros ((1 , dim ))
352
-
343
+
353
344
h = jnp .reshape (k , (1 , dim ))
354
345
kcl = jnp .append (kcl , h , axis = 0 )
355
346
return kcl
@@ -361,12 +352,22 @@ def f(x):
361
352
class Concat (EnzymeJaxTest ):
362
353
def setUp (self ):
363
354
dim = 12
364
- self .ins = [jnp .array (range (dim ), dtype = jnp .float32 ), 10 * jnp .array (range (dim ), dtype = jnp .float32 )]
365
- self .dins = [jnp .array ([i * i for i in range (dim )], dtype = jnp .float32 ), jnp .array ([i * i * i / 3. for i in range (dim )], dtype = jnp .float32 )]
366
- self .douts = [jnp .array ([i * i for i in range (2 * dim )], dtype = jnp .float32 )]
355
+ self .ins = [
356
+ jnp .array (range (dim ), dtype = jnp .float32 ),
357
+ 10 * jnp .array (range (dim ), dtype = jnp .float32 ),
358
+ ]
359
+ self .dins = [
360
+ jnp .array ([i * i for i in range (dim )], dtype = jnp .float32 ),
361
+ jnp .array ([i * i * i / 3.0 for i in range (dim )], dtype = jnp .float32 ),
362
+ ]
363
+ self .douts = [jnp .array ([i * i for i in range (2 * dim )], dtype = jnp .float32 )]
367
364
368
365
def nomlir (x ):
369
- return [(name , a ) for (name , a ) in x if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA" ]
366
+ return [
367
+ (name , a )
368
+ for (name , a ) in x
369
+ if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"
370
+ ]
370
371
371
372
self .revfilter = nomlir
372
373
@@ -376,5 +377,6 @@ def f(x, y):
376
377
self .fn = f
377
378
self .name = "Concat"
378
379
380
+
379
381
if __name__ == "__main__" :
380
382
absltest .main ()
0 commit comments