Skip to content

Molecule Ring & Close Atoms Lifting (Graph to Combinatorial) #16

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 62 commits into from
Feb 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
62ce27f
update
bertranMiquel May 13, 2024
e8a35dc
First ring-based commit
bertranMiquel Jun 13, 2024
1385b4e
Merge remote-tracking branch 'upstream/main'
bertranMiquel Jun 13, 2024
82c6b25
Readme update
bertranMiquel Jun 13, 2024
4be4aa6
mol ring lifting correctation
bertranMiquel Jun 13, 2024
02ea2bf
Correct more errors
bertranMiquel Jun 13, 2024
a83f484
Fix lifting details
bertranMiquel Jun 13, 2024
cd936f7
Hope it is the last one
bertranMiquel Jun 13, 2024
431e246
Deleting other test since they are using other test dataset
bertranMiquel Jun 13, 2024
ae365b5
git ignore modification
bertranMiquel Jun 13, 2024
b201b54
minor modifications
bertranMiquel Jun 14, 2024
917799b
Update ring_lifting.py
bertranMiquel Jun 14, 2024
a7534a8
Ring lifting modifications
bertranMiquel Jun 15, 2024
b884a92
Update dependencies
bertranMiquel Jun 18, 2024
43c3582
First try of rings + close_atoms. Difficulties on lifting the topology.
bertranMiquel Jun 19, 2024
75fc664
Update with tests
bertranMiquel Jun 20, 2024
827d15d
First try of combinatorial complexes.
bertranMiquel Jun 20, 2024
1625f93
Ending combinatorial implementation
bertranMiquel Jun 20, 2024
dd9cc50
Final combinatorial implementation.
bertranMiquel Jun 20, 2024
7f6c207
Readme update
bertranMiquel Jun 20, 2024
7ba809f
Solve errors
bertranMiquel Jun 20, 2024
790347b
Fixing errors
bertranMiquel Jun 20, 2024
4eff037
utils solve
bertranMiquel Jun 20, 2024
5404c3d
utils update
bertranMiquel Jun 20, 2024
e9c93af
errors solved
bertranMiquel Jun 20, 2024
850c7ee
Applied --fix option
bertranMiquel Jun 20, 2024
bc3941d
Reduce dataset
bertranMiquel Jun 20, 2024
9e063c1
Update incidence hyperedge matrix
bertranMiquel Jun 21, 2024
6c22934
Update README.md
bertranMiquel Jun 26, 2024
c4c31f7
Adding attributes
bertranMiquel Jul 1, 2024
e6d923e
Merge remote-tracking branch 'origin/main' into funcional
bertranMiquel Jul 1, 2024
d8d50b5
Finish attributes
bertranMiquel Jul 2, 2024
28ab854
ruff modifications
bertranMiquel Jul 2, 2024
0418e52
add ring sanity check
bertranMiquel Jul 2, 2024
2a48f01
check sanitizer
bertranMiquel Jul 2, 2024
2793dfb
Notebook modify
bertranMiquel Jul 8, 2024
d52cb57
add test and combinatorial nn
bertranMiquel Jul 8, 2024
16d664d
test + combinatorial nn
bertranMiquel Jul 8, 2024
634eba8
Change load manual test data
bertranMiquel Jul 8, 2024
e8240dd
Solve test
bertranMiquel Jul 8, 2024
309af34
Solve test adding position to manual data
bertranMiquel Jul 8, 2024
86e557e
Insert asserts
bertranMiquel Jul 8, 2024
7f05c0e
Rectify asserts
bertranMiquel Jul 8, 2024
b4b8a50
Change edge definition
bertranMiquel Jul 8, 2024
f78dc96
Change hyperedge definition
bertranMiquel Jul 8, 2024
4a32fad
add logging
bertranMiquel Jul 8, 2024
f83527e
ruff checks
bertranMiquel Jul 8, 2024
29b068f
Update hyperedges with non-repeated close atoms
bertranMiquel Jul 8, 2024
d55f661
Update incidence hyperedges
bertranMiquel Jul 9, 2024
fee249c
Update incidence hyperedges
bertranMiquel Jul 9, 2024
57877df
Print incidence hyperedge
bertranMiquel Jul 9, 2024
bf1a015
Sort hyperedges
bertranMiquel Jul 9, 2024
e312401
Update incidence hyperedge in test
bertranMiquel Jul 9, 2024
3b938e1
Update incidence hyperedge sorting
bertranMiquel Jul 9, 2024
794e71d
New test
bertranMiquel Jul 9, 2024
c14a227
ruff fixes
bertranMiquel Jul 9, 2024
5e12c61
Revert "ruff fixes"
bertranMiquel Jul 9, 2024
57befdb
ruff fixes
bertranMiquel Jul 9, 2024
7d3ea20
Refine code
bertranMiquel Jul 9, 2024
65fb3ad
Update load manual mol name
bertranMiquel Jul 9, 2024
3ab3aae
Change load manual prot to mol name
bertranMiquel Jul 9, 2024
a689099
Merge branch 'main' into rings_close_atoms
gbg141 Feb 15, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 17 additions & 73 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,78 +1,22 @@
# ICML Topological Deep Learning Challenge 2024: Beyond the Graph Domain
Official repository for the Topological Deep Learning Challenge 2024, jointly organized by [TAG-DS](https://www.tagds.com) & PyT-Team and hosted by the [Geometry-grounded Representation Learning and Generative Modeling (GRaM) Workshop](https://gram-workshop.github.io) at ICML 2024.
# Molecule Ring & Close Atoms Lifting (Graph to Combinatorial)
This notebook imports QM9 dataset and applies a lifting from a graph molecular representation to a combinatorial complex. Then, a neural network is run using the loaded data.

## Relevant Information
- The deadline is **July 12th, 2024 (Anywhere on Earth)**. Participants are welcome to modify their submissions until this time.
- Please, check out the [main webpage of the challenge](https://pyt-team.github.io/packs/challenge.html) for the full description of the competition (motivation, submission requirements, evaluation, etc.)
Using [QM9 dataset](https://paperswithcode.com/dataset/qm9), we implement a lifting from a molecule graph to a combinatorial complex based on two points:
- The ring information of the molecule. Rings will be represented as 2-cells in the combinatorial complex.
- The distance between atoms in the molecule. Distances between atoms will be computed. If the atoms are under a predefined threshold, they will be considered as close and groupped together. This clusters will be introduced as hyperedges in the combinatorial complex.

## Brief Description
The main purpose of the challenge is to further expand the current scope and impact of Topological Deep Learning (TDL), enabling the exploration of its applicability in new contexts and scenarios. To do so, we propose participants to design and implement lifting mappings between different data structures and topological domains (point-clouds, graphs, hypergraphs, simplicial/cell/combinatorial complexes), potentially bridging the gap between TDL and all kinds of existing datasets.
So far, to the best of our knowledge, it is the first implementation of a molecule as a combinatorial complex, combining both hypergraphs and cell complexes.

Here, the elements are the following:
- **Nodes**: Atoms in the molecule.
- **Edges**: Bonds between atoms.
- **Hyperedges**: Clusters of atoms that are close to each other.
- **2-cells**: Rings in the molecule.

## General Guidelines
Everyone can participate and participation is free --only principal PyT-Team developers are excluded. It is sufficient to:
- Send a valid Pull Request (i.e. passing all tests) before the deadline.
- Respect Submission Requirements (see below).
Additionally, attributes inspired by those used in [(Battiloro et al., 2024)](https://arxiv.org/abs/2405.15429) are incorporated into the elements, enhancing the representation of the molecule.
The attributes are:
- **Node**: Atom type, atomic number, and chirality.
- **Edge**: Bond type, conjugation and stereochemistry.
- **Rings**: Ring size, aromaticity, heteroatoms, saturation, hydrophobicity, electrophilicity, nucleophilicity, and polarity.

Teams are accepted, and there is no restriction on the number of team members. An acceptable Pull Request automatically subscribes a participant/team to the challenge.

We encourage participants to start submitting their Pull Request early on, as this helps addressing potential issues with the code. Moreover, earlier Pull Requests will be given priority consideration in the case of multiple submissions of similar quality implementing the same lifting.

A Pull Request should contain no more than one lifting. However, there is no restriction on the number of submissions (Pull Requests) per participant/team.

## Basic Setup
To develop on your machine, here are some tips.

First, we recommend using Python 3.11.3, which is the python version used to run the unit-tests.

For example, create a conda environment:
```bash
conda create -n topox python=3.11.3
conda activate topox
```

Then:

1. Clone a copy of tmx from source:

```bash
git clone [email protected]:pyt-team/challenge-icml-2024.git
cd challenge-icml-2024
```

2. Install tmx in editable mode:

```bash
pip install -e '.[all]'
```
**Notes:**
- Requires pip >= 21.3. Refer: [PEP 660](https://peps.python.org/pep-0660/).
- On Windows, use `pip install -e .[all]` instead (without quotes around `[all]`).

4. Install torch, torch-scatter, torch-sparse with or without CUDA depending on your needs.

```bash
pip install torch==2.0.1 --extra-index-url https://download.pytorch.org/whl/${CUDA}
pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.1+${CUDA}.html
pip install torch-cluster -f https://data.pyg.org/whl/torch-2.0.0+${CUDA}.html
```

where `${CUDA}` should be replaced by either `cpu`, `cu102`, `cu113`, or `cu115` depending on your PyTorch installation (`torch.version.cuda`).

5. Ensure that you have a working tmx installation by running the entire test suite with

```bash
pytest
```

In case an error occurs, please first check if all sub-packages ([`torch-scatter`](https://github.com/rusty1s/pytorch_scatter), [`torch-sparse`](https://github.com/rusty1s/pytorch_sparse), [`torch-cluster`](https://github.com/rusty1s/pytorch_cluster) and [`torch-spline-conv`](https://github.com/rusty1s/pytorch_spline_conv)) are on its latest reported version.

6. Install pre-commit hooks:

```bash
pre-commit install
```

## Questions

Feel free to contact us through GitHub issues on this repository, or through the [Geometry and Topology in Machine Learning slack](https://tda-in-ml.slack.com/join/shared_invite/enQtOTIyMTIyNTYxMTM2LTA2YmQyZjVjNjgxZWYzMDUyODY5MjlhMGE3ZTI1MzE4NjI2OTY0MmUyMmQ3NGE0MTNmMzNiMTViMjM2MzE4OTc#/). Alternatively, you can contact us via mail at any of these accounts: [email protected], [email protected].
This pull request is done under the team formed by: Bertran Miquel Oliver, Manel Gil Sorribes, Alexis Molina
14 changes: 14 additions & 0 deletions configs/datasets/QM9.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
data_domain: graph
data_type: QM9
data_name: QM9
data_dir: datasets/${data_domain}/${data_type}
#data_split_dir: ${oc.env:PROJECT_ROOT}/datasets/data_splits/${data_name}

# Dataset parameters
num_features: 11
num_classes: 1
task: regression
loss_type: mse
monitor_metric: mae
task_level: graph

12 changes: 12 additions & 0 deletions configs/datasets/manual_mol.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
data_domain: graph
data_type: toy_dataset
data_name: manual_mol
data_dir: datasets/${data_domain}/${data_type}

# Dataset parameters
num_features: 1
num_classes: 2
task: classification
loss_type: cross_entropy
monitor_metric: accuracy
task_level: node
5 changes: 5 additions & 0 deletions configs/models/combinatorial/hmc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
in_channels: null # This will be set by the dataset
hidden_channels: 32
out_channels: null # This will be set by the dataset
n_layers: 2
negative_slope: 0.2
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
transform_type: 'lifting'
transform_name: "CombinatorialRingCloseAtomsLifting"
max_cell_length: null
preserve_edge_attr: True
feature_lifting: ProjectionSum

threshold_distance: 1.5
25 changes: 25 additions & 0 deletions modules/data/load/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import torch_geometric
from omegaconf import DictConfig

# silent RDKit warnings
from rdkit import Chem, RDLogger

from modules.data.load.base import AbstractLoader
from modules.data.utils.concat2geometric_dataset import (
ConcatToGeometricDataset,
Expand All @@ -14,9 +17,12 @@
load_cell_complex_dataset,
load_hypergraph_pickle_dataset,
load_manual_graph,
load_manual_mol,
load_simplicial_dataset,
)

RDLogger.DisableLog("rdApp.*")


class GraphLoader(AbstractLoader):
r"""Loader for graph datasets.
Expand All @@ -31,6 +37,15 @@ def __init__(self, parameters: DictConfig):
super().__init__(parameters)
self.parameters = parameters

def is_valid_smiles(self, smiles):
"""Check if a SMILES string is valid using RDKit."""
mol = Chem.MolFromSmiles(smiles)
return mol is not None

def filter_qm9_dataset(self, dataset):
"""Filter the QM9 dataset to remove invalid SMILES strings."""
return [data for data in dataset if self.is_valid_smiles(data.smiles)]

def load(self) -> torch_geometric.data.Dataset:
r"""Load graph dataset.

Expand Down Expand Up @@ -108,10 +123,20 @@ def load(self) -> torch_geometric.data.Dataset:
dataset = datasets[0] + datasets[1] + datasets[2]
dataset = ConcatToGeometricDataset(dataset)

elif self.parameters.data_name == "QM9":
dataset = torch_geometric.datasets.QM9(root=root_data_dir)
# Filter the QM9 dataset to remove invalid SMILES strings
valid_dataset = self.filter_qm9_dataset(dataset)
dataset = CustomDataset(valid_dataset, self.data_dir)

elif self.parameters.data_name in ["manual"]:
data = load_manual_graph()
dataset = CustomDataset([data], self.data_dir)

elif self.parameters.data_name in ["manual_rings"]:
data = load_manual_mol()
dataset = CustomDataset([data], self.data_dir)

else:
raise NotImplementedError(
f"Dataset {self.parameters.data_name} not implemented"
Expand Down
98 changes: 98 additions & 0 deletions modules/data/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,104 @@ def load_manual_graph():
y=torch.tensor(y),
)

def load_manual_mol():
"""Create a manual graph for testing the ring implementation.
Actually is the 471 molecule of QM9 dataset."""
# Define the vertices
vertices = [i for i in range(12)]
y = torch.tensor([[ 2.2569e+00, 4.5920e+01, -6.3076e+00, 1.9211e+00, 8.2287e+00,
4.6414e+02, 2.6121e+00, -8.3351e+03, -8.3349e+03, -8.3349e+03,
-8.3359e+03, 2.0187e+01, -4.8740e+01, -4.9057e+01, -4.9339e+01,
-4.5375e+01, 6.5000e+00, 3.8560e+00, 3.0122e+00]])

# Define the edges
edges = [
[0, 1],
[0, 6],
[1, 0],
[1, 2],
[1, 3],
[1, 5],
[2, 1],
[2, 3],
[2, 7],
[2, 8],
[3, 1],
[3, 2],
[3, 4],
[3, 9],
[4, 3],
[4, 5],
[5, 1],
[5, 4],
[5, 10],
[5, 11],
[6, 0],
[7, 2],
[8, 2],
[9, 3],
[10, 5],
[11, 5],
]

# Add smile
smiles = "[H]O[C@@]12C([H])([H])O[C@]1([H])C2([H])[H]"

# # Create a graph
# G = nx.Graph()

# # Add vertices
# G.add_nodes_from(vertices)

# # Add edges
# G.add_edges_from(edges)

# G.to_undirected()
# edge_list = torch.Tensor(list(G.edges())).T.long()

x = [
[0.0, 0.0, 0.0, 1.0, 0.0, 8.0, 0.0, 0.0, 0.0, 0.0, 1.0],
[0.0, 1.0, 0.0, 0.0, 0.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0, 0.0, 6.0, 0.0, 0.0, 0.0, 0.0, 2.0],
[0.0, 1.0, 0.0, 0.0, 0.0, 6.0, 0.0, 0.0, 0.0, 0.0, 1.0],
[0.0, 0.0, 0.0, 1.0, 0.0, 8.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0, 0.0, 6.0, 0.0, 0.0, 0.0, 0.0, 2.0],
[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
]

pos = torch.tensor(
[
[-0.0520, 1.4421, 0.0438],
[-0.0146, 0.0641, 0.0278],
[-0.2878, -0.7834, -1.1968],
[-1.1365, -0.9394, 0.0399],
[-0.4768, -1.7722, 0.9962],
[ 0.6009, -0.8025, 1.1266],
[ 0.6168, 1.7721, -0.5660],
[-0.7693, -0.2348, -2.0014],
[ 0.3816, -1.5834, -1.5029],
[-2.2159, -0.8594, 0.0798],
[ 1.5885, -1.2463, 0.9538],
[ 0.5680, -0.3171, 2.1084]
]
)

assert len(x) == len(vertices)
assert len(pos) == len(vertices)

return torch_geometric.data.Data(
x=torch.tensor(x).float(),
edge_index=torch.tensor(edges).T.long(),
num_nodes=len(vertices),
y=torch.tensor(y),
smiles=smiles,
pos=pos
)

def get_Planetoid_pyg(cfg):
r"""Loads Planetoid graph datasets from torch_geometric.
Expand Down
89 changes: 89 additions & 0 deletions modules/models/combinatorial/hmc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import torch
from topomodelx.nn.combinatorial.hmc import HMC


class HMCModel(torch.nn.Module):
r"""HMC model that runs over combinatorial Complexes (CCC)

Parameters
----------
model_config : Dict | DictConfig
Model configuration.
dataset_config : Dict | DictConfig
Dataset configuration.
"""

def __init__(self, model_config, dataset_config):
in_channels = (
dataset_config["num_features"]
if isinstance(dataset_config["num_features"], int)
else dataset_config["num_features"][0]
)
negative_slope = model_config["negative_slope"]
hidden_channels = model_config["hidden_channels"]
out_channels = dataset_config["num_classes"]
n_layers = model_config["n_layers"]
super().__init__()
channels_per_layer = []

for layer in range(n_layers):
in_channels_l = []
int_channels_l = []
out_channels_l = []

for _ in range(3): # only 3 ranks
# First layer behavior
if layer == 0:
in_channels_l.append(in_channels)
else:
in_channels_l.append(hidden_channels)
int_channels_l.append(hidden_channels)
out_channels_l.append(hidden_channels)

channels_per_layer.append((in_channels_l, int_channels_l, out_channels_l))

self.base_model = HMC(
channels_per_layer=channels_per_layer,
negative_slope=negative_slope
)
self.linear_0 = torch.nn.Linear(hidden_channels, out_channels)
self.linear_1 = torch.nn.Linear(hidden_channels, out_channels)
self.linear_2 = torch.nn.Linear(hidden_channels, out_channels)

def forward(self, data):
r"""Forward pass of the model.

Parameters
----------
data : torch_geometric.data.Data
Input data.

Returns
-------
torch.Tensor
Output tensor.
"""
x_0 = data.x_0
x_1 = data.x_1
x_2 = data.x_2
adj_0 = data["adjacency_0"]
adj_1 = data["adjacency_1"]
adj_2 = data["adjacency_2"]
inc_1 = data["incidence_1"]
inc_2 = data["incidence_2"]

x_0, x_1, x_2 = self.base_model(
x_0,
x_1,
x_2,
adj_0,
adj_1,
adj_2,
inc_1,
inc_2,
)

x_0 = self.linear_0(x_0)
x_1 = self.linear_1(x_1)
x_2 = self.linear_2(x_2)
return (x_0, x_1, x_2)
Loading
Loading