Skip to content

Commit 34514ee

Browse files
committed
Merge branch 'bertranMiquel-rings_clt push origin mainose_atoms'
2 parents 21e1580 + 3d184c4 commit 34514ee

File tree

16 files changed

+1227
-79
lines changed

16 files changed

+1227
-79
lines changed

README.md

Lines changed: 17 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,78 +1,22 @@
1-
# ICML Topological Deep Learning Challenge 2024: Beyond the Graph Domain
2-
Official repository for the Topological Deep Learning Challenge 2024, jointly organized by [TAG-DS](https://www.tagds.com) & PyT-Team and hosted by the [Geometry-grounded Representation Learning and Generative Modeling (GRaM) Workshop](https://gram-workshop.github.io) at ICML 2024.
1+
# Molecule Ring & Close Atoms Lifting (Graph to Combinatorial)
2+
This notebook imports QM9 dataset and applies a lifting from a graph molecular representation to a combinatorial complex. Then, a neural network is run using the loaded data.
33

4-
## Relevant Information
5-
- The deadline is **July 12th, 2024 (Anywhere on Earth)**. Participants are welcome to modify their submissions until this time.
6-
- Please, check out the [main webpage of the challenge](https://pyt-team.github.io/packs/challenge.html) for the full description of the competition (motivation, submission requirements, evaluation, etc.)
4+
Using [QM9 dataset](https://paperswithcode.com/dataset/qm9), we implement a lifting from a molecule graph to a combinatorial complex based on two points:
5+
- The ring information of the molecule. Rings will be represented as 2-cells in the combinatorial complex.
6+
- The distance between atoms in the molecule. Distances between atoms will be computed. If the atoms are under a predefined threshold, they will be considered as close and groupped together. This clusters will be introduced as hyperedges in the combinatorial complex.
77

8-
## Brief Description
9-
The main purpose of the challenge is to further expand the current scope and impact of Topological Deep Learning (TDL), enabling the exploration of its applicability in new contexts and scenarios. To do so, we propose participants to design and implement lifting mappings between different data structures and topological domains (point-clouds, graphs, hypergraphs, simplicial/cell/combinatorial complexes), potentially bridging the gap between TDL and all kinds of existing datasets.
8+
So far, to the best of our knowledge, it is the first implementation of a molecule as a combinatorial complex, combining both hypergraphs and cell complexes.
109

10+
Here, the elements are the following:
11+
- **Nodes**: Atoms in the molecule.
12+
- **Edges**: Bonds between atoms.
13+
- **Hyperedges**: Clusters of atoms that are close to each other.
14+
- **2-cells**: Rings in the molecule.
1115

12-
## General Guidelines
13-
Everyone can participate and participation is free --only principal PyT-Team developers are excluded. It is sufficient to:
14-
- Send a valid Pull Request (i.e. passing all tests) before the deadline.
15-
- Respect Submission Requirements (see below).
16+
Additionally, attributes inspired by those used in [(Battiloro et al., 2024)](https://arxiv.org/abs/2405.15429) are incorporated into the elements, enhancing the representation of the molecule.
17+
The attributes are:
18+
- **Node**: Atom type, atomic number, and chirality.
19+
- **Edge**: Bond type, conjugation and stereochemistry.
20+
- **Rings**: Ring size, aromaticity, heteroatoms, saturation, hydrophobicity, electrophilicity, nucleophilicity, and polarity.
1621

17-
Teams are accepted, and there is no restriction on the number of team members. An acceptable Pull Request automatically subscribes a participant/team to the challenge.
18-
19-
We encourage participants to start submitting their Pull Request early on, as this helps addressing potential issues with the code. Moreover, earlier Pull Requests will be given priority consideration in the case of multiple submissions of similar quality implementing the same lifting.
20-
21-
A Pull Request should contain no more than one lifting. However, there is no restriction on the number of submissions (Pull Requests) per participant/team.
22-
23-
## Basic Setup
24-
To develop on your machine, here are some tips.
25-
26-
First, we recommend using Python 3.11.3, which is the python version used to run the unit-tests.
27-
28-
For example, create a conda environment:
29-
```bash
30-
conda create -n topox python=3.11.3
31-
conda activate topox
32-
```
33-
34-
Then:
35-
36-
1. Clone a copy of tmx from source:
37-
38-
```bash
39-
git clone [email protected]:pyt-team/challenge-icml-2024.git
40-
cd challenge-icml-2024
41-
```
42-
43-
2. Install tmx in editable mode:
44-
45-
```bash
46-
pip install -e '.[all]'
47-
```
48-
**Notes:**
49-
- Requires pip >= 21.3. Refer: [PEP 660](https://peps.python.org/pep-0660/).
50-
- On Windows, use `pip install -e .[all]` instead (without quotes around `[all]`).
51-
52-
4. Install torch, torch-scatter, torch-sparse with or without CUDA depending on your needs.
53-
54-
```bash
55-
pip install torch==2.0.1 --extra-index-url https://download.pytorch.org/whl/${CUDA}
56-
pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.1+${CUDA}.html
57-
pip install torch-cluster -f https://data.pyg.org/whl/torch-2.0.0+${CUDA}.html
58-
```
59-
60-
where `${CUDA}` should be replaced by either `cpu`, `cu102`, `cu113`, or `cu115` depending on your PyTorch installation (`torch.version.cuda`).
61-
62-
5. Ensure that you have a working tmx installation by running the entire test suite with
63-
64-
```bash
65-
pytest
66-
```
67-
68-
In case an error occurs, please first check if all sub-packages ([`torch-scatter`](https://github.com/rusty1s/pytorch_scatter), [`torch-sparse`](https://github.com/rusty1s/pytorch_sparse), [`torch-cluster`](https://github.com/rusty1s/pytorch_cluster) and [`torch-spline-conv`](https://github.com/rusty1s/pytorch_spline_conv)) are on its latest reported version.
69-
70-
6. Install pre-commit hooks:
71-
72-
```bash
73-
pre-commit install
74-
```
75-
76-
## Questions
77-
78-
Feel free to contact us through GitHub issues on this repository, or through the [Geometry and Topology in Machine Learning slack](https://tda-in-ml.slack.com/join/shared_invite/enQtOTIyMTIyNTYxMTM2LTA2YmQyZjVjNjgxZWYzMDUyODY5MjlhMGE3ZTI1MzE4NjI2OTY0MmUyMmQ3NGE0MTNmMzNiMTViMjM2MzE4OTc#/). Alternatively, you can contact us via mail at any of these accounts: [email protected], [email protected].
22+
This pull request is done under the team formed by: Bertran Miquel Oliver, Manel Gil Sorribes, Alexis Molina

configs/datasets/QM9.yaml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
data_domain: graph
2+
data_type: QM9
3+
data_name: QM9
4+
data_dir: datasets/${data_domain}/${data_type}
5+
#data_split_dir: ${oc.env:PROJECT_ROOT}/datasets/data_splits/${data_name}
6+
7+
# Dataset parameters
8+
num_features: 11
9+
num_classes: 1
10+
task: regression
11+
loss_type: mse
12+
monitor_metric: mae
13+
task_level: graph
14+

configs/datasets/manual_mol.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
data_domain: graph
2+
data_type: toy_dataset
3+
data_name: manual_mol
4+
data_dir: datasets/${data_domain}/${data_type}
5+
6+
# Dataset parameters
7+
num_features: 1
8+
num_classes: 2
9+
task: classification
10+
loss_type: cross_entropy
11+
monitor_metric: accuracy
12+
task_level: node

configs/models/combinatorial/hmc.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
in_channels: null # This will be set by the dataset
2+
hidden_channels: 32
3+
out_channels: null # This will be set by the dataset
4+
n_layers: 2
5+
negative_slope: 0.2
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
transform_type: 'lifting'
2+
transform_name: "CombinatorialRingCloseAtomsLifting"
3+
max_cell_length: null
4+
preserve_edge_attr: True
5+
feature_lifting: ProjectionSum
6+
7+
threshold_distance: 1.5

modules/data/load/loaders.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
import torch_geometric
66
from omegaconf import DictConfig
77

8+
# silent RDKit warnings
9+
from rdkit import Chem, RDLogger
10+
811
from modules.data.load.base import AbstractLoader
912
from modules.data.utils.concat2geometric_dataset import (
1013
ConcatToGeometricDataset,
@@ -14,9 +17,12 @@
1417
load_cell_complex_dataset,
1518
load_hypergraph_pickle_dataset,
1619
load_manual_graph,
20+
load_manual_mol,
1721
load_simplicial_dataset,
1822
)
1923

24+
RDLogger.DisableLog("rdApp.*")
25+
2026

2127
class GraphLoader(AbstractLoader):
2228
r"""Loader for graph datasets.
@@ -31,6 +37,15 @@ def __init__(self, parameters: DictConfig):
3137
super().__init__(parameters)
3238
self.parameters = parameters
3339

40+
def is_valid_smiles(self, smiles):
41+
"""Check if a SMILES string is valid using RDKit."""
42+
mol = Chem.MolFromSmiles(smiles)
43+
return mol is not None
44+
45+
def filter_qm9_dataset(self, dataset):
46+
"""Filter the QM9 dataset to remove invalid SMILES strings."""
47+
return [data for data in dataset if self.is_valid_smiles(data.smiles)]
48+
3449
def load(self) -> torch_geometric.data.Dataset:
3550
r"""Load graph dataset.
3651
@@ -108,10 +123,20 @@ def load(self) -> torch_geometric.data.Dataset:
108123
dataset = datasets[0] + datasets[1] + datasets[2]
109124
dataset = ConcatToGeometricDataset(dataset)
110125

126+
elif self.parameters.data_name == "QM9":
127+
dataset = torch_geometric.datasets.QM9(root=root_data_dir)
128+
# Filter the QM9 dataset to remove invalid SMILES strings
129+
valid_dataset = self.filter_qm9_dataset(dataset)
130+
dataset = CustomDataset(valid_dataset, self.data_dir)
131+
111132
elif self.parameters.data_name in ["manual"]:
112133
data = load_manual_graph()
113134
dataset = CustomDataset([data], self.data_dir)
114135

136+
elif self.parameters.data_name in ["manual_rings"]:
137+
data = load_manual_mol()
138+
dataset = CustomDataset([data], self.data_dir)
139+
115140
else:
116141
raise NotImplementedError(
117142
f"Dataset {self.parameters.data_name} not implemented"

modules/data/utils/utils.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,104 @@ def load_manual_graph():
333333
y=torch.tensor(y),
334334
)
335335

336+
def load_manual_mol():
337+
"""Create a manual graph for testing the ring implementation.
338+
Actually is the 471 molecule of QM9 dataset."""
339+
# Define the vertices
340+
vertices = [i for i in range(12)]
341+
y = torch.tensor([[ 2.2569e+00, 4.5920e+01, -6.3076e+00, 1.9211e+00, 8.2287e+00,
342+
4.6414e+02, 2.6121e+00, -8.3351e+03, -8.3349e+03, -8.3349e+03,
343+
-8.3359e+03, 2.0187e+01, -4.8740e+01, -4.9057e+01, -4.9339e+01,
344+
-4.5375e+01, 6.5000e+00, 3.8560e+00, 3.0122e+00]])
345+
346+
# Define the edges
347+
edges = [
348+
[0, 1],
349+
[0, 6],
350+
[1, 0],
351+
[1, 2],
352+
[1, 3],
353+
[1, 5],
354+
[2, 1],
355+
[2, 3],
356+
[2, 7],
357+
[2, 8],
358+
[3, 1],
359+
[3, 2],
360+
[3, 4],
361+
[3, 9],
362+
[4, 3],
363+
[4, 5],
364+
[5, 1],
365+
[5, 4],
366+
[5, 10],
367+
[5, 11],
368+
[6, 0],
369+
[7, 2],
370+
[8, 2],
371+
[9, 3],
372+
[10, 5],
373+
[11, 5],
374+
]
375+
376+
# Add smile
377+
smiles = "[H]O[C@@]12C([H])([H])O[C@]1([H])C2([H])[H]"
378+
379+
# # Create a graph
380+
# G = nx.Graph()
381+
382+
# # Add vertices
383+
# G.add_nodes_from(vertices)
384+
385+
# # Add edges
386+
# G.add_edges_from(edges)
387+
388+
# G.to_undirected()
389+
# edge_list = torch.Tensor(list(G.edges())).T.long()
390+
391+
x = [
392+
[0.0, 0.0, 0.0, 1.0, 0.0, 8.0, 0.0, 0.0, 0.0, 0.0, 1.0],
393+
[0.0, 1.0, 0.0, 0.0, 0.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0],
394+
[0.0, 1.0, 0.0, 0.0, 0.0, 6.0, 0.0, 0.0, 0.0, 0.0, 2.0],
395+
[0.0, 1.0, 0.0, 0.0, 0.0, 6.0, 0.0, 0.0, 0.0, 0.0, 1.0],
396+
[0.0, 0.0, 0.0, 1.0, 0.0, 8.0, 0.0, 0.0, 0.0, 0.0, 0.0],
397+
[0.0, 1.0, 0.0, 0.0, 0.0, 6.0, 0.0, 0.0, 0.0, 0.0, 2.0],
398+
[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
399+
[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
400+
[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
401+
[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
402+
[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
403+
[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
404+
]
405+
406+
pos = torch.tensor(
407+
[
408+
[-0.0520, 1.4421, 0.0438],
409+
[-0.0146, 0.0641, 0.0278],
410+
[-0.2878, -0.7834, -1.1968],
411+
[-1.1365, -0.9394, 0.0399],
412+
[-0.4768, -1.7722, 0.9962],
413+
[ 0.6009, -0.8025, 1.1266],
414+
[ 0.6168, 1.7721, -0.5660],
415+
[-0.7693, -0.2348, -2.0014],
416+
[ 0.3816, -1.5834, -1.5029],
417+
[-2.2159, -0.8594, 0.0798],
418+
[ 1.5885, -1.2463, 0.9538],
419+
[ 0.5680, -0.3171, 2.1084]
420+
]
421+
)
422+
423+
assert len(x) == len(vertices)
424+
assert len(pos) == len(vertices)
425+
426+
return torch_geometric.data.Data(
427+
x=torch.tensor(x).float(),
428+
edge_index=torch.tensor(edges).T.long(),
429+
num_nodes=len(vertices),
430+
y=torch.tensor(y),
431+
smiles=smiles,
432+
pos=pos
433+
)
336434

337435
def get_Planetoid_pyg(cfg):
338436
r"""Loads Planetoid graph datasets from torch_geometric.

modules/models/combinatorial/hmc.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import torch
2+
from topomodelx.nn.combinatorial.hmc import HMC
3+
4+
5+
class HMCModel(torch.nn.Module):
6+
r"""HMC model that runs over combinatorial Complexes (CCC)
7+
8+
Parameters
9+
----------
10+
model_config : Dict | DictConfig
11+
Model configuration.
12+
dataset_config : Dict | DictConfig
13+
Dataset configuration.
14+
"""
15+
16+
def __init__(self, model_config, dataset_config):
17+
in_channels = (
18+
dataset_config["num_features"]
19+
if isinstance(dataset_config["num_features"], int)
20+
else dataset_config["num_features"][0]
21+
)
22+
negative_slope = model_config["negative_slope"]
23+
hidden_channels = model_config["hidden_channels"]
24+
out_channels = dataset_config["num_classes"]
25+
n_layers = model_config["n_layers"]
26+
super().__init__()
27+
channels_per_layer = []
28+
29+
for layer in range(n_layers):
30+
in_channels_l = []
31+
int_channels_l = []
32+
out_channels_l = []
33+
34+
for _ in range(3): # only 3 ranks
35+
# First layer behavior
36+
if layer == 0:
37+
in_channels_l.append(in_channels)
38+
else:
39+
in_channels_l.append(hidden_channels)
40+
int_channels_l.append(hidden_channels)
41+
out_channels_l.append(hidden_channels)
42+
43+
channels_per_layer.append((in_channels_l, int_channels_l, out_channels_l))
44+
45+
self.base_model = HMC(
46+
channels_per_layer=channels_per_layer,
47+
negative_slope=negative_slope
48+
)
49+
self.linear_0 = torch.nn.Linear(hidden_channels, out_channels)
50+
self.linear_1 = torch.nn.Linear(hidden_channels, out_channels)
51+
self.linear_2 = torch.nn.Linear(hidden_channels, out_channels)
52+
53+
def forward(self, data):
54+
r"""Forward pass of the model.
55+
56+
Parameters
57+
----------
58+
data : torch_geometric.data.Data
59+
Input data.
60+
61+
Returns
62+
-------
63+
torch.Tensor
64+
Output tensor.
65+
"""
66+
x_0 = data.x_0
67+
x_1 = data.x_1
68+
x_2 = data.x_2
69+
adj_0 = data["adjacency_0"]
70+
adj_1 = data["adjacency_1"]
71+
adj_2 = data["adjacency_2"]
72+
inc_1 = data["incidence_1"]
73+
inc_2 = data["incidence_2"]
74+
75+
x_0, x_1, x_2 = self.base_model(
76+
x_0,
77+
x_1,
78+
x_2,
79+
adj_0,
80+
adj_1,
81+
adj_2,
82+
inc_1,
83+
inc_2,
84+
)
85+
86+
x_0 = self.linear_0(x_0)
87+
x_1 = self.linear_1(x_1)
88+
x_2 = self.linear_2(x_2)
89+
return (x_0, x_1, x_2)

0 commit comments

Comments
 (0)