@@ -1355,6 +1355,7 @@ def maybe_unwrap(
1355
1355
t ,
1356
1356
preserve_tensor_network = False ,
1357
1357
preserve_tensor = False ,
1358
+ strip_exponent = False ,
1358
1359
equalize_norms = False ,
1359
1360
output_inds = None ,
1360
1361
):
@@ -1373,6 +1374,9 @@ def maybe_unwrap(
1373
1374
preserve_tensor : bool, optional
1374
1375
If ``True``, then don't unwrap a ``Tensor`` to a scalar even if it has
1375
1376
no indices.
1377
+ strip_exponent : bool, optional
1378
+ If ``True``, then return the overall exponent of the contraction, in
1379
+ log10, as well as the 'mantissa' tensor or scalar.
1376
1380
equalize_norms : bool, optional
1377
1381
If ``True``, then equalize the norms of all tensors in the tensor
1378
1382
network before unwrapping.
@@ -1381,29 +1385,50 @@ def maybe_unwrap(
1381
1385
1382
1386
Returns
1383
1387
-------
1384
- TensorNetwork, Tensor or Number
1388
+ TensorNetwork, Tensor or scalar
1385
1389
"""
1390
+ exponent = 0.0
1391
+
1386
1392
if isinstance (t , TensorNetwork ):
1387
1393
if equalize_norms is True :
1388
- # this also redistributes the any collected norm exponent
1389
- t .equalize_norms_ ()
1394
+ if strip_exponent :
1395
+ # accumulate into the exponent
1396
+ t .equalize_norms_ (1.0 )
1397
+ else :
1398
+ # this also redistributes the any collected norm exponent
1399
+ t .equalize_norms_ ()
1390
1400
1391
1401
if preserve_tensor_network or (t .num_tensors != 1 ):
1392
1402
return t
1393
1403
1404
+ if strip_exponent :
1405
+ # extract from tn
1406
+ exponent += t .exponent
1407
+
1394
1408
# else get the single tensor
1395
1409
(t ,) = t .tensor_map .values ()
1396
1410
1411
+ # now we have Tensor
1397
1412
if output_inds is not None and t .inds != output_inds :
1398
1413
t .transpose_ (* output_inds )
1399
1414
1415
+ if strip_exponent :
1416
+ tnorm = t .norm ()
1417
+ t /= tnorm
1418
+ exponent += do ("log10" , tnorm )
1419
+
1400
1420
if preserve_tensor or t .ndim != 0 :
1401
1421
# return as a tensor
1402
- return t
1422
+ result = t
1423
+ else :
1424
+ # else return as a scalar, maybe dropping imaginary part
1425
+ result = maybe_realify_scalar (t .data )
1403
1426
1404
- # else return as a scalar, maybe dropping imaginary part
1405
- return maybe_realify_scalar (t .data )
1427
+ if strip_exponent :
1428
+ # return mantissa and exponent separately
1429
+ return result , exponent
1406
1430
1431
+ return result
1407
1432
1408
1433
# --------------------------------------------------------------------------- #
1409
1434
# Tensor Class #
@@ -7531,7 +7556,10 @@ def gauge_simple_insert(
7531
7556
remove : bool, optional
7532
7557
Whether to remove the gauges from the store after inserting them.
7533
7558
smudge : float, optional
7534
- A small value to add to the gauge vectors to avoid singularities.
7559
+ A small value to add to the gauge vectors to avoid singularities
7560
+ when inserting.
7561
+ power : float, optional
7562
+ A power to raise the gauge vectors to when inserting.
7535
7563
7536
7564
Returns
7537
7565
-------
@@ -7611,6 +7639,10 @@ def gauge_simple_temp(
7611
7639
The store of gauge bonds, the keys being indices and the values
7612
7640
being the vectors. Only bonds present in this dictionary will be
7613
7641
gauged.
7642
+ smudge : float, optional
7643
+ A small value to add to the gauge vectors to avoid singularities.
7644
+ power : float, optional
7645
+ A power to raise the gauge vectors to when inserting.
7614
7646
ungauge_outer : bool, optional
7615
7647
Whether to ungauge the outer bonds.
7616
7648
ungauge_inner : bool, optional
@@ -7679,7 +7711,8 @@ def _contract_compressed_tid_sequence(
7679
7711
compress_matrices = True ,
7680
7712
compress_exclude = None ,
7681
7713
compress_opts = None ,
7682
- equalize_norms = False ,
7714
+ strip_exponent = False ,
7715
+ equalize_norms = "auto" ,
7683
7716
gauges = None ,
7684
7717
gauge_smudge = 1e-6 ,
7685
7718
callback_pre_contract = None ,
@@ -7719,6 +7752,11 @@ def _contract_compressed_tid_sequence(
7719
7752
gauges = True
7720
7753
canonize_distance = 0
7721
7754
7755
+ if equalize_norms == "auto" :
7756
+ # if we are going to extract exponent at end, assume we
7757
+ # should do it throughout the computation as well
7758
+ equalize_norms = strip_exponent
7759
+
7722
7760
if gauges is True :
7723
7761
gauges = {}
7724
7762
if gauge_boundary_only :
@@ -7939,6 +7977,7 @@ def _compress_neighbors(tid, t, d):
7939
7977
tn ,
7940
7978
preserve_tensor_network = inplace ,
7941
7979
preserve_tensor = preserve_tensor ,
7980
+ strip_exponent = strip_exponent ,
7942
7981
equalize_norms = equalize_norms ,
7943
7982
output_inds = output_inds ,
7944
7983
)
@@ -8083,7 +8122,8 @@ def contract_compressed(
8083
8122
compress_matrices = True ,
8084
8123
compress_exclude = None ,
8085
8124
compress_opts = None ,
8086
- equalize_norms = False ,
8125
+ strip_exponent = False ,
8126
+ equalize_norms = "auto" ,
8087
8127
gauges = None ,
8088
8128
gauge_smudge = 1e-6 ,
8089
8129
callback_pre_contract = None ,
@@ -8279,6 +8319,7 @@ def contract_compressed(
8279
8319
compress_span = compress_span ,
8280
8320
compress_matrices = compress_matrices ,
8281
8321
compress_exclude = compress_exclude ,
8322
+ strip_exponent = strip_exponent ,
8282
8323
equalize_norms = equalize_norms ,
8283
8324
gauges = gauges ,
8284
8325
gauge_smudge = gauge_smudge ,
@@ -8652,6 +8693,8 @@ def contract_tags(
8652
8693
optimize = None ,
8653
8694
get = None ,
8654
8695
backend = None ,
8696
+ strip_exponent = False ,
8697
+ equalize_norms = "auto" ,
8655
8698
preserve_tensor = False ,
8656
8699
inplace = False ,
8657
8700
** contract_opts ,
@@ -8701,6 +8744,14 @@ def contract_tags(
8701
8744
backend : {'auto', 'numpy', 'jax', 'cupy', 'tensorflow', ...}, optional
8702
8745
Which backend to use to perform the contraction. Supplied to
8703
8746
`cotengra`.
8747
+ strip_exponent : bool, optional
8748
+ Whether the strip an overall exponent, log10, from the *final*
8749
+ contraction. If a TensorNetwork is returned, this exponent is
8750
+ accumulated in the `exponent` attribute. If a Tensor or scalar is
8751
+ returned, the exponent is returned separately.
8752
+ equalize_norms : bool or "auto", optional
8753
+ Whether to equalize the norms of the tensors *during* the
8754
+ contraction. By default ("auto") this follows `strip_exponent`.
8704
8755
preserve_tensor : bool, optional
8705
8756
Whether to return a tensor regardless of whether the output object
8706
8757
is a scalar (has no indices) or not.
@@ -8731,6 +8782,11 @@ def contract_tags(
8731
8782
"(Change this to a no-op maybe?)"
8732
8783
)
8733
8784
8785
+ if equalize_norms == "auto" :
8786
+ # if we are going to extract exponent at end, assume we
8787
+ # should do it throughout the computation as well
8788
+ equalize_norms = strip_exponent
8789
+
8734
8790
# whether we should let tensor_contract return a raw scalar
8735
8791
preserve_tensor = preserve_tensor or inplace or (tn .num_tensors >= 1 )
8736
8792
@@ -8740,16 +8796,46 @@ def contract_tags(
8740
8796
optimize = optimize ,
8741
8797
get = get ,
8742
8798
backend = backend ,
8799
+ strip_exponent = equalize_norms ,
8743
8800
preserve_tensor = preserve_tensor ,
8744
8801
** contract_opts ,
8745
8802
)
8746
8803
8804
+ if equalize_norms :
8805
+ # exponent already returned separately
8806
+ t , exponent = t
8807
+ elif strip_exponent :
8808
+ # explicitly remove exponent now
8809
+ if isinstance (t , Tensor ):
8810
+ tnorm = t .norm ()
8811
+ else :
8812
+ # already scalar
8813
+ tnorm = do ("abs" , t )
8814
+
8815
+ t /= tnorm
8816
+ exponent = do ("log10" , tnorm )
8817
+ else :
8818
+ exponent = None
8819
+
8747
8820
if (tn .num_tensors == 0 ) and (not inplace ):
8748
8821
# contracted all down to single tensor or scalar -> return it
8749
8822
# (apart from if inplace -> we want to keep the tensor network)
8823
+ if exponent is not None :
8824
+ if strip_exponent :
8825
+ # return separately
8826
+ return t , exponent
8827
+
8828
+ # scale by stripped exponent directly
8829
+ t = t * 10 ** exponent
8830
+
8750
8831
return t
8751
8832
8752
8833
tn .add_tensor (t , virtual = True )
8834
+
8835
+ if exponent is not None :
8836
+ # scale by stripped exponent lazily
8837
+ tn .exponent += exponent
8838
+
8753
8839
return tn
8754
8840
8755
8841
contract_tags_ = functools .partialmethod (contract_tags , inplace = True )
@@ -8766,7 +8852,7 @@ def contract(
8766
8852
strip_exponent = False ,
8767
8853
exponent = True ,
8768
8854
inplace = False ,
8769
- ** opts ,
8855
+ ** kwargs ,
8770
8856
):
8771
8857
"""Contract some, or all, of the tensors in this network. This method
8772
8858
dispatches to ``contract_tags``, ``contract_structured``, or
@@ -8828,7 +8914,7 @@ def contract(
8828
8914
inplace : bool, optional
8829
8915
Whether to perform the contraction inplace. This is only valid
8830
8916
if not all tensors are contracted (which doesn't produce a TN).
8831
- opts
8917
+ kwargs
8832
8918
Passed to :func:`~quimb.tensor.tensor_core.tensor_contract`,
8833
8919
:meth:`~quimb.tensor.tensor_core.TensorNetwork.contract_compressed`
8834
8920
.
@@ -8844,36 +8930,44 @@ def contract(
8844
8930
contract_tags, contract_cumulative
8845
8931
"""
8846
8932
# for visibility we put these in the function signature
8847
- opts ["output_inds" ] = output_inds
8848
- opts ["optimize" ] = optimize
8849
- opts ["get" ] = get
8850
- opts ["backend" ] = backend
8851
- opts ["preserve_tensor" ] = preserve_tensor
8933
+ kwargs ["output_inds" ] = output_inds
8934
+ kwargs ["optimize" ] = optimize
8935
+ kwargs ["get" ] = get
8936
+ kwargs ["backend" ] = backend
8937
+ kwargs ["preserve_tensor" ] = preserve_tensor
8852
8938
8853
8939
all_tags = (tags is all ) or (tags is ...)
8854
8940
8855
8941
if max_bond is not None :
8856
8942
if not all_tags :
8857
8943
raise NotImplementedError
8858
- if opts .pop ("get" , None ) is not None :
8944
+ if kwargs .pop ("get" , None ) is not None :
8945
+ raise NotImplementedError
8946
+ if kwargs .pop ("backend" , None ) is not None :
8859
8947
raise NotImplementedError
8860
- if opts . pop ( "backend" , None ) is not None :
8948
+ if exponent is not True :
8861
8949
raise NotImplementedError
8862
8950
8863
8951
return self .contract_compressed (
8864
8952
max_bond = max_bond ,
8953
+ strip_exponent = strip_exponent ,
8865
8954
inplace = inplace ,
8866
- ** opts ,
8955
+ ** kwargs ,
8867
8956
)
8868
8957
8869
8958
# this checks whether certain TN classes have a manually specified
8870
8959
# contraction pattern (e.g. 1D along the line)
8871
8960
if self ._CONTRACT_STRUCTURED :
8961
+
8962
+ if exponent is not True :
8963
+ raise NotImplementedError
8964
+
8872
8965
if (tags is ...) or isinstance (tags , slice ):
8873
8966
return self .contract_structured (
8874
8967
tags ,
8968
+ strip_exponent = strip_exponent ,
8875
8969
inplace = inplace ,
8876
- ** opts ,
8970
+ ** kwargs ,
8877
8971
)
8878
8972
8879
8973
# contracting everything to single output
@@ -8888,14 +8982,15 @@ def contract(
8888
8982
* self .tensor_map .values (),
8889
8983
strip_exponent = strip_exponent ,
8890
8984
exponent = exponent ,
8891
- ** opts
8985
+ ** kwargs
8892
8986
)
8893
8987
8894
8988
# contract some or all tensors, but keeping tensor network
8895
8989
return self .contract_tags (
8896
8990
tags ,
8991
+ strip_exponent = strip_exponent ,
8897
8992
inplace = inplace ,
8898
- ** opts
8993
+ ** kwargs
8899
8994
)
8900
8995
8901
8996
contract_ = functools .partialmethod (contract , inplace = True )
@@ -8905,9 +9000,10 @@ def contract_cumulative(
8905
9000
tags_seq ,
8906
9001
output_inds = None ,
8907
9002
preserve_tensor = False ,
8908
- equalize_norms = False ,
9003
+ strip_exponent = False ,
9004
+ equalize_norms = "auto" ,
8909
9005
inplace = False ,
8910
- ** opts ,
9006
+ ** contract_opts ,
8911
9007
):
8912
9008
"""Cumulative contraction of tensor network. Contract the first set of
8913
9009
tags, then that set with the next set, then both of those with the next
@@ -8927,7 +9023,7 @@ def contract_cumulative(
8927
9023
is a scalar (has no indices) or not.
8928
9024
inplace : bool, optional
8929
9025
Whether to perform the contraction inplace.
8930
- opts
9026
+ contract_opts
8931
9027
Passed to :func:`~quimb.tensor.tensor_core.tensor_contract`.
8932
9028
8933
9029
Returns
@@ -8943,12 +9039,22 @@ def contract_cumulative(
8943
9039
tn = self if inplace else self .copy ()
8944
9040
c_tags = oset ()
8945
9041
9042
+ if equalize_norms == "auto" :
9043
+ # if we are going to extract exponent at end, assume we
9044
+ # should do it throughout the computation as well
9045
+ equalize_norms = strip_exponent
9046
+
8946
9047
for tags in tags_seq :
8947
9048
# accumulate tags from each contractions
8948
9049
c_tags |= tags_to_oset (tags )
8949
9050
8950
9051
# peform the next contraction
8951
- tn .contract_tags_ (c_tags , which = "any" , ** opts )
9052
+ tn .contract_tags_ (
9053
+ c_tags ,
9054
+ which = "any" ,
9055
+ equalize_norms = equalize_norms ,
9056
+ ** contract_opts
9057
+ )
8952
9058
8953
9059
if tn .num_tensors == 1 :
8954
9060
# nothing more to contract
@@ -8958,6 +9064,7 @@ def contract_cumulative(
8958
9064
tn ,
8959
9065
preserve_tensor_network = inplace ,
8960
9066
preserve_tensor = preserve_tensor ,
9067
+ strip_exponent = strip_exponent ,
8961
9068
equalize_norms = equalize_norms ,
8962
9069
output_inds = output_inds ,
8963
9070
)
0 commit comments