@@ -38,74 +38,115 @@ def test_atomref(model_name, enable_atomref):
38
38
39
39
# check if the output of both models differs by the expected atomref contribution
40
40
if enable_atomref :
41
- expected_offset = scatter (dataset .get_atomref ().squeeze ()[z ], batch ).unsqueeze (1 )
41
+ expected_offset = scatter (dataset .get_atomref ().squeeze ()[z ], batch ).unsqueeze (
42
+ 1
43
+ )
42
44
else :
43
45
expected_offset = 0
44
46
torch .testing .assert_close (x_atomref , x_no_atomref + expected_offset )
45
47
48
+
46
49
@mark .parametrize ("trainable" , [True , False ])
47
50
def test_atomref_trainable (trainable ):
48
51
dataset = DummyDataset (has_atomref = True )
49
52
atomref = Atomref (max_z = 100 , dataset = dataset , trainable = trainable )
50
53
assert atomref .atomref .weight .requires_grad == trainable
51
54
55
+
52
56
def test_learnableatomref ():
53
57
atomref = LearnableAtomref (max_z = 100 )
54
58
assert atomref .atomref .weight .requires_grad == True
55
59
60
+
56
61
def test_zbl ():
57
- pos = torch .tensor ([[1.0 , 0.0 , 0.0 ], [2.5 , 0.0 , 0.0 ], [1.0 , 1.0 , 0.0 ], [0.0 , 0.0 , - 1.0 ]], dtype = torch .float32 ) # Atom positions in Bohr
62
+ pos = torch .tensor (
63
+ [[1.0 , 0.0 , 0.0 ], [2.5 , 0.0 , 0.0 ], [1.0 , 1.0 , 0.0 ], [0.0 , 0.0 , - 1.0 ]],
64
+ dtype = torch .float32 ,
65
+ ) # Atom positions in Bohr
58
66
types = torch .tensor ([0 , 1 , 2 , 1 ], dtype = torch .long ) # Atom types
59
- atomic_number = torch .tensor ([1 , 6 , 8 ], dtype = torch .int8 ) # Mapping of atom types to atomic numbers
67
+ atomic_number = torch .tensor (
68
+ [1 , 6 , 8 ], dtype = torch .int8
69
+ ) # Mapping of atom types to atomic numbers
60
70
distance_scale = 5.29177210903e-11 # Convert Bohr to meters
61
- energy_scale = 1000.0 / 6.02214076e23 # Convert kJ/mol to Joules
71
+ energy_scale = 1000.0 / 6.02214076e23 # Convert kJ/mol to Joules
62
72
63
73
# Use the ZBL class to compute the energy.
64
74
65
- zbl = ZBL (10.0 , 5 , atomic_number , distance_scale = distance_scale , energy_scale = energy_scale )
66
- energy = zbl .post_reduce (torch .zeros ((1 ,)), types , pos , torch .zeros_like (types ), None , {})[0 ]
75
+ zbl = ZBL (
76
+ 10.0 , 5 , atomic_number , distance_scale = distance_scale , energy_scale = energy_scale
77
+ )
78
+ energy = zbl .post_reduce (
79
+ torch .zeros ((1 ,)), types , pos , torch .zeros_like (types ), None , {}
80
+ )[0 ]
67
81
68
82
# Compare to the expected value.
69
83
70
84
def compute_interaction (pos1 , pos2 , z1 , z2 ):
71
- delta = pos1 - pos2
85
+ delta = pos1 - pos2
72
86
r = torch .sqrt (torch .dot (delta , delta ))
73
- x = r / (0.8854 / (z1 ** 0.23 + z2 ** 0.23 ))
74
- phi = 0.1818 * torch .exp (- 3.2 * x ) + 0.5099 * torch .exp (- 0.9423 * x ) + 0.2802 * torch .exp (- 0.4029 * x ) + 0.02817 * torch .exp (- 0.2016 * x )
75
- cutoff = 0.5 * (torch .cos (r * torch .pi / 10.0 ) + 1.0 )
76
- return cutoff * phi * (138.935 / 5.29177210903e-2 )* z1 * z2 / r
87
+ x = r / (0.8854 / (z1 ** 0.23 + z2 ** 0.23 ))
88
+ phi = (
89
+ 0.1818 * torch .exp (- 3.2 * x )
90
+ + 0.5099 * torch .exp (- 0.9423 * x )
91
+ + 0.2802 * torch .exp (- 0.4029 * x )
92
+ + 0.02817 * torch .exp (- 0.2016 * x )
93
+ )
94
+ cutoff = 0.5 * (torch .cos (r * torch .pi / 10.0 ) + 1.0 )
95
+ return cutoff * phi * (138.935 / 5.29177210903e-2 ) * z1 * z2 / r
77
96
78
97
expected = 0
79
98
for i in range (len (pos )):
80
99
for j in range (i ):
81
- expected += compute_interaction (pos [i ], pos [j ], atomic_number [types [i ]], atomic_number [types [j ]])
100
+ expected += compute_interaction (
101
+ pos [i ], pos [j ], atomic_number [types [i ]], atomic_number [types [j ]]
102
+ )
82
103
torch .testing .assert_close (expected , energy , rtol = 1e-4 , atol = 1e-4 )
83
104
105
+
84
106
@pytest .mark .parametrize ("dtype" , [torch .float32 , torch .float64 ])
85
107
def test_coulomb (dtype ):
86
- pos = torch .tensor ([[0.5 , 0.0 , 0.0 ], [1.5 , 0.0 , 0.0 ], [0.8 , 0.8 , 0.0 ], [0.0 , 0.0 , - 0.4 ]], dtype = dtype ) # Atom positions in nm
108
+ pos = torch .tensor (
109
+ [[0.5 , 0.0 , 0.0 ], [1.5 , 0.0 , 0.0 ], [0.8 , 0.8 , 0.0 ], [0.0 , 0.0 , - 0.4 ]],
110
+ dtype = dtype ,
111
+ ) # Atom positions in nm
87
112
charge = torch .tensor ([0.2 , - 0.1 , 0.8 , - 0.9 ], dtype = dtype ) # Partial charges
88
113
types = torch .tensor ([0 , 1 , 2 , 1 ], dtype = torch .long ) # Atom types
89
114
distance_scale = 1e-9 # Convert nm to meters
90
- energy_scale = 1000.0 / 6.02214076e23 # Convert kJ/mol to Joules
115
+ energy_scale = 1000.0 / 6.02214076e23 # Convert kJ/mol to Joules
91
116
lower_switch_distance = 0.9
92
117
upper_switch_distance = 1.3
93
118
94
119
# Use the Coulomb class to compute the energy.
95
120
96
- coulomb = Coulomb (lower_switch_distance , upper_switch_distance , 5 , distance_scale = distance_scale , energy_scale = energy_scale )
97
- energy = coulomb .post_reduce (torch .zeros ((1 ,)), types , pos , torch .zeros_like (types ), extra_args = {'partial_charges' :charge })[0 ]
121
+ coulomb = Coulomb (
122
+ lower_switch_distance ,
123
+ upper_switch_distance ,
124
+ 5 ,
125
+ distance_scale = distance_scale ,
126
+ energy_scale = energy_scale ,
127
+ )
128
+ energy = coulomb .post_reduce (
129
+ torch .zeros ((1 ,)),
130
+ types ,
131
+ pos ,
132
+ torch .zeros_like (types ),
133
+ extra_args = {"partial_charges" : charge },
134
+ )[0 ]
98
135
99
136
# Compare to the expected value.
100
137
101
138
def compute_interaction (pos1 , pos2 , z1 , z2 ):
102
- delta = pos1 - pos2
139
+ delta = pos1 - pos2
103
140
r = torch .sqrt (torch .dot (delta , delta ))
104
141
if r < lower_switch_distance :
105
142
return 0
106
- energy = 138.935 * z1 * z2 / r
143
+ energy = 138.935 * z1 * z2 / r
107
144
if r < upper_switch_distance :
108
- energy *= 0.5 - 0.5 * torch .cos (torch .pi * (r - lower_switch_distance )/ (upper_switch_distance - lower_switch_distance ))
145
+ energy *= 0.5 - 0.5 * torch .cos (
146
+ torch .pi
147
+ * (r - lower_switch_distance )
148
+ / (upper_switch_distance - lower_switch_distance )
149
+ )
109
150
return energy
110
151
111
152
expected = 0
@@ -120,10 +161,12 @@ def test_multiple_priors(dtype):
120
161
# Create a model from a config file.
121
162
122
163
dataset = DummyDataset (has_atomref = True )
123
- config_file = join (dirname (__file__ ), 'priors.yaml' )
124
- args = load_example_args ('equivariant-transformer' , config_file = config_file , dtype = dtype )
164
+ config_file = join (dirname (__file__ ), "priors.yaml" )
165
+ args = load_example_args (
166
+ "equivariant-transformer" , config_file = config_file , dtype = dtype
167
+ )
125
168
prior_models = create_prior_models (args , dataset )
126
- args [' prior_args' ] = [p .get_init_args () for p in prior_models ]
169
+ args [" prior_args" ] = [p .get_init_args () for p in prior_models ]
127
170
model = LNNP (args , prior_model = prior_models )
128
171
priors = model .model .prior_model
129
172
0 commit comments