Skip to content

Commit 471d455

Browse files
committed
Sparsity for MulticompartmentConnection
1 parent d1d3e42 commit 471d455

File tree

2 files changed

+91
-23
lines changed

2 files changed

+91
-23
lines changed

bindsnet/network/topology.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -446,12 +446,19 @@ def compute(self, s: torch.Tensor) -> torch.Tensor:
446446

447447
# Sum signals for each of the output/terminal neurons
448448
# |out_signal| = [batch_size, target.n]
449-
out_signal = conn_spikes.view(s.size(0), self.source.n, self.target.n).sum(1)
449+
if conn_spikes.size() != torch.Size([s.size(0), self.source.n, self.target.n]):
450+
if conn_spikes.is_sparse:
451+
conn_spikes = conn_spikes.to_dense()
452+
conn_spikes = conn_spikes.view(s.size(0), self.source.n, self.target.n)
453+
out_signal = conn_spikes.sum(1)
450454

451455
if self.traces:
452456
self.activity = out_signal
453457

454-
return out_signal.view(s.size(0), *self.target.shape)
458+
if out_signal.size() != torch.Size([s.size(0)] + self.target.shape):
459+
return out_signal.view(s.size(0), *self.target.shape)
460+
else:
461+
return out_signal
455462

456463
def compute_window(self, s: torch.Tensor) -> torch.Tensor:
457464
# language=rst

bindsnet/network/topology_features.py

Lines changed: 82 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def __init__(
3131
enforce_polarity: Optional[bool] = False,
3232
decay: float = 0.0,
3333
parent_feature=None,
34+
sparse: Optional[bool] = False,
3435
**kwargs,
3536
) -> None:
3637
# language=rst
@@ -47,6 +48,7 @@ def __init__(
4748
dimension
4849
:param decay: Constant multiple to decay weights by on each iteration
4950
:param parent_feature: Parent feature to inherit :code:`value` from
51+
:param sparse: Should :code:`value` parameter be sparse tensor or not
5052
"""
5153

5254
#### Initialize class variables ####
@@ -61,6 +63,7 @@ def __init__(
6163
self.reduction = reduction
6264
self.decay = decay
6365
self.parent_feature = parent_feature
66+
self.sparse = sparse
6467
self.kwargs = kwargs
6568

6669
## Backend ##
@@ -119,6 +122,10 @@ def __init__(
119122
self.assert_valid_range()
120123
if value is not None:
121124
self.assert_feature_in_range()
125+
if self.sparse:
126+
self.value = self.value.to_sparse()
127+
assert not getattr(self, 'enforce_polarity', False), \
128+
"enforce_polarity isn't supported for sparse tensors"
122129

123130
@abstractmethod
124131
def reset_state_variables(self) -> None:
@@ -161,7 +168,10 @@ def prime_feature(self, connection, device, **kwargs) -> None:
161168

162169
# Check if values/norms are the correct shape
163170
if isinstance(self.value, torch.Tensor):
164-
assert tuple(self.value.shape) == (connection.source.n, connection.target.n)
171+
if self.sparse:
172+
assert tuple(self.value.shape[1:]) == (connection.source.n, connection.target.n)
173+
else:
174+
assert tuple(self.value.shape) == (connection.source.n, connection.target.n)
165175

166176
if self.norm is not None and isinstance(self.norm, torch.Tensor):
167177
assert self.norm.shape[0] == connection.target.n
@@ -214,9 +224,15 @@ def normalize(self) -> None:
214224
"""
215225

216226
if self.norm is not None:
217-
abs_sum = self.value.sum(0).unsqueeze(0)
218-
abs_sum[abs_sum == 0] = 1.0
219-
self.value *= self.norm / abs_sum
227+
if self.sparse:
228+
abs_sum = self.value.sum(1).to_dense()
229+
abs_sum[abs_sum == 0] = 1.0
230+
abs_sum = abs_sum.unsqueeze(1).expand(-1, *self.value.shape[1:])
231+
self.value = self.value * (self.norm / abs_sum)
232+
else:
233+
abs_sum = self.value.sum(0).unsqueeze(0)
234+
abs_sum[abs_sum == 0] = 1.0
235+
self.value *= self.norm / abs_sum
220236

221237
def degrade(self) -> None:
222238
# language=rst
@@ -299,11 +315,17 @@ def assert_feature_in_range(self):
299315

300316
def assert_valid_shape(self, source_shape, target_shape, f):
301317
# Multidimensional feat
302-
if len(f.shape) > 1:
303-
assert f.shape == (
318+
if (not self.sparse and len(f.shape) > 1) or (self.sparse and len(f.shape[1:]) > 1):
319+
if self.sparse:
320+
f_shape = f.shape[1:]
321+
expected = ('batch_size', source_shape, target_shape)
322+
else:
323+
f_shape = f.shape
324+
expected = (source_shape, target_shape)
325+
assert f_shape == (
304326
source_shape,
305327
target_shape,
306-
), f"Feature {self.name} has an incorrect shape of {f.shape}. Should be of shape {(source_shape, target_shape)}"
328+
), f"Feature {self.name} has an incorrect shape of {f.shape}. Should be of shape {expected}"
307329
# Else assume scalar, which is a valid shape
308330

309331

@@ -319,6 +341,7 @@ def __init__(
319341
reduction: Optional[callable] = None,
320342
decay: float = 0.0,
321343
parent_feature=None,
344+
sparse: Optional[bool] = False
322345
) -> None:
323346
# language=rst
324347
"""
@@ -336,6 +359,7 @@ def __init__(
336359
dimension
337360
:param decay: Constant multiple to decay weights by on each iteration
338361
:param parent_feature: Parent feature to inherit :code:`value` from
362+
:param sparse: Should :code:`value` parameter be sparse tensor or not
339363
"""
340364

341365
### Assertions ###
@@ -349,10 +373,25 @@ def __init__(
349373
reduction=reduction,
350374
decay=decay,
351375
parent_feature=parent_feature,
376+
sparse=sparse
377+
)
378+
379+
def sparse_bernoulli(self):
380+
values = torch.bernoulli(self.value.values())
381+
mask = values != 0
382+
indices = self.value.indices()[:, mask]
383+
non_zero = values[mask]
384+
return torch.sparse_coo_tensor(
385+
indices,
386+
non_zero,
387+
self.value.size()
352388
)
353389

354390
def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]:
355-
return conn_spikes * torch.bernoulli(self.value)
391+
if self.sparse:
392+
return conn_spikes * self.sparse_bernoulli()
393+
else:
394+
return conn_spikes * torch.bernoulli(self.value)
356395

