Skip to content

Commit 40c7df4

Browse files
committed
gloops, sloops + bp updates
1 parent 76d0148 commit 40c7df4

File tree

11 files changed

+1246
-362
lines changed

11 files changed

+1246
-362
lines changed

quimb/experimental/belief_propagation/hd1gbp.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,10 @@ def iterate(self, tol=5e-6):
173173
# check for convergence
174174
try:
175175
m_old = self.messages[parent, child]
176-
mdiff = self._distance_fn(m.data, m_old.data)
176+
# XXX: need to handle index alignment here to compare
177+
# using _distance_fn:
178+
# mdiff = self._distance_fn(m_old.data, m.data)
179+
mdiff = (m_old - m).norm()
177180
except KeyError:
178181
mdiff = 1.0
179182
max_mdiff = max(mdiff, max_mdiff)
@@ -186,14 +189,17 @@ def iterate(self, tol=5e-6):
186189
# note that the raw, undamped `new_messages` are used in the
187190
# denominator of the message computations, and so kept 'as is'
188191
for pair in self.new_messages:
189-
if self.damping == 0.0 or pair not in self.messages:
192+
if pair not in self.messages:
193+
# no old message yet
190194
self.messages[pair] = self.new_messages[pair]
191195
else:
192-
self.messages[pair] = self.fn_damping(
196+
self.messages[pair] = self._damping_fn(
193197
self.messages[pair],
194198
self.new_messages[pair],
195199
)
196200

201+
# self.new_messages.clear()
202+
197203
return {
198204
"nconv": nconv,
199205
"ncheck": ncheck,

quimb/tensor/belief_propagation/bp_common.py

+49-4
Original file line numberDiff line numberDiff line change
@@ -129,14 +129,21 @@ def damping(self):
129129
@damping.setter
130130
def damping(self, damping):
131131
if callable(damping):
132-
self.fn_damping = self._damping = damping
132+
self._damping_fn = self._damping = damping
133133
else:
134134
self._damping = damping
135135

136-
def fn_damping(old, new):
137-
return damping * old + (1 - damping) * new
136+
if damping == 0.0:
138137

139-
self.fn_damping = fn_damping
138+
def _damping_fn(old, new):
139+
return new
140+
141+
else:
142+
143+
def _damping_fn(old, new):
144+
return damping * old + (1 - damping) * new
145+
146+
self._damping_fn = _damping_fn
140147

141148
@property
142149
def normalize(self):
@@ -416,6 +423,29 @@ def mdiff(self):
416423
except IndexError:
417424
return float("nan")
418425

426+
def iterate(self, tol=1e-6):
427+
"""Perform a single iteration of belief propagation. Subclasses should
428+
implement this method, returning either `max_mdiff` or a dictionary
429+
containing `max_mdiff` and any other relevant information:
430+
431+
{
432+
"nconv": nconv,
433+
"ncheck": ncheck,
434+
"max_mdiff": max_mdiff,
435+
}
436+
437+
"""
438+
raise NotImplementedError
439+
440+
def contract(
441+
self,
442+
strip_exponent=False,
443+
check_zero=True,
444+
**kwargs,
445+
):
446+
"""Contract the tensor network and return the resulting value."""
447+
raise NotImplementedError
448+
419449
def __repr__(self):
420450
return f"{self.__class__.__name__}(n={self.n}, mdiff={self.mdiff:.3g})"
421451

@@ -761,3 +791,18 @@ def create_lazy_community_edge_map(tn, site_tags=None, rank_simplify=True):
761791
pass
762792

763793
return edges, neighbors, local_tns, touch_map
794+
795+
796+
def auto_add_indices(tn, regions):
797+
"""Make sure all indices incident to any tensor in each region are
798+
included in the region.
799+
"""
800+
new_regions = []
801+
for r in regions:
802+
new_r = set(r)
803+
tids = [x for x in new_r if isinstance(x, int)]
804+
for tid in tids:
805+
t = tn.tensor_map[tid]
806+
new_r.update(t.inds)
807+
new_regions.append(frozenset(new_r))
808+
return new_regions

quimb/tensor/belief_propagation/d1bp.py

+12-14
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def _update_m(key, new_m):
180180
mdiff = self._distance_fn(old_m, new_m)
181181

182182
if self.damping:
183-
new_m = self.fn_damping(old_m, new_m)
183+
new_m = self._damping_fn(old_m, new_m)
184184

185185
# # post-damp distance
186186
# mdiff = self._distance_fn(old_m, new_m)
@@ -341,7 +341,7 @@ def contract(
341341

342342
return combine_local_contractions(
343343
zvals,
344-
self.backend,
344+
backend=self.backend,
345345
strip_exponent=strip_exponent,
346346
check_zero=check_zero,
347347
mantissa=self.sign,
@@ -406,9 +406,9 @@ def contract_with_loops(
406406
exponent=self.exponent,
407407
)
408408

409-
def contract_cluster_expansion(
409+
def contract_gloop_expand(
410410
self,
411-
clusters=None,
411+
gloops=None,
412412
autocomplete=True,
413413
strip_exponent=False,
414414
check_zero=True,
@@ -417,20 +417,18 @@ def contract_cluster_expansion(
417417
):
418418
from .regions import RegionGraph
419419

420-
if isinstance(clusters, int):
421-
max_cluster_size = clusters
422-
clusters = None
420+
if isinstance(gloops, int):
421+
max_size = gloops
422+
gloops = None
423423
else:
424-
max_cluster_size = None
424+
max_size = None
425425

426-
if clusters is None:
427-
clusters = tuple(
428-
self.tn.gen_regions(max_region_size=max_cluster_size)
429-
)
426+
if gloops is None:
427+
gloops = tuple(self.tn.gen_gloops(max_size=max_size))
430428
else:
431-
clusters = tuple(clusters)
429+
gloops = tuple(gloops)
432430

433-
rg = RegionGraph(clusters, autocomplete=autocomplete)
431+
rg = RegionGraph(gloops, autocomplete=autocomplete)
434432

435433
zvals = []
436434
for r in rg.regions:

quimb/tensor/belief_propagation/d2bp.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def _update_m(key, new_m):
227227
mdiff = self._distance_fn(old_m, new_m)
228228

229229
if self.damping:
230-
new_m = self.fn_damping(old_m, new_m)
230+
new_m = self._damping_fn(old_m, new_m)
231231

232232
# # post-damp distance
233233
# mdiff = self._distance_fn(old_m, new_m)
@@ -381,9 +381,9 @@ def contract(
381381
check_zero=check_zero,
382382
)
383383

384-
def contract_cluster_expansion(
384+
def contract_gloop_expand(
385385
self,
386-
clusters=None,
386+
gloops=None,
387387
autocomplete=True,
388388
optimize="auto-hq",
389389
strip_exponent=False,
@@ -394,20 +394,20 @@ def contract_cluster_expansion(
394394
):
395395
self.normalize_message_pairs()
396396

397-
if isinstance(clusters, int):
398-
max_cluster_size = clusters
399-
clusters = None
397+
if isinstance(gloops, int):
398+
max_size = gloops
399+
gloops = None
400400
else:
401-
max_cluster_size = None
401+
max_size = None
402402

403-
if clusters is None:
404-
clusters = tuple(
405-
self.tn.gen_regions(max_region_size=max_cluster_size)
403+
if gloops is None:
404+
gloops = tuple(
405+
self.tn.gen_gloops(max_size=max_size)
406406
)
407407
else:
408-
clusters = tuple(clusters)
408+
gloops = tuple(gloops)
409409

410-
rg = RegionGraph(clusters, autocomplete=autocomplete)
410+
rg = RegionGraph(gloops, autocomplete=autocomplete)
411411

412412
for tid in self.tn.tensor_map:
413413
rg.add_region([tid])

quimb/tensor/belief_propagation/hd1bp.py

+98-1
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@
1010
"""
1111

1212
import autoray as ar
13+
1314
import quimb.tensor as qtn
1415

1516
from .bp_common import (
1617
BeliefPropagationCommon,
18+
combine_local_contractions,
1719
compute_all_index_marginals_from_messages,
1820
contract_hyper_messages,
1921
initialize_hyper_messages,
@@ -266,7 +268,7 @@ def _normalize_and_insert(key, new_m, max_mdiff):
266268
mdiff = self._distance_fn(old_m, new_m)
267269

268270
if self.damping:
269-
new_m = self.fn_damping(old_m, new_m)
271+
new_m = self._damping_fn(old_m, new_m)
270272

271273
# # post-damp distance
272274
# mdiff = self._distance_fn(old_m, new_m)
@@ -341,6 +343,101 @@ def contract(self, strip_exponent=False, check_zero=True):
341343
backend=self.backend,
342344
)
343345

346+
def normalize_messages(self):
347+
"""Normalize all messages such that the 'region contraction' of a
348+
single hyper index is 1.
349+
"""
350+
for ind, tids in self.tn.ind_map.items():
351+
ms = [self.messages[tid, ind] for tid in tids]
352+
overlap = qtn.array_contract(ms, [(0,) for _ in ms], [])
353+
overlap **= 1 / len(ms)
354+
for tid, m in zip(tids, ms):
355+
self.messages[tid, ind] = m / overlap
356+
357+
def get_cluster(self, r, virtual=True, autocomplete=True):
358+
"""Get the tensor network of a region ``r``, with all boundary
359+
messages attached.
360+
361+
Parameters
362+
----------
363+
r : sequence of int or str
364+
The region to get, given as a sequence of indices or tensor ids.
365+
virtual : bool, optional
366+
Whether the view the original tensors (`virtual=True`, the default)
367+
or take copies (`virtual=False`).
368+
autocomplete : bool, optional
369+
Whether to automatically include all indices attached to the
370+
tensors in the region, or just the ones given in ``r``.
371+
372+
Returns
373+
-------
374+
TensorNetwork
375+
"""
376+
rtids = set()
377+
rinds = set()
378+
for x in r:
379+
if isinstance(x, int):
380+
rtids.add(x)
381+
if autocomplete:
382+
rinds.update(self.tn.tensor_map[x].inds)
383+
else:
384+
rinds.add(x)
385+
386+
tnr = self.tn._select_tids(rtids, virtual=virtual)
387+
for ind in rinds:
388+
# attach all messages coming from tensors outside the cluster
389+
for ntid in self.tn.ind_map[ind]:
390+
if ntid not in rtids:
391+
tnr |= qtn.Tensor(
392+
data=self.messages[ntid, ind], inds=(ind,)
393+
)
394+
395+
return tnr
396+
397+
def contract_gloop_expand(
398+
self,
399+
gloops=None,
400+
strip_exponent=False,
401+
check_zero=True,
402+
optimize="auto-hq",
403+
progbar=False,
404+
**contract_otps,
405+
):
406+
from .regions import RegionGraph
407+
408+
# if we normalized messages we can ignore all index-only regions
409+
self.normalize_messages()
410+
411+
if isinstance(gloops, int):
412+
gloops = tuple(self.tn.gen_gloops(gloops))
413+
414+
rg = RegionGraph(gloops)
415+
416+
if progbar:
417+
import tqdm
418+
419+
regions = tqdm.tqdm(rg.regions)
420+
else:
421+
regions = rg.regions
422+
423+
zvals = []
424+
for r in regions:
425+
# XXX: autoreduce intersecting clusters to gloops?
426+
cr = rg.get_count(r)
427+
# either we autocomplete above or we do it here per region
428+
tnr = self.get_cluster(r, autocomplete=True)
429+
zr = tnr.contract(
430+
output_inds=(), optimize=optimize, **contract_otps
431+
)
432+
zvals.append((zr, cr))
433+
434+
return combine_local_contractions(
435+
zvals,
436+
strip_exponent=strip_exponent,
437+
check_zero=check_zero,
438+
backend=self.backend,
439+
)
440+
344441

345442
def contract_hd1bp(
346443
tn,

quimb/tensor/belief_propagation/l1bp.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def _update_m(key, data):
178178
mdiff = self._distance_fn(data, tm.data)
179179

180180
if self.damping:
181-
data = self.fn_damping(data, tm.data)
181+
data = self._damping_fn(data, tm.data)
182182

183183
# # post-damp distance
184184
# mdiff = self._distance_fn(data, tm.data)

quimb/tensor/belief_propagation/l2bp.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def _update_m(key, data):
246246
mdiff = self._distance_fn(data, tm.data)
247247

248248
if self.damping:
249-
data = self.fn_damping(data, tm.data)
249+
data = self._damping_fn(data, tm.data)
250250

251251
# # post-damp distance
252252
# mdiff = self._distance_fn(data, tm.data)

0 commit comments

Comments
 (0)