Skip to content

Commit 44083cf

Browse files
committed
add tn.contract(strip_exponent=True)
1 parent daceba2 commit 44083cf

File tree

13 files changed

+140
-49
lines changed

13 files changed

+140
-49
lines changed

ci/requirements/py-base.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ channels:
22
- conda-forge
33
dependencies:
44
- autoray>=0.6.12
5-
- cotengra>=0.6.1
5+
- cotengra>=0.7.1
66
- coverage
77
- cytoolz
88
- joblib

ci/requirements/py-jax.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ channels:
22
- conda-forge
33
dependencies:
44
- autoray>=0.6.12
5-
- cotengra>=0.6.1
5+
- cotengra>=0.7.1
66
- coverage
77
- cytoolz
88
- joblib

ci/requirements/py-openblas.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ channels:
33
dependencies:
44
- autoray>=0.6.12
55
- blas=*=openblas
6-
- cotengra>=0.6.1
6+
- cotengra>=0.7.1
77
- coverage
88
- cytoolz
99
- joblib

ci/requirements/py-slepc.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ channels:
22
- conda-forge
33
dependencies:
44
- autoray>=0.6.12
5-
- cotengra>=0.6.1
5+
- cotengra>=0.7.1
66
- coverage
77
- cytoolz
88
- joblib

ci/requirements/py-tensorflow.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ channels:
22
- conda-forge
33
dependencies:
44
- autoray>=0.7.0
5-
- cotengra>=0.6.1
5+
- cotengra>=0.7.1
66
- coverage
77
- cytoolz
88
- joblib

ci/requirements/py-torch.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ channels:
33
- conda-forge
44
dependencies:
55
- autoray>=0.6.12
6-
- cotengra>=0.6.1
6+
- cotengra>=0.7.1
77
- coverage
88
- cpuonly
99
- cytoolz

docs/changelog.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ Release notes for `quimb`.
88
**Breaking Changes**
99

1010
- move belief propagation to `quimb.tensor.belief_propagation`
11+
- calling `tn.contract()` when an non-zero value has been accrued into `tn.exponent` now automatically re-absorbs that exponent.
1112

1213
**Enhancements:**
1314

@@ -19,8 +20,9 @@ Release notes for `quimb`.
1920
- belief propagation, add a `contract_every` option.
2021
- HV1BP: vectorize both contraction and message initialization
2122
- add [`qu.plot_multi_series_zoom`](quimb.utils_plot.plot_multi_series_zoom) for plotting multiple series with a zoomed inset, useful for various convergence plots such as BP
22-
- add `info` option to [`tn.gauge_all_simple`](quimb.tensor.tensor_core.TensorNetwork.gauge_all_simple) for tracking extra information such as number of iterations and max gauge difffi
23+
- add `info` option to [`tn.gauge_all_simple`](quimb.tensor.tensor_core.TensorNetwork.gauge_all_simple) for tracking extra information such as number of iterations and max gauge diffs
2324
- [`Tensor.gate`](quimb.tensor.tensor_core.Tensor.gate): add `transposed` option
25+
- [`TensorNetwork.contract`](quimb.tensor.tensor_core.TensorNetwork.contract): add `strip_exponent` option for return the mantissa and exponent (log10) separately.
2426

2527
**Bug fixes:**
2628

pyproject.toml

+9-9
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,14 @@ keywords = [
2727
]
2828
requires-python = ">=3.9"
2929
dependencies = [
30-
"autoray >=0.6.12",
31-
"cotengra >=0.6.1",
32-
"cytoolz >=0.8.0",
33-
"numba >=0.39",
34-
"numpy >=1.17",
35-
"psutil >=4.3.1",
36-
"scipy >=1.0.0",
37-
"tqdm >=4",
30+
"autoray>=0.6.12",
31+
"cotengra>=0.7.1",
32+
"cytoolz>=0.8.0",
33+
"numba>=0.39",
34+
"numpy>=1.17",
35+
"psutil>=4.3.1",
36+
"scipy>=1.0.0",
37+
"tqdm>=4",
3838
]
3939

4040

