Skip to content

KNN lifting (Pointcloud to Graph) with New dataset MNIST #19

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions .github/workflows/test_codebase.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,22 @@ name: "Testing Codebase"
on:
workflow_dispatch:
push:
branches: [main,github-actions-test]
branches: [main, github-actions-test]
paths-ignore:
- 'docs/**'
- 'README.md'
- 'LICENSE.txt'
- '.gitignore'
- "docs/**"
- "README.md"
- "LICENSE.txt"
- ".gitignore"

pull_request:
branches: [main]
paths-ignore:
- 'docs/**'
- 'README.md'
- 'LICENSE.txt'
- '.gitignore'
- "docs/**"
- "README.md"
- "LICENSE.txt"
- ".gitignore"

jobs:

pytest:
runs-on: ${{ matrix.os }}

Expand All @@ -39,12 +38,12 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
cache: "pip"
cache-dependency-path: '**/pyproject.toml'
cache-dependency-path: "**/pyproject.toml"

- name: Install PyTorch ${{ matrix.torch-version }}+cpu
run: |
pip install --upgrade pip setuptools wheel
pip install torch==${{ matrix.torch-version}} --extra-index-url https://download.pytorch.org/whl/cpu
pip install torch==${{ matrix.torch-version}} torchvision --extra-index-url https://download.pytorch.org/whl/cpu
pip install torch-scatter -f https://data.pyg.org/whl/torch-${{ matrix.torch-version }}+cpu.html
pip install torch-sparse -f https://data.pyg.org/whl/torch-${{ matrix.torch-version }}+cpu.html
pip install torch-cluster -f https://data.pyg.org/whl/torch-${{ matrix.torch-version }}+cpu.html
Expand All @@ -60,4 +59,4 @@ jobs:
with:
token: ${{ secrets.CODECOV_TOKEN }}
file: coverage.xml
fail_ci_if_error: false
fail_ci_if_error: false
20 changes: 10 additions & 10 deletions .github/workflows/test_tutorials.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@ name: "Testing Tutorials"
on:
workflow_dispatch:
push:
branches: [main,github-actions-test]
branches: [main, github-actions-test]
paths-ignore:
- 'docs/**'
- 'README.md'
- 'LICENSE.txt'
- '.gitignore'
- "docs/**"
- "README.md"
- "LICENSE.txt"
- ".gitignore"

pull_request:
branches: [main]
paths-ignore:
- 'docs/**'
- 'README.md'
- 'LICENSE.txt'
- '.gitignore'
- "docs/**"
- "README.md"
- "LICENSE.txt"
- ".gitignore"

# Disable debugger's warnings from nbconvert in test_tutorials.py
env:
Expand Down Expand Up @@ -47,7 +47,7 @@ jobs:
- name: Install PyTorch ${{ matrix.torch-version }}+cpu
run: |
pip install --upgrade pip setuptools wheel
pip install torch==${{ matrix.torch-version}} --extra-index-url https://download.pytorch.org/whl/cpu
pip install torch==${{ matrix.torch-version}} torchvision --extra-index-url https://download.pytorch.org/whl/cpu
pip install torch-scatter -f https://data.pyg.org/whl/torch-${{ matrix.torch-version }}+cpu.html
pip install torch-sparse -f https://data.pyg.org/whl/torch-${{ matrix.torch-version }}+cpu.html
pip install torch-cluster -f https://data.pyg.org/whl/torch-${{ matrix.torch-version }}+cpu.html
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Then:
4. Install torch, torch-scatter, torch-sparse with or without CUDA depending on your needs.

```bash
pip install torch==2.0.1 --extra-index-url https://download.pytorch.org/whl/${CUDA}
pip install torch==2.0.1 torchvision --extra-index-url https://download.pytorch.org/whl/${CUDA}
pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.1+${CUDA}.html
pip install torch-cluster -f https://data.pyg.org/whl/torch-2.0.0+${CUDA}.html
```
Expand Down
13 changes: 13 additions & 0 deletions configs/datasets/MNIST.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
data_domain: point_cloud
data_type: MNIST
data_name: MNIST
data_dir: datasets/${data_domain}/${data_type}
#data_split_dir: ${oc.env:PROJECT_ROOT}/datasets/data_splits/${data_name}

