Skip to content

Commit 6ccb807

Browse files
committed
BP: add some loop functions
1 parent a2a7805 commit 6ccb807

File tree

3 files changed

+711
-48
lines changed

3 files changed

+711
-48
lines changed

quimb/tensor/belief_propagation/bp_common.py

+43
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import functools
2+
import math
23
import operator
34

45
import autoray as ar
@@ -806,3 +807,45 @@ def auto_add_indices(tn, regions):
806807
new_r.update(t.inds)
807808
new_regions.append(frozenset(new_r))
808809
return new_regions
810+
811+
812+
def process_loop_series_expansion_weights(
813+
weights,
814+
mantissa=1.0,
815+
exponent=0.0,
816+
multi_excitation_correct=True,
817+
maxiter_correction=100,
818+
tol_correction=1e-14,
819+
strip_exponent=False,
820+
return_all=False,
821+
):
822+
"""Assuming a normalized BP fixed point, take a series of loop weights, and
823+
iteratively compute the free energy by requiring self-consistency with
824+
exponential suppression factors. See https://arxiv.org/abs/2409.03108.
825+
"""
826+
# this is the single exictation approximation
827+
f_uncorrected = -sum(weights.values())
828+
829+
if multi_excitation_correct:
830+
# iteratively compute a self consistent free energy
831+
fold = float("inf")
832+
f = f_uncorrected
833+
for _ in range(maxiter_correction):
834+
f = -sum(
835+
wl * math.exp(len(gloop) * f) for gloop, wl in weights.items()
836+
)
837+
if abs(f - fold) < tol_correction:
838+
break
839+
fold = f
840+
else:
841+
f = f_uncorrected
842+
843+
if return_all:
844+
return {gloop: math.exp(len(gloop) * f) for gloop in weights}
845+
846+
mantissa = mantissa * (1 - f)
847+
848+
if strip_exponent:
849+
return mantissa, exponent
850+
851+
return mantissa * 10**exponent

quimb/tensor/belief_propagation/d1bp.py

+108
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
BeliefPropagationCommon,
1919
combine_local_contractions,
2020
normalize_message_pair,
21+
process_loop_series_expansion_weights,
2122
)
2223
from .hd1bp import (
2324
compute_all_tensor_messages_tree,
@@ -303,6 +304,98 @@ def get_cluster(self, tids):
303304
t.vector_reduce_(ix, self.messages[ix, tid])
304305
return tnr
305306

307+
def get_cluster_excited(self, tids):
308+
"""Get the local tensor network for ``tids`` with BP messages inserted
309+
on the boundary and excitation projectors inserted on the inner bonds.
310+
See https://arxiv.org/abs/2409.03108 for more details.
311+
"""
312+
stn = self.tn._select_tids(tids, virtual=False)
313+
314+
for ix, tids in tuple(stn.ind_map.items()):
315+
if ix in stn._inner_inds:
316+
# insert inner excitation projector
317+
tidl, tidr = tids
318+
ml = self.messages[ix, tidl]
319+
mr = self.messages[ix, tidr]
320+
# form outer product
321+
p0 = ar.do("einsum", "i,j->ij", ml, mr)
322+
# subtract from identity
323+
pe = ar.do("eye", ar.do("shape", p0)[0]) - p0
324+
# absorb into one of tensors
325+
stn.tensor_map[tidr].gate_(pe, ix)
326+
else:
327+
# insert boundary message
328+
tid, = tids
329+
m = self.messages[ix, tid]
330+
t = stn.tensor_map[tid]
331+
t.vector_reduce_(ix, m)
332+
333+
return stn
334+
335+
def contract_loop_series_expansion(
336+
self,
337+
gloops=None,
338+
multi_excitation_correct=True,
339+
tol_correction=1e-12,
340+
maxiter_correction=100,
341+
strip_exponent=False,
342+
optimize="auto-hq",
343+
**contract_opts,
344+
):
345+
"""Contract the tensor network using the same procedure as
346+
in https://arxiv.org/abs/2409.03108 - "Loop Series Expansions for
347+
Tensor Networks".
348+
349+
Parameters
350+
----------
351+
gloops : int or iterable of tuples, optional
352+
The gloop sizes to use. If an integer, then generate all gloop
353+
sizes up to this size. If a tuple, then use these gloops.
354+
multi_excitation_correct : bool, optional
355+
Whether to use the multi-excitation correction. If ``True``, then
356+
the free energy is refined iteratively until self consistent.
357+
tol_correction : float, optional
358+
The tolerance for the multi-excitation correction.
359+
maxiter_correction : int, optional
360+
The maximum number of iterations for the multi-excitation
361+
correction.
362+
strip_exponent : bool, optional
363+
Whether to strip the exponent from the final result. If ``True``
364+
then the returned result is ``(mantissa, exponent)``.
365+
optimize : str or PathOptimizer, optional
366+
The path optimizer to use when contracting the messages.
367+
contract_opts
368+
Other options supplied to ``TensorNetwork.contract``.
369+
"""
370+
self.normalize_message_pairs()
371+
# accrues BP estimate into self.sign and self.exponent
372+
self.normalize_tensors()
373+
374+
if isinstance(gloops, int):
375+
gloops = tuple(self.tn.gen_gloops(max_size=gloops))
376+
else:
377+
gloops = tuple(gloops)
378+
379+
weights = {}
380+
for gloop in gloops:
381+
# get local tensor network with boundary
382+
# messages and exctiation projectors
383+
etn = self.get_cluster_excited(gloop)
384+
# contract it to get local weight!
385+
weights[tuple(gloop)] = etn.contract(
386+
optimize=optimize, **contract_opts
387+
)
388+
389+
return process_loop_series_expansion_weights(
390+
weights,
391+
mantissa=self.sign,
392+
exponent=self.exponent,
393+
multi_excitation_correct=multi_excitation_correct,
394+
tol_correction=tol_correction,
395+
maxiter_correction=maxiter_correction,
396+
strip_exponent=strip_exponent,
397+
)
398+
306399
def local_tensor_contract(self, tid):
307400
"""Contract the messages around tensor ``tid``."""
308401
t = self.tn.tensor_map[tid]
@@ -413,8 +506,12 @@ def contract_gloop_expand(
413506
strip_exponent=False,
414507
check_zero=True,
415508
optimize="auto-hq",
509+
combine="prod",
416510
**contract_opts,
417511
):
512+
"""Contract the tensor network using generalized loop cluster
513+
expansion.
514+
"""
418515
from .regions import RegionGraph
419516

420517
if isinstance(gloops, int):
@@ -428,6 +525,11 @@ def contract_gloop_expand(
428525
else:
429526
gloops = tuple(gloops)
430527

528+
if combine == "sum":
529+
# make sure each contraction has the same BP-scaled environment
530+
self.normalize_message_pairs()
531+
self.normalize_tensors()
532+
431533
rg = RegionGraph(gloops, autocomplete=autocomplete)
432534

433535
zvals = []
@@ -438,6 +540,12 @@ def contract_gloop_expand(
438540

439541
zvals.append((zr, c))
440542

543+
if combine == "sum":
544+
mantissa = self.sign * sum(zr * cr for zr, cr in zvals)
545+
if strip_exponent:
546+
return mantissa, self.exponent
547+
return mantissa * 10 ** self.exponent
548+
441549
return combine_local_contractions(
442550
zvals,
443551
backend=self.backend,

0 commit comments

Comments
 (0)