Skip to content

Commit 02a5b31

Browse files
authored
Merge branch 'main' into line
2 parents 9fcf718 + d8ef506 commit 02a5b31

File tree

20 files changed

+1318
-186
lines changed

20 files changed

+1318
-186
lines changed

.github/workflows/lint.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,6 @@ jobs:
1111
runs-on: ubuntu-latest
1212
steps:
1313
- uses: actions/checkout@v3
14-
- uses: chartboost/ruff-action@v1
14+
- uses: chartboost/ruff-action@v1
15+
with:
16+
src: './modules'

.github/workflows/test_codebase.yml

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,34 +27,31 @@ jobs:
2727
fail-fast: false
2828
matrix:
2929
os: [ubuntu-latest]
30-
python-version: ["3.10", "3.11"]
31-
torch-version: [2.0.1]
32-
include:
33-
- torch-version: 2.0.1
30+
python-version: [3.11.3]
3431

3532
steps:
36-
- uses: actions/checkout@v3
37-
- name: Set up Python ${{ matrix.python-version }}
38-
uses: actions/setup-python@v4
33+
- uses: actions/checkout@v4
34+
- name: Set up Python ${{matrix.python-version}}
35+
uses: actions/setup-python@v5
3936
with:
40-
python-version: ${{ matrix.python-version }}
41-
cache: "pip"
42-
cache-dependency-path: '**/pyproject.toml'
37+
python-version: ${{matrix.python-version}}
4338

44-
- name: Install PyTorch ${{ matrix.torch-version }}+cpu
45-
run: |
46-
pip install --upgrade pip setuptools wheel
47-
pip install torch==${{ matrix.torch-version}} --extra-index-url https://download.pytorch.org/whl/cpu
48-
pip install torch-scatter -f https://data.pyg.org/whl/torch-${{ matrix.torch-version }}+cpu.html
49-
pip install torch-sparse -f https://data.pyg.org/whl/torch-${{ matrix.torch-version }}+cpu.html
50-
pip install torch-cluster -f https://data.pyg.org/whl/torch-${{ matrix.torch-version }}+cpu.html
51-
pip show pip
52-
- name: Install main package
39+
- uses: actions/cache@v4
40+
with:
41+
path: ~/.cache/pip
42+
key: ${{matrix.os}}-${{matrix.python-version}}-${{ hashFiles('pyproject.toml') }}
43+
44+
- name: Install dependencies
5345
run: |
54-
pip install -e .[all]
55-
- name: Run tests for codebase [pytest]
46+
python -m pip install --upgrade pip
47+
pip install pytest
48+
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
49+
source env_setup.sh
50+
51+
- name: Test with pytest
5652
run: |
5753
pytest -n 2 --cov --cov-report=xml:coverage.xml test/transforms/feature_liftings test/transforms/liftings
54+
pytest test/tutorials/test_tutorials.py
5855
- name: Upload coverage
5956
uses: codecov/codecov-action@v3
6057
with:

.github/workflows/test_tutorials.yml

Lines changed: 0 additions & 59 deletions
This file was deleted.

.pre-commit-config.yaml

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,12 @@ repos:
1717
- id: trailing-whitespace
1818
- id: requirements-txt-fixer
1919

20-
- repo: https://github.com/psf/black
21-
rev: 23.3.0
20+
- repo: https://github.com/astral-sh/ruff-pre-commit
21+
rev: v0.4.4
2222
hooks:
23-
- id: black-jupyter
23+
- id: ruff-format
2424

25-
- repo: https://github.com/pycqa/isort
26-
rev: 5.12.0
27-
hooks:
28-
- id : isort
29-
args : ["--profile=black", "--filter-files"]
30-
31-
#- repo: https://github.com/asottile/blacken-docs
32-
# rev: 1.13.0
33-
# hooks:
34-
# - id: blacken-docs
35-
# additional_dependencies: [black==23.3.0]
36-
37-
# - repo: https://github.com/pycqa/flake8
38-
# rev: 6.0.0
25+
# - repo: https://github.com/numpy/numpydoc
26+
# rev: v1.6.0
3927
# hooks:
40-
# - id: flake8
41-
# additional_dependencies: [flake8-docstrings, Flake8-pyproject]
28+
# - id: numpydoc-validation
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
transform_type: 'lifting'
2+
transform_name: "SimplicialGraphInducedLifting"
3+
complex_dim: 3
4+
preserve_edge_attr: False
5+
signed: True
6+
feature_lifting: ProjectionSum
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
transform_type: 'lifting'
2+
transform_name: "SimplicialVietorisRipsLifting"
3+
complex_dim: 3
4+
preserve_edge_attr: False
5+
signed: True
6+
distance_threshold: 2.0
7+
feature_lifting: ProjectionSum

env_setup.sh

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#!/bin/bash -l
2+
pip install --upgrade pip
3+
pip install -e '.[all]'
4+
5+
# Note that not all combinations of torch and CUDA are available
6+
# See https://github.com/pyg-team/pyg-lib to check the configuration that works for you
7+
TORCH="2.3.0" # available options: 1.12.0, 1.13.0, 2.0.0, 2.1.0, 2.2.0, or 2.3.0
8+
CUDA="cpu" # if available, select the CUDA version suitable for your system
9+
# available options: cpu, cu102, cu113, cu116, cu117, cu118, or cu121
10+
pip install torch==${TORCH} --extra-index-url https://download.pytorch.org/whl/${CUDA}
11+
pip install pyg-lib torch-scatter torch-sparse torch-cluster -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html
12+
13+
#pytest
14+
15+
pre-commit install