357396
def reset_state_variables(self) -> None:
358397
pass
@@ -395,12 +434,14 @@ def __init__(
395434
self,
396435
name: str,
397436
value: Union[torch.Tensor, float, int] = None,
437+
sparse: Optional[bool] = False
398438
) -> None:
399439
# language=rst
400440
"""
401441
Boolean mask which determines whether or not signals are allowed to traverse certain synapses.
402442
:param name: Name of the feature
403443
:param value: Boolean mask. :code:`True` means a signal can pass, :code:`False` means the synapse is impassable
444+
:param sparse: Should :code:`value` parameter be sparse tensor or not
404445
"""
405446

406447
### Assertions ###
@@ -419,11 +460,9 @@ def __init__(
419460
super().__init__(
420461
name=name,
421462
value=value,
463+
sparse=sparse
422464
)
423465

424-
self.name = name
425-
self.value = value
426-
427466
def compute(self, conn_spikes) -> torch.Tensor:
428467
return conn_spikes * self.value
429468

@@ -505,6 +544,7 @@ def __init__(
505544
reduction: Optional[callable] = None,
506545
enforce_polarity: Optional[bool] = False,
507546
decay: float = 0.0,
547+
sparse: Optional[bool] = False
508548
) -> None:
509549
# language=rst
510550
"""
@@ -523,6 +563,7 @@ def __init__(
523563
dimension
524564
:param enforce_polarity: Will prevent synapses from changing signs if :code:`True`
525565
:param decay: Constant multiple to decay weights by on each iteration
566+
:param sparse: Should :code:`value` parameter be sparse tensor or not
526567
"""
527568

528569
self.norm_frequency = norm_frequency
@@ -536,6 +577,7 @@ def __init__(
536577
nu=nu,
537578
reduction=reduction,
538579
decay=decay,
580+
sparse=sparse
539581
)
540582

541583
def reset_state_variables(self) -> None:
@@ -589,6 +631,7 @@ def __init__(
589631
value: Union[torch.Tensor, float, int] = None,
590632
range: Optional[Sequence[float]] = None,
591633
norm: Optional[Union[torch.Tensor, float, int]] = None,
634+
sparse: Optional[bool] = False
592635
) -> None:
593636
# language=rst
594637
"""
@@ -598,13 +641,15 @@ def __init__(
598641
:param range: Range of acceptable values for the :code:`value` parameter
599642
:param norm: Value which all values in :code:`value` will sum to. Normalization of values occurs after each sample
600643
and after the value has been updated by the learning rule (if there is one)
644+
:param sparse: Should :code:`value` parameter be sparse tensor or not
601645
"""
602646

603647
super().__init__(
604648
name=name,
605649
value=value,
606650
range=[-torch.inf, +torch.inf] if range is None else range,
607651
norm=norm,
652+
sparse=sparse
608653
)
609654

610655
def reset_state_variables(self) -> None:
@@ -629,15 +674,17 @@ def __init__(
629674
name: str,
630675
value: Union[torch.Tensor, float, int] = None,
631676
range: Optional[Sequence[float]] = None,
677+
sparse: Optional[bool] = False
632678
) -> None:
633679
# language=rst
634680
"""
635681
Adds scalars to signals
636682
:param name: Name of the feature
637683
:param value: Values to scale signals by
684+
:param sparse: Should :code:`value` parameter be sparse tensor or not
638685
"""
639686

640-
super().__init__(name=name, value=value, range=range)
687+
super().__init__(name=name, value=value, range=range, sparse=sparse)
641688

642689
def reset_state_variables(self) -> None:
643690
pass
@@ -666,6 +713,7 @@ def __init__(
666713
value: Union[torch.Tensor, float, int] = None,
667714
degrade_function: callable = None,
668715
parent_feature: Optional[AbstractFeature] = None,
716+
sparse: Optional[bool] = False
669717
) -> None:
670718
# language=rst
671719
"""
@@ -676,10 +724,11 @@ def __init__(
676724
:param degrade_function: Callable function which takes a single argument (:code:`value`) and returns a tensor or
677725
constant to be *subtracted* from the propagating spikes.
678726
:param parent_feature: Parent feature with desired :code:`value` to inherit
727+
:param sparse: Should :code:`value` parameter be sparse tensor or not
679728
"""
680729

681730
# Note: parent_feature will override value. See abstract constructor
682-
super().__init__(name=name, value=value, parent_feature=parent_feature)
731+
super().__init__(name=name, value=value, parent_feature=parent_feature, sparse=sparse)
683732

684733
self.degrade_function = degrade_function
685734

@@ -698,6 +747,7 @@ def __init__(
698747
ann_values: Union[list, tuple] = None,
699748
const_update_rate: float = 0.1,
700749
const_decay: float = 0.001,
750+
sparse: Optional[bool] = False
701751
) -> None:
702752
# language=rst
703753
"""
@@ -708,6 +758,7 @@ def __init__(
708758
:param value: Values to be use to build an initial mask for the synapses.
709759
:param const_update_rate: The mask upatate rate of the ANN decision.
710760
:param const_decay: The spontaneous activation of the synapses.
761+
:param sparse: Should :code:`value` parameter be sparse tensor or not
711762
"""
712763

713764
# Define the ANN
@@ -743,16 +794,18 @@ def forward(self, x):
743794
self.const_update_rate = const_update_rate
744795
self.const_decay = const_decay
745796

746-
super().__init__(name=name, value=value)
797+
super().__init__(name=name, value=value, sparse=sparse)
747798

748799
def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]:
749800

