@@ -431,6 +431,112 @@ def tensor_network_ag_sum(
431
431
return tna
432
432
433
433
434
+ def tensor_network_ag_gate (
435
+ self : "TensorNetworkGen" ,
436
+ G ,
437
+ where ,
438
+ contract = False ,
439
+ tags = None ,
440
+ propagate_tags = False ,
441
+ which = None ,
442
+ info = None ,
443
+ inplace = False ,
444
+ ** compress_opts ,
445
+ ):
446
+ r"""Apply a gate to this vector tensor network at sites ``where``. This is
447
+ essentially a wrapper around
448
+ :meth:`~quimb.tensor.tensor_core.TensorNetwork.gate_inds` apart from
449
+ ``where`` can be specified as a list of sites, and tags can be optionally,
450
+ intelligently propagated to the new gate tensor.
451
+
452
+ .. math::
453
+
454
+ | \psi \rangle \rightarrow G_\mathrm{where} | \psi \rangle
455
+
456
+ Parameters
457
+ ----------
458
+ G : array_ike
459
+ The gate array to apply, should match or be factorable into the shape
460
+ ``(*phys_dims, *phys_dims)``.
461
+ where : node or sequence[node]
462
+ The sites to apply the gate to.
463
+ contract : {False, True, 'split', 'reduce-split', 'split-gate',
464
+ 'swap-split-gate', 'auto-split-gate'}, optional
465
+ How to apply the gate, see
466
+ :meth:`~quimb.tensor.tensor_core.TensorNetwork.gate_inds`.
467
+ tags : str or sequence of str, optional
468
+ Tags to add to the new gate tensor.
469
+ propagate_tags : {False, True, 'register', 'sites'}, optional
470
+ Whether to propagate tags to the new gate tensor::
471
+
472
+ - False: no tags are propagated
473
+ - True: all tags are propagated
474
+ - 'register': only site tags corresponding to ``where`` are
475
+ added.
476
+ - 'sites': all site tags on the current sites are propgated,
477
+ resulting in a lightcone like tagging.
478
+
479
+ info : None or dict, optional
480
+ Used to store extra optional information such as the singular
481
+ values if not absorbed.
482
+ inplace : bool, optional
483
+ Whether to perform the gate operation inplace on the tensor network
484
+ or not.
485
+ compress_opts
486
+ Supplied to :func:`~quimb.tensor.tensor_core.tensor_split` for any
487
+ ``contract`` methods that involve splitting. Ignored otherwise.
488
+
489
+ Returns
490
+ -------
491
+ TensorNetworkGenVector
492
+
493
+ See Also
494
+ --------
495
+ TensorNetwork.gate_inds
496
+ """
497
+ check_opt ("propagate_tags" , propagate_tags , _VALID_GATE_PROPAGATE )
498
+
499
+ tn = self if inplace else self .copy ()
500
+
501
+ if which is None :
502
+ site_ind_fn = tn .site_ind
503
+ elif which == "upper" :
504
+ site_ind_fn = tn .upper_ind
505
+ elif which == "lower" :
506
+ site_ind_fn = tn .lower_ind
507
+ else :
508
+ raise ValueError ("`which` should be None, 'upper' or 'lower'." )
509
+
510
+ if not isinstance (where , (tuple , list )):
511
+ where = (where ,)
512
+ inds = tuple (map (site_ind_fn , where ))
513
+
514
+ # potentially add tags from current tensors to the new ones,
515
+ # only do this if we are lazily adding the gate tensor(s)
516
+ if (contract in _LAZY_GATE_CONTRACT ) and (
517
+ propagate_tags in (True , "sites" )
518
+ ):
519
+ old_tags = oset .union (* (t .tags for t in tn ._inds_get (* inds )))
520
+ if propagate_tags == "sites" :
521
+ old_tags = tn .filter_valid_site_tags (old_tags )
522
+
523
+ tags = tags_to_oset (tags )
524
+ tags .update (old_tags )
525
+
526
+ # perform the actual gating
527
+ tn .gate_inds_ (
528
+ G , inds , contract = contract , tags = tags , info = info , ** compress_opts
529
+ )
530
+
531
+ # possibly add tags based on where the gate was applied
532
+ if propagate_tags == "register" :
533
+ for ix , site in zip (inds , where ):
534
+ (t ,) = tn ._inds_get (ix )
535
+ t .add_tag (tn .site_tag (site ))
536
+
537
+ return tn
538
+
539
+
434
540
class TensorNetworkGen (TensorNetwork ):
435
541
"""A tensor network which notionally has a single tensor per 'site',
436
542
though these could be labelled arbitrarily could also be linked in an
@@ -1058,102 +1164,8 @@ def gate_with_op_lazy(self, A, transpose=False, inplace=False, **kwargs):
1058
1164
gate_with_op_lazy , inplace = True
1059
1165
)
1060
1166
1061
- def gate (
1062
- self ,
1063
- G ,
1064
- where ,
1065
- contract = False ,
1066
- tags = None ,
1067
- propagate_tags = False ,
1068
- info = None ,
1069
- inplace = False ,
1070
- ** compress_opts ,
1071
- ):
1072
- r"""Apply a gate to this vector tensor network at sites ``where``. This
1073
- is essentially a wrapper around
1074
- :meth:`~quimb.tensor.tensor_core.TensorNetwork.gate_inds` apart from
1075
- ``where`` can be specified as a list of sites, and tags can be
1076
- optionally, intelligently propagated to the new gate tensor.
1077
-
1078
- .. math::
1079
-
1080
- | \psi \rangle \rightarrow G_\mathrm{where} | \psi \rangle
1081
-
1082
- Parameters
1083
- ----------
1084
- G : array_ike
1085
- The gate array to apply, should match or be factorable into the
1086
- shape ``(*phys_dims, *phys_dims)``.
1087
- where : node or sequence[node]
1088
- The sites to apply the gate to.
1089
- contract : {False, True, 'split', 'reduce-split', 'split-gate',
1090
- 'swap-split-gate', 'auto-split-gate'}, optional
1091
- How to apply the gate, see
1092
- :meth:`~quimb.tensor.tensor_core.TensorNetwork.gate_inds`.
1093
- tags : str or sequence of str, optional
1094
- Tags to add to the new gate tensor.
1095
- propagate_tags : {False, True, 'register', 'sites'}, optional
1096
- Whether to propagate tags to the new gate tensor::
1097
-
1098
- - False: no tags are propagated
1099
- - True: all tags are propagated
1100
- - 'register': only site tags corresponding to ``where`` are
1101
- added.
1102
- - 'sites': all site tags on the current sites are propgated,
1103
- resulting in a lightcone like tagging.
1104
-
1105
- info : None or dict, optional
1106
- Used to store extra optional information such as the singular
1107
- values if not absorbed.
1108
- inplace : bool, optional
1109
- Whether to perform the gate operation inplace on the tensor network
1110
- or not.
1111
- compress_opts
1112
- Supplied to :func:`~quimb.tensor.tensor_core.tensor_split` for any
1113
- ``contract`` methods that involve splitting. Ignored otherwise.
1114
-
1115
- Returns
1116
- -------
1117
- TensorNetworkGenVector
1118
-
1119
- See Also
1120
- --------
1121
- TensorNetwork.gate_inds
1122
- """
1123
- check_opt ("propagate_tags" , propagate_tags , _VALID_GATE_PROPAGATE )
1124
-
1125
- tn = self if inplace else self .copy ()
1126
-
1127
- if not isinstance (where , (tuple , list )):
1128
- where = (where ,)
1129
- inds = tuple (map (tn .site_ind , where ))
1130
-
1131
- # potentially add tags from current tensors to the new ones,
1132
- # only do this if we are lazily adding the gate tensor(s)
1133
- if (contract in _LAZY_GATE_CONTRACT ) and (
1134
- propagate_tags in (True , "sites" )
1135
- ):
1136
- old_tags = oset .union (* (t .tags for t in tn ._inds_get (* inds )))
1137
- if propagate_tags == "sites" :
1138
- old_tags = tn .filter_valid_site_tags (old_tags )
1139
-
1140
- tags = tags_to_oset (tags )
1141
- tags .update (old_tags )
1142
-
1143
- # perform the actual gating
1144
- tn .gate_inds_ (
1145
- G , inds , contract = contract , tags = tags , info = info , ** compress_opts
1146
- )
1147
-
1148
- # possibly add tags based on where the gate was applied
1149
- if propagate_tags == "register" :
1150
- for ix , site in zip (inds , where ):
1151
- (t ,) = tn ._inds_get (ix )
1152
- t .add_tag (tn .site_tag (site ))
1153
-
1154
- return tn
1155
-
1156
- gate_ = functools .partialmethod (gate , inplace = True )
1167
+ gate = tensor_network_ag_gate
1168
+ gate_ = functools .partialmethod (tensor_network_ag_gate , inplace = True )
1157
1169
1158
1170
def gate_simple_ (
1159
1171
self ,
@@ -3032,6 +3044,19 @@ def phys_dim(self, site=None, which="upper"):
3032
3044
if which == "lower" :
3033
3045
return self [site ].ind_size (self .lower_ind (site ))
3034
3046
3047
+ gate_upper = functools .partialmethod (tensor_network_ag_gate , which = "upper" )
3048
+ gate_upper_ = functools .partialmethod (
3049
+ tensor_network_ag_gate ,
3050
+ which = "upper" ,
3051
+ inplace = True ,
3052
+ )
3053
+ gate_lower = functools .partialmethod (tensor_network_ag_gate , which = "lower" )
3054
+ gate_lower_ = functools .partialmethod (
3055
+ tensor_network_ag_gate ,
3056
+ which = "lower" ,
3057
+ inplace = True ,
3058
+ )
3059
+
3035
3060
def gate_upper_with_op_lazy (
3036
3061
self ,
3037
3062
A ,
0 commit comments