@@ -454,58 +454,63 @@ def erf(x):
454
454
return tf_np .asarray (tf .math .erf (x .data ))
455
455
456
456
457
- # Given lhs representation, rhs representation, contraction and batch dimensions,
458
- # compose the output representation.
459
- # e.g., ij, jk, (((1,), (0,)), ((), ())) -> ik
460
- # aij, ajk, (((2,), (1,)), ((0,), (0,))) -> aik
461
- def compose_output_rep (lhs_rep , rhs_rep , lhs_contraction , rhs_contraction ,
457
+ def _minus ( a , b ):
458
+ return [ x for x in a if x not in b ]
459
+
460
+
461
+ def _compose_output_rep (lhs_rep , rhs_rep , lhs_contraction , rhs_contraction ,
462
462
lhs_batch , rhs_batch ):
463
+ """ Given lhs representation, rhs representation, contraction and batch dimensions,
464
+ compose the output representation.
465
+ e.g., ij, jk, (((1,), (0,)), ((), ())) -> ik
466
+ aij, ajk, (((2,), (1,)), ((0,), (0,))) -> aik
467
+ """
463
468
output_rep = []
464
469
for dim in lhs_batch :
465
470
output_rep .append (lhs_rep [dim ])
466
471
for dim in rhs_batch :
467
472
if rhs_rep [dim ] not in output_rep :
468
473
output_rep .append (rhs_rep [dim ])
469
474
470
- for i in range (len (lhs_rep )):
471
- if i not in lhs_batch and i not in lhs_contraction :
472
- output_rep .append (lhs_rep [i ])
473
- for i in range (len (rhs_rep )):
474
- if i not in rhs_batch and i not in rhs_contraction :
475
- output_rep .append (rhs_rep [i ])
475
+ for i in _minus (range (len (lhs_rep )), lhs_batch + lhs_contraction ):
476
+ output_rep .append (lhs_rep [i ])
477
+ for i in _minus (range (len (rhs_rep )), rhs_batch + rhs_contraction ):
478
+ output_rep .append (rhs_rep [i ])
476
479
return '' .join (output_rep )
477
480
478
481
479
- # If it is the general non-batched/single-batched matrix multiplication,
480
- # use the highly optimized kernel `tf.tensordot` to handle it.
481
- def non_batched_matmul (lhs , rhs , lhs_contraction , rhs_contraction ):
482
+ def _non_batched_matmul (lhs , rhs , lhs_contraction , rhs_contraction ):
483
+ """ If it is the general non-batched/single-batched matrix multiplication,
484
+ use the highly optimized kernel `tf.tensordot` to handle it.
485
+ """
482
486
return tf .tensordot (lhs , rhs , axes = (list (lhs_contraction ), list (rhs_contraction )))
483
487
484
488
485
- # An equivalent general dot operation as that in JAX -
486
- # <https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.dot_general.html>
487
- #
488
- # Although there is an implementation in TF XLA, avoid directly using XLA when
489
- # possible.
490
- #
491
- # e.g., non-batched: ij,jk->ik
492
- # batched: ijk,ikl->ijl
493
489
def tf_dot_general (lhs , rhs , dimension_numbers ):
490
+ """ An equivalent general dot operation as that in JAX -
491
+ <https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.dot_general.html>
494
492
495
- char_list = list (string .ascii_lowercase )[8 :]
496
- lhs_dim , rhs_dim = len (lhs .shape ), len (rhs .shape )
497
- lhs_rep = char_list [:lhs_dim ]
498
- rhs_rep = char_list [lhs_dim :lhs_dim + rhs_dim ]
493
+ Although there is an implementation in TF XLA, avoid directly using XLA when
494
+ possible.
495
+
496
+ e.g., non-batched: ij,jk->ik
497
+ batched: ijk,ikl->ijl
498
+ """
499
+ char_list = list (string .ascii_lowercase )
500
+ char_list = char_list [8 :] + char_list [:8 ]
501
+ lhs_rank , rhs_rank = len (lhs .shape ), len (rhs .shape )
502
+ lhs_rep = char_list [:lhs_rank ]
503
+ rhs_rep = char_list [lhs_rank :lhs_rank + rhs_rank ]
499
504
contraction , batch = dimension_numbers
500
505
lhs_contraction , rhs_contraction = contraction
501
506
lhs_batch , rhs_batch = batch
502
507
503
508
if len (lhs_batch ) == 0 and len (rhs_batch ) == 0 :
504
- return non_batched_matmul (lhs , rhs , lhs_contraction , rhs_contraction )
509
+ return _non_batched_matmul (lhs , rhs , lhs_contraction , rhs_contraction )
505
510
506
- cond_a = lhs_dim == rhs_dim == 3
511
+ cond_a = lhs_rank == rhs_rank == 3
507
512
cond_b = lhs_batch == (0 ,) and rhs_batch == (0 ,)
508
- cond_c = lhs_contraction == (lhs_dim - 1 ,) and rhs_contraction == (1 ,)
513
+ cond_c = lhs_contraction == (2 ,) and rhs_contraction == (1 ,)
509
514
if cond_a and cond_b and cond_c :
510
515
return tf .linalg .matmul (lhs , rhs )
511
516
@@ -515,7 +520,7 @@ def tf_dot_general(lhs, rhs, dimension_numbers):
515
520
if i < len (rhs_batch ):
516
521
rhs_rep [rhs_batch [i ]] = lhs_rep [lhs_batch [i ]]
517
522
518
- output_rep = compose_output_rep (lhs_rep , rhs_rep , lhs_contraction ,
523
+ output_rep = _compose_output_rep (lhs_rep , rhs_rep , lhs_contraction ,
519
524
rhs_contraction , lhs_batch , rhs_batch )
520
525
equation = '' .join (lhs_rep ) + ',' + '' .join (rhs_rep ) + "->" + output_rep
521
526
return tf .einsum (equation , lhs , rhs )
0 commit comments