Skip to content

Commit d31355a

Browse files
author
DarrenZhang01
committed
Make the revisions as suggested in #956
1 parent 08ad9f0 commit d31355a

File tree

2 files changed

+38
-33
lines changed

2 files changed

+38
-33
lines changed

trax/tf_numpy/extensions/extensions.py

+35-30
Original file line numberDiff line numberDiff line change
@@ -454,58 +454,63 @@ def erf(x):
454454
return tf_np.asarray(tf.math.erf(x.data))
455455

456456

457-
# Given lhs representation, rhs representation, contraction and batch dimensions,
458-
# compose the output representation.
459-
# e.g., ij, jk, (((1,), (0,)), ((), ())) -> ik
460-
# aij, ajk, (((2,), (1,)), ((0,), (0,))) -> aik
461-
def compose_output_rep(lhs_rep, rhs_rep, lhs_contraction, rhs_contraction,
457+
def _minus(a, b):
458+
return [x for x in a if x not in b]
459+
460+
461+
def _compose_output_rep(lhs_rep, rhs_rep, lhs_contraction, rhs_contraction,
462462
lhs_batch, rhs_batch):
463+
""" Given lhs representation, rhs representation, contraction and batch dimensions,
464+
compose the output representation.
465+
e.g., ij, jk, (((1,), (0,)), ((), ())) -> ik
466+
aij, ajk, (((2,), (1,)), ((0,), (0,))) -> aik
467+
"""
463468
output_rep = []
464469
for dim in lhs_batch:
465470
output_rep.append(lhs_rep[dim])
466471
for dim in rhs_batch:
467472
if rhs_rep[dim] not in output_rep:
468473
output_rep.append(rhs_rep[dim])
469474

470-
for i in range(len(lhs_rep)):
471-
if i not in lhs_batch and i not in lhs_contraction:
472-
output_rep.append(lhs_rep[i])
473-
for i in range(len(rhs_rep)):
474-
if i not in rhs_batch and i not in rhs_contraction:
475-
output_rep.append(rhs_rep[i])
475+
for i in _minus(range(len(lhs_rep)), lhs_batch + lhs_contraction):
476+
output_rep.append(lhs_rep[i])
477+
for i in _minus(range(len(rhs_rep)), rhs_batch + rhs_contraction):
478+
output_rep.append(rhs_rep[i])
476479
return ''.join(output_rep)
477480

478481

479-
# If it is the general non-batched/single-batched matrix multiplication,
480-
# use the highly optimized kernel `tf.tensordot` to handle it.
481-
def non_batched_matmul(lhs, rhs, lhs_contraction, rhs_contraction):
482+
def _non_batched_matmul(lhs, rhs, lhs_contraction, rhs_contraction):
483+
""" If it is the general non-batched/single-batched matrix multiplication,
484+
use the highly optimized kernel `tf.tensordot` to handle it.
485+
"""
482486
return tf.tensordot(lhs, rhs, axes=(list(lhs_contraction), list(rhs_contraction)))
483487

484488

485-
# An equivalent general dot operation as that in JAX -
486-
# <https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.dot_general.html>
487-
#
488-
# Although there is an implementation in TF XLA, avoid directly using XLA when
489-
# possible.
490-
#
491-
# e.g., non-batched: ij,jk->ik
492-
# batched: ijk,ikl->ijl
493489
def tf_dot_general(lhs, rhs, dimension_numbers):
490+
""" An equivalent general dot operation as that in JAX -
491+
<https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.dot_general.html>
494492
495-
char_list = list(string.ascii_lowercase)[8:]
496-
lhs_dim, rhs_dim = len(lhs.shape), len(rhs.shape)
497-
lhs_rep = char_list[:lhs_dim]
498-
rhs_rep = char_list[lhs_dim:lhs_dim+rhs_dim]
493+
Although there is an implementation in TF XLA, avoid directly using XLA when
494+
possible.
495+
496+
e.g., non-batched: ij,jk->ik
497+
batched: ijk,ikl->ijl
498+
"""
499+
char_list = list(string.ascii_lowercase)
500+
char_list = char_list[8:] + char_list[:8]
501+
lhs_rank, rhs_rank = len(lhs.shape), len(rhs.shape)
502+
lhs_rep = char_list[:lhs_rank]
503+
rhs_rep = char_list[lhs_rank:lhs_rank+rhs_rank]
499504
contraction, batch = dimension_numbers
500505
lhs_contraction, rhs_contraction = contraction
501506
lhs_batch, rhs_batch = batch
502507

503508
if len(lhs_batch) == 0 and len(rhs_batch) == 0:
504-
return non_batched_matmul(lhs, rhs, lhs_contraction, rhs_contraction)
509+
return _non_batched_matmul(lhs, rhs, lhs_contraction, rhs_contraction)
505510

506-
cond_a = lhs_dim == rhs_dim == 3
511+
cond_a = lhs_rank == rhs_rank == 3
507512
cond_b = lhs_batch == (0,) and rhs_batch == (0,)
508-
cond_c = lhs_contraction == (lhs_dim - 1,) and rhs_contraction == (1,)
513+
cond_c = lhs_contraction == (2,) and rhs_contraction == (1,)
509514
if cond_a and cond_b and cond_c:
510515
return tf.linalg.matmul(lhs, rhs)
511516

@@ -515,7 +520,7 @@ def tf_dot_general(lhs, rhs, dimension_numbers):
515520
if i < len(rhs_batch):
516521
rhs_rep[rhs_batch[i]] = lhs_rep[lhs_batch[i]]
517522

518-
output_rep = compose_output_rep(lhs_rep, rhs_rep, lhs_contraction,
523+
output_rep = _compose_output_rep(lhs_rep, rhs_rep, lhs_contraction,
519524
rhs_contraction, lhs_batch, rhs_batch)
520525
equation = ''.join(lhs_rep) + ',' + ''.join(rhs_rep) + "->" + output_rep
521526
return tf.einsum(equation, lhs, rhs)

trax/tf_numpy/extensions/extensions_test.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -408,9 +408,9 @@ def test_compose_output_rep(self, lhs, rhs, dims, result):
408408
"dims": (((4,), (1,)), ((0,), (0,)))},
409409
)
410410
def test_tf_dot_general(self, lhs_np, rhs_np, dims):
411-
ans = lax.dot_general(jnp.array(lhs_np), jnp.array(rhs_np), dims)
412-
result = extensions.tf_dot_general(tf.constant(lhs_np), tf.constant(rhs_np), dims)
413-
self.assertTrue((result.numpy() == np.array(ans)).all())
411+
ans = lax.dot_general(lhs_np, rhs_np, dims)
412+
result = extensions.tf_dot_general(lhs_np, rhs_np, dims)
413+
self.assertAllClose(result, np.array(ans))
414414

415415

416416
def testConv(self):

0 commit comments

Comments
 (0)