modules/data/load/loaders.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
from omegaconf import DictConfig
77

88
from modules.data.load.base import AbstractLoader
9-
from modules.data.utils.concat2geometric_dataset import ConcatToGeometricDataset
9+
from modules.data.utils.concat2geometric_dataset import (
10+
ConcatToGeometricDataset,
11+
)
1012
from modules.data.utils.custom_dataset import CustomDataset
1113
from modules.data.utils.utils import (
1214
load_cell_complex_dataset,
@@ -45,7 +47,9 @@ def load(self) -> torch_geometric.data.Dataset:
4547
root_folder = rootutils.find_root()
4648
root_data_dir = os.path.join(root_folder, self.parameters["data_dir"])
4749

48-
self.data_dir = os.path.join(root_data_dir, self.parameters["data_name"])
50+
self.data_dir = os.path.join(
51+
root_data_dir, self.parameters["data_name"]
52+
)
4953
if (
5054
self.parameters.data_name.lower() in ["cora", "citeseer", "pubmed"]
5155
and self.parameters.data_type == "cocitation"

modules/transforms/data_transform.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,33 @@
88
OneHotDegreeFeatures,
99
)
1010
from modules.transforms.feature_liftings.feature_liftings import ProjectionSum
11-
from modules.transforms.liftings.graph2cell.cycle_lifting import CellCycleLifting
11+
from modules.transforms.liftings.graph2cell.cycle_lifting import (
12+
CellCycleLifting,
13+
)
1214
from modules.transforms.liftings.graph2hypergraph.knn_lifting import (
1315
HypergraphKNNLifting,
1416
)
1517
from modules.transforms.liftings.graph2simplicial.clique_lifting import (
1618
SimplicialCliqueLifting,
1719
)
20+
from modules.transforms.liftings.graph2simplicial.graph_induced_lifting import (
21+
SimplicialGraphInducedLifting,
22+
)
1823
from modules.transforms.liftings.graph2simplicial.line_lifting import (
1924
SimplicialLineLifting,
2025
)
26+
from modules.transforms.liftings.graph2simplicial.vietoris_rips_lifting import (
27+
SimplicialVietorisRipsLifting,
28+
)
2129

2230
TRANSFORMS = {
2331
# Graph -> Hypergraph
2432
"HypergraphKNNLifting": HypergraphKNNLifting,
2533
# Graph -> Simplicial Complex
2634
"SimplicialCliqueLifting": SimplicialCliqueLifting,
2735
"SimplicialLineLifting": SimplicialLineLifting,
36+
"SimplicialVietorisRipsLifting": SimplicialVietorisRipsLifting,
37+
"SimplicialGraphInducedLifting": SimplicialGraphInducedLifting,
2838
# Graph -> Cell Complex
2939
"CellCycleLifting": CellCycleLifting,
3040
# Feature Liftings
@@ -56,10 +66,14 @@ def __init__(self, transform_name, **kwargs):
5666
self.parameters = kwargs
5767

5868
self.transform = (
59-
TRANSFORMS[transform_name](**kwargs) if transform_name is not None else None
69+
TRANSFORMS[transform_name](**kwargs)
70+
if transform_name is not None
71+
else None
6072
)
6173

62-
def forward(self, data: torch_geometric.data.Data) -> torch_geometric.data.Data:
74+
def forward(
75+
self, data: torch_geometric.data.Data
76+
) -> torch_geometric.data.Data:
6377
"""Forward pass of the lifting.
6478
6579
Parameters

modules/transforms/liftings/graph2hypergraph/knn_lifting.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import torch
22
import torch_geometric
33

4-
from modules.transforms.liftings.graph2hypergraph.base import Graph2HypergraphLifting
4+
from modules.transforms.liftings.graph2hypergraph.base import (
5+
Graph2HypergraphLifting,
6+
)
57

68

79
class HypergraphKNNLifting(Graph2HypergraphLifting):
@@ -45,22 +47,30 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict:
4547
if self.loop:
4648
for i in range(num_nodes):
4749
if not torch.any(
48-
torch.all(data_lifted.edge_index == torch.tensor([[i, i]]).T, dim=0)
50+
torch.all(
51+
data_lifted.edge_index == torch.tensor([[i, i]]).T,
52+
dim=0,
53+
)
4954
):
5055
connected_nodes = data_lifted.edge_index[
5156
0, data_lifted.edge_index[1] == i
5257
]
5358
dists = torch.sqrt(
5459
torch.sum(
55-
(data.pos[connected_nodes] - data.pos[i].unsqueeze(0) ** 2),
60+
(
61+
data.pos[connected_nodes]
62+
- data.pos[i].unsqueeze(0) ** 2
63+
),
5664
dim=1,
5765
)
5866
)
5967
furthest = torch.argmax(dists)
6068
idx = torch.where(
6169
torch.all(
6270
data_lifted.edge_index
63-
== torch.tensor([[connected_nodes[furthest], i]]).T,
71+
== torch.tensor(
72+
[[connected_nodes[furthest], i]]
73+
).T,
6474
dim=0,
6575
)
6676
)[0]

0 commit comments

Comments
 (0)