Skip to content

Commit 6fb622c

Browse files
authored
Merge pull request #22 from aidos-lab/example-gnn-training-er
Examples gnn training, data distribution visualization.
2 parents 5b8e3d6 + 7b576c2 commit 6fb622c

File tree

9 files changed

+890
-5
lines changed

9 files changed

+890
-5
lines changed

.github/workflows/create_docs.yaml

-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ jobs:
2424
pip install torch==2.2.0
2525
pip install numpy
2626
pip install git+https://github.com/Lezcano/geotorch/
27-
pip install pdoc
2827
2928
- name: Install mantra
3029
run: |

docs/source/conf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# -- General configuration ---------------------------------------------------
1414
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
1515

16-
extensions = ["sphinx.ext.autodoc", "sphinx_rtd_theme", "myst_parser"]
16+
extensions = ["sphinx.ext.autodoc", "sphinx_rtd_theme", "myst_parser","nbsphinx"]
1717

1818
myst_enable_extensions = [
1919
"amsmath",

docs/source/index.rst

+10-2
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,18 @@
22

33
.. toctree::
44
:hidden:
5-
:caption: Modules
6-
5+
:caption: Modules
6+
77
datasets
88

9+
.. toctree::
10+
:hidden:
11+
:caption: Examples:
12+
13+
notebooks/train_gnn.ipynb
14+
15+
16+
917
.. toctree::
1018
:hidden:
1119
:caption: Licence

docs/source/notebooks/train_gnn.ipynb

+224
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Training a GNN on the Mantra Dataset\n",
8+
"\n",
9+
"In this tutorial, we provide an example use-case for the mantra dataset. We show \n",
10+
"how to train a GNN to predict the orientability based on random node features. \n",
11+
"\n",
12+
"The `torch-geometric` interface for the MANTRA dataset can be installed with \n",
13+
"pip via the command \n",
14+
"```{python}\n",
15+
"pip install mantra\n",
16+
"```\n",
17+
"\n",
18+
"As a preprocessing step we apply three transforms to the base dataset.\n",
19+
"Since the dataset does not have intrinsic coordinates attached to the vertices, \n",
20+
"we first have to create a transform that generates random node features.\n",
21+
"Each manifold in MANTRA comes as a list of triples, where the integers in each \n",
22+
"triple are vertex id's. The starting id in each manifold is $1$ and has to be \n",
23+
"converted to a torch-geometric compliant $0$-based index.\n",
24+
"GNN's are typically trained on graphs and the FaceToEdge transform converts our\n",
25+
"manifold to a graph. \n",
26+
"\n",
27+
"For each of the transforms we use a single class and are succesively applied to\n",
28+
"form the final transformed dataset. "
29+
]
30+
},
31+
{
32+
"cell_type": "code",
33+
"execution_count": 1,
34+
"metadata": {},
35+
"outputs": [],
36+
"source": [
37+
"# Load all required packages. \n",
38+
"\n",
39+
"import torch \n",
40+
"import torch.nn.functional as F\n",
41+
"from torch import nn\n",
42+
"from torch.utils.data import random_split\n",
43+
"\n",
44+
"from torchvision.transforms import Compose\n",
45+
"\n",
46+
"from torch_geometric.loader import DataLoader\n",
47+
"from torch_geometric.transforms import Compose, FaceToEdge\n",
48+
"\n",
49+
"from torch_geometric.nn import GCNConv, global_mean_pool\n",
50+
"\n",
51+
"# Load the mantra dataset\n",
52+
"from mantra.datasets import ManifoldTriangulations\n",
53+
"\n",
54+
"class NodeIndex: \n",
55+
" def __call__(self,data):\n",
56+
" '''\n",
57+
" In the base dataset, the vertex start index is 1 and is provided as a\n",
58+
" list. The transform converts the list to a tensor and changes the start\n",
59+
" index to 0, in compliance with torch-geometric. \n",
60+
" '''\n",
61+
" data.face = torch.tensor(data.triangulation ).T- 1\n",
62+
" return data\n",
63+
"\n",
64+
"\n",
65+
"class RandomNodeFeatures: \n",
66+
" def __call__(self,data):\n",
67+
" \"\"\"\n",
68+
" We create an 8-dimensional vector with random numbers for each vertex. \n",
69+
" Often the coordinates of the graph or triangulation are tightly coupled \n",
70+
" with the structure of the graph, an assumtion we hope to tackle.\n",
71+
" \"\"\"\n",
72+
" data.x = torch.rand(size=(data.face.max()+1,8))\n",
73+
" return data\n",
74+
"\n",
75+
"\n",
76+
"# Instantiate the dataset. Following the `torch-geometric` API, we download the \n",
77+
"# dataset into the root directory. \n",
78+
"dataset = ManifoldTriangulations(root=\"./data\", manifold=\"2\", version=\"latest\",\n",
79+
" transform=Compose([\n",
80+
" NodeIndex(),\n",
81+
" RandomNodeFeatures(),\n",
82+
" FaceToEdge(remove_faces=True),\n",
83+
" ]\n",
84+
" )\n",
85+
" )\n"
86+
]
87+
},
88+
{
89+
"cell_type": "code",
90+
"execution_count": 2,
91+
"metadata": {},
92+
"outputs": [],
93+
"source": [
94+
"train_dataset, test_dataset = random_split(\n",
95+
" dataset,\n",
96+
" [0.8,0.2\n",
97+
" ],\n",
98+
" ) # type: ignore\n",
99+
"\n",
100+
"train_dataloader = DataLoader(train_dataset,batch_size=32)\n",
101+
"test_dataloader = DataLoader(test_dataset,batch_size=32)\n"
102+
]
103+
},
104+
{
105+
"cell_type": "code",
106+
"execution_count": 3,
107+
"metadata": {},
108+
"outputs": [],
109+
"source": [
110+
"class GCN(nn.Module):\n",
111+
" def __init__(self):\n",
112+
" super().__init__()\n",
113+
"\n",
114+
" self.conv_input = GCNConv(\n",
115+
" 8, 16\n",
116+
" )\n",
117+
" self.final_linear = nn.Linear(\n",
118+
" 16, 1\n",
119+
" )\n",
120+
"\n",
121+
" def forward(self, batch):\n",
122+
" x, edge_index, batch = batch.x, batch.edge_index, batch.batch\n",
123+
" \n",
124+
" # 1. Obtain node embeddings\n",
125+
" x = self.conv_input(x, edge_index)\n",
126+
" # 2. Readout layer\n",
127+
" x = global_mean_pool(x, batch) # [batch_size, hidden_channels]\n",
128+
" # 3. Apply a final classifier\n",
129+
" x = F.dropout(x, p=0.5, training=self.training)\n",
130+
" x = self.final_linear(x)\n",
131+
" return x"
132+
]
133+
},
134+
{
135+
"cell_type": "code",
136+
"execution_count": 4,
137+
"metadata": {},
138+
"outputs": [
139+
{
140+
"name": "stdout",
141+
"output_type": "stream",
142+
"text": [
143+
"Epoch 0, 0.2743515074253082\n",
144+
"Epoch 1, 0.24504387378692627\n",
145+
"Epoch 2, 0.2461807280778885\n",
146+
"Epoch 3, 0.24599741399288177\n",
147+
"Epoch 4, 0.2461780607700348\n",
148+
"Epoch 5, 0.24923910200595856\n",
149+
"Epoch 6, 0.24623213708400726\n",
150+
"Epoch 7, 0.24637295305728912\n",
151+
"Epoch 8, 0.24762295186519623\n",
152+
"Epoch 9, 0.24508829414844513\n"
153+
]
154+
}
155+
],
156+
"source": [
157+
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
158+
"model = GCN().to(device)\n",
159+
"optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)\n",
160+
"loss_fn = nn.BCEWithLogitsLoss()\n",
161+
"\n",
162+
"model.train()\n",
163+
"for epoch in range(10):\n",
164+
" for batch in train_dataloader: \n",
165+
" batch.orientable = batch.orientable.to(torch.float)\n",
166+
" batch.to(device)\n",
167+
" optimizer.zero_grad()\n",
168+
" out = model(batch)\n",
169+
" loss = loss_fn(out.squeeze(), batch.orientable)\n",
170+
" loss.backward()\n",
171+
" optimizer.step()\n",
172+
" print(f\"Epoch {epoch}, {loss.item()}\")\n"
173+
]
174+
},
175+
{
176+
"cell_type": "code",
177+
"execution_count": 5,
178+
"metadata": {},
179+
"outputs": [
180+
{
181+
"name": "stdout",
182+
"output_type": "stream",
183+
"text": [
184+
"Accuracy: 0.0825\n"
185+
]
186+
}
187+
],
188+
"source": [
189+
"correct = 0\n",
190+
"total = 0\n",
191+
"model.eval()\n",
192+
"for testbatch in test_dataloader: \n",
193+
" testbatch.to(device)\n",
194+
" pred = model(testbatch)\n",
195+
" correct += ((pred.squeeze() < 0) == testbatch.orientable).sum()\n",
196+
" total += len(testbatch)\n",
197+
"\n",
198+
"acc = int(correct) / int(total)\n",
199+
"print(f'Accuracy: {acc:.4f}')"
200+
]
201+
}
202+
],
203+
"metadata": {
204+
"kernelspec": {
205+
"display_name": "Python 3",
206+
"language": "python",
207+
"name": "python3"
208+
},
209+
"language_info": {
210+
"codemirror_mode": {
211+
"name": "ipython",
212+
"version": 3
213+
},
214+
"file_extension": ".py",
215+
"mimetype": "text/x-python",
216+
"name": "python",
217+
"nbconvert_exporter": "python",
218+
"pygments_lexer": "ipython3",
219+
"version": "3.10.11"
220+
}
221+
},
222+
"nbformat": 4,
223+
"nbformat_minor": 2
224+
}

0 commit comments

Comments
 (0)