|
10 | 10 | """
|
11 | 11 |
|
12 | 12 | import autoray as ar
|
| 13 | + |
13 | 14 | import quimb.tensor as qtn
|
14 | 15 |
|
15 | 16 | from .bp_common import (
|
16 | 17 | BeliefPropagationCommon,
|
| 18 | + combine_local_contractions, |
17 | 19 | compute_all_index_marginals_from_messages,
|
18 | 20 | contract_hyper_messages,
|
19 | 21 | initialize_hyper_messages,
|
@@ -266,7 +268,7 @@ def _normalize_and_insert(key, new_m, max_mdiff):
|
266 | 268 | mdiff = self._distance_fn(old_m, new_m)
|
267 | 269 |
|
268 | 270 | if self.damping:
|
269 |
| - new_m = self.fn_damping(old_m, new_m) |
| 271 | + new_m = self._damping_fn(old_m, new_m) |
270 | 272 |
|
271 | 273 | # # post-damp distance
|
272 | 274 | # mdiff = self._distance_fn(old_m, new_m)
|
@@ -341,6 +343,101 @@ def contract(self, strip_exponent=False, check_zero=True):
|
341 | 343 | backend=self.backend,
|
342 | 344 | )
|
343 | 345 |
|
| 346 | + def normalize_messages(self): |
| 347 | + """Normalize all messages such that the 'region contraction' of a |
| 348 | + single hyper index is 1. |
| 349 | + """ |
| 350 | + for ind, tids in self.tn.ind_map.items(): |
| 351 | + ms = [self.messages[tid, ind] for tid in tids] |
| 352 | + overlap = qtn.array_contract(ms, [(0,) for _ in ms], []) |
| 353 | + overlap **= 1 / len(ms) |
| 354 | + for tid, m in zip(tids, ms): |
| 355 | + self.messages[tid, ind] = m / overlap |
| 356 | + |
| 357 | + def get_cluster(self, r, virtual=True, autocomplete=True): |
| 358 | + """Get the tensor network of a region ``r``, with all boundary |
| 359 | + messages attached. |
| 360 | +
|
| 361 | + Parameters |
| 362 | + ---------- |
| 363 | + r : sequence of int or str |
| 364 | + The region to get, given as a sequence of indices or tensor ids. |
| 365 | + virtual : bool, optional |
| 366 | + Whether the view the original tensors (`virtual=True`, the default) |
| 367 | + or take copies (`virtual=False`). |
| 368 | + autocomplete : bool, optional |
| 369 | + Whether to automatically include all indices attached to the |
| 370 | + tensors in the region, or just the ones given in ``r``. |
| 371 | +
|
| 372 | + Returns |
| 373 | + ------- |
| 374 | + TensorNetwork |
| 375 | + """ |
| 376 | + rtids = set() |
| 377 | + rinds = set() |
| 378 | + for x in r: |
| 379 | + if isinstance(x, int): |
| 380 | + rtids.add(x) |
| 381 | + if autocomplete: |
| 382 | + rinds.update(self.tn.tensor_map[x].inds) |
| 383 | + else: |
| 384 | + rinds.add(x) |
| 385 | + |
| 386 | + tnr = self.tn._select_tids(rtids, virtual=virtual) |
| 387 | + for ind in rinds: |
| 388 | + # attach all messages coming from tensors outside the cluster |
| 389 | + for ntid in self.tn.ind_map[ind]: |
| 390 | + if ntid not in rtids: |
| 391 | + tnr |= qtn.Tensor( |
| 392 | + data=self.messages[ntid, ind], inds=(ind,) |
| 393 | + ) |
| 394 | + |
| 395 | + return tnr |
| 396 | + |
| 397 | + def contract_gloop_expand( |
| 398 | + self, |
| 399 | + gloops=None, |
| 400 | + strip_exponent=False, |
| 401 | + check_zero=True, |
| 402 | + optimize="auto-hq", |
| 403 | + progbar=False, |
| 404 | + **contract_otps, |
| 405 | + ): |
| 406 | + from .regions import RegionGraph |
| 407 | + |
| 408 | + # if we normalized messages we can ignore all index-only regions |
| 409 | + self.normalize_messages() |
| 410 | + |
| 411 | + if isinstance(gloops, int): |
| 412 | + gloops = tuple(self.tn.gen_gloops(gloops)) |
| 413 | + |
| 414 | + rg = RegionGraph(gloops) |
| 415 | + |
| 416 | + if progbar: |
| 417 | + import tqdm |
| 418 | + |
| 419 | + regions = tqdm.tqdm(rg.regions) |
| 420 | + else: |
| 421 | + regions = rg.regions |
| 422 | + |
| 423 | + zvals = [] |
| 424 | + for r in regions: |
| 425 | + # XXX: autoreduce intersecting clusters to gloops? |
| 426 | + cr = rg.get_count(r) |
| 427 | + # either we autocomplete above or we do it here per region |
| 428 | + tnr = self.get_cluster(r, autocomplete=True) |
| 429 | + zr = tnr.contract( |
| 430 | + output_inds=(), optimize=optimize, **contract_otps |
| 431 | + ) |
| 432 | + zvals.append((zr, cr)) |
| 433 | + |
| 434 | + return combine_local_contractions( |
| 435 | + zvals, |
| 436 | + strip_exponent=strip_exponent, |
| 437 | + check_zero=check_zero, |
| 438 | + backend=self.backend, |
| 439 | + ) |
| 440 | + |
344 | 441 |
|
345 | 442 | def contract_hd1bp(
|
346 | 443 | tn,
|
|
0 commit comments