6
6
from typing import Optional
7
7
import torch
8
8
from torch import nn
9
- from torchmdnet .models .utils import act_class_mapping , GatedEquivariantBlock , scatter
9
+ from torchmdnet .models .utils import (
10
+ act_class_mapping ,
11
+ GatedEquivariantBlock ,
12
+ scatter ,
13
+ MLP ,
14
+ )
10
15
from torchmdnet .utils import atomic_masses
11
16
from torchmdnet .extensions import is_current_stream_capturing
12
17
from warnings import warn
@@ -60,24 +65,23 @@ def __init__(
60
65
allow_prior_model = True ,
61
66
reduce_op = "sum" ,
62
67
dtype = torch .float ,
68
+ ** kwargs
63
69
):
64
70
super (Scalar , self ).__init__ (
65
71
allow_prior_model = allow_prior_model , reduce_op = reduce_op
66
72
)
67
- act_class = act_class_mapping [activation ]
68
- self .output_network = nn .Sequential (
69
- nn .Linear (hidden_channels , hidden_channels // 2 , dtype = dtype ),
70
- act_class (),
71
- nn .Linear (hidden_channels // 2 , 1 , dtype = dtype ),
73
+ self .output_network = MLP (
74
+ in_channels = hidden_channels ,
75
+ out_channels = 1 ,
76
+ hidden_channels = hidden_channels // 2 ,
77
+ activation = activation ,
78
+ num_layers = kwargs .get ("num_layers" , 0 ),
79
+ dtype = dtype ,
72
80
)
73
-
74
81
self .reset_parameters ()
75
82
76
83
def reset_parameters (self ):
77
- nn .init .xavier_uniform_ (self .output_network [0 ].weight )
78
- self .output_network [0 ].bias .data .fill_ (0 )
79
- nn .init .xavier_uniform_ (self .output_network [2 ].weight )
80
- self .output_network [2 ].bias .data .fill_ (0 )
84
+ self .output_network .reset_parameters ()
81
85
82
86
def pre_reduce (self , x , v : Optional [torch .Tensor ], z , pos , batch ):
83
87
return self .output_network (x )
@@ -91,10 +95,13 @@ def __init__(
91
95
allow_prior_model = True ,
92
96
reduce_op = "sum" ,
93
97
dtype = torch .float ,
98
+ ** kwargs
94
99
):
95
100
super (EquivariantScalar , self ).__init__ (
96
101
allow_prior_model = allow_prior_model , reduce_op = reduce_op
97
102
)
103
+ if kwargs .get ("num_layers" , 0 ) > 0 :
104
+ warn ("num_layers is not used in EquivariantScalar" )
98
105
self .output_network = nn .ModuleList (
99
106
[
100
107
GatedEquivariantBlock (
@@ -125,14 +132,20 @@ def pre_reduce(self, x, v, z, pos, batch):
125
132
126
133
class DipoleMoment (Scalar ):
127
134
def __init__ (
128
- self , hidden_channels , activation = "silu" , reduce_op = "sum" , dtype = torch .float
135
+ self ,
136
+ hidden_channels ,
137
+ activation = "silu" ,
138
+ reduce_op = "sum" ,
139
+ dtype = torch .float ,
140
+ ** kwargs
129
141
):
130
142
super (DipoleMoment , self ).__init__ (
131
143
hidden_channels ,
132
144
activation ,
133
145
allow_prior_model = False ,
134
146
reduce_op = reduce_op ,
135
147
dtype = dtype ,
148
+ ** kwargs
136
149
)
137
150
atomic_mass = torch .from_numpy (atomic_masses ).to (dtype )
138
151
self .register_buffer ("atomic_mass" , atomic_mass )
@@ -152,14 +165,20 @@ def post_reduce(self, x):
152
165
153
166
class EquivariantDipoleMoment (EquivariantScalar ):
154
167
def __init__ (
155
- self , hidden_channels , activation = "silu" , reduce_op = "sum" , dtype = torch .float
168
+ self ,
169
+ hidden_channels ,
170
+ activation = "silu" ,
171
+ reduce_op = "sum" ,
172
+ dtype = torch .float ,
173
+ ** kwargs
156
174
):
157
175
super (EquivariantDipoleMoment , self ).__init__ (
158
176
hidden_channels ,
159
177
activation ,
160
178
allow_prior_model = False ,
161
179
reduce_op = reduce_op ,
162
180
dtype = dtype ,
181
+ ** kwargs
163
182
)
164
183
atomic_mass = torch .from_numpy (atomic_masses ).to (dtype )
165
184
self .register_buffer ("atomic_mass" , atomic_mass )
@@ -180,27 +199,31 @@ def post_reduce(self, x):
180
199
181
200
class ElectronicSpatialExtent (OutputModel ):
182
201
def __init__ (
183
- self , hidden_channels , activation = "silu" , reduce_op = "sum" , dtype = torch .float
202
+ self ,
203
+ hidden_channels ,
204
+ activation = "silu" ,
205
+ reduce_op = "sum" ,
206
+ dtype = torch .float ,
207
+ ** kwargs
184
208
):
185
209
super (ElectronicSpatialExtent , self ).__init__ (
186
210
allow_prior_model = False , reduce_op = reduce_op
187
211
)
188
- act_class = act_class_mapping [activation ]
189
- self .output_network = nn .Sequential (
190
- nn .Linear (hidden_channels , hidden_channels // 2 , dtype = dtype ),
191
- act_class (),
192
- nn .Linear (hidden_channels // 2 , 1 , dtype = dtype ),
212
+ self .output_network = MLP (
213
+ in_channels = hidden_channels ,
214
+ out_channels = 1 ,
215
+ hidden_channels = hidden_channels // 2 ,
216
+ activation = activation ,
217
+ num_layers = kwargs .get ("num_layers" , 0 ),
218
+ dtype = dtype ,
193
219
)
194
220
atomic_mass = torch .from_numpy (atomic_masses ).to (dtype )
195
221
self .register_buffer ("atomic_mass" , atomic_mass )
196
222
197
223
self .reset_parameters ()
198
224
199
225
def reset_parameters (self ):
200
- nn .init .xavier_uniform_ (self .output_network [0 ].weight )
201
- self .output_network [0 ].bias .data .fill_ (0 )
202
- nn .init .xavier_uniform_ (self .output_network [2 ].weight )
203
- self .output_network [2 ].bias .data .fill_ (0 )
226
+ self .output_network .reset_parameters ()
204
227
205
228
def pre_reduce (self , x , v : Optional [torch .Tensor ], z , pos , batch ):
206
229
x = self .output_network (x )
@@ -219,14 +242,20 @@ class EquivariantElectronicSpatialExtent(ElectronicSpatialExtent):
219
242
220
243
class EquivariantVectorOutput (EquivariantScalar ):
221
244
def __init__ (
222
- self , hidden_channels , activation = "silu" , reduce_op = "sum" , dtype = torch .float
245
+ self ,
246
+ hidden_channels ,
247
+ activation = "silu" ,
248
+ reduce_op = "sum" ,
249
+ dtype = torch .float ,
250
+ ** kwargs
223
251
):
224
252
super (EquivariantVectorOutput , self ).__init__ (
225
253
hidden_channels ,
226
254
activation ,
227
255
allow_prior_model = False ,
228
256
reduce_op = "sum" ,
229
257
dtype = dtype ,
258
+ ** kwargs
230
259
)
231
260
232
261
def pre_reduce (self , x , v , z , pos , batch ):
0 commit comments