Skip to content

Commit 6fa9dfa

Browse files
committed
Tensor: align and broadcast binary ops
1 parent f3fd72f commit 6fa9dfa

File tree

4 files changed

+90
-60
lines changed

4 files changed

+90
-60
lines changed

docs/changelog.md

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Release notes for `quimb`.
1111

1212
**Enhancements:**
1313

14+
- [`Tensor`](quimb.tensor.tensor_core.Tensor): make binary operations (`+, -, *, /, **`) automatically align and broadcast indices. This would previously error.
1415
- [`MatrixProductState.measure`](quimb.tensor.tensor_1d.MatrixProductState.measure): add a `seed` kwarg
1516
- belief propagation, implement DIIS (direct inversion in the iterative subspace)
1617
- belief propagation, unify various aspects such as message normalization and distance.

docs/tensor-basics.ipynb

+56-50
Large diffs are not rendered by default.

quimb/tensor/tensor_core.py

+28-8
Original file line numberDiff line numberDiff line change
@@ -3300,15 +3300,35 @@ def COPY_tree_tensors(d, inds, tags=None, dtype=float, ssa_path=None):
33003300
def _make_promote_array_func(op, meth_name):
33013301
@functools.wraps(getattr(np.ndarray, meth_name))
33023302
def _promote_array_func(self, other):
3303-
"""Use standard array func, but make sure Tensor inds match."""
3303+
"""Use standard array func, but auto match up indices."""
33043304
if isinstance(other, Tensor):
3305-
if set(self.inds) != set(other.inds):
3306-
raise ValueError(
3307-
"The indicies of these two tensors do not "
3308-
f"match: {self.inds} != {other.inds}"
3309-
)
3310-
3311-
otherT = other.transpose(*self.inds)
3305+
# auto match up indices - i.e. broadcast dimensions
3306+
left_expand = []
3307+
right_expand = []
3308+
3309+
for ix in self.inds:
3310+
if ix not in other.inds:
3311+
right_expand.append(ix)
3312+
for ix in other.inds:
3313+
if ix not in self.inds:
3314+
left_expand.append(ix)
3315+
3316+
# new_ind is an inplace operation -> track if we need to copy
3317+
copied = False
3318+
for ix in left_expand:
3319+
if not copied:
3320+
self = self.copy()
3321+
copied = True
3322+
self.new_ind(ix, axis=-1)
3323+
3324+
copied = False
3325+
for ix in right_expand:
3326+
if not copied:
3327+
other = other.copy()
3328+
copied = True
3329+
other.new_ind(ix)
3330+
3331+
otherT = other.transpose(*self.inds, inplace=copied)
33123332

33133333
return Tensor(
33143334
data=op(self.data, otherT.data),

tests/test_tensor/test_tensor_core.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,11 @@ def test_tensor_tensor_arithmetic(self, op, mismatch):
126126
b = Tensor(np.random.rand(2, 3, 4), inds=[0, 1, 2], tags="red")
127127
if mismatch:
128128
b.modify(inds=(0, 1, 3))
129-
with pytest.raises(ValueError):
130-
op(a, b)
129+
c = op(a, b)
130+
assert_allclose(c.data, op(
131+
a.data.reshape(2, 3, 4, 1),
132+
b.data.reshape(2, 3, 1, 4))
133+
)
131134
else:
132135
c = op(a, b)
133136
assert_allclose(c.data, op(a.data, b.data))

0 commit comments

Comments
 (0)