Skip to content

Commit 0ab3a14

Browse files
authored
fix #1343 device handling in mog_log_prob (#1356)
* fix device in mog_log_prob, add test, fix tests. * add cpu test for multiround mdn
1 parent a6a220d commit 0ab3a14

File tree

2 files changed

+36
-5
lines changed

2 files changed

+36
-5
lines changed

sbi/utils/sbiutils.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -841,7 +841,9 @@ def mog_log_prob(
841841

842842
# Split up evaluation into parts.
843843
weights = logits_pp - torch.logsumexp(logits_pp, dim=-1, keepdim=True)
844-
constant = -(output_dim / 2.0) * torch.log(torch.tensor([2 * pi]))
844+
constant = -(output_dim / 2.0) * torch.log(
845+
torch.tensor([2 * pi], device=theta.device)
846+
)
845847
log_det = 0.5 * torch.log(torch.det(precisions_pp))
846848
theta_minus_mean = theta.expand_as(means_pp) - means_pp
847849
exponent = -0.5 * batched_mixture_vmv(precisions_pp, theta_minus_mean)

tests/inference_on_device_test.py

+33-4
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from __future__ import annotations
55

6-
from typing import Tuple
6+
from typing import Tuple, Union
77

88
import pytest
99
import torch
@@ -127,7 +127,7 @@ def simulator(theta):
127127
model=model, num_transforms=2, dtype=torch.float32
128128
)
129129
)
130-
train_kwargs = dict(force_first_round_loss=True)
130+
train_kwargs = dict()
131131
elif method == NLE:
132132
kwargs = dict(
133133
density_estimator=likelihood_nn(
@@ -152,9 +152,12 @@ def simulator(theta):
152152
x = simulator(theta).to(data_device)
153153
theta = theta.to(data_device)
154154

155-
estimator = inferer.append_simulations(theta, x, data_device=data_device).train(
156-
training_batch_size=100, max_num_epochs=max_num_epochs, **train_kwargs
155+
data_kwargs = (
156+
dict(proposal=proposals[-1]) if method in [NPE_A, NPE_C] else dict()
157157
)
158+
estimator = inferer.append_simulations(
159+
theta, x, data_device=data_device, **data_kwargs
160+
).train(max_num_epochs=max_num_epochs, **train_kwargs)
158161

159162
# mcmc cases
160163
if sampling_method in ["slice_np", "slice_np_vectorized", "nuts_pymc"]:
@@ -436,3 +439,29 @@ def test_boxuniform_device_handling(arg_device, device):
436439
low=zeros(1).to(arg_device), high=ones(1).to(arg_device), device=device
437440
)
438441
NPE_C(prior=prior, device=arg_device)
442+
443+
444+
@pytest.mark.gpu
445+
@pytest.mark.parametrize("method", [NPE_A, NPE_C])
446+
@pytest.mark.parametrize("device", ["cpu", "gpu"])
447+
def test_multiround_mdn_training_on_device(method: Union[NPE_A, NPE_C], device: str):
448+
num_dim = 2
449+
num_rounds = 2
450+
num_simulations = 100
451+
device = process_device("gpu")
452+
prior = BoxUniform(-torch.ones(num_dim), torch.ones(num_dim), device=device)
453+
simulator = diagonal_linear_gaussian
454+
455+
estimator = "mdn_snpe_a" if method == NPE_A else "mdn"
456+
457+
trainer = method(prior, density_estimator=estimator, device=device)
458+
459+
theta = prior.sample((num_simulations,))
460+
x = simulator(theta)
461+
462+
proposal = prior
463+
for _ in range(num_rounds):
464+
trainer.append_simulations(theta, x, proposal=proposal).train(max_num_epochs=2)
465+
proposal = trainer.build_posterior().set_default_x(torch.zeros(num_dim))
466+
theta = proposal.sample((num_simulations,))
467+
x = simulator(theta)

0 commit comments

Comments
 (0)