@@ -56,7 +56,7 @@ tests = [
5656
docs = [
5757
"astroid<3.0.0",
5858
"autoray>=0.6.12",
59-
"cotengra>=0.6.1",
59+
"cotengra>=0.7.1",
6060
"doc2dash>=2.4.1",
6161
"furo",
6262
"ipython!=8.7.0",

quimb/tensor/contraction.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,12 @@ def array_contract(
284284
if backend is None:
285285
backend = get_contract_backend()
286286
return ctg.array_contract(
287-
arrays, inputs, output, optimize=optimize, backend=backend, **kwargs
287+
arrays,
288+
inputs,
289+
output,
290+
optimize=optimize,
291+
backend=backend,
292+
**kwargs,
288293
)
289294

290295

quimb/tensor/tensor_core.py

+70-11
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,8 @@ def tensor_contract(
214214
backend=None,
215215
preserve_tensor=False,
216216
drop_tags=False,
217+
strip_exponent=False,
218+
exponent=None,
217219
**contract_opts,
218220
):
219221
"""Contract a collection of tensors into a scalar or tensor, automatically
@@ -268,6 +270,11 @@ def tensor_contract(
268270
drop_tags : bool, optional
269271
Whether to drop all tags from the output tensor. By default the output
270272
tensor will keep the union of all tags from the input tensors.
273+
strip_exponent : bool, optional
274+
If `True`, return the exponent of the result, log10, as well as the
275+
rescaled 'mantissa'. Useful for very large or small values.
276+
exponent : float, optional
277+
If supplied, a base exponent to add to the result exponent.
271278
contract_opts
272279
Passed to ``cotengra.array_contract``.
273280
@@ -300,20 +307,40 @@ def tensor_contract(
300307
inds,
301308
inds_out,
302309
optimize=optimize,
310+
strip_exponent=strip_exponent,
303311
backend=backend,
304312
**contract_opts,
305313
)
306314

307-
if not inds_out and not preserve_tensor:
308-
return maybe_realify_scalar(data_out)
315+
if strip_exponent:
316+
# mantissa and exponent returned separately
317+
data_out, result_exponent = data_out
318+
319+
if exponent is not None:
320+
# custom base exponent supplied
321+
result_exponent = result_exponent + exponent
309322

310-
if drop_tags:
311-
tags_out = None
323+
elif exponent is not None:
324+
# custom exponent but not stripping, so we need to scale the result
325+
data_out = data_out * 10**exponent
326+
327+
if not inds_out and not preserve_tensor:
328+
# return a scalar, possibly casting to real
329+
# but only if numpy with v. small imag part
330+
result = maybe_realify_scalar(data_out)
312331
else:
313-
# union of all tags
314-
tags_out = oset_union(t.tags for t in tensors)
332+
if drop_tags:
333+
tags_out = None
334+
else:
335+
# union of all tags
336+
tags_out = oset_union(t.tags for t in tensors)
337+
338+
result = Tensor(data=data_out, inds=inds_out, tags=tags_out)
315339

316-
return Tensor(data=data_out, inds=inds_out, tags=tags_out)
340+
if strip_exponent:
341+
return result, result_exponent
342+
343+
return result
317344

318345

319346
# generate a random base to avoid collisions on difference processes ...
@@ -8736,6 +8763,8 @@ def contract(
87368763
backend=None,
87378764
preserve_tensor=False,
87388765
max_bond=None,
8766+
strip_exponent=False,
8767+
exponent=True,
87398768
inplace=False,
87408769
**opts,
87418770
):
@@ -8788,6 +8817,14 @@ def contract(
87888817
preserve_tensor : bool, optional
87898818
Whether to return a tensor regardless of whether the output object
87908819
is a scalar (has no indices) or not.
8820+
strip_exponent : bool, optional
8821+
If contracting the entire tensor network, whether to strip a log10
8822+
exponent and return it separately. This is useful for very large or
8823+
small values.
8824+
exponent : float, optional
8825+
The current exponent to scale the whole contraction by. If ``True``
8826+
this taken from `tn.exponent`. If `False` then this is ignored.
8827+
If a float, this is the exponent to use.
87918828
inplace : bool, optional
87928829
Whether to perform the contraction inplace. This is only valid
87938830
if not all tensors are contracted (which doesn't produce a TN).
@@ -8806,6 +8843,7 @@ def contract(
88068843
--------
88078844
contract_tags, contract_cumulative
88088845
"""
8846+
# for visibility we put these in the function signature
88098847
opts["output_inds"] = output_inds
88108848
opts["optimize"] = optimize
88118849
opts["get"] = get
@@ -8823,21 +8861,42 @@ def contract(
88238861
raise NotImplementedError
88248862

88258863
return self.contract_compressed(
8826-
max_bond=max_bond, inplace=inplace, **opts
8864+
max_bond=max_bond,
8865+
inplace=inplace,
8866+
**opts,
88278867
)
88288868

88298869
# this checks whether certain TN classes have a manually specified
88308870
# contraction pattern (e.g. 1D along the line)
88318871
if self._CONTRACT_STRUCTURED:
88328872
if (tags is ...) or isinstance(tags, slice):
8833-
return self.contract_structured(tags, inplace=inplace, **opts)
8873+
return self.contract_structured(
8874+
tags,
8875+
inplace=inplace,
8876+
**opts,
8877+
)
88348878

88358879
# contracting everything to single output
88368880
if all_tags and not inplace:
8837-
return tensor_contract(*self.tensor_map.values(), **opts)
8881+
8882+
if exponent is True:
8883+
exponent = self.exponent
8884+
elif exponent is False:
8885+
exponent = 0.0
8886+
8887+
return tensor_contract(
8888+
*self.tensor_map.values(),
8889+
strip_exponent=strip_exponent,
8890+
exponent=exponent,
8891+
**opts
8892+
)
88388893

88398894
# contract some or all tensors, but keeping tensor network
8840-
return self.contract_tags(tags, inplace=inplace, **opts)
8895+
return self.contract_tags(
8896+
tags,
8897+
inplace=inplace,
8898+
**opts
8899+
)
88418900

88428901
contract_ = functools.partialmethod(contract, inplace=True)
88438902

tests/test_tensor/test_belief_propagation/test_l2bp.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def test_contract_double_layer_tree_exact(dtype):
6161
norm_bp = qbp.contract_l2bp(tn, info=info, progbar=True)
6262
assert info["converged"]
6363

64-
assert norm_bp == pytest.approx(norm_ex, rel=1e-6)
64+
assert norm_bp == pytest.approx(norm_ex, rel=5e-6)
6565

6666

6767
@pytest.mark.parametrize("dtype", ["float32", "complex64"])

tests/test_tensor/test_contract.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,36 @@
1-
import pytest
21
import numpy as np
2+
import pytest
33

44
import quimb.tensor as qtn
55
from quimb.tensor.contraction import _CONTRACT_BACKEND, _TENSOR_LINOP_BACKEND
66

77

8+
def test_tensor_contract_strip_exponent():
9+
tn = qtn.TN_rand_reg(10, 3, 3, dtype=complex)
10+
z0 = tn.contract()
11+
m1, e1 = qtn.tensor_contract(*tn, strip_exponent=True)
12+
assert m1 * 10**e1 == pytest.approx(z0)
13+
# test tn.exponent is reinserted
14+
tn.equalize_norms_(value=1.0)
15+
z2 = tn.contract()
16+
assert z2 == pytest.approx(z0)
17+
# test tn.exponent is reinserted with strip exponent
18+
m3, e3 = tn.contract(strip_exponent=True)
19+
assert m3 * 10**e3 == pytest.approx(z0)
20+
# test tn.exponent is not reinserted when specified
21+
z4 = tn.contract(exponent=False)
22+
assert z4 != pytest.approx(z0)
23+
# test tn.exponent is not reinserted when specified with strip exponent
24+
m5, e5 = tn.contract(strip_exponent=True, exponent=False)
25+
assert m5 * 10**e5 != pytest.approx(z0)
26+
# test explicit exponent
27+
z6 = tn.contract(exponent=tn.exponent)
28+
assert z6 == pytest.approx(z0)
29+
# test explicit exponent with strip exponent
30+
m7, e7 = tn.contract(strip_exponent=True, exponent=tn.exponent)
31+
assert m7 * 10**e7 == pytest.approx(z0)
32+
33+
834
class TestContractOpts:
935
def test_contract_strategy(self):
1036
assert qtn.get_contract_strategy() == "greedy"

0 commit comments

Comments
 (0)