Skip to content

Commit f52a707

Browse files
committed
factor out tensor_network_ag_gate, add gate_upper+gate_lower
1 parent 422a89c commit f52a707

File tree

1 file changed

+121
-96
lines changed

1 file changed

+121
-96
lines changed

quimb/tensor/tensor_arbgeom.py

+121-96
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,112 @@ def tensor_network_ag_sum(
431431
return tna
432432

433433

434+
def tensor_network_ag_gate(
435+
self: "TensorNetworkGen",
436+
G,
437+
where,
438+
contract=False,
439+
tags=None,
440+
propagate_tags=False,
441+
which=None,
442+
info=None,
443+
inplace=False,
444+
**compress_opts,
445+
):
446+
r"""Apply a gate to this vector tensor network at sites ``where``. This is
447+
essentially a wrapper around
448+
:meth:`~quimb.tensor.tensor_core.TensorNetwork.gate_inds` apart from
449+
``where`` can be specified as a list of sites, and tags can be optionally,
450+
intelligently propagated to the new gate tensor.
451+
452+
.. math::
453+
454+
| \psi \rangle \rightarrow G_\mathrm{where} | \psi \rangle
455+
456+
Parameters
457+
----------
458+
G : array_ike
459+
The gate array to apply, should match or be factorable into the shape
460+
``(*phys_dims, *phys_dims)``.
461+
where : node or sequence[node]
462+
The sites to apply the gate to.
463+
contract : {False, True, 'split', 'reduce-split', 'split-gate',
464+
'swap-split-gate', 'auto-split-gate'}, optional
465+
How to apply the gate, see
466+
:meth:`~quimb.tensor.tensor_core.TensorNetwork.gate_inds`.
467+
tags : str or sequence of str, optional
468+
Tags to add to the new gate tensor.
469+
propagate_tags : {False, True, 'register', 'sites'}, optional
470+
Whether to propagate tags to the new gate tensor::
471+
472+
- False: no tags are propagated
473+
- True: all tags are propagated
474+
- 'register': only site tags corresponding to ``where`` are
475+
added.
476+
- 'sites': all site tags on the current sites are propgated,
477+
resulting in a lightcone like tagging.
478+
479+
info : None or dict, optional
480+
Used to store extra optional information such as the singular
481+
values if not absorbed.
482+
inplace : bool, optional
483+
Whether to perform the gate operation inplace on the tensor network
484+
or not.
485+
compress_opts
486+
Supplied to :func:`~quimb.tensor.tensor_core.tensor_split` for any
487+
``contract`` methods that involve splitting. Ignored otherwise.
488+
489+
Returns
490+
-------
491+
TensorNetworkGenVector
492+
493+
See Also
494+
--------
495+
TensorNetwork.gate_inds
496+
"""
497+
check_opt("propagate_tags", propagate_tags, _VALID_GATE_PROPAGATE)
498+
499+
tn = self if inplace else self.copy()
500+
501+
if which is None:
502+
site_ind_fn = tn.site_ind
503+
elif which == "upper":
504+
site_ind_fn = tn.upper_ind
505+
elif which == "lower":
506+
site_ind_fn = tn.lower_ind
507+
else:
508+
raise ValueError("`which` should be None, 'upper' or 'lower'.")
509+
510+
if not isinstance(where, (tuple, list)):
511+
where = (where,)
512+
inds = tuple(map(site_ind_fn, where))
513+
514+
# potentially add tags from current tensors to the new ones,
515+
# only do this if we are lazily adding the gate tensor(s)
516+
if (contract in _LAZY_GATE_CONTRACT) and (
517+
propagate_tags in (True, "sites")
518+
):
519+
old_tags = oset.union(*(t.tags for t in tn._inds_get(*inds)))
520+
if propagate_tags == "sites":
521+
old_tags = tn.filter_valid_site_tags(old_tags)
522+
523+
tags = tags_to_oset(tags)
524+
tags.update(old_tags)
525+
526+
# perform the actual gating
527+
tn.gate_inds_(
528+
G, inds, contract=contract, tags=tags, info=info, **compress_opts
529+
)
530+
531+
# possibly add tags based on where the gate was applied
532+
if propagate_tags == "register":
533+
for ix, site in zip(inds, where):
534+
(t,) = tn._inds_get(ix)
535+
t.add_tag(tn.site_tag(site))
536+
537+
return tn
538+
539+
434540
class TensorNetworkGen(TensorNetwork):
435541
"""A tensor network which notionally has a single tensor per 'site',
436542
though these could be labelled arbitrarily could also be linked in an
@@ -1058,102 +1164,8 @@ def gate_with_op_lazy(self, A, transpose=False, inplace=False, **kwargs):
10581164
gate_with_op_lazy, inplace=True
10591165
)
10601166

1061-
def gate(
1062-
self,
1063-
G,
1064-
where,
1065-
contract=False,
1066-
tags=None,
1067-
propagate_tags=False,
1068-
info=None,
1069-
inplace=False,
1070-
**compress_opts,
1071-
):
1072-
r"""Apply a gate to this vector tensor network at sites ``where``. This
1073-
is essentially a wrapper around
1074-
:meth:`~quimb.tensor.tensor_core.TensorNetwork.gate_inds` apart from
1075-
``where`` can be specified as a list of sites, and tags can be
1076-
optionally, intelligently propagated to the new gate tensor.
1077-
1078-
.. math::
1079-
1080-
| \psi \rangle \rightarrow G_\mathrm{where} | \psi \rangle
1081-
1082-
Parameters
1083-
----------
1084-
G : array_ike
1085-
The gate array to apply, should match or be factorable into the
1086-
shape ``(*phys_dims, *phys_dims)``.
1087-
where : node or sequence[node]
1088-
The sites to apply the gate to.
1089-
contract : {False, True, 'split', 'reduce-split', 'split-gate',
1090-
'swap-split-gate', 'auto-split-gate'}, optional
1091-
How to apply the gate, see
1092-
:meth:`~quimb.tensor.tensor_core.TensorNetwork.gate_inds`.
1093-
tags : str or sequence of str, optional
1094-
Tags to add to the new gate tensor.
1095-
propagate_tags : {False, True, 'register', 'sites'}, optional
1096-
Whether to propagate tags to the new gate tensor::
1097-
1098-
- False: no tags are propagated
1099-
- True: all tags are propagated
1100-
- 'register': only site tags corresponding to ``where`` are
1101-
added.
1102-
- 'sites': all site tags on the current sites are propgated,
1103-
resulting in a lightcone like tagging.
1104-
1105-
info : None or dict, optional
1106-
Used to store extra optional information such as the singular
1107-
values if not absorbed.
1108-
inplace : bool, optional
1109-
Whether to perform the gate operation inplace on the tensor network
1110-
or not.
1111-
compress_opts
1112-
Supplied to :func:`~quimb.tensor.tensor_core.tensor_split` for any
1113-
``contract`` methods that involve splitting. Ignored otherwise.
1114-
1115-
Returns
1116-
-------
1117-
TensorNetworkGenVector
1118-
1119-
See Also
1120-
--------
1121-
TensorNetwork.gate_inds
1122-
"""
1123-
check_opt("propagate_tags", propagate_tags, _VALID_GATE_PROPAGATE)
1124-
1125-
tn = self if inplace else self.copy()
1126-
1127-
if not isinstance(where, (tuple, list)):
1128-
where = (where,)
1129-
inds = tuple(map(tn.site_ind, where))
1130-
1131-
# potentially add tags from current tensors to the new ones,
1132-
# only do this if we are lazily adding the gate tensor(s)
1133-
if (contract in _LAZY_GATE_CONTRACT) and (
1134-
propagate_tags in (True, "sites")
1135-
):
1136-
old_tags = oset.union(*(t.tags for t in tn._inds_get(*inds)))
1137-
if propagate_tags == "sites":
1138-
old_tags = tn.filter_valid_site_tags(old_tags)
1139-
1140-
tags = tags_to_oset(tags)
1141-
tags.update(old_tags)
1142-
1143-
# perform the actual gating
1144-
tn.gate_inds_(
1145-
G, inds, contract=contract, tags=tags, info=info, **compress_opts
1146-
)
1147-
1148-
# possibly add tags based on where the gate was applied
1149-
if propagate_tags == "register":
1150-
for ix, site in zip(inds, where):
1151-
(t,) = tn._inds_get(ix)
1152-
t.add_tag(tn.site_tag(site))
1153-
1154-
return tn
1155-
1156-
gate_ = functools.partialmethod(gate, inplace=True)
1167+
gate = tensor_network_ag_gate
1168+
gate_ = functools.partialmethod(tensor_network_ag_gate, inplace=True)
11571169

11581170
def gate_simple_(
11591171
self,
@@ -3032,6 +3044,19 @@ def phys_dim(self, site=None, which="upper"):
30323044
if which == "lower":
30333045
return self[site].ind_size(self.lower_ind(site))
30343046

3047+
gate_upper = functools.partialmethod(tensor_network_ag_gate, which="upper")
3048+
gate_upper_ = functools.partialmethod(
3049+
tensor_network_ag_gate,
3050+
which="upper",
3051+
inplace=True,
3052+
)
3053+
gate_lower = functools.partialmethod(tensor_network_ag_gate, which="lower")
3054+
gate_lower_ = functools.partialmethod(
3055+
tensor_network_ag_gate,
3056+
which="lower",
3057+
inplace=True,
3058+
)
3059+
30353060
def gate_upper_with_op_lazy(
30363061
self,
30373062
A,

0 commit comments

Comments
 (0)