# Dataset parameters
num_features: 1
num_classes: 10
task: classification
loss_type: cross_entropy
monitor_metric: accuracy
task_level: graph
4 changes: 4 additions & 0 deletions configs/models/graph/gcn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
in_channels: null # This will be set by the dataset
hidden_channels: 32
out_channels: null # This will be set by the dataset
n_layers: 2
5 changes: 5 additions & 0 deletions configs/transforms/liftings/pointcloud2graph/knn_lifting.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
transform_type: "lifting"
transform_name: "GraphKNNLifting"
k_value: 5
loop: True
feature_lifting: ProjectionSum
7 changes: 7 additions & 0 deletions modules/data/load/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
load_hypergraph_pickle_dataset,
load_manual_graph,
load_simplicial_dataset,
mnist_to_pointcloud,
)


Expand Down Expand Up @@ -108,6 +109,12 @@ def load(self) -> torch_geometric.data.Dataset:
data = load_manual_graph()
dataset = CustomDataset([data], self.data_dir)

elif self.parameters.data_name in ["MNIST"]:
data = mnist_to_pointcloud(self.parameters)
dataset = CustomDataset(
data, self.data_dir, data_name=self.parameters.data_name
)

else:
raise NotImplementedError(
f"Dataset {self.parameters.data_name} not implemented"
Expand Down
17 changes: 17 additions & 0 deletions modules/data/utils/concat2geometric_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,31 @@


class ConcatToGeometricDataset(Dataset):
r"""Concatenates list of PyTorch Geometric Data objects to form a dataset."""

def __init__(self, concat_dataset):
r"""Initializes the dataset.

Parameters
----------
concat_dataset : list
List of PyTorch Geometric Data objects to be concatenated.
"""
super().__init__()
self.concat_dataset = concat_dataset

def len(self):
r"""Returns the length of the dataset."""
return len(self.concat_dataset)

def get(self, idx):
r"""Returns the PyTorch Geometric Data object at the specified index.

Parameters
----------
idx : int
Index of the data object to be returned.
"""
data = self.concat_dataset[idx]

x = data.x.float()
Expand Down
6 changes: 5 additions & 1 deletion modules/data/utils/custom_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@


class CustomDataset(torch_geometric.data.InMemoryDataset):
def __init__(self, data_list, data_dir, transform=None):
def __init__(self, data_list, data_dir, data_name="Custom", transform=None):
self.data_list = data_list
self.data_name = data_name
super().__init__(data_dir, transform)
self.load(self.processed_paths[0])

Expand All @@ -13,3 +14,6 @@ def processed_file_names(self):

def process(self):
self.save(self.data_list, self.processed_paths[0])

def __repr__(self):
return f"{self.data_name}({len(self)})"
51 changes: 43 additions & 8 deletions modules/data/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,16 @@ def get_complex_connectivity(complex, max_rank, signed=False):
)
except ValueError: # noqa: PERF203
if connectivity_info == "incidence":
connectivity[f"{connectivity_info}_{rank_idx}"] = (
generate_zero_sparse_connectivity(
m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx]
)
connectivity[
f"{connectivity_info}_{rank_idx}"
] = generate_zero_sparse_connectivity(
m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx]
)
else:
connectivity[f"{connectivity_info}_{rank_idx}"] = (
generate_zero_sparse_connectivity(
m=practical_shape[rank_idx], n=practical_shape[rank_idx]
)
connectivity[
f"{connectivity_info}_{rank_idx}"
] = generate_zero_sparse_connectivity(
m=practical_shape[rank_idx], n=practical_shape[rank_idx]
)
connectivity["shape"] = practical_shape
return connectivity
Expand Down Expand Up @@ -421,3 +421,38 @@ def make_hash(o):
hash_as_hex = sha1.hexdigest()
# Convert the hex back to int and restrict it to the relevant int range
return int(hash_as_hex, 16) % 4294967295


def mnist_to_pointcloud(cfg, threshold=0.0):
r"""Convert MNIST dataset into point cloud.

Parameters
----------
dataset : torchvision.datasets.MNIST
MNIST dataset.
threshold : float
Threshold value to filter out pixel values.

Returns
-------
list
List containing the point cloud.
"""
from torchvision.datasets import MNIST

dataset = MNIST(root=cfg["data_dir"], download=True)
Xs = dataset.data
ys = dataset.targets

point_clouds = []
for i, sample in enumerate(Xs):
sample = sample.where(sample > threshold, 0)
values = sample[sample > threshold].float()
values = values.unsqueeze(1)
x, y = (sample > 0).nonzero(as_tuple=True)
pos = torch.stack((x, y), dim=1).float()

data = Data(x=values, pos=pos, y=ys[i])
point_clouds.append(data)

return point_clouds
80 changes: 80 additions & 0 deletions modules/models/graph/gcn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import torch
from torch import Tensor
from torch_geometric.nn.models import GCN
from torch_geometric.utils import scatter


def global_mean_pool(x, batch=None, size=None) -> Tensor:
r"""Returns batch-wise graph-level-outputs by averaging node features
across the node dimension.

For a single graph :math:`\mathcal{G}_i`, its output is computed by

.. math::
\mathbf{r}_i = \frac{1}{N_i} \sum_{n=1}^{N_i} \mathbf{x}_n.

Functional method of the
:class:`~torch_geometric.nn.aggr.MeanAggregation` module.

Parameters
----------
x : torch.Tensor
Node feature matrix :math:`\mathbf{X}`.
batch : torch.Tensor, optional
The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`,
which assigns each node to a specific example.
size : int, optional
The number of examples :math:`B`. Automatically calculated if not given.
"""
dim = -1 if isinstance(x, Tensor) and x.dim() == 1 else -2

if batch is None:
return x.mean(dim=dim, keepdim=x.dim() <= 2)
return scatter(x, batch, dim=dim, dim_size=size, reduce="mean")


class GCNModel(torch.nn.Module):
r"""A simple GCN model that runs over graph data.
Note that some parameters are defined by the considered dataset.

Parameters
----------
model_config : Dict | DictConfig
Model configuration.
dataset_config : Dict | DictConfig
Dataset configuration.
"""

def __init__(self, model_config, dataset_config):
in_channels = (
dataset_config["num_features"]
if isinstance(dataset_config["num_features"], int)
else dataset_config["num_features"][0]
)
hidden_channels = model_config["hidden_channels"]
out_channels = dataset_config["num_classes"]
n_layers = model_config["n_layers"]
super().__init__()
self.base_model = GCN(
in_channels=in_channels,
hidden_channels=hidden_channels,
out_channels=out_channels,
num_layers=n_layers,
)
self.pool = global_mean_pool

def forward(self, data):
r"""Forward pass of the model.

Parameters
----------
data : torch_geometric.data.Data
Input data.

Returns
-------
torch.Tensor
Output tensor.
"""
z = self.base_model(data.x, data.edge_index)
return self.pool(z)
3 changes: 3 additions & 0 deletions modules/transforms/data_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from modules.transforms.liftings.graph2simplicial.clique_lifting import (
SimplicialCliqueLifting,
)
from modules.transforms.liftings.pointcloud2graph.knn_lifting import GraphKNNLifting

TRANSFORMS = {
# Graph -> Hypergraph
Expand All @@ -31,6 +32,8 @@
"OneHotDegreeFeatures": OneHotDegreeFeatures,
"NodeFeaturesToFloat": NodeFeaturesToFloat,
"KeepOnlyConnectedComponent": KeepOnlyConnectedComponent,
# Point Cloud -> Graph
"GraphKNNLifting": GraphKNNLifting,
}


Expand Down
Loading
Loading