Skip to content

fix custom dtype_value_context setting #2132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions gpytorch/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,13 @@ def _set_value(cls, float_value, double_value, half_value):
if half_value is not None:
cls._global_half_value = half_value

def __init__(self, float=None, double=None, half=None):
self._orig_float_value = self.__class__.value()
self._instance_float_value = float
self._orig_double_value = self.__class__.value()
self._instance_double_value = double
self._orig_half_value = self.__class__.value()
self._instance_half_value = half
def __init__(self, float_value=None, double_value=None, half_value=None):
self._orig_float_value = self.__class__.value(torch.float)
self._instance_float_value = float_value if float_value is not None else self._orig_float_value
self._orig_double_value = self.__class__.value(torch.double)
self._instance_double_value = double_value if double_value is not None else self._orig_double_value
self._orig_half_value = self.__class__.value(torch.half)
self._instance_half_value = half_value if half_value is not None else self._orig_half_value

def __enter__(
self,
Expand Down
34 changes: 34 additions & 0 deletions test/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import unittest

import torch

from gpytorch import settings
from gpytorch.test.base_test_case import BaseTestCase

Expand All @@ -16,3 +18,35 @@ def test_feature_flag(self):
with settings.fast_pred_var(False):
self.assertFalse(settings.fast_pred_var.is_default())
self.assertFalse(settings.fast_pred_var.on())

def test_dtype_value_context(self):
# test custom settings
x = torch.zeros(1, dtype=torch.float)
with settings.min_fixed_noise(float_value=0.1, double_value=0.2, half_value=0.3):
self.assertEqual(settings.min_fixed_noise.value(x), 0.1)
self.assertEqual(settings.min_fixed_noise.value(x.double()), 0.2)
self.assertEqual(settings.min_fixed_noise.value(x.half()), 0.3)
# test defaults are restored
self.assertEqual(
settings.min_fixed_noise.value(x),
settings.min_fixed_noise._global_float_value,
)
self.assertEqual(
settings.min_fixed_noise.value(x.double()),
settings.min_fixed_noise._global_double_value,
)
self.assertEqual(
settings.min_fixed_noise.value(x.half()),
settings.min_fixed_noise._global_half_value,
)
# test setting one dtype
with settings.min_fixed_noise(double_value=0.2):
self.assertEqual(settings.min_fixed_noise.value(x.double()), 0.2)
self.assertEqual(
settings.min_fixed_noise.value(x),
settings.min_fixed_noise._global_float_value,
)
self.assertEqual(
settings.min_fixed_noise.value(x.half()),
settings.min_fixed_noise._global_half_value,
)