-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvalidate.py
256 lines (231 loc) · 6.9 KB
/
validate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
import argparse
import time
from functools import partial
import e3nn_jax as e3nn
import haiku as hk
import jax
import jax.numpy as jnp
import wandb
from experiments import setup_data, train
from segnn_jax import SEGNN, weight_balanced_irreps
key = jax.random.PRNGKey(1337)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Run parameters
parser.add_argument("--epochs", type=int, default=100, help="Number of epochs")
parser.add_argument(
"--batch-size",
type=int,
default=128,
help="Batch size (number of graphs).",
)
parser.add_argument("--lr", type=float, default=5e-4, help="Learning rate")
parser.add_argument(
"--lr-scheduling",
action="store_true",
help="Use learning rate scheduling",
)
parser.add_argument(
"--weight-decay", type=float, default=1e-12, help="Weight decay"
)
parser.add_argument(
"--dataset",
type=str,
choices=["qm9", "charged", "gravity"],
help="Dataset name",
)
parser.add_argument(
"--max-samples",
type=int,
default=3000,
help="Maximum number of samples in nbody dataset",
)
parser.add_argument(
"--val-freq",
type=int,
default=10,
help="Evaluation frequency (number of epochs)",
)
# nbody parameters
parser.add_argument(
"--target",
type=str,
default="pos",
help="Target. e.g. pos, force (gravity), alpha (qm9)",
)
parser.add_argument(
"--neighbours",
type=int,
default=20,
help="Number of connected nearest neighbours",
)
parser.add_argument(
"--n-bodies",
type=int,
default=5,
help="Number of bodies in the dataset",
)
parser.add_argument(
"--dataset-name",
type=str,
default="small",
choices=["small", "default", "small_out_dist"],
help="Name of nbody data partition: default (200 steps), small (1000 steps)",
)
# qm9 parameters
parser.add_argument(
"--radius",
type=float,
default=2.0,
help="Radius (Angstrom) between which atoms to add links.",
)
parser.add_argument(
"--feature-type",
type=str,
default="one_hot",
choices=["one_hot", "cormorant", "gilmer"],
help="Type of input feature",
)
# Model parameters
parser.add_argument(
"--units", type=int, default=64, help="Number of values in the hidden layers"
)
parser.add_argument(
"--lmax-hidden",
type=int,
default=1,
help="Max degree of hidden representations.",
)
parser.add_argument(
"--lmax-attributes",
type=int,
default=1,
help="Max degree of geometric attribute embedding",
)
parser.add_argument(
"--layers", type=int, default=7, help="Number of message passing layers"
)
parser.add_argument(
"--blocks", type=int, default=2, help="Number of layers in steerable MLPs."
)
parser.add_argument(
"--norm",
type=str,
default="none",
choices=["instance", "batch", "none"],
help="Normalisation type",
)
parser.add_argument(
"--double-precision",
action="store_true",
help="Use double precision in model",
)
parser.add_argument(
"--scn",
action="store_true",
help="Train SEGNN with the eSCN optimization",
)
# wandb parameters
parser.add_argument(
"--wandb",
action="store_true",
help="Activate weights and biases logging",
)
parser.add_argument(
"--wandb-project",
type=str,
default="segnn",
help="Weights and biases project",
)
parser.add_argument(
"--wandb-entity",
type=str,
default="",
help="Weights and biases entity",
)
args = parser.parse_args()
# if specified set jax in double precision
jax.config.update("jax_enable_x64", args.double_precision)
# connect to wandb
if args.wandb:
wandb_name = "_".join(
[
args.wandb_project,
args.dataset,
args.target,
str(int(time.time())),
]
)
wandb.init(
project=args.wandb_project,
name=wandb_name,
config=args,
entity=args.wandb_entity,
)
# feature representations
if args.dataset == "qm9":
args.task = "graph"
if args.feature_type == "one_hot":
args.node_irreps = e3nn.Irreps("5x0e")
elif args.feature_type == "cormorant":
args.node_irreps = e3nn.Irreps("15x0e")
elif args.feature_type == "gilmer":
args.node_irreps = e3nn.Irreps("11x0e")
args.output_irreps = e3nn.Irreps("1x0e")
args.additional_message_irreps = e3nn.Irreps("1x0e")
assert not args.scn, "eSCN not implemented for qm9"
elif args.dataset in ["charged", "gravity"]:
args.task = "node"
args.node_irreps = e3nn.Irreps("2x1o + 1x0e")
args.output_irreps = e3nn.Irreps("1x1o")
args.additional_message_irreps = e3nn.Irreps("2x0e")
# Create hidden irreps
if not args.scn:
attr_irreps = e3nn.Irreps.spherical_harmonics(args.lmax_attributes)
else:
attr_irreps = e3nn.Irrep(f"{args.lmax_attribute}y")
hidden_irreps = weight_balanced_irreps(
scalar_units=args.units,
irreps_right=attr_irreps,
use_sh=(not args.scn),
lmax=args.lmax_hidden,
)
args.o3_layer = "scn" if args.scn else "tpl"
del args.scn
# build model
def segnn(x):
return SEGNN(
hidden_irreps=hidden_irreps,
output_irreps=args.output_irreps,
num_layers=args.layers,
task=args.task,
pool="avg",
blocks_per_layer=args.blocks,
norm=args.norm,
o3_layer=args.o3_layer,
)(x)
segnn = hk.without_apply_rng(hk.transform_with_state(segnn))
loader_train, loader_val, loader_test, graph_transform, eval_trn = setup_data(args)
if args.dataset == "qm9":
from experiments.train import loss_fn_wrapper
def _mae(p, t):
return jnp.abs(p - t)
train_loss = partial(loss_fn_wrapper, criterion=_mae)
eval_loss = partial(loss_fn_wrapper, criterion=_mae, eval_trn=eval_trn)
if args.dataset in ["charged", "gravity"]:
from experiments.train import loss_fn_wrapper
def _mse(p, t):
return jnp.power(p - t, 2)
train_loss = partial(loss_fn_wrapper, criterion=_mse, do_mask=False)
eval_loss = partial(loss_fn_wrapper, criterion=_mse, do_mask=False)
train(
key,
segnn,
loader_train,
loader_val,
loader_test,
train_loss,
eval_loss,
graph_transform,
args,
)