Skip to content

Commit d8d5151

Browse files
authored
Merge pull request #2132 from sdaulton/min_noise
fix custom dtype_value_context setting
2 parents 78fae38 + 2e43196 commit d8d5151

File tree

2 files changed

+41
-7
lines changed

2 files changed

+41
-7
lines changed

gpytorch/settings.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,13 @@ def _set_value(cls, float_value, double_value, half_value):
5656
if half_value is not None:
5757
cls._global_half_value = half_value
5858

59-
def __init__(self, float=None, double=None, half=None):
60-
self._orig_float_value = self.__class__.value()
61-
self._instance_float_value = float
62-
self._orig_double_value = self.__class__.value()
63-
self._instance_double_value = double
64-
self._orig_half_value = self.__class__.value()
65-
self._instance_half_value = half
59+
def __init__(self, float_value=None, double_value=None, half_value=None):
60+
self._orig_float_value = self.__class__.value(torch.float)
61+
self._instance_float_value = float_value if float_value is not None else self._orig_float_value
62+
self._orig_double_value = self.__class__.value(torch.double)
63+
self._instance_double_value = double_value if double_value is not None else self._orig_double_value
64+
self._orig_half_value = self.__class__.value(torch.half)
65+
self._instance_half_value = half_value if half_value is not None else self._orig_half_value
6666

6767
def __enter__(
6868
self,

test/test_settings.py

+34
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import unittest
44

5+
import torch
6+
57
from gpytorch import settings
68
from gpytorch.test.base_test_case import BaseTestCase
79

@@ -16,3 +18,35 @@ def test_feature_flag(self):
1618
with settings.fast_pred_var(False):
1719
self.assertFalse(settings.fast_pred_var.is_default())
1820
self.assertFalse(settings.fast_pred_var.on())
21+
22+
def test_dtype_value_context(self):
23+
# test custom settings
24+
x = torch.zeros(1, dtype=torch.float)
25+
with settings.min_fixed_noise(float_value=0.1, double_value=0.2, half_value=0.3):
26+
self.assertEqual(settings.min_fixed_noise.value(x), 0.1)
27+
self.assertEqual(settings.min_fixed_noise.value(x.double()), 0.2)
28+
self.assertEqual(settings.min_fixed_noise.value(x.half()), 0.3)
29+
# test defaults are restored
30+
self.assertEqual(
31+
settings.min_fixed_noise.value(x),
32+
settings.min_fixed_noise._global_float_value,
33+
)
34+
self.assertEqual(
35+
settings.min_fixed_noise.value(x.double()),
36+
settings.min_fixed_noise._global_double_value,
37+
)
38+
self.assertEqual(
39+
settings.min_fixed_noise.value(x.half()),
40+
settings.min_fixed_noise._global_half_value,
41+
)
42+
# test setting one dtype
43+
with settings.min_fixed_noise(double_value=0.2):
44+
self.assertEqual(settings.min_fixed_noise.value(x.double()), 0.2)
45+
self.assertEqual(
46+
settings.min_fixed_noise.value(x),
47+
settings.min_fixed_noise._global_float_value,
48+
)
49+
self.assertEqual(
50+
settings.min_fixed_noise.value(x.half()),
51+
settings.min_fixed_noise._global_half_value,
52+
)

0 commit comments

Comments
 (0)