3
3
4
4
from __future__ import annotations
5
5
6
- from typing import Tuple
6
+ from typing import Tuple , Union
7
7
8
8
import pytest
9
9
import torch
@@ -127,7 +127,7 @@ def simulator(theta):
127
127
model = model , num_transforms = 2 , dtype = torch .float32
128
128
)
129
129
)
130
- train_kwargs = dict (force_first_round_loss = True )
130
+ train_kwargs = dict ()
131
131
elif method == NLE :
132
132
kwargs = dict (
133
133
density_estimator = likelihood_nn (
@@ -152,9 +152,12 @@ def simulator(theta):
152
152
x = simulator (theta ).to (data_device )
153
153
theta = theta .to (data_device )
154
154
155
- estimator = inferer . append_simulations ( theta , x , data_device = data_device ). train (
156
- training_batch_size = 100 , max_num_epochs = max_num_epochs , ** train_kwargs
155
+ data_kwargs = (
156
+ dict ( proposal = proposals [ - 1 ]) if method in [ NPE_A , NPE_C ] else dict ()
157
157
)
158
+ estimator = inferer .append_simulations (
159
+ theta , x , data_device = data_device , ** data_kwargs
160
+ ).train (max_num_epochs = max_num_epochs , ** train_kwargs )
158
161
159
162
# mcmc cases
160
163
if sampling_method in ["slice_np" , "slice_np_vectorized" , "nuts_pymc" ]:
@@ -436,3 +439,29 @@ def test_boxuniform_device_handling(arg_device, device):
436
439
low = zeros (1 ).to (arg_device ), high = ones (1 ).to (arg_device ), device = device
437
440
)
438
441
NPE_C (prior = prior , device = arg_device )
442
+
443
+
444
+ @pytest .mark .gpu
445
+ @pytest .mark .parametrize ("method" , [NPE_A , NPE_C ])
446
+ @pytest .mark .parametrize ("device" , ["cpu" , "gpu" ])
447
+ def test_multiround_mdn_training_on_device (method : Union [NPE_A , NPE_C ], device : str ):
448
+ num_dim = 2
449
+ num_rounds = 2
450
+ num_simulations = 100
451
+ device = process_device ("gpu" )
452
+ prior = BoxUniform (- torch .ones (num_dim ), torch .ones (num_dim ), device = device )
453
+ simulator = diagonal_linear_gaussian
454
+
455
+ estimator = "mdn_snpe_a" if method == NPE_A else "mdn"
456
+
457
+ trainer = method (prior , density_estimator = estimator , device = device )
458
+
459
+ theta = prior .sample ((num_simulations ,))
460
+ x = simulator (theta )
461
+
462
+ proposal = prior
463
+ for _ in range (num_rounds ):
464
+ trainer .append_simulations (theta , x , proposal = proposal ).train (max_num_epochs = 2 )
465
+ proposal = trainer .build_posterior ().set_default_x (torch .zeros (num_dim ))
466
+ theta = proposal .sample ((num_simulations ,))
467
+ x = simulator (theta )
0 commit comments