750801
# Update the spike buffer
751802
if self.start_counter == False or conn_spikes.sum() > 0:
752803
self.start_counter = True
753-
self.spike_buffer[:, self.counter % self.spike_buffer.shape[1]] = (
754-
conn_spikes.flatten()
755-
)
804+
if self.sparse:
805+
flat_conn_spikes = conn_spikes.to_dense().flatten()
806+
else:
807+
flat_conn_spikes = conn_spikes.flatten()
808+
self.spike_buffer[:, self.counter % self.spike_buffer.shape[1]] = flat_conn_spikes
756809
self.counter += 1
757810

758811
# Update the masks
@@ -767,6 +820,8 @@ def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]:
767820

768821
# self.mask = torch.clamp(self.mask, -1, 1)
769822
self.value = (self.mask > 0).float()
823+
if self.sparse:
824+
self.value = self.value.to_sparse()
770825

771826
return conn_spikes * self.value
772827

@@ -788,6 +843,7 @@ def __init__(
788843
ann_values: Union[list, tuple] = None,
789844
const_update_rate: float = 0.1,
790845
const_decay: float = 0.01,
846+
sparse: Optional[bool] = False
791847
) -> None:
792848
# language=rst
793849
"""
@@ -798,6 +854,7 @@ def __init__(
798854
:param value: Values to be use to build an initial mask for the synapses.
799855
:param const_update_rate: The mask upatate rate of the ANN decision.
800856
:param const_decay: The spontaneous activation of the synapses.
857+
:param sparse: Should :code:`value` parameter be sparse tensor or not
801858
"""
802859

803860
# Define the ANN
@@ -833,16 +890,18 @@ def forward(self, x):
833890
self.const_update_rate = const_update_rate
834891
self.const_decay = const_decay
835892

836-
super().__init__(name=name, value=value)
893+
super().__init__(name=name, value=value, sparse=sparse)
837894

838895
def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]:
839896

840897
# Update the spike buffer
841898
if self.start_counter == False or conn_spikes.sum() > 0:
842899
self.start_counter = True
843-
self.spike_buffer[:, self.counter % self.spike_buffer.shape[1]] = (
844-
conn_spikes.flatten()
845-
)
900+
if self.sparse:
901+
flat_conn_spikes = conn_spikes.to_dense().flatten()
902+
else:
903+
flat_conn_spikes = conn_spikes.flatten()
904+
self.spike_buffer[:, self.counter % self.spike_buffer.shape[1]] = flat_conn_spikes
846905
self.counter += 1
847906

848907
# Update the masks
@@ -857,6 +916,8 @@ def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]:
857916

858917
# self.mask = torch.clamp(self.mask, -1, 1)
859918
self.value = (self.mask > 0).float()
919+
if self.sparse:
920+
self.value = self.value.to_sparse()
860921

861922
return conn_spikes * self.value
862923

0 commit comments

Comments
 (0)