10
10
import torch .nn as nn
11
11
import bindsnet .learning
12
12
13
-
14
13
class AbstractFeature (ABC ):
15
14
# language=rst
16
15
"""
@@ -938,3 +937,301 @@ def __init__(
938
937
super ().__init__ (name , parent_feature )
939
938
940
939
self .sub_feature = self .parent .update
940
+
941
+
942
+
943
+
944
+ class ForwardForwardWeight (AbstractFeature ):
945
+ """
946
+ Forward-Forward learning weight feature for MulticompartmentConnection.
947
+
948
+ Implements the Forward-Forward algorithm with surrogate gradients, enabling
949
+ layer-wise learning without backpropagation through time. This feature adds:
950
+ - Arctangent surrogate gradient computation
951
+ - Membrane potential tracking for goodness scores
952
+ - Forward-Forward loss computation capabilities
953
+
954
+ Compatible with the MCC architecture and composable with other features.
955
+ """
956
+
957
+ def __init__ (
958
+ self ,
959
+ spike_threshold : float = 1.0 ,
960
+ alpha : float = 2.0 ,
961
+ alpha_loss : float = 0.6 ,
962
+ dt : float = 1.0 ,
963
+ ** kwargs
964
+ ):
965
+ """
966
+ Initialize Forward-Forward weight feature.
967
+
968
+ Args:
969
+ spike_threshold: Threshold for spike generation and surrogate gradient
970
+ alpha: Arctangent surrogate gradient steepness parameter
971
+ alpha_loss: Forward-Forward loss threshold parameter
972
+ dt: Time step size for membrane potential integration
973
+ **kwargs: Additional arguments passed to parent WeightFeature
974
+ """
975
+ super ().__init__ (** kwargs )
976
+
977
+ self .spike_threshold = spike_threshold
978
+ self .alpha = alpha
979
+ self .alpha_loss = alpha_loss
980
+ self .dt = dt
981
+
982
+ # Membrane potential state for goodness computation
983
+ self .v_membrane = None
984
+
985
+ def reset_state (self ):
986
+ """Reset membrane potential state."""
987
+ self .v_membrane = None
988
+
989
+ def forward (
990
+ self ,
991
+ s : torch .Tensor ,
992
+ connection : 'MulticompartmentConnection' ,
993
+ ** kwargs
994
+ ) -> torch .Tensor :
995
+ """
996
+ Forward pass through weight feature with surrogate gradients.
997
+
998
+ This method integrates with the MCC forward pass pipeline.
999
+
1000
+ Args:
1001
+ s: Input spikes [batch_size, source_neurons]
1002
+ connection: Parent MulticompartmentConnection instance
1003
+ **kwargs: Additional arguments from MCC forward pass
1004
+
1005
+ Returns:
1006
+ Weighted synaptic input with surrogate gradient computation
1007
+ """
1008
+ # Get connection weights (handled by parent MCC)
1009
+ w = connection .w
1010
+
1011
+ # Compute synaptic input: I = s * W
1012
+ synaptic_input = torch .mm (s .float (), w )
1013
+
1014
+ # Track this for goodness score computation if needed
1015
+ if hasattr (self , '_track_activity' ) and self ._track_activity :
1016
+ self ._last_synaptic_input = synaptic_input .detach ()
1017
+
1018
+ return synaptic_input
1019
+
1020
+ def compute_spikes_with_surrogate (
1021
+ self ,
1022
+ synaptic_input : torch .Tensor ,
1023
+ target_layer : 'AbstractPopulation'
1024
+ ) -> torch .Tensor :
1025
+ """
1026
+ Generate spikes with surrogate gradients from synaptic input.
1027
+
1028
+ This method should be called after the weight forward pass to convert
1029
+ synaptic input to spikes using the Forward-Forward surrogate gradient.
1030
+
1031
+ Args:
1032
+ synaptic_input: Weighted input [batch_size, target_neurons]
1033
+ target_layer: Target neuron population
1034
+
1035
+ Returns:
1036
+ Spikes with surrogate gradients [batch_size, target_neurons]
1037
+ """
1038
+ # Initialize or update membrane potential
1039
+ if self .v_membrane is None :
1040
+ self .v_membrane = torch .zeros_like (synaptic_input )
1041
+
1042
+ # Integrate synaptic input (simple Euler integration)
1043
+ self .v_membrane = self .v_membrane + synaptic_input * self .dt
1044
+
1045
+ # Generate spikes using arctangent surrogate gradient
1046
+ spikes = ArctangentSurrogate .apply (
1047
+ self .v_membrane ,
1048
+ self .spike_threshold ,
1049
+ self .alpha
1050
+ )
1051
+
1052
+ # Optional: Reset membrane potential where spikes occurred
1053
+ # self.v_membrane = self.v_membrane * (1 - spikes)
1054
+
1055
+ return spikes
1056
+
1057
+ def compute_goodness_score (self , spike_activity : torch .Tensor ) -> torch .Tensor :
1058
+ """
1059
+ Compute Forward-Forward goodness score from spike activity.
1060
+
1061
+ Args:
1062
+ spike_activity: Spike traces [batch_size, time_steps, neurons] or
1063
+ spike counts [batch_size, neurons]
1064
+
1065
+ Returns:
1066
+ Goodness scores [batch_size]
1067
+ """
1068
+ if spike_activity .dim () == 3 :
1069
+ # Sum over time dimension if time traces provided
1070
+ spike_counts = torch .sum (spike_activity , dim = 1 ) # [batch_size, neurons]
1071
+ else :
1072
+ spike_counts = spike_activity # Already summed
1073
+
1074
+ # Forward-Forward goodness: mean squared spike activity
1075
+ goodness = torch .mean (spike_counts ** 2 , dim = 1 ) # [batch_size]
1076
+
1077
+ return goodness
1078
+
1079
+ def compute_ff_loss (
1080
+ self ,
1081
+ goodness_pos : torch .Tensor ,
1082
+ goodness_neg : torch .Tensor
1083
+ ) -> torch .Tensor :
1084
+ """
1085
+ Compute Forward-Forward contrastive loss.
1086
+
1087
+ Loss = log(1 + exp(-g_pos + α)) + log(1 + exp(g_neg - α))
1088
+
1089
+ Args:
1090
+ goodness_pos: Goodness scores for positive samples [batch_size]
1091
+ goodness_neg: Goodness scores for negative samples [batch_size]
1092
+
1093
+ Returns:
1094
+ Forward-Forward loss [batch_size]
1095
+ """
1096
+ # Positive loss: encourage high goodness for true labels
1097
+ loss_pos = torch .log (1 + torch .exp (- goodness_pos + self .alpha_loss ))
1098
+
1099
+ # Negative loss: encourage low goodness for false labels
1100
+ loss_neg = torch .log (1 + torch .exp (goodness_neg - self .alpha_loss ))
1101
+
1102
+ return loss_pos + loss_neg
1103
+
1104
+ def get_feature_info (self ) -> dict :
1105
+ """Get information about this Forward-Forward feature."""
1106
+ return {
1107
+ 'feature_type' : 'ForwardForwardWeight' ,
1108
+ 'spike_threshold' : self .spike_threshold ,
1109
+ 'alpha_surrogate' : self .alpha ,
1110
+ 'alpha_loss' : self .alpha_loss ,
1111
+ 'dt' : self .dt ,
1112
+ 'surrogate_function' : 'arctangent' ,
1113
+ 'compatible_with' : ['MCC' , 'other_weight_features' ]
1114
+ }
1115
+
1116
+
1117
+ class ArctangentSurrogate (torch .autograd .Function ):
1118
+ """
1119
+ Arctangent surrogate gradient function for Forward-Forward training.
1120
+
1121
+ Forward pass: spikes = (membrane_potential >= threshold)
1122
+ Backward pass: gradient = 1 / (α * |membrane_potential - threshold| + 1)
1123
+
1124
+ This enables gradient-based learning in spiking neural networks by
1125
+ providing a smooth approximation of the non-differentiable spike function.
1126
+ """
1127
+
1128
+ @staticmethod
1129
+ def forward (
1130
+ ctx ,
1131
+ membrane_potential : torch .Tensor ,
1132
+ threshold : float ,
1133
+ alpha : float
1134
+ ) -> torch .Tensor :
1135
+ """
1136
+ Forward pass: generate binary spikes.
1137
+
1138
+ Args:
1139
+ membrane_potential: Neuron membrane potentials
1140
+ threshold: Spike threshold
1141
+ alpha: Surrogate gradient steepness parameter
1142
+
1143
+ Returns:
1144
+ Binary spike tensor (0 or 1)
1145
+ """
1146
+ # Save tensors and parameters for backward pass
1147
+ ctx .save_for_backward (membrane_potential )
1148
+ ctx .threshold = threshold
1149
+ ctx .alpha = alpha
1150
+
1151
+ # Generate spikes (heaviside step function)
1152
+ spikes = (membrane_potential >= threshold ).float ()
1153
+
1154
+ return spikes
1155
+
1156
+ @staticmethod
1157
+ def backward (
1158
+ ctx ,
1159
+ grad_output : torch .Tensor
1160
+ ) -> Tuple [torch .Tensor , None , None ]:
1161
+ """
1162
+ Backward pass: compute surrogate gradients.
1163
+
1164
+ Uses arctangent-based surrogate: 1 / (α * |v - threshold| + 1)
1165
+
1166
+ Args:
1167
+ grad_output: Gradient from subsequent layers
1168
+
1169
+ Returns:
1170
+ Tuple of (grad_membrane_potential, None, None)
1171
+ """
1172
+ membrane_potential , = ctx .saved_tensors
1173
+ threshold = ctx .threshold
1174
+ alpha = ctx .alpha
1175
+
1176
+ # Compute arctangent surrogate gradient
1177
+ # grad = 1 / (α * |v - v_th| + 1)
1178
+ surrogate_grad = 1.0 / (alpha * torch .abs (membrane_potential - threshold ) + 1.0 )
1179
+
1180
+ # Apply chain rule with incoming gradients
1181
+ grad_membrane_potential = grad_output * surrogate_grad
1182
+
1183
+ # Return gradients (only for first argument)
1184
+ return grad_membrane_potential , None , None
1185
+
1186
+
1187
+ # Add this helper function to create FF-enabled MCC connections
1188
+ def create_ff_connection (
1189
+ source : 'AbstractPopulation' ,
1190
+ target : 'AbstractPopulation' ,
1191
+ w : Optional [torch .Tensor ] = None ,
1192
+ spike_threshold : float = 1.0 ,
1193
+ alpha : float = 2.0 ,
1194
+ alpha_loss : float = 0.6 ,
1195
+ dt : float = 1.0 ,
1196
+ ** mcc_kwargs
1197
+ ) -> 'MulticompartmentConnection' :
1198
+ """
1199
+ Helper function to create MulticompartmentConnection with ForwardForwardWeight feature.
1200
+
1201
+ Args:
1202
+ source: Source neuron population
1203
+ target: Target neuron population
1204
+ w: Connection weights (if None, will be initialized)
1205
+ spike_threshold: FF spike threshold
1206
+ alpha: FF surrogate gradient parameter
1207
+ alpha_loss: FF loss threshold parameter
1208
+ dt: Time step size
1209
+ **mcc_kwargs: Additional arguments for MulticompartmentConnection
1210
+
1211
+ Returns:
1212
+ MCC with ForwardForwardWeight feature attached
1213
+ """
1214
+ from bindsnet .network .topology import MulticompartmentConnection
1215
+
1216
+ # Create ForwardForwardWeight feature
1217
+ ff_feature = ForwardForwardWeight (
1218
+ spike_threshold = spike_threshold ,
1219
+ alpha = alpha ,
1220
+ alpha_loss = alpha_loss ,
1221
+ dt = dt
1222
+ )
1223
+
1224
+ # Initialize weights if not provided
1225
+ if w is None :
1226
+ w = 0.1 * torch .randn (source .n , target .n )
1227
+
1228
+ # Create MCC with FF feature
1229
+ connection = MulticompartmentConnection (
1230
+ source = source ,
1231
+ target = target ,
1232
+ w = w ,
1233
+ features = [ff_feature ],
1234
+ ** mcc_kwargs
1235
+ )
1236
+
1237
+ return connection
0 commit comments