Skip to content

Commit e25c574

Browse files
committed
GAT option for GCN
1 parent ce2fb8f commit e25c574

File tree

2 files changed

+50
-5
lines changed

2 files changed

+50
-5
lines changed

topognn/layers.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33
import torch.nn as nn
44
import torch.nn.functional as F
5-
from torch_geometric.nn import GCNConv, GINConv
5+
from torch_geometric.nn import GCNConv, GINConv, GATConv
66
from torch_scatter import scatter
77
from torch_persistent_homology.persistent_homology_cpu import compute_persistence_homology_batched_mt
88
from topognn.data_utils import remove_duplicate_edges
@@ -67,6 +67,38 @@ def forward(self, x, edge_index, **kwargs):
6767
return self.dropout(h)
6868

6969

70+
class GATLayer(nn.Module):
71+
def __init__(
72+
self,
73+
in_features,
74+
out_features,
75+
activation,
76+
dropout,
77+
batch_norm,
78+
num_heads,
79+
residual=True,
80+
train_eps=False,
81+
**kwargs
82+
):
83+
super().__init__()
84+
85+
86+
self.activation = activation
87+
self.residual = residual
88+
self.dropout = nn.Dropout(dropout)
89+
self.batchnorm = nn.BatchNorm1d(
90+
out_features * num_heads) if batch_norm else nn.Identity()
91+
92+
self.conv = GATConv(in_features, out_features, heads = num_heads, dropout = dropout)
93+
94+
def forward(self, x, edge_index, **kwargs):
95+
h = self.conv(x, edge_index)
96+
h = self.batchnorm(h)
97+
h = self.activation(h)
98+
if self.residual:
99+
h = h + x
100+
return self.dropout(h)
101+
70102
class DeepSetLayer(nn.Module):
71103
"""Simple equivariant deep set layer."""
72104

topognn/models.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from topognn import Tasks
1313
from topognn.cli_utils import str2bool, int_or_none
14-
from topognn.layers import GCNLayer, GINLayer, SimpleSetTopoLayer, fake_persistence_computation
14+
from topognn.layers import GCNLayer, GINLayer, GATLayer, SimpleSetTopoLayer, fake_persistence_computation
1515
from topognn.metrics import WeightedAccuracy
1616
from topognn.data_utils import remove_duplicate_edges
1717
from torch_persistent_homology.persistent_homology_cpu import compute_persistence_homology_batched_mt
@@ -529,24 +529,35 @@ def forward(self,x, **kwargs):
529529

530530
class LargerGCNModel(pl.LightningModule):
531531
def __init__(self, hidden_dim, depth, num_node_features, num_classes, task,
532-
lr=0.001, dropout_p=0.2, GIN=False, batch_norm=False,
532+
lr=0.001, dropout_p=0.2, GIN=False, GAT = False, batch_norm=False,
533533
residual=False, train_eps=True, save_filtration = False,
534534
add_mlp=False, weight_decay = 0., **kwargs):
535535
super().__init__()
536536
self.save_hyperparameters()
537-
self.embedding = torch.nn.Linear(num_node_features, hidden_dim)
538537
self.save_filtration = save_filtration
539538

539+
num_heads = 1
540+
540541
if GIN:
541542
def build_gnn_layer():
542543
return GINLayer( in_features = hidden_dim, out_features = hidden_dim, train_eps=train_eps, activation = F.relu, batch_norm = batch_norm, dropout = dropout_p, **kwargs)
543544
graph_pooling_operation = global_add_pool
545+
546+
elif GAT:
547+
num_heads = 8
548+
def build_gnn_layer():
549+
return GATLayer( in_features = hidden_dim * num_heads, out_features = hidden_dim, train_eps=train_eps, activation = F.relu, batch_norm = batch_norm, dropout = dropout_p, num_heads = num_heads, **kwargs)
550+
graph_pooling_operation = global_mean_pool
551+
544552
else:
545553
def build_gnn_layer():
546554
return GCNLayer(
547555
hidden_dim, hidden_dim, F.relu, dropout_p, batch_norm)
548556
graph_pooling_operation = global_mean_pool
549557

558+
559+
self.embedding = torch.nn.Linear(num_node_features, hidden_dim * num_heads)
560+
550561
layers = [build_gnn_layer() for _ in range(depth)]
551562

552563
if add_mlp:
@@ -567,7 +578,7 @@ def fake_pool(x, batch):
567578
if (kwargs.get("dim1",False) and ("dim1_out_dim" in kwargs.keys()) and ( not kwargs.get("fake",False))):
568579
dim_before_class = hidden_dim + kwargs["dim1_out_dim"] #SimpleTopoGNN with dim1
569580
else:
570-
dim_before_class = hidden_dim
581+
dim_before_class = hidden_dim * num_heads
571582

572583
self.classif = torch.nn.Sequential(
573584
nn.Linear(dim_before_class, hidden_dim // 2),
@@ -633,6 +644,7 @@ def configure_optimizers(self):
633644
return [optimizer], [scheduler]
634645

635646
def forward(self, data):
647+
636648
x, edge_index = data.x, data.edge_index
637649
x = self.embedding(x)
638650

@@ -722,6 +734,7 @@ def add_model_specific_args(cls, parent):
722734
parser.add_argument("--min_lr", type=float, default=0.00001)
723735
parser.add_argument("--dropout_p", type=float, default=0.0)
724736
parser.add_argument('--GIN', type=str2bool, default=False)
737+
parser.add_argument('--GAT', type=str2bool, default=False)
725738
parser.add_argument('--train_eps', type=str2bool, default=True)
726739
parser.add_argument('--batch_norm', type=str2bool, default=True)
727740
parser.add_argument('--residual', type=str2bool, default=True)

0 commit comments

Comments
 (0)