Skip to content

Commit b4772e2

Browse files
authored
blacken code (#359)
1 parent 2be8620 commit b4772e2

24 files changed

+369
-203
lines changed

benchmarks/inference.py

+1
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def benchmark_pdb(pdb_file, **kwargs):
122122
"2L emb 64": {"num_layers": 2, "embedding_dimension": 64},
123123
}
124124

125+
125126
def benchmark_all():
126127
timings = {}
127128
for pdb_file in os.listdir("systems"):

benchmarks/neighbors.py

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing import Optional
1111
from torch_cluster import radius_graph
1212

13+
1314
class Distance(nn.Module):
1415
def __init__(
1516
self,

examples/openmm-integration.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
import openmm
66
import openmmtorch
77
except ImportError:
8-
raise ImportError("Please install OpenMM and OpenMM-Torch (you can use conda install -c conda-forge openmm openmm-torch)")
8+
raise ImportError(
9+
"Please install OpenMM and OpenMM-Torch (you can use conda install -c conda-forge openmm openmm-torch)"
10+
)
911

1012
import sys
1113
import torch
@@ -34,9 +36,9 @@ def __init__(self, embeddings, model):
3436
def forward(self, positions):
3537
# OpenMM works with nanometer positions and kilojoule per mole energies
3638
# Depending on the model, you might need to convert the units
37-
positions = positions.to(torch.float32) * 10.0 # nm -> A
39+
positions = positions.to(torch.float32) * 10.0 # nm -> A
3840
energy = self.model(z=self.embeddings, pos=positions)[0]
39-
return energy * 96.4916 # eV -> kJ/mol
41+
return energy * 96.4916 # eV -> kJ/mol
4042

4143

4244
pdb = PDBFile("../benchmarks/systems/chignolin.pdb")
@@ -54,9 +56,11 @@ def forward(self, positions):
5456
for atom in pdb.topology.atoms():
5557
system.addParticle(atom.element.mass)
5658
system.addForce(torch_force)
57-
integrator = LangevinMiddleIntegrator(298.15*kelvin, 1/picosecond, 2*femtosecond)
58-
platform = Platform.getPlatformByName('CPU')
59+
integrator = LangevinMiddleIntegrator(298.15 * kelvin, 1 / picosecond, 2 * femtosecond)
60+
platform = Platform.getPlatformByName("CPU")
5961
simulation = Simulation(pdb.topology, system, integrator, platform)
6062
simulation.context.setPositions(pdb.positions)
61-
simulation.reporters.append(StateDataReporter(sys.stdout, 1, step=True, potentialEnergy=True, temperature=True))
63+
simulation.reporters.append(
64+
StateDataReporter(sys.stdout, 1, step=True, potentialEnergy=True, temperature=True)
65+
)
6266
simulation.step(10)

tests/test_datasets.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,10 @@ def test_hdf5_with_and_without_caching(num_files, tile_embed, batch_size, tmpdir
328328

329329
for sample_cached, sample in zip(dl_cached, dl):
330330
assert np.allclose(sample_cached.pos, sample.pos), "Sample has incorrect coords"
331-
assert np.allclose(sample_cached.z, sample.z), "Sample has incorrect atom numbers"
331+
assert np.allclose(
332+
sample_cached.z, sample.z
333+
), "Sample has incorrect atom numbers"
332334
assert np.allclose(sample_cached.y, sample.y), "Sample has incorrect energy"
333-
assert np.allclose(sample_cached.neg_dy, sample.neg_dy), "Sample has incorrect forces"
335+
assert np.allclose(
336+
sample_cached.neg_dy, sample.neg_dy
337+
), "Sample has incorrect forces"

tests/test_mdcath.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,7 @@ def test_mdcath_args(tmpdir, skipframes, batch_size, pdb_list):
167167
data.flush()
168168
data.close()
169169

170-
dataset = MDCATH(
171-
root=tmpdir, skip_frames=skipframes, pdb_list=pdb_list
172-
)
170+
dataset = MDCATH(root=tmpdir, skip_frames=skipframes, pdb_list=pdb_list)
173171
dl = DataLoader(
174172
dataset,
175173
batch_size=batch_size,

tests/test_priors.py

+65-22
Original file line numberDiff line numberDiff line change
@@ -38,74 +38,115 @@ def test_atomref(model_name, enable_atomref):
3838

3939
# check if the output of both models differs by the expected atomref contribution
4040
if enable_atomref:
41-
expected_offset = scatter(dataset.get_atomref().squeeze()[z], batch).unsqueeze(1)
41+
expected_offset = scatter(dataset.get_atomref().squeeze()[z], batch).unsqueeze(
42+
1
43+
)
4244
else:
4345
expected_offset = 0
4446
torch.testing.assert_close(x_atomref, x_no_atomref + expected_offset)
4547

48+
4649
@mark.parametrize("trainable", [True, False])
4750
def test_atomref_trainable(trainable):
4851
dataset = DummyDataset(has_atomref=True)
4952
atomref = Atomref(max_z=100, dataset=dataset, trainable=trainable)
5053
assert atomref.atomref.weight.requires_grad == trainable
5154

55+
5256
def test_learnableatomref():
5357
atomref = LearnableAtomref(max_z=100)
5458
assert atomref.atomref.weight.requires_grad == True
5559

60+
5661
def test_zbl():
57-
pos = torch.tensor([[1.0, 0.0, 0.0], [2.5, 0.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, -1.0]], dtype=torch.float32) # Atom positions in Bohr
62+
pos = torch.tensor(
63+
[[1.0, 0.0, 0.0], [2.5, 0.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, -1.0]],
64+
dtype=torch.float32,
65+
) # Atom positions in Bohr
5866
types = torch.tensor([0, 1, 2, 1], dtype=torch.long) # Atom types
59-
atomic_number = torch.tensor([1, 6, 8], dtype=torch.int8) # Mapping of atom types to atomic numbers
67+
atomic_number = torch.tensor(
68+
[1, 6, 8], dtype=torch.int8
69+
) # Mapping of atom types to atomic numbers
6070
distance_scale = 5.29177210903e-11 # Convert Bohr to meters
61-
energy_scale = 1000.0/6.02214076e23 # Convert kJ/mol to Joules
71+
energy_scale = 1000.0 / 6.02214076e23 # Convert kJ/mol to Joules
6272

6373
# Use the ZBL class to compute the energy.
6474

65-
zbl = ZBL(10.0, 5, atomic_number, distance_scale=distance_scale, energy_scale=energy_scale)
66-
energy = zbl.post_reduce(torch.zeros((1,)), types, pos, torch.zeros_like(types), None, {})[0]
75+
zbl = ZBL(
76+
10.0, 5, atomic_number, distance_scale=distance_scale, energy_scale=energy_scale
77+
)
78+
energy = zbl.post_reduce(
79+
torch.zeros((1,)), types, pos, torch.zeros_like(types), None, {}
80+
)[0]
6781

6882
# Compare to the expected value.
6983

7084
def compute_interaction(pos1, pos2, z1, z2):
71-
delta = pos1-pos2
85+
delta = pos1 - pos2
7286
r = torch.sqrt(torch.dot(delta, delta))
73-
x = r / (0.8854/(z1**0.23 + z2**0.23))
74-
phi = 0.1818*torch.exp(-3.2*x) + 0.5099*torch.exp(-0.9423*x) + 0.2802*torch.exp(-0.4029*x) + 0.02817*torch.exp(-0.2016*x)
75-
cutoff = 0.5*(torch.cos(r*torch.pi/10.0) + 1.0)
76-
return cutoff*phi*(138.935/5.29177210903e-2)*z1*z2/r
87+
x = r / (0.8854 / (z1**0.23 + z2**0.23))
88+
phi = (
89+
0.1818 * torch.exp(-3.2 * x)
90+
+ 0.5099 * torch.exp(-0.9423 * x)
91+
+ 0.2802 * torch.exp(-0.4029 * x)
92+
+ 0.02817 * torch.exp(-0.2016 * x)
93+
)
94+
cutoff = 0.5 * (torch.cos(r * torch.pi / 10.0) + 1.0)
95+
return cutoff * phi * (138.935 / 5.29177210903e-2) * z1 * z2 / r
7796

7897
expected = 0
7998
for i in range(len(pos)):
8099
for j in range(i):
81-
expected += compute_interaction(pos[i], pos[j], atomic_number[types[i]], atomic_number[types[j]])
100+
expected += compute_interaction(
101+
pos[i], pos[j], atomic_number[types[i]], atomic_number[types[j]]
102+
)
82103
torch.testing.assert_close(expected, energy, rtol=1e-4, atol=1e-4)
83104

105+
84106
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
85107
def test_coulomb(dtype):
86-
pos = torch.tensor([[0.5, 0.0, 0.0], [1.5, 0.0, 0.0], [0.8, 0.8, 0.0], [0.0, 0.0, -0.4]], dtype=dtype) # Atom positions in nm
108+
pos = torch.tensor(
109+
[[0.5, 0.0, 0.0], [1.5, 0.0, 0.0], [0.8, 0.8, 0.0], [0.0, 0.0, -0.4]],
110+
dtype=dtype,
111+
) # Atom positions in nm
87112
charge = torch.tensor([0.2, -0.1, 0.8, -0.9], dtype=dtype) # Partial charges
88113
types = torch.tensor([0, 1, 2, 1], dtype=torch.long) # Atom types
89114
distance_scale = 1e-9 # Convert nm to meters
90-
energy_scale = 1000.0/6.02214076e23 # Convert kJ/mol to Joules
115+
energy_scale = 1000.0 / 6.02214076e23 # Convert kJ/mol to Joules
91116
lower_switch_distance = 0.9
92117
upper_switch_distance = 1.3
93118

94119
# Use the Coulomb class to compute the energy.
95120

96-
coulomb = Coulomb(lower_switch_distance, upper_switch_distance, 5, distance_scale=distance_scale, energy_scale=energy_scale)
97-
energy = coulomb.post_reduce(torch.zeros((1,)), types, pos, torch.zeros_like(types), extra_args={'partial_charges':charge})[0]
121+
coulomb = Coulomb(
122+
lower_switch_distance,
123+
upper_switch_distance,
124+
5,
125+
distance_scale=distance_scale,
126+
energy_scale=energy_scale,
127+
)
128+
energy = coulomb.post_reduce(
129+
torch.zeros((1,)),
130+
types,
131+
pos,
132+
torch.zeros_like(types),
133+
extra_args={"partial_charges": charge},
134+
)[0]
98135

99136
# Compare to the expected value.
100137

101138
def compute_interaction(pos1, pos2, z1, z2):
102-
delta = pos1-pos2
139+
delta = pos1 - pos2
103140
r = torch.sqrt(torch.dot(delta, delta))
104141
if r < lower_switch_distance:
105142
return 0
106-
energy = 138.935*z1*z2/r
143+
energy = 138.935 * z1 * z2 / r
107144
if r < upper_switch_distance:
108-
energy *= 0.5-0.5*torch.cos(torch.pi*(r-lower_switch_distance)/(upper_switch_distance-lower_switch_distance))
145+
energy *= 0.5 - 0.5 * torch.cos(
146+
torch.pi
147+
* (r - lower_switch_distance)
148+
/ (upper_switch_distance - lower_switch_distance)
149+
)
109150
return energy
110151

111152
expected = 0
@@ -120,10 +161,12 @@ def test_multiple_priors(dtype):
120161
# Create a model from a config file.
121162

122163
dataset = DummyDataset(has_atomref=True)
123-
config_file = join(dirname(__file__), 'priors.yaml')
124-
args = load_example_args('equivariant-transformer', config_file=config_file, dtype=dtype)
164+
config_file = join(dirname(__file__), "priors.yaml")
165+
args = load_example_args(
166+
"equivariant-transformer", config_file=config_file, dtype=dtype
167+
)
125168
prior_models = create_prior_models(args, dataset)
126-
args['prior_args'] = [p.get_init_args() for p in prior_models]
169+
args["prior_args"] = [p.get_init_args() for p in prior_models]
127170
model = LNNP(args, prior_model=prior_models)
128171
priors = model.model.prior_model
129172

tests/utils.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
def load_example_args(model_name, remove_prior=False, config_file=None, **kwargs):
1212
if config_file is None:
1313
if model_name == "tensornet":
14-
config_file = join(dirname(dirname(__file__)), "examples", "TensorNet-QM9.yaml")
14+
config_file = join(
15+
dirname(dirname(__file__)), "examples", "TensorNet-QM9.yaml"
16+
)
1517
else:
1618
config_file = join(dirname(dirname(__file__)), "examples", "ET-QM9.yaml")
1719
with open(config_file, "r") as f:
@@ -84,7 +86,7 @@ def _get_atomref(self):
8486
return self.atomref
8587

8688
DummyDataset.get_atomref = _get_atomref
87-
self.atomic_number = torch.arange(max(atom_types)+1)
89+
self.atomic_number = torch.arange(max(atom_types) + 1)
8890
self.distance_scale = 1.0
8991
self.energy_scale = 1.0
9092

torchmdnet/data.py

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torchmdnet.models.utils import scatter
1515
import warnings
1616

17+
1718
class DataModule(LightningDataModule):
1819
"""A LightningDataModule for loading datasets from the torchmdnet.datasets module.
1920

torchmdnet/datasets/ani.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,12 @@ def raw_file_names(self):
4848
def get_atomref(self, max_z=100):
4949
"""Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior.
5050
51-
Args:
52-
max_z (int): Maximum atomic number
51+
Args:
52+
max_z (int): Maximum atomic number
5353
54-
Returns:
55-
torch.Tensor: Atomic energy reference values for each element in the dataset.
56-
"""
54+
Returns:
55+
torch.Tensor: Atomic energy reference values for each element in the dataset.
56+
"""
5757
refs = pt.zeros(max_z)
5858
for key, val in self._ELEMENT_ENERGIES.items():
5959
refs[key] = val * self.HARTREE_TO_EV

torchmdnet/datasets/comp6.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,12 @@ def raw_url(self):
6262
def get_atomref(self, max_z=100):
6363
"""Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior.
6464
65-
Args:
66-
max_z (int): Maximum atomic number
65+
Args:
66+
max_z (int): Maximum atomic number
6767
68-
Returns:
69-
torch.Tensor: Atomic energy reference values for each element in the dataset.
70-
"""
68+
Returns:
69+
torch.Tensor: Atomic energy reference values for each element in the dataset.
70+
"""
7171
refs = pt.zeros(max_z)
7272
for key, val in self._ELEMENT_ENERGIES.items():
7373
refs[key] = val * self.HARTREE_TO_EV
@@ -142,6 +142,7 @@ def raw_url_name(self):
142142
def raw_file_names(self):
143143
return ["ani_md_bench.h5"]
144144

145+
145146
class DrugBank(COMP6Base):
146147
"""
147148
DrugBank Benchmark. This benchmark is developed through a subsampling of the
@@ -247,7 +248,7 @@ def __init__(
247248

248249
self.subsets = [
249250
DS(root, transform, pre_transform, pre_filter)
250-
for DS in (ANIMD, DrugBank, GDB07to09, GDB10to13, Tripeptides, S66X8)
251+
for DS in (ANIMD, DrugBank, GDB07to09, GDB10to13, Tripeptides, S66X8)
251252
]
252253

253254
self.num_samples = sum(len(subset) for subset in self.subsets)
@@ -347,12 +348,12 @@ def sample_iter(self, mol_ids=False):
347348
def get_atomref(self, max_z=100):
348349
"""Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior.
349350
350-
Args:
351-
max_z (int): Maximum atomic number
351+
Args:
352+
max_z (int): Maximum atomic number
352353
353-
Returns:
354-
torch.Tensor: Atomic energy reference values for each element in the dataset.
355-
"""
354+
Returns:
355+
torch.Tensor: Atomic energy reference values for each element in the dataset.
356+
"""
356357
refs = pt.zeros(max_z)
357358
for key, val in self._ELEMENT_ENERGIES.items():
358359
refs[key] = val * self.HARTREE_TO_EV

0 commit comments

Comments
 (0)