2
2
3
3
import unittest
4
4
5
+ import torch
6
+
5
7
from gpytorch import settings
6
8
from gpytorch .test .base_test_case import BaseTestCase
7
9
@@ -16,3 +18,35 @@ def test_feature_flag(self):
16
18
with settings .fast_pred_var (False ):
17
19
self .assertFalse (settings .fast_pred_var .is_default ())
18
20
self .assertFalse (settings .fast_pred_var .on ())
21
+
22
+ def test_dtype_value_context (self ):
23
+ # test custom settings
24
+ x = torch .zeros (1 , dtype = torch .float )
25
+ with settings .min_fixed_noise (float_value = 0.1 , double_value = 0.2 , half_value = 0.3 ):
26
+ self .assertEqual (settings .min_fixed_noise .value (x ), 0.1 )
27
+ self .assertEqual (settings .min_fixed_noise .value (x .double ()), 0.2 )
28
+ self .assertEqual (settings .min_fixed_noise .value (x .half ()), 0.3 )
29
+ # test defaults are restored
30
+ self .assertEqual (
31
+ settings .min_fixed_noise .value (x ),
32
+ settings .min_fixed_noise ._global_float_value ,
33
+ )
34
+ self .assertEqual (
35
+ settings .min_fixed_noise .value (x .double ()),
36
+ settings .min_fixed_noise ._global_double_value ,
37
+ )
38
+ self .assertEqual (
39
+ settings .min_fixed_noise .value (x .half ()),
40
+ settings .min_fixed_noise ._global_half_value ,
41
+ )
42
+ # test setting one dtype
43
+ with settings .min_fixed_noise (double_value = 0.2 ):
44
+ self .assertEqual (settings .min_fixed_noise .value (x .double ()), 0.2 )
45
+ self .assertEqual (
46
+ settings .min_fixed_noise .value (x ),
47
+ settings .min_fixed_noise ._global_float_value ,
48
+ )
49
+ self .assertEqual (
50
+ settings .min_fixed_noise .value (x .half ()),
51
+ settings .min_fixed_noise ._global_half_value ,
52
+ )
0 commit comments