Skip to content

Commit 0af5f5f

Browse files
committed
contract_{tags|cumulative|compressed} support strip_exponent kwarg
1 parent 44083cf commit 0af5f5f

File tree

3 files changed

+241
-27
lines changed

3 files changed

+241
-27
lines changed

docs/changelog.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ Release notes for `quimb`.
2222
- add [`qu.plot_multi_series_zoom`](quimb.utils_plot.plot_multi_series_zoom) for plotting multiple series with a zoomed inset, useful for various convergence plots such as BP
2323
- add `info` option to [`tn.gauge_all_simple`](quimb.tensor.tensor_core.TensorNetwork.gauge_all_simple) for tracking extra information such as number of iterations and max gauge diffs
2424
- [`Tensor.gate`](quimb.tensor.tensor_core.Tensor.gate): add `transposed` option
25-
- [`TensorNetwork.contract`](quimb.tensor.tensor_core.TensorNetwork.contract): add `strip_exponent` option for return the mantissa and exponent (log10) separately.
25+
- [`TensorNetwork.contract`](quimb.tensor.tensor_core.TensorNetwork.contract): add `strip_exponent` option for return the mantissa and exponent (log10) separately. Compatible with [`contract_tags`](quimb.tensor.tensor_core.TensorNetwork.contract_tags), [`contract_cumulative`](quimb.tensor.tensor_core.TensorNetwork.contract_cumulative), [`contract_compressed`](quimb.tensor.tensor_core.TensorNetwork.contract_compressed) sub modes.
2626

2727
**Bug fixes:**
2828

quimb/tensor/tensor_core.py

+133-26
Original file line numberDiff line numberDiff line change
@@ -1355,6 +1355,7 @@ def maybe_unwrap(
13551355
t,
13561356
preserve_tensor_network=False,
13571357
preserve_tensor=False,
1358+
strip_exponent=False,
13581359
equalize_norms=False,
13591360
output_inds=None,
13601361
):
@@ -1373,6 +1374,9 @@ def maybe_unwrap(
13731374
preserve_tensor : bool, optional
13741375
If ``True``, then don't unwrap a ``Tensor`` to a scalar even if it has
13751376
no indices.
1377+
strip_exponent : bool, optional
1378+
If ``True``, then return the overall exponent of the contraction, in
1379+
log10, as well as the 'mantissa' tensor or scalar.
13761380
equalize_norms : bool, optional
13771381
If ``True``, then equalize the norms of all tensors in the tensor
13781382
network before unwrapping.
@@ -1381,29 +1385,50 @@ def maybe_unwrap(
13811385
13821386
Returns
13831387
-------
1384-
TensorNetwork, Tensor or Number
1388+
TensorNetwork, Tensor or scalar
13851389
"""
1390+
exponent = 0.0
1391+
13861392
if isinstance(t, TensorNetwork):
13871393
if equalize_norms is True:
1388-
# this also redistributes the any collected norm exponent
1389-
t.equalize_norms_()
1394+
if strip_exponent:
1395+
# accumulate into the exponent
1396+
t.equalize_norms_(1.0)
1397+
else:
1398+
# this also redistributes the any collected norm exponent
1399+
t.equalize_norms_()
13901400

13911401
if preserve_tensor_network or (t.num_tensors != 1):
13921402
return t
13931403

1404+
if strip_exponent:
1405+
# extract from tn
1406+
exponent += t.exponent
1407+
13941408
# else get the single tensor
13951409
(t,) = t.tensor_map.values()
13961410

1411+
# now we have Tensor
13971412
if output_inds is not None and t.inds != output_inds:
13981413
t.transpose_(*output_inds)
13991414

1415+
if strip_exponent:
1416+
tnorm = t.norm()
1417+
t /= tnorm
1418+
exponent += do("log10", tnorm)
1419+
14001420
if preserve_tensor or t.ndim != 0:
14011421
# return as a tensor
1402-
return t
1422+
result = t
1423+
else:
1424+
# else return as a scalar, maybe dropping imaginary part
1425+
result = maybe_realify_scalar(t.data)
14031426

1404-
# else return as a scalar, maybe dropping imaginary part
1405-
return maybe_realify_scalar(t.data)
1427+
if strip_exponent:
1428+
# return mantissa and exponent separately
1429+
return result, exponent
14061430

1431+
return result
14071432

14081433
# --------------------------------------------------------------------------- #
14091434
# Tensor Class #
@@ -7531,7 +7556,10 @@ def gauge_simple_insert(
75317556
remove : bool, optional
75327557
Whether to remove the gauges from the store after inserting them.
75337558
smudge : float, optional
7534-
A small value to add to the gauge vectors to avoid singularities.
7559+
A small value to add to the gauge vectors to avoid singularities
7560+
when inserting.
7561+
power : float, optional
7562+
A power to raise the gauge vectors to when inserting.
75357563
75367564
Returns
75377565
-------
@@ -7611,6 +7639,10 @@ def gauge_simple_temp(
76117639
The store of gauge bonds, the keys being indices and the values
76127640
being the vectors. Only bonds present in this dictionary will be
76137641
gauged.
7642+
smudge : float, optional
7643+
A small value to add to the gauge vectors to avoid singularities.
7644+
power : float, optional
7645+
A power to raise the gauge vectors to when inserting.
76147646
ungauge_outer : bool, optional
76157647
Whether to ungauge the outer bonds.
76167648
ungauge_inner : bool, optional
@@ -7679,7 +7711,8 @@ def _contract_compressed_tid_sequence(
76797711
compress_matrices=True,
76807712
compress_exclude=None,
76817713
compress_opts=None,
7682-
equalize_norms=False,
7714+
strip_exponent=False,
7715+
equalize_norms="auto",
76837716
gauges=None,
76847717
gauge_smudge=1e-6,
76857718
callback_pre_contract=None,
@@ -7719,6 +7752,11 @@ def _contract_compressed_tid_sequence(
77197752
gauges = True
77207753
canonize_distance = 0
77217754

7755+
if equalize_norms == "auto":
7756+
# if we are going to extract exponent at end, assume we
7757+
# should do it throughout the computation as well
7758+
equalize_norms = strip_exponent
7759+
77227760
if gauges is True:
77237761
gauges = {}
77247762
if gauge_boundary_only:
@@ -7939,6 +7977,7 @@ def _compress_neighbors(tid, t, d):
79397977
tn,
79407978
preserve_tensor_network=inplace,
79417979
preserve_tensor=preserve_tensor,
7980+
strip_exponent=strip_exponent,
79427981
equalize_norms=equalize_norms,
79437982
output_inds=output_inds,
79447983
)
@@ -8083,7 +8122,8 @@ def contract_compressed(
80838122
compress_matrices=True,
80848123
compress_exclude=None,
80858124
compress_opts=None,
8086-
equalize_norms=False,
8125+
strip_exponent=False,
8126+
equalize_norms="auto",
80878127
gauges=None,
80888128
gauge_smudge=1e-6,
80898129
callback_pre_contract=None,
@@ -8279,6 +8319,7 @@ def contract_compressed(
82798319
compress_span=compress_span,
82808320
compress_matrices=compress_matrices,
82818321
compress_exclude=compress_exclude,
8322+
strip_exponent=strip_exponent,
82828323
equalize_norms=equalize_norms,
82838324
gauges=gauges,
82848325
gauge_smudge=gauge_smudge,
@@ -8652,6 +8693,8 @@ def contract_tags(
86528693
optimize=None,
86538694
get=None,
86548695
backend=None,
8696+
strip_exponent=False,
8697+
equalize_norms="auto",
86558698
preserve_tensor=False,
86568699
inplace=False,
86578700
**contract_opts,
@@ -8701,6 +8744,14 @@ def contract_tags(
87018744
backend : {'auto', 'numpy', 'jax', 'cupy', 'tensorflow', ...}, optional
87028745
Which backend to use to perform the contraction. Supplied to
87038746
`cotengra`.
8747+
strip_exponent : bool, optional
8748+
Whether the strip an overall exponent, log10, from the *final*
8749+
contraction. If a TensorNetwork is returned, this exponent is
8750+
accumulated in the `exponent` attribute. If a Tensor or scalar is
8751+
returned, the exponent is returned separately.
8752+
equalize_norms : bool or "auto", optional
8753+
Whether to equalize the norms of the tensors *during* the
8754+
contraction. By default ("auto") this follows `strip_exponent`.
87048755
preserve_tensor : bool, optional
87058756
Whether to return a tensor regardless of whether the output object
87068757
is a scalar (has no indices) or not.
@@ -8731,6 +8782,11 @@ def contract_tags(
87318782
"(Change this to a no-op maybe?)"
87328783
)
87338784

8785+
if equalize_norms == "auto":
8786+
# if we are going to extract exponent at end, assume we
8787+
# should do it throughout the computation as well
8788+
equalize_norms = strip_exponent
8789+
87348790
# whether we should let tensor_contract return a raw scalar
87358791
preserve_tensor = preserve_tensor or inplace or (tn.num_tensors >= 1)
87368792

@@ -8740,16 +8796,46 @@ def contract_tags(
87408796
optimize=optimize,
87418797
get=get,
87428798
backend=backend,
8799+
strip_exponent=equalize_norms,
87438800
preserve_tensor=preserve_tensor,
87448801
**contract_opts,
87458802
)
87468803

8804+
if equalize_norms:
8805+
# exponent already returned separately
8806+
t, exponent = t
8807+
elif strip_exponent:
8808+
# explicitly remove exponent now
8809+
if isinstance(t, Tensor):
8810+
tnorm = t.norm()
8811+
else:
8812+
# already scalar
8813+
tnorm = do("abs", t)
8814+
8815+
t /= tnorm
8816+
exponent = do("log10", tnorm)
8817+
else:
8818+
exponent = None
8819+
87478820
if (tn.num_tensors == 0) and (not inplace):
87488821
# contracted all down to single tensor or scalar -> return it
87498822
# (apart from if inplace -> we want to keep the tensor network)
8823+
if exponent is not None:
8824+
if strip_exponent:
8825+
# return separately
8826+
return t, exponent
8827+
8828+
# scale by stripped exponent directly
8829+
t = t * 10**exponent
8830+
87508831
return t
87518832

87528833
tn.add_tensor(t, virtual=True)
8834+
8835+
if exponent is not None:
8836+
# scale by stripped exponent lazily
8837+
tn.exponent += exponent
8838+
87538839
return tn
87548840

87558841
contract_tags_ = functools.partialmethod(contract_tags, inplace=True)
@@ -8766,7 +8852,7 @@ def contract(
87668852
strip_exponent=False,
87678853
exponent=True,
87688854
inplace=False,
8769-
**opts,
8855+
**kwargs,
87708856
):
87718857
"""Contract some, or all, of the tensors in this network. This method
87728858
dispatches to ``contract_tags``, ``contract_structured``, or
@@ -8828,7 +8914,7 @@ def contract(
88288914
inplace : bool, optional
88298915
Whether to perform the contraction inplace. This is only valid
88308916
if not all tensors are contracted (which doesn't produce a TN).
8831-
opts
8917+
kwargs
88328918
Passed to :func:`~quimb.tensor.tensor_core.tensor_contract`,
88338919
:meth:`~quimb.tensor.tensor_core.TensorNetwork.contract_compressed`
88348920
.
@@ -8844,36 +8930,44 @@ def contract(
88448930
contract_tags, contract_cumulative
88458931
"""
88468932
# for visibility we put these in the function signature
8847-
opts["output_inds"] = output_inds
8848-
opts["optimize"] = optimize
8849-
opts["get"] = get
8850-
opts["backend"] = backend
8851-
opts["preserve_tensor"] = preserve_tensor
8933+
kwargs["output_inds"] = output_inds
8934+
kwargs["optimize"] = optimize
8935+
kwargs["get"] = get
8936+
kwargs["backend"] = backend
8937+
kwargs["preserve_tensor"] = preserve_tensor
88528938

88538939
all_tags = (tags is all) or (tags is ...)
88548940

88558941
if max_bond is not None:
88568942
if not all_tags:
88578943
raise NotImplementedError
8858-
if opts.pop("get", None) is not None:
8944+
if kwargs.pop("get", None) is not None:
8945+
raise NotImplementedError
8946+
if kwargs.pop("backend", None) is not None:
88598947
raise NotImplementedError
8860-
if opts.pop("backend", None) is not None:
8948+
if exponent is not True:
88618949
raise NotImplementedError
88628950

88638951
return self.contract_compressed(
88648952
max_bond=max_bond,
8953+
strip_exponent=strip_exponent,
88658954
inplace=inplace,
8866-
**opts,
8955+
**kwargs,
88678956
)
88688957

88698958
# this checks whether certain TN classes have a manually specified
88708959
# contraction pattern (e.g. 1D along the line)
88718960
if self._CONTRACT_STRUCTURED:
8961+
8962+
if exponent is not True:
8963+
raise NotImplementedError
8964+
88728965
if (tags is ...) or isinstance(tags, slice):
88738966
return self.contract_structured(
88748967
tags,
8968+
strip_exponent=strip_exponent,
88758969
inplace=inplace,
8876-
**opts,
8970+
**kwargs,
88778971
)
88788972

88798973
# contracting everything to single output
@@ -8888,14 +8982,15 @@ def contract(
88888982
*self.tensor_map.values(),
88898983
strip_exponent=strip_exponent,
88908984
exponent=exponent,
8891-
**opts
8985+
**kwargs
88928986
)
88938987

88948988
# contract some or all tensors, but keeping tensor network
88958989
return self.contract_tags(
88968990
tags,
8991+
strip_exponent=strip_exponent,
88978992
inplace=inplace,
8898-
**opts
8993+
**kwargs
88998994
)
89008995

89018996
contract_ = functools.partialmethod(contract, inplace=True)
@@ -8905,9 +9000,10 @@ def contract_cumulative(
89059000
tags_seq,
89069001
output_inds=None,
89079002
preserve_tensor=False,
8908-
equalize_norms=False,
9003+
strip_exponent=False,
9004+
equalize_norms="auto",
89099005
inplace=False,
8910-
**opts,
9006+
**contract_opts,
89119007
):
89129008
"""Cumulative contraction of tensor network. Contract the first set of
89139009
tags, then that set with the next set, then both of those with the next
@@ -8927,7 +9023,7 @@ def contract_cumulative(
89279023
is a scalar (has no indices) or not.
89289024
inplace : bool, optional
89299025
Whether to perform the contraction inplace.
8930-
opts
9026+
contract_opts
89319027
Passed to :func:`~quimb.tensor.tensor_core.tensor_contract`.
89329028
89339029
Returns
@@ -8943,12 +9039,22 @@ def contract_cumulative(
89439039
tn = self if inplace else self.copy()
89449040
c_tags = oset()
89459041

9042+
if equalize_norms == "auto":
9043+
# if we are going to extract exponent at end, assume we
9044+
# should do it throughout the computation as well
9045+
equalize_norms = strip_exponent
9046+
89469047
for tags in tags_seq:
89479048
# accumulate tags from each contractions
89489049
c_tags |= tags_to_oset(tags)
89499050

89509051
# peform the next contraction
8951-
tn.contract_tags_(c_tags, which="any", **opts)
9052+
tn.contract_tags_(
9053+
c_tags,
9054+
which="any",
9055+
equalize_norms=equalize_norms,
9056+
**contract_opts
9057+
)
89529058

89539059
if tn.num_tensors == 1:
89549060
# nothing more to contract
@@ -8958,6 +9064,7 @@ def contract_cumulative(
89589064
tn,
89599065
preserve_tensor_network=inplace,
89609066
preserve_tensor=preserve_tensor,
9067+
strip_exponent=strip_exponent,
89619068
equalize_norms=equalize_norms,
89629069
output_inds=output_inds,
89639070
)

0 commit comments

Comments
 (0)