@@ -291,7 +291,9 @@ class Slicing(EnzymeJaxTest):
291
291
def setUp (self ):
292
292
dim = 3
293
293
self .ins = [jnp .array (range (dim ), dtype = jnp .float32 ).reshape (1 , dim , 1 )]
294
- self .dins = [jnp .array ([i * i for i in range (dim )], dtype = jnp .float32 ).reshape (1 , dim , 1 )]
294
+ self .dins = [
295
+ jnp .array ([i * i for i in range (dim )], dtype = jnp .float32 ).reshape (1 , dim , 1 )
296
+ ]
295
297
self .douts = [jnp .array ([i * i for i in range (dim )], dtype = jnp .float32 )]
296
298
297
299
def nomlir (x ):
@@ -311,16 +313,24 @@ def setUp(self):
311
313
dim = 12
312
314
self .ins = [jnp .array (range (dim ), dtype = jnp .float32 )]
313
315
self .dins = [jnp .array ([i * i for i in range (dim )], dtype = jnp .float32 )]
314
- self .douts = [jnp .array ([i * i for i in range (2 * dim )], dtype = jnp .float32 ).reshape ((2 , dim ))]
316
+ self .douts = [
317
+ jnp .array ([i * i for i in range (2 * dim )], dtype = jnp .float32 ).reshape (
318
+ (2 , dim )
319
+ )
320
+ ]
315
321
316
322
def nomlir (x ):
317
- return [(name , a ) for (name , a ) in x if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA" ]
323
+ return [
324
+ (name , a )
325
+ for (name , a ) in x
326
+ if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"
327
+ ]
318
328
319
329
self .revfilter = nomlir
320
330
321
331
def f (x ):
322
332
toconv2 = jnp .ones ((dim , dim ))
323
- k = jnp .einsum (' jk,k->j' , toconv2 , x )
333
+ k = jnp .einsum (" jk,k->j" , toconv2 , x )
324
334
kcl = jnp .zeros ((1 , dim ))
325
335
h = jnp .reshape (k , (1 , dim ))
326
336
kcl = jnp .append (kcl , h , axis = 0 )
@@ -329,15 +339,24 @@ def f(x):
329
339
self .fn = f
330
340
self .name = "activitymismatch"
331
341
342
+
332
343
class GenDot (EnzymeJaxTest ):
333
344
def setUp (self ):
334
345
dim = 12
335
346
self .ins = [jnp .array (range (dim ), dtype = jnp .float32 )]
336
347
self .dins = [jnp .array ([i * i for i in range (dim )], dtype = jnp .float32 )]
337
- self .douts = [jnp .array ([i * i for i in range (2 * dim )], dtype = jnp .float32 ).reshape ((2 , dim ))]
348
+ self .douts = [
349
+ jnp .array ([i * i for i in range (2 * dim )], dtype = jnp .float32 ).reshape (
350
+ (2 , dim )
351
+ )
352
+ ]
338
353
339
354
def nomlir (x ):
340
- return [(name , a ) for (name , a ) in x if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA" ]
355
+ return [
356
+ (name , a )
357
+ for (name , a ) in x
358
+ if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"
359
+ ]
341
360
342
361
self .revfilter = nomlir
343
362
@@ -349,7 +368,7 @@ def f(x):
349
368
k = jnp .reshape (jnp .einsum ("ijk,ik -> ij" , toconv2 , k_tmp ), (dim ,))
350
369
351
370
kcl = jnp .zeros ((1 , dim ))
352
-
371
+
353
372
h = jnp .reshape (k , (1 , dim ))
354
373
kcl = jnp .append (kcl , h , axis = 0 )
355
374
return kcl
@@ -361,12 +380,22 @@ def f(x):
361
380
class Concat (EnzymeJaxTest ):
362
381
def setUp (self ):
363
382
dim = 12
364
- self .ins = [jnp .array (range (dim ), dtype = jnp .float32 ), 10 * jnp .array (range (dim ), dtype = jnp .float32 )]
365
- self .dins = [jnp .array ([i * i for i in range (dim )], dtype = jnp .float32 ), jnp .array ([i * i * i / 3. for i in range (dim )], dtype = jnp .float32 )]
366
- self .douts = [jnp .array ([i * i for i in range (2 * dim )], dtype = jnp .float32 )]
383
+ self .ins = [
384
+ jnp .array (range (dim ), dtype = jnp .float32 ),
385
+ 10 * jnp .array (range (dim ), dtype = jnp .float32 ),
386
+ ]
387
+ self .dins = [
388
+ jnp .array ([i * i for i in range (dim )], dtype = jnp .float32 ),
389
+ jnp .array ([i * i * i / 3.0 for i in range (dim )], dtype = jnp .float32 ),
390
+ ]
391
+ self .douts = [jnp .array ([i * i for i in range (2 * dim )], dtype = jnp .float32 )]
367
392
368
393
def nomlir (x ):
369
- return [(name , a ) for (name , a ) in x if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA" ]
394
+ return [
395
+ (name , a )
396
+ for (name , a ) in x
397
+ if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"
398
+ ]
370
399
371
400
self .revfilter = nomlir
372
401
@@ -376,5 +405,6 @@ def f(x, y):
376
405
self .fn = f
377
406
self .name = "Concat"
378
407
408
+
379
409
if __name__ == "__main__" :
380
410
absltest .main ()
0 commit comments