Skip to content

Commit 6abbeed

Browse files
committed
init
1 parent 01bc1b0 commit 6abbeed

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

bindsnet/models/models.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def __init__(
169169
)
170170

171171
# Connections
172-
w = 0.3 * torch.rand(self.n_inpt, self.n_neurons)
172+
w = 0.3 * torch.rand(self.n_inpt, self.n_neurons, dtype=torch.float16)
173173
input_exc_conn = Connection(
174174
source=input_layer,
175175
target=exc_layer,
@@ -181,13 +181,13 @@ def __init__(
181181
wmax=wmax,
182182
norm=norm,
183183
)
184-
w = self.exc * torch.diag(torch.ones(self.n_neurons))
184+
w = self.exc * torch.diag(torch.ones(self.n_neurons, dtype=torch.float16))
185185
exc_inh_conn = Connection(
186186
source=exc_layer, target=inh_layer, w=w, wmin=0, wmax=self.exc
187187
)
188188
w = -self.inh * (
189-
torch.ones(self.n_neurons, self.n_neurons)
190-
- torch.diag(torch.ones(self.n_neurons))
189+
torch.ones(self.n_neurons, self.n_neurons, dtype=torch.float16)
190+
- torch.diag(torch.ones(self.n_neurons, dtype=torch.float16))
191191
)
192192
inh_exc_conn = Connection(
193193
source=inh_layer, target=exc_layer, w=w, wmin=-self.inh, wmax=0

bindsnet/network/topology.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def compute(self, s: torch.Tensor) -> torch.Tensor:
322322
"""
323323
# Compute multiplication of spike activations by weights and add bias.
324324
if self.b is None:
325-
post = s.view(s.size(0), -1).float() @ self.w
325+
post = s.view(s.size(0), -1).to(dtype=torch.float16) @ self.w
326326
else:
327327
post = s.view(s.size(0), -1).float() @ self.w + self.b
328328
return post.view(s.size(0), *self.target.shape)

examples/mnist/eth_mnist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
parser.add_argument("--test", dest="train", action="store_false")
4343
parser.add_argument("--plot", dest="plot", action="store_true")
4444
parser.add_argument("--gpu", dest="gpu", action="store_true")
45-
parser.set_defaults(plot=True, gpu=True)
45+
parser.set_defaults(plot=False, gpu=True)
4646

4747
args = parser.parse_args()
4848

0 commit comments

Comments
 (0)