Skip to content

Commit f7c2993

Browse files
author
Kevin Chang
committed
ffssn
1 parent 36ee769 commit f7c2993

File tree

3 files changed

+300
-129
lines changed

3 files changed

+300
-129
lines changed

bindsnet/network/topology_features.py

Lines changed: 298 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import torch.nn as nn
1111
import bindsnet.learning
1212

13-
1413
class AbstractFeature(ABC):
1514
# language=rst
1615
"""
@@ -938,3 +937,301 @@ def __init__(
938937
super().__init__(name, parent_feature)
939938

940939
self.sub_feature = self.parent.update
940+
941+
942+
943+
944+
class ForwardForwardWeight(AbstractFeature):
945+
"""
946+
Forward-Forward learning weight feature for MulticompartmentConnection.
947+
948+
Implements the Forward-Forward algorithm with surrogate gradients, enabling
949+
layer-wise learning without backpropagation through time. This feature adds:
950+
- Arctangent surrogate gradient computation
951+
- Membrane potential tracking for goodness scores
952+
- Forward-Forward loss computation capabilities
953+
954+
Compatible with the MCC architecture and composable with other features.
955+
"""
956+
957+
def __init__(
958+
self,
959+
spike_threshold: float = 1.0,
960+
alpha: float = 2.0,
961+
alpha_loss: float = 0.6,
962+
dt: float = 1.0,
963+
**kwargs
964+
):
965+
"""
966+
Initialize Forward-Forward weight feature.
967+
968+
Args:
969+
spike_threshold: Threshold for spike generation and surrogate gradient
970+
alpha: Arctangent surrogate gradient steepness parameter
971+
alpha_loss: Forward-Forward loss threshold parameter
972+
dt: Time step size for membrane potential integration
973+
**kwargs: Additional arguments passed to parent WeightFeature
974+
"""
975+
super().__init__(**kwargs)
976+
977+
self.spike_threshold = spike_threshold
978+
self.alpha = alpha
979+
self.alpha_loss = alpha_loss
980+
self.dt = dt
981+
982+
# Membrane potential state for goodness computation
983+
self.v_membrane = None
984+
985+
def reset_state(self):
986+
"""Reset membrane potential state."""
987+
self.v_membrane = None
988+
989+
def forward(
990+
self,
991+
s: torch.Tensor,
992+
connection: 'MulticompartmentConnection',
993+
**kwargs
994+
) -> torch.Tensor:
995+
"""
996+
Forward pass through weight feature with surrogate gradients.
997+
998+
This method integrates with the MCC forward pass pipeline.
999+
1000+
Args:
1001+
s: Input spikes [batch_size, source_neurons]
1002+
connection: Parent MulticompartmentConnection instance
1003+
**kwargs: Additional arguments from MCC forward pass
1004+
1005+
Returns:
1006+
Weighted synaptic input with surrogate gradient computation
1007+
"""
1008+
# Get connection weights (handled by parent MCC)
1009+
w = connection.w
1010+
1011+
# Compute synaptic input: I = s * W
1012+
synaptic_input = torch.mm(s.float(), w)
1013+
1014+
# Track this for goodness score computation if needed
1015+
if hasattr(self, '_track_activity') and self._track_activity:
1016+
self._last_synaptic_input = synaptic_input.detach()
1017+
1018+
return synaptic_input
1019+
1020+
def compute_spikes_with_surrogate(
1021+
self,
1022+
synaptic_input: torch.Tensor,
1023+
target_layer: 'AbstractPopulation'
1024+
) -> torch.Tensor:
1025+
"""
1026+
Generate spikes with surrogate gradients from synaptic input.
1027+
1028+
This method should be called after the weight forward pass to convert
1029+
synaptic input to spikes using the Forward-Forward surrogate gradient.
1030+
1031+
Args:
1032+
synaptic_input: Weighted input [batch_size, target_neurons]
1033+
target_layer: Target neuron population
1034+
1035+
Returns:
1036+
Spikes with surrogate gradients [batch_size, target_neurons]
1037+
"""
1038+
# Initialize or update membrane potential
1039+
if self.v_membrane is None:
1040+
self.v_membrane = torch.zeros_like(synaptic_input)
1041+
1042+
# Integrate synaptic input (simple Euler integration)
1043+
self.v_membrane = self.v_membrane + synaptic_input * self.dt
1044+
1045+
# Generate spikes using arctangent surrogate gradient
1046+
spikes = ArctangentSurrogate.apply(
1047+
self.v_membrane,
1048+
self.spike_threshold,
1049+
self.alpha
1050+
)
1051+
1052+
# Optional: Reset membrane potential where spikes occurred
1053+
# self.v_membrane = self.v_membrane * (1 - spikes)
1054+
1055+
return spikes
1056+
1057+
def compute_goodness_score(self, spike_activity: torch.Tensor) -> torch.Tensor:
1058+
"""
1059+
Compute Forward-Forward goodness score from spike activity.
1060+
1061+
Args:
1062+
spike_activity: Spike traces [batch_size, time_steps, neurons] or
1063+
spike counts [batch_size, neurons]
1064+
1065+
Returns:
1066+
Goodness scores [batch_size]
1067+
"""
1068+
if spike_activity.dim() == 3:
1069+
# Sum over time dimension if time traces provided
1070+
spike_counts = torch.sum(spike_activity, dim=1) # [batch_size, neurons]
1071+
else:
1072+
spike_counts = spike_activity # Already summed
1073+
1074+
# Forward-Forward goodness: mean squared spike activity
1075+
goodness = torch.mean(spike_counts ** 2, dim=1) # [batch_size]
1076+
1077+
return goodness
1078+
1079+
def compute_ff_loss(
1080+
self,
1081+
goodness_pos: torch.Tensor,
1082+
goodness_neg: torch.Tensor
1083+
) -> torch.Tensor:
1084+
"""
1085+
Compute Forward-Forward contrastive loss.
1086+
1087+
Loss = log(1 + exp(-g_pos + α)) + log(1 + exp(g_neg - α))
1088+
1089+
Args:
1090+
goodness_pos: Goodness scores for positive samples [batch_size]
1091+
goodness_neg: Goodness scores for negative samples [batch_size]
1092+
1093+
Returns:
1094+
Forward-Forward loss [batch_size]
1095+
"""
1096+
# Positive loss: encourage high goodness for true labels
1097+
loss_pos = torch.log(1 + torch.exp(-goodness_pos + self.alpha_loss))
1098+
1099+
# Negative loss: encourage low goodness for false labels
1100+
loss_neg = torch.log(1 + torch.exp(goodness_neg - self.alpha_loss))
1101+
1102+
return loss_pos + loss_neg
1103+
1104+
def get_feature_info(self) -> dict:
1105+
"""Get information about this Forward-Forward feature."""
1106+
return {
1107+
'feature_type': 'ForwardForwardWeight',
1108+
'spike_threshold': self.spike_threshold,
1109+
'alpha_surrogate': self.alpha,
1110+
'alpha_loss': self.alpha_loss,
1111+
'dt': self.dt,
1112+
'surrogate_function': 'arctangent',
1113+
'compatible_with': ['MCC', 'other_weight_features']
1114+
}
1115+
1116+
1117+
class ArctangentSurrogate(torch.autograd.Function):
1118+
"""
1119+
Arctangent surrogate gradient function for Forward-Forward training.
1120+
1121+
Forward pass: spikes = (membrane_potential >= threshold)
1122+
Backward pass: gradient = 1 / (α * |membrane_potential - threshold| + 1)
1123+
1124+
This enables gradient-based learning in spiking neural networks by
1125+
providing a smooth approximation of the non-differentiable spike function.
1126+
"""
1127+
1128+
@staticmethod
1129+
def forward(
1130+
ctx,
1131+
membrane_potential: torch.Tensor,
1132+
threshold: float,
1133+
alpha: float
1134+
) -> torch.Tensor:
1135+
"""
1136+
Forward pass: generate binary spikes.
1137+
1138+
Args:
1139+
membrane_potential: Neuron membrane potentials
1140+
threshold: Spike threshold
1141+
alpha: Surrogate gradient steepness parameter
1142+
1143+
Returns:
1144+
Binary spike tensor (0 or 1)
1145+
"""
1146+
# Save tensors and parameters for backward pass
1147+
ctx.save_for_backward(membrane_potential)
1148+
ctx.threshold = threshold
1149+
ctx.alpha = alpha
1150+
1151+
# Generate spikes (heaviside step function)
1152+
spikes = (membrane_potential >= threshold).float()
1153+
1154+
return spikes
1155+
1156+
@staticmethod
1157+
def backward(
1158+
ctx,
1159+
grad_output: torch.Tensor
1160+
) -> Tuple[torch.Tensor, None, None]:
1161+
"""
1162+
Backward pass: compute surrogate gradients.
1163+
1164+
Uses arctangent-based surrogate: 1 / (α * |v - threshold| + 1)
1165+
1166+
Args:
1167+
grad_output: Gradient from subsequent layers
1168+
1169+
Returns:
1170+
Tuple of (grad_membrane_potential, None, None)
1171+
"""
1172+
membrane_potential, = ctx.saved_tensors
1173+
threshold = ctx.threshold
1174+
alpha = ctx.alpha
1175+
1176+
# Compute arctangent surrogate gradient
1177+
# grad = 1 / (α * |v - v_th| + 1)
1178+
surrogate_grad = 1.0 / (alpha * torch.abs(membrane_potential - threshold) + 1.0)
1179+
1180+
# Apply chain rule with incoming gradients
1181+
grad_membrane_potential = grad_output * surrogate_grad
1182+
1183+
# Return gradients (only for first argument)
1184+
return grad_membrane_potential, None, None
1185+
1186+
1187+
# Add this helper function to create FF-enabled MCC connections
1188+
def create_ff_connection(
1189+
source: 'AbstractPopulation',
1190+
target: 'AbstractPopulation',
1191+
w: Optional[torch.Tensor] = None,
1192+
spike_threshold: float = 1.0,
1193+
alpha: float = 2.0,
1194+
alpha_loss: float = 0.6,
1195+
dt: float = 1.0,
1196+
**mcc_kwargs
1197+
) -> 'MulticompartmentConnection':
1198+
"""
1199+
Helper function to create MulticompartmentConnection with ForwardForwardWeight feature.
1200+
1201+
Args:
1202+
source: Source neuron population
1203+
target: Target neuron population
1204+
w: Connection weights (if None, will be initialized)
1205+
spike_threshold: FF spike threshold
1206+
alpha: FF surrogate gradient parameter
1207+
alpha_loss: FF loss threshold parameter
1208+
dt: Time step size
1209+
**mcc_kwargs: Additional arguments for MulticompartmentConnection
1210+
1211+
Returns:
1212+
MCC with ForwardForwardWeight feature attached
1213+
"""
1214+
from bindsnet.network.topology import MulticompartmentConnection
1215+
1216+
# Create ForwardForwardWeight feature
1217+
ff_feature = ForwardForwardWeight(
1218+
spike_threshold=spike_threshold,
1219+
alpha=alpha,
1220+
alpha_loss=alpha_loss,
1221+
dt=dt
1222+
)
1223+
1224+
# Initialize weights if not provided
1225+
if w is None:
1226+
w = 0.1 * torch.randn(source.n, target.n)
1227+
1228+
# Create MCC with FF feature
1229+
connection = MulticompartmentConnection(
1230+
source=source,
1231+
target=target,
1232+
w=w,
1233+
features=[ff_feature],
1234+
**mcc_kwargs
1235+
)
1236+
1237+
return connection

0 commit comments

Comments
 (0)