@@ -31,6 +31,7 @@ def __init__(
31
31
enforce_polarity : Optional [bool ] = False ,
32
32
decay : float = 0.0 ,
33
33
parent_feature = None ,
34
+ sparse : Optional [bool ] = False ,
34
35
** kwargs ,
35
36
) -> None :
36
37
# language=rst
@@ -47,6 +48,7 @@ def __init__(
47
48
dimension
48
49
:param decay: Constant multiple to decay weights by on each iteration
49
50
:param parent_feature: Parent feature to inherit :code:`value` from
51
+ :param sparse: Should :code:`value` parameter be sparse tensor or not
50
52
"""
51
53
52
54
#### Initialize class variables ####
@@ -61,6 +63,7 @@ def __init__(
61
63
self .reduction = reduction
62
64
self .decay = decay
63
65
self .parent_feature = parent_feature
66
+ self .sparse = sparse
64
67
self .kwargs = kwargs
65
68
66
69
## Backend ##
@@ -119,6 +122,10 @@ def __init__(
119
122
self .assert_valid_range ()
120
123
if value is not None :
121
124
self .assert_feature_in_range ()
125
+ if self .sparse :
126
+ self .value = self .value .to_sparse ()
127
+ assert not getattr (self , 'enforce_polarity' , False ), \
128
+ "enforce_polarity isn't supported for sparse tensors"
122
129
123
130
@abstractmethod
124
131
def reset_state_variables (self ) -> None :
@@ -161,7 +168,10 @@ def prime_feature(self, connection, device, **kwargs) -> None:
161
168
162
169
# Check if values/norms are the correct shape
163
170
if isinstance (self .value , torch .Tensor ):
164
- assert tuple (self .value .shape ) == (connection .source .n , connection .target .n )
171
+ if self .sparse :
172
+ assert tuple (self .value .shape [1 :]) == (connection .source .n , connection .target .n )
173
+ else :
174
+ assert tuple (self .value .shape ) == (connection .source .n , connection .target .n )
165
175
166
176
if self .norm is not None and isinstance (self .norm , torch .Tensor ):
167
177
assert self .norm .shape [0 ] == connection .target .n
@@ -214,9 +224,15 @@ def normalize(self) -> None:
214
224
"""
215
225
216
226
if self .norm is not None :
217
- abs_sum = self .value .sum (0 ).unsqueeze (0 )
218
- abs_sum [abs_sum == 0 ] = 1.0
219
- self .value *= self .norm / abs_sum
227
+ if self .sparse :
228
+ abs_sum = self .value .sum (1 ).to_dense ()
229
+ abs_sum [abs_sum == 0 ] = 1.0
230
+ abs_sum = abs_sum .unsqueeze (1 ).expand (- 1 , * self .value .shape [1 :])
231
+ self .value = self .value * (self .norm / abs_sum )
232
+ else :
233
+ abs_sum = self .value .sum (0 ).unsqueeze (0 )
234
+ abs_sum [abs_sum == 0 ] = 1.0
235
+ self .value *= self .norm / abs_sum
220
236
221
237
def degrade (self ) -> None :
222
238
# language=rst
@@ -299,11 +315,17 @@ def assert_feature_in_range(self):
299
315
300
316
def assert_valid_shape (self , source_shape , target_shape , f ):
301
317
# Multidimensional feat
302
- if len (f .shape ) > 1 :
303
- assert f .shape == (
318
+ if (not self .sparse and len (f .shape ) > 1 ) or (self .sparse and len (f .shape [1 :]) > 1 ):
319
+ if self .sparse :
320
+ f_shape = f .shape [1 :]
321
+ expected = ('batch_size' , source_shape , target_shape )
322
+ else :
323
+ f_shape = f .shape
324
+ expected = (source_shape , target_shape )
325
+ assert f_shape == (
304
326
source_shape ,
305
327
target_shape ,
306
- ), f"Feature { self .name } has an incorrect shape of { f .shape } . Should be of shape { ( source_shape , target_shape ) } "
328
+ ), f"Feature { self .name } has an incorrect shape of { f .shape } . Should be of shape { expected } "
307
329
# Else assume scalar, which is a valid shape
308
330
309
331
@@ -319,6 +341,7 @@ def __init__(
319
341
reduction : Optional [callable ] = None ,
320
342
decay : float = 0.0 ,
321
343
parent_feature = None ,
344
+ sparse : Optional [bool ] = False
322
345
) -> None :
323
346
# language=rst
324
347
"""
@@ -336,6 +359,7 @@ def __init__(
336
359
dimension
337
360
:param decay: Constant multiple to decay weights by on each iteration
338
361
:param parent_feature: Parent feature to inherit :code:`value` from
362
+ :param sparse: Should :code:`value` parameter be sparse tensor or not
339
363
"""
340
364
341
365
### Assertions ###
@@ -349,10 +373,25 @@ def __init__(
349
373
reduction = reduction ,
350
374
decay = decay ,
351
375
parent_feature = parent_feature ,
376
+ sparse = sparse
377
+ )
378
+
379
+ def sparse_bernoulli (self ):
380
+ values = torch .bernoulli (self .value .values ())
381
+ mask = values != 0
382
+ indices = self .value .indices ()[:, mask ]
383
+ non_zero = values [mask ]
384
+ return torch .sparse_coo_tensor (
385
+ indices ,
386
+ non_zero ,
387
+ self .value .size ()
352
388
)
353
389
354
390
def compute (self , conn_spikes ) -> Union [torch .Tensor , float , int ]:
355
- return conn_spikes * torch .bernoulli (self .value )
391
+ if self .sparse :
392
+ return conn_spikes * self .sparse_bernoulli ()
393
+ else :
394
+ return conn_spikes * torch .bernoulli (self .value )
356
395
357
396
def reset_state_variables (self ) -> None :
358
397
pass
@@ -395,12 +434,14 @@ def __init__(
395
434
self ,
396
435
name : str ,
397
436
value : Union [torch .Tensor , float , int ] = None ,
437
+ sparse : Optional [bool ] = False
398
438
) -> None :
399
439
# language=rst
400
440
"""
401
441
Boolean mask which determines whether or not signals are allowed to traverse certain synapses.
402
442
:param name: Name of the feature
403
443
:param value: Boolean mask. :code:`True` means a signal can pass, :code:`False` means the synapse is impassable
444
+ :param sparse: Should :code:`value` parameter be sparse tensor or not
404
445
"""
405
446
406
447
### Assertions ###
@@ -419,11 +460,9 @@ def __init__(
419
460
super ().__init__ (
420
461
name = name ,
421
462
value = value ,
463
+ sparse = sparse
422
464
)
423
465
424
- self .name = name
425
- self .value = value
426
-
427
466
def compute (self , conn_spikes ) -> torch .Tensor :
428
467
return conn_spikes * self .value
429
468
@@ -505,6 +544,7 @@ def __init__(
505
544
reduction : Optional [callable ] = None ,
506
545
enforce_polarity : Optional [bool ] = False ,
507
546
decay : float = 0.0 ,
547
+ sparse : Optional [bool ] = False
508
548
) -> None :
509
549
# language=rst
510
550
"""
@@ -523,6 +563,7 @@ def __init__(
523
563
dimension
524
564
:param enforce_polarity: Will prevent synapses from changing signs if :code:`True`
525
565
:param decay: Constant multiple to decay weights by on each iteration
566
+ :param sparse: Should :code:`value` parameter be sparse tensor or not
526
567
"""
527
568
528
569
self .norm_frequency = norm_frequency
@@ -536,6 +577,7 @@ def __init__(
536
577
nu = nu ,
537
578
reduction = reduction ,
538
579
decay = decay ,
580
+ sparse = sparse
539
581
)
540
582
541
583
def reset_state_variables (self ) -> None :
@@ -589,6 +631,7 @@ def __init__(
589
631
value : Union [torch .Tensor , float , int ] = None ,
590
632
range : Optional [Sequence [float ]] = None ,
591
633
norm : Optional [Union [torch .Tensor , float , int ]] = None ,
634
+ sparse : Optional [bool ] = False
592
635
) -> None :
593
636
# language=rst
594
637
"""
@@ -598,13 +641,15 @@ def __init__(
598
641
:param range: Range of acceptable values for the :code:`value` parameter
599
642
:param norm: Value which all values in :code:`value` will sum to. Normalization of values occurs after each sample
600
643
and after the value has been updated by the learning rule (if there is one)
644
+ :param sparse: Should :code:`value` parameter be sparse tensor or not
601
645
"""
602
646
603
647
super ().__init__ (
604
648
name = name ,
605
649
value = value ,
606
650
range = [- torch .inf , + torch .inf ] if range is None else range ,
607
651
norm = norm ,
652
+ sparse = sparse
608
653
)
609
654
610
655
def reset_state_variables (self ) -> None :
@@ -629,15 +674,17 @@ def __init__(
629
674
name : str ,
630
675
value : Union [torch .Tensor , float , int ] = None ,
631
676
range : Optional [Sequence [float ]] = None ,
677
+ sparse : Optional [bool ] = False
632
678
) -> None :
633
679
# language=rst
634
680
"""
635
681
Adds scalars to signals
636
682
:param name: Name of the feature
637
683
:param value: Values to scale signals by
684
+ :param sparse: Should :code:`value` parameter be sparse tensor or not
638
685
"""
639
686
640
- super ().__init__ (name = name , value = value , range = range )
687
+ super ().__init__ (name = name , value = value , range = range , sparse = sparse )
641
688
642
689
def reset_state_variables (self ) -> None :
643
690
pass
@@ -666,6 +713,7 @@ def __init__(
666
713
value : Union [torch .Tensor , float , int ] = None ,
667
714
degrade_function : callable = None ,
668
715
parent_feature : Optional [AbstractFeature ] = None ,
716
+ sparse : Optional [bool ] = False
669
717
) -> None :
670
718
# language=rst
671
719
"""
@@ -676,10 +724,11 @@ def __init__(
676
724
:param degrade_function: Callable function which takes a single argument (:code:`value`) and returns a tensor or
677
725
constant to be *subtracted* from the propagating spikes.
678
726
:param parent_feature: Parent feature with desired :code:`value` to inherit
727
+ :param sparse: Should :code:`value` parameter be sparse tensor or not
679
728
"""
680
729
681
730
# Note: parent_feature will override value. See abstract constructor
682
- super ().__init__ (name = name , value = value , parent_feature = parent_feature )
731
+ super ().__init__ (name = name , value = value , parent_feature = parent_feature , sparse = sparse )
683
732
684
733
self .degrade_function = degrade_function
685
734
@@ -698,6 +747,7 @@ def __init__(
698
747
ann_values : Union [list , tuple ] = None ,
699
748
const_update_rate : float = 0.1 ,
700
749
const_decay : float = 0.001 ,
750
+ sparse : Optional [bool ] = False
701
751
) -> None :
702
752
# language=rst
703
753
"""
@@ -708,6 +758,7 @@ def __init__(
708
758
:param value: Values to be use to build an initial mask for the synapses.
709
759
:param const_update_rate: The mask upatate rate of the ANN decision.
710
760
:param const_decay: The spontaneous activation of the synapses.
761
+ :param sparse: Should :code:`value` parameter be sparse tensor or not
711
762
"""
712
763
713
764
# Define the ANN
@@ -743,16 +794,18 @@ def forward(self, x):
743
794
self .const_update_rate = const_update_rate
744
795
self .const_decay = const_decay
745
796
746
- super ().__init__ (name = name , value = value )
797
+ super ().__init__ (name = name , value = value , sparse = sparse )
747
798
748
799
def compute (self , conn_spikes ) -> Union [torch .Tensor , float , int ]:
749
800
750
801
# Update the spike buffer
751
802
if self .start_counter == False or conn_spikes .sum () > 0 :
752
803
self .start_counter = True
753
- self .spike_buffer [:, self .counter % self .spike_buffer .shape [1 ]] = (
754
- conn_spikes .flatten ()
755
- )
804
+ if self .sparse :
805
+ flat_conn_spikes = conn_spikes .to_dense ().flatten ()
806
+ else :
807
+ flat_conn_spikes = conn_spikes .flatten ()
808
+ self .spike_buffer [:, self .counter % self .spike_buffer .shape [1 ]] = flat_conn_spikes
756
809
self .counter += 1
757
810
758
811
# Update the masks
@@ -767,6 +820,8 @@ def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]:
767
820
768
821
# self.mask = torch.clamp(self.mask, -1, 1)
769
822
self .value = (self .mask > 0 ).float ()
823
+ if self .sparse :
824
+ self .value = self .value .to_sparse ()
770
825
771
826
return conn_spikes * self .value
772
827
@@ -788,6 +843,7 @@ def __init__(
788
843
ann_values : Union [list , tuple ] = None ,
789
844
const_update_rate : float = 0.1 ,
790
845
const_decay : float = 0.01 ,
846
+ sparse : Optional [bool ] = False
791
847
) -> None :
792
848
# language=rst
793
849
"""
@@ -798,6 +854,7 @@ def __init__(
798
854
:param value: Values to be use to build an initial mask for the synapses.
799
855
:param const_update_rate: The mask upatate rate of the ANN decision.
800
856
:param const_decay: The spontaneous activation of the synapses.
857
+ :param sparse: Should :code:`value` parameter be sparse tensor or not
801
858
"""
802
859
803
860
# Define the ANN
@@ -833,16 +890,18 @@ def forward(self, x):
833
890
self .const_update_rate = const_update_rate
834
891
self .const_decay = const_decay
835
892
836
- super ().__init__ (name = name , value = value )
893
+ super ().__init__ (name = name , value = value , sparse = sparse )
837
894
838
895
def compute (self , conn_spikes ) -> Union [torch .Tensor , float , int ]:
839
896
840
897
# Update the spike buffer
841
898
if self .start_counter == False or conn_spikes .sum () > 0 :
842
899
self .start_counter = True
843
- self .spike_buffer [:, self .counter % self .spike_buffer .shape [1 ]] = (
844
- conn_spikes .flatten ()
845
- )
900
+ if self .sparse :
901
+ flat_conn_spikes = conn_spikes .to_dense ().flatten ()
902
+ else :
903
+ flat_conn_spikes = conn_spikes .flatten ()
904
+ self .spike_buffer [:, self .counter % self .spike_buffer .shape [1 ]] = flat_conn_spikes
846
905
self .counter += 1
847
906
848
907
# Update the masks
@@ -857,6 +916,8 @@ def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]:
857
916
858
917
# self.mask = torch.clamp(self.mask, -1, 1)
859
918
self .value = (self .mask > 0 ).float ()
919
+ if self .sparse :
920
+ self .value = self .value .to_sparse ()
860
921
861
922
return conn_spikes * self .value
862
923
0 commit comments