18
18
BeliefPropagationCommon ,
19
19
combine_local_contractions ,
20
20
normalize_message_pair ,
21
+ process_loop_series_expansion_weights ,
21
22
)
22
23
from .hd1bp import (
23
24
compute_all_tensor_messages_tree ,
@@ -303,6 +304,98 @@ def get_cluster(self, tids):
303
304
t .vector_reduce_ (ix , self .messages [ix , tid ])
304
305
return tnr
305
306
307
+ def get_cluster_excited (self , tids ):
308
+ """Get the local tensor network for ``tids`` with BP messages inserted
309
+ on the boundary and excitation projectors inserted on the inner bonds.
310
+ See https://arxiv.org/abs/2409.03108 for more details.
311
+ """
312
+ stn = self .tn ._select_tids (tids , virtual = False )
313
+
314
+ for ix , tids in tuple (stn .ind_map .items ()):
315
+ if ix in stn ._inner_inds :
316
+ # insert inner excitation projector
317
+ tidl , tidr = tids
318
+ ml = self .messages [ix , tidl ]
319
+ mr = self .messages [ix , tidr ]
320
+ # form outer product
321
+ p0 = ar .do ("einsum" , "i,j->ij" , ml , mr )
322
+ # subtract from identity
323
+ pe = ar .do ("eye" , ar .do ("shape" , p0 )[0 ]) - p0
324
+ # absorb into one of tensors
325
+ stn .tensor_map [tidr ].gate_ (pe , ix )
326
+ else :
327
+ # insert boundary message
328
+ tid , = tids
329
+ m = self .messages [ix , tid ]
330
+ t = stn .tensor_map [tid ]
331
+ t .vector_reduce_ (ix , m )
332
+
333
+ return stn
334
+
335
+ def contract_loop_series_expansion (
336
+ self ,
337
+ gloops = None ,
338
+ multi_excitation_correct = True ,
339
+ tol_correction = 1e-12 ,
340
+ maxiter_correction = 100 ,
341
+ strip_exponent = False ,
342
+ optimize = "auto-hq" ,
343
+ ** contract_opts ,
344
+ ):
345
+ """Contract the tensor network using the same procedure as
346
+ in https://arxiv.org/abs/2409.03108 - "Loop Series Expansions for
347
+ Tensor Networks".
348
+
349
+ Parameters
350
+ ----------
351
+ gloops : int or iterable of tuples, optional
352
+ The gloop sizes to use. If an integer, then generate all gloop
353
+ sizes up to this size. If a tuple, then use these gloops.
354
+ multi_excitation_correct : bool, optional
355
+ Whether to use the multi-excitation correction. If ``True``, then
356
+ the free energy is refined iteratively until self consistent.
357
+ tol_correction : float, optional
358
+ The tolerance for the multi-excitation correction.
359
+ maxiter_correction : int, optional
360
+ The maximum number of iterations for the multi-excitation
361
+ correction.
362
+ strip_exponent : bool, optional
363
+ Whether to strip the exponent from the final result. If ``True``
364
+ then the returned result is ``(mantissa, exponent)``.
365
+ optimize : str or PathOptimizer, optional
366
+ The path optimizer to use when contracting the messages.
367
+ contract_opts
368
+ Other options supplied to ``TensorNetwork.contract``.
369
+ """
370
+ self .normalize_message_pairs ()
371
+ # accrues BP estimate into self.sign and self.exponent
372
+ self .normalize_tensors ()
373
+
374
+ if isinstance (gloops , int ):
375
+ gloops = tuple (self .tn .gen_gloops (max_size = gloops ))
376
+ else :
377
+ gloops = tuple (gloops )
378
+
379
+ weights = {}
380
+ for gloop in gloops :
381
+ # get local tensor network with boundary
382
+ # messages and exctiation projectors
383
+ etn = self .get_cluster_excited (gloop )
384
+ # contract it to get local weight!
385
+ weights [tuple (gloop )] = etn .contract (
386
+ optimize = optimize , ** contract_opts
387
+ )
388
+
389
+ return process_loop_series_expansion_weights (
390
+ weights ,
391
+ mantissa = self .sign ,
392
+ exponent = self .exponent ,
393
+ multi_excitation_correct = multi_excitation_correct ,
394
+ tol_correction = tol_correction ,
395
+ maxiter_correction = maxiter_correction ,
396
+ strip_exponent = strip_exponent ,
397
+ )
398
+
306
399
def local_tensor_contract (self , tid ):
307
400
"""Contract the messages around tensor ``tid``."""
308
401
t = self .tn .tensor_map [tid ]
@@ -413,8 +506,12 @@ def contract_gloop_expand(
413
506
strip_exponent = False ,
414
507
check_zero = True ,
415
508
optimize = "auto-hq" ,
509
+ combine = "prod" ,
416
510
** contract_opts ,
417
511
):
512
+ """Contract the tensor network using generalized loop cluster
513
+ expansion.
514
+ """
418
515
from .regions import RegionGraph
419
516
420
517
if isinstance (gloops , int ):
@@ -428,6 +525,11 @@ def contract_gloop_expand(
428
525
else :
429
526
gloops = tuple (gloops )
430
527
528
+ if combine == "sum" :
529
+ # make sure each contraction has the same BP-scaled environment
530
+ self .normalize_message_pairs ()
531
+ self .normalize_tensors ()
532
+
431
533
rg = RegionGraph (gloops , autocomplete = autocomplete )
432
534
433
535
zvals = []
@@ -438,6 +540,12 @@ def contract_gloop_expand(
438
540
439
541
zvals .append ((zr , c ))
440
542
543
+ if combine == "sum" :
544
+ mantissa = self .sign * sum (zr * cr for zr , cr in zvals )
545
+ if strip_exponent :
546
+ return mantissa , self .exponent
547
+ return mantissa * 10 ** self .exponent
548
+
441
549
return combine_local_contractions (
442
550
zvals ,
443
551
backend = self .backend ,
0 commit comments