Skip to content

Commit bf329b9

Browse files
Merge pull request #687 from BindsNET/hananel
CUDA update and deprecation python 3.8 welcome 3.11
2 parents 65fd024 + d8000ee commit bf329b9

File tree

5 files changed

+1232
-1040
lines changed

5 files changed

+1232
-1040
lines changed

.github/workflows/pythonpackage.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ jobs:
99
strategy:
1010
max-parallel: 4
1111
matrix:
12-
python-version: ["3.8", "3.9", "3.10"]
12+
python-version: ["3.9", "3.10", "3.11"]
1313

1414
steps:
1515
- uses: actions/checkout@v3

bindsnet/datasets/torchvision_wrapper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Dict, Optional
22

33
import torch
4-
import torchvision
4+
from torchvision import datasets as torchDB
55

66
from bindsnet.encoding import Encoder, NullEncoder
77

@@ -13,7 +13,7 @@ def create_torchvision_dataset_wrapper(ds_type):
1313
``__getitem__``. This applies to all of the datasets inside of ``torchvision``.
1414
"""
1515
if type(ds_type) == str:
16-
ds_type = getattr(torchvision.datasets, ds_type)
16+
ds_type = getattr(torchDB, ds_type)
1717

1818
class TorchvisionDatasetWrapper(ds_type):
1919
__doc__ = (

bindsnet/network/topology.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ def __init__(
4848
:param Union[float, torch.Tensor] wmax: Maximum allowed value(s) on the connection weights. Single value, or
4949
tensor of same size as w
5050
:param float norm: Total weight per target neuron normalization.
51+
:param Union[bool, torch.Tensor] Dales_rule: Whether to enforce Dale's rule. input is boolean tensor in weight shape
52+
where True means force zero or positive values and False means force zero or negative values.
5153
"""
5254
super().__init__()
5355

@@ -88,6 +90,12 @@ def __init__(
8890
**kwargs,
8991
)
9092

93+
self.Dales_rule = kwargs.get("Dales_rule", None)
94+
if self.Dales_rule is not None:
95+
self.Dales_rule = Parameter(
96+
torch.as_tensor(self.Dales_rule, dtype=torch.bool), requires_grad=False
97+
)
98+
9199
@abstractmethod
92100
def compute(self, s: torch.Tensor) -> None:
93101
# language=rst
@@ -117,6 +125,12 @@ def update(self, **kwargs) -> None:
117125
if mask is not None:
118126
self.w.masked_fill_(mask, 0)
119127

128+
if self.Dales_rule is not None:
129+
# weight that are negative and should be positive are set to 0
130+
self.w[self.w < 0 * self.Dales_rule.to(torch.float)] = 0
131+
# weight that are positive and should be negative are set to 0
132+
self.w[self.w > 0 * 1 - self.Dales_rule.to(torch.float)] = 0
133+
120134
@abstractmethod
121135
def reset_state_variables(self) -> None:
122136
# language=rst

0 commit comments

Comments
 (0)