Skip to content

Commit c4a6d89

Browse files
committed
Fix unit tests broken by upstream gpytorch changes
cornellius-gp/gpytorch#2082 changed the shape of the mean constant (and did some other things). These changes fix the resulting unit test breakages.
1 parent 05e4069 commit c4a6d89

File tree

4 files changed

+15
-17
lines changed

4 files changed

+15
-17
lines changed

ax/models/tests/test_alebo.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def testALEBOGP(self):
8989

9090
# Batch
9191
Uvec_b = m.covar_module.base_kernel.Uvec.repeat(5, 1)
92-
mean_b = m.mean_module.constant.repeat(5, 1)
92+
mean_b = m.mean_module.constant.repeat(5)
9393
output_scale_b = m.covar_module.raw_outputscale.repeat(5)
9494
m_b = get_batch_model(
9595
B=B,
@@ -132,7 +132,7 @@ def testALEBOGP(self):
132132
{
133133
"covar_module.base_kernel.Uvec",
134134
"covar_module.raw_outputscale",
135-
"mean_module.constant",
135+
"mean_module.raw_constant",
136136
"covar_module.raw_outputscale_constraint.lower_bound",
137137
"covar_module.raw_outputscale_constraint.upper_bound",
138138
},
@@ -151,7 +151,7 @@ def testALEBOGP(self):
151151
{
152152
"covar_module.base_kernel.Uvec",
153153
"covar_module.raw_outputscale",
154-
"mean_module.constant",
154+
"mean_module.raw_constant",
155155
"covar_module.raw_outputscale_constraint.lower_bound",
156156
"covar_module.raw_outputscale_constraint.upper_bound",
157157
},

ax/models/tests/test_botorch_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ def test_BotorchModel(self, dtype=torch.float, cuda=False):
457457

458458
# Test loading state dict
459459
true_state_dict = {
460-
"mean_module.constant": [3.5004],
460+
"mean_module.raw_constant": 3.5004,
461461
"covar_module.raw_outputscale": 2.2438,
462462
"covar_module.base_kernel.raw_lengthscale": [
463463
[-0.9274, -0.9274, -0.9274]
@@ -489,7 +489,7 @@ def test_BotorchModel(self, dtype=torch.float, cuda=False):
489489
self.assertTrue(torch.equal(true_state_dict[k], v))
490490

491491
# Test for some change in model parameters & buffer for refit_model=True
492-
true_state_dict["mean_module.constant"] += 0.1
492+
true_state_dict["mean_module.raw_constant"] += 0.1
493493
true_state_dict["covar_module.raw_outputscale"] += 0.1
494494
true_state_dict["covar_module.base_kernel.raw_lengthscale"] += 0.1
495495
model = get_and_fit_model(

ax/models/tests/test_fully_bayesian.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def test_FullyBayesianBotorchModel(self, dtype=torch.float, cuda=False):
348348
self.assertEqual(len(model.model.models), 2)
349349
m1, m2 = model.model.models[0], model.model.models[1]
350350
# Mean
351-
self.assertEqual(m1.mean_module.constant.shape, (4, 1))
351+
self.assertEqual(m1.mean_module.constant.shape, (4,))
352352
self.assertFalse(
353353
torch.isclose(
354354
m1.mean_module.constant, m2.mean_module.constant
@@ -376,11 +376,9 @@ def test_FullyBayesianBotorchModel(self, dtype=torch.float, cuda=False):
376376
device = torch.device("cuda") if cuda else torch.device("cpu")
377377
objective_weights = torch.tensor([1.0, 0.0], dtype=dtype, device=device)
378378
objective_transform = get_objective_weights_transform(objective_weights)
379-
infeasible_cost = torch.tensor(
380-
get_infeasible_cost(
379+
infeasible_cost = get_infeasible_cost(
381380
X=Xs1[0], model=model.model, objective=objective_transform
382-
)
383-
)
381+
).detach().clone()
384382
expected_infeasible_cost = -1 * torch.min(
385383
objective_transform(
386384
model.model.posterior(Xs1[0]).mean
@@ -841,7 +839,7 @@ def test_FullyBayesianBotorchModelPyro(self, dtype=torch.double, cuda=False):
841839
)
842840
self.assertEqual(
843841
m.mean_module.constant.shape,
844-
torch.Size([4, 1]),
842+
torch.Size([4]),
845843
)
846844
if use_input_warping:
847845
self.assertTrue(hasattr(m, "input_transform"))

ax/models/torch/alebo.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def f(x):
331331

332332
# Sample only Uvec; leave mean and output scale fixed.
333333
assert list(property_dict.keys()) == [
334-
"model.mean_module.constant",
334+
"model.mean_module.raw_constant",
335335
"model.covar_module.raw_outputscale",
336336
"model.covar_module.base_kernel.Uvec",
337337
]
@@ -347,7 +347,7 @@ def f(x):
347347
nsamp, *attrs.shape
348348
)
349349
# Get the other properties into batch mode
350-
mean_constant_batch = mll.model.mean_module.constant.repeat(nsamp, 1)
350+
mean_constant_batch = mll.model.mean_module.constant.repeat(nsamp)
351351
output_scale_batch = mll.model.covar_module.raw_outputscale.repeat(nsamp)
352352
return Uvec_batch, mean_constant_batch, output_scale_batch
353353

@@ -383,10 +383,10 @@ def get_batch_model(
383383
)
384384
m_b.train()
385385
# Set mean constant
386-
# pyre-fixme[16]: `Optional` has no attribute `constant`.
387-
m_b.mean_module.constant.requires_grad_(False)
388-
m_b.mean_module.constant.copy_(mean_constant_batch)
389-
m_b.mean_module.constant.requires_grad_(True)
386+
# pyre-fixme[16]: `Optional` has no attribute `raw_constant`.
387+
m_b.mean_module.raw_constant.requires_grad_(False)
388+
m_b.mean_module.raw_constant.copy_(mean_constant_batch)
389+
m_b.mean_module.raw_constant.requires_grad_(True)
390390
# Set output scale
391391
m_b.covar_module.raw_outputscale.requires_grad_(False)
392392
m_b.covar_module.raw_outputscale.copy_(output_scale_batch)

0 commit comments

Comments
 (0)