diff --git a/gpytorch/settings.py b/gpytorch/settings.py index 01006bf00..653193d09 100644 --- a/gpytorch/settings.py +++ b/gpytorch/settings.py @@ -56,13 +56,13 @@ def _set_value(cls, float_value, double_value, half_value): if half_value is not None: cls._global_half_value = half_value - def __init__(self, float=None, double=None, half=None): - self._orig_float_value = self.__class__.value() - self._instance_float_value = float - self._orig_double_value = self.__class__.value() - self._instance_double_value = double - self._orig_half_value = self.__class__.value() - self._instance_half_value = half + def __init__(self, float_value=None, double_value=None, half_value=None): + self._orig_float_value = self.__class__.value(torch.float) + self._instance_float_value = float_value if float_value is not None else self._orig_float_value + self._orig_double_value = self.__class__.value(torch.double) + self._instance_double_value = double_value if double_value is not None else self._orig_double_value + self._orig_half_value = self.__class__.value(torch.half) + self._instance_half_value = half_value if half_value is not None else self._orig_half_value def __enter__( self, diff --git a/test/test_settings.py b/test/test_settings.py index 65904ccfe..baddda973 100644 --- a/test/test_settings.py +++ b/test/test_settings.py @@ -2,6 +2,8 @@ import unittest +import torch + from gpytorch import settings from gpytorch.test.base_test_case import BaseTestCase @@ -16,3 +18,35 @@ def test_feature_flag(self): with settings.fast_pred_var(False): self.assertFalse(settings.fast_pred_var.is_default()) self.assertFalse(settings.fast_pred_var.on()) + + def test_dtype_value_context(self): + # test custom settings + x = torch.zeros(1, dtype=torch.float) + with settings.min_fixed_noise(float_value=0.1, double_value=0.2, half_value=0.3): + self.assertEqual(settings.min_fixed_noise.value(x), 0.1) + self.assertEqual(settings.min_fixed_noise.value(x.double()), 0.2) + self.assertEqual(settings.min_fixed_noise.value(x.half()), 0.3) + # test defaults are restored + self.assertEqual( + settings.min_fixed_noise.value(x), + settings.min_fixed_noise._global_float_value, + ) + self.assertEqual( + settings.min_fixed_noise.value(x.double()), + settings.min_fixed_noise._global_double_value, + ) + self.assertEqual( + settings.min_fixed_noise.value(x.half()), + settings.min_fixed_noise._global_half_value, + ) + # test setting one dtype + with settings.min_fixed_noise(double_value=0.2): + self.assertEqual(settings.min_fixed_noise.value(x.double()), 0.2) + self.assertEqual( + settings.min_fixed_noise.value(x), + settings.min_fixed_noise._global_float_value, + ) + self.assertEqual( + settings.min_fixed_noise.value(x.half()), + settings.min_fixed_noise._global_half_value, + )