11
11
12
12
from topognn import Tasks
13
13
from topognn .cli_utils import str2bool , int_or_none
14
- from topognn .layers import GCNLayer , GINLayer , SimpleSetTopoLayer , fake_persistence_computation
14
+ from topognn .layers import GCNLayer , GINLayer , GATLayer , SimpleSetTopoLayer , fake_persistence_computation
15
15
from topognn .metrics import WeightedAccuracy
16
16
from topognn .data_utils import remove_duplicate_edges
17
17
from torch_persistent_homology .persistent_homology_cpu import compute_persistence_homology_batched_mt
@@ -529,24 +529,35 @@ def forward(self,x, **kwargs):
529
529
530
530
class LargerGCNModel (pl .LightningModule ):
531
531
def __init__ (self , hidden_dim , depth , num_node_features , num_classes , task ,
532
- lr = 0.001 , dropout_p = 0.2 , GIN = False , batch_norm = False ,
532
+ lr = 0.001 , dropout_p = 0.2 , GIN = False , GAT = False , batch_norm = False ,
533
533
residual = False , train_eps = True , save_filtration = False ,
534
534
add_mlp = False , weight_decay = 0. , ** kwargs ):
535
535
super ().__init__ ()
536
536
self .save_hyperparameters ()
537
- self .embedding = torch .nn .Linear (num_node_features , hidden_dim )
538
537
self .save_filtration = save_filtration
539
538
539
+ num_heads = 1
540
+
540
541
if GIN :
541
542
def build_gnn_layer ():
542
543
return GINLayer ( in_features = hidden_dim , out_features = hidden_dim , train_eps = train_eps , activation = F .relu , batch_norm = batch_norm , dropout = dropout_p , ** kwargs )
543
544
graph_pooling_operation = global_add_pool
545
+
546
+ elif GAT :
547
+ num_heads = 8
548
+ def build_gnn_layer ():
549
+ return GATLayer ( in_features = hidden_dim * num_heads , out_features = hidden_dim , train_eps = train_eps , activation = F .relu , batch_norm = batch_norm , dropout = dropout_p , num_heads = num_heads , ** kwargs )
550
+ graph_pooling_operation = global_mean_pool
551
+
544
552
else :
545
553
def build_gnn_layer ():
546
554
return GCNLayer (
547
555
hidden_dim , hidden_dim , F .relu , dropout_p , batch_norm )
548
556
graph_pooling_operation = global_mean_pool
549
557
558
+
559
+ self .embedding = torch .nn .Linear (num_node_features , hidden_dim * num_heads )
560
+
550
561
layers = [build_gnn_layer () for _ in range (depth )]
551
562
552
563
if add_mlp :
@@ -567,7 +578,7 @@ def fake_pool(x, batch):
567
578
if (kwargs .get ("dim1" ,False ) and ("dim1_out_dim" in kwargs .keys ()) and ( not kwargs .get ("fake" ,False ))):
568
579
dim_before_class = hidden_dim + kwargs ["dim1_out_dim" ] #SimpleTopoGNN with dim1
569
580
else :
570
- dim_before_class = hidden_dim
581
+ dim_before_class = hidden_dim * num_heads
571
582
572
583
self .classif = torch .nn .Sequential (
573
584
nn .Linear (dim_before_class , hidden_dim // 2 ),
@@ -633,6 +644,7 @@ def configure_optimizers(self):
633
644
return [optimizer ], [scheduler ]
634
645
635
646
def forward (self , data ):
647
+
636
648
x , edge_index = data .x , data .edge_index
637
649
x = self .embedding (x )
638
650
@@ -722,6 +734,7 @@ def add_model_specific_args(cls, parent):
722
734
parser .add_argument ("--min_lr" , type = float , default = 0.00001 )
723
735
parser .add_argument ("--dropout_p" , type = float , default = 0.0 )
724
736
parser .add_argument ('--GIN' , type = str2bool , default = False )
737
+ parser .add_argument ('--GAT' , type = str2bool , default = False )
725
738
parser .add_argument ('--train_eps' , type = str2bool , default = True )
726
739
parser .add_argument ('--batch_norm' , type = str2bool , default = True )
727
740
parser .add_argument ('--residual' , type = str2bool , default = True )
0 commit comments