Skip to content

Commit 9ff6f01

Browse files
committed
Add runtime experiment
1 parent 95626a9 commit 9ff6f01

File tree

1 file changed

+97
-27
lines changed

1 file changed

+97
-27
lines changed
Lines changed: 97 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,108 @@
11
import torch
2+
import time
3+
import argparse
4+
5+
from bindsnet.evaluation import all_activity, assign_labels, proportion_weighting
6+
7+
8+
parser = argparse.ArgumentParser()
9+
parser.add_argument("--benchmark_type", choices=['memory', 'runtime'], default='memory')
10+
args = parser.parse_args()
211

312

413
assert torch.cuda.is_available(), 'Benchmark works only on cuda'
5-
device = torch.device("cuda")
14+
device = torch.device("cpu")
15+
shape = (500, 500, 500)
616

717

8-
def create_spikes_tensor(percent_of_true_values, sparse):
18+
def create_spikes_tensor(percent_of_true_values, sparse, return_memory_usage=True):
919
spikes_tensor = torch.bernoulli(
10-
torch.full((500, 500, 500), percent_of_true_values, device=device)
20+
torch.full(shape, percent_of_true_values, device=device)
1121
).bool()
1222
if sparse:
1323
spikes_tensor = spikes_tensor.to_sparse()
1424

15-
torch.cuda.reset_peak_memory_stats(device=device)
16-
return round(torch.cuda.max_memory_allocated(device=device) / (1024 ** 2))
17-
18-
19-
print('======================= ====================== ====================== ====================')
20-
print('Sparse (megabytes used) Dense (megabytes used) Ratio (Sparse/Dense) % % of non zero values')
21-
print('======================= ====================== ====================== ====================')
22-
percent_of_true_values = 0.005
23-
while percent_of_true_values < 0.1:
24-
result = {}
25-
for sparse in [True, False]:
26-
result[sparse] = create_spikes_tensor(percent_of_true_values, sparse)
27-
percent = round((result[True] / result[False]) * 100)
28-
29-
row = [
30-
str(result[True]).ljust(23),
31-
str(result[False]).ljust(22),
32-
str(percent).ljust(22),
33-
str(round(percent_of_true_values * 100, 1)).ljust(20),
34-
]
35-
print(' '.join(row))
36-
percent_of_true_values += 0.005
37-
38-
print('======================= ====================== ====================== ====================')
25+
if return_memory_usage:
26+
torch.cuda.reset_peak_memory_stats(device=device)
27+
return round(torch.cuda.max_memory_allocated(device=device) / (1024 ** 2))
28+
else:
29+
return spikes_tensor
30+
31+
32+
def memory_benchmark():
33+
print('======================= ====================== ====================== ====================')
34+
print('Sparse (megabytes used) Dense (megabytes used) Ratio (Sparse/Dense) % % of non zero values')
35+
print('======================= ====================== ====================== ====================')
36+
percent_of_true_values = 0.005
37+
while percent_of_true_values < 0.1:
38+
result = {}
39+
for sparse in [True, False]:
40+
result[sparse] = create_spikes_tensor(percent_of_true_values, sparse)
41+
percent = round((result[True] / result[False]) * 100)
42+
43+
row = [
44+
str(result[True]).ljust(23),
45+
str(result[False]).ljust(22),
46+
str(percent).ljust(22),
47+
str(round(percent_of_true_values * 100, 1)).ljust(20),
48+
]
49+
print(' '.join(row))
50+
percent_of_true_values += 0.005
51+
52+
print('======================= ====================== ====================== ====================')
53+
54+
55+
def run(sparse):
56+
n_classes = 10
57+
proportions = torch.zeros((500, n_classes), device=device)
58+
rates = torch.zeros((500, n_classes), device=device)
59+
assignments = -torch.ones(500, device=device)
60+
spike_record = []
61+
for _ in range(5):
62+
tmp = torch.zeros(shape, device=device)
63+
spike_record.append(tmp.to_sparse() if sparse else tmp)
64+
65+
spike_record_idx = 0
66+
67+
delta = 0
68+
for _ in range(10):
69+
start = time.perf_counter()
70+
label_tensor = torch.randint(0, n_classes, (n_classes,), device=device)
71+
spike_record_tensor = torch.cat(spike_record, dim=0)
72+
all_activity(
73+
spikes=spike_record_tensor, assignments=assignments, n_labels=n_classes
74+
)
75+
proportion_weighting(
76+
spikes=spike_record_tensor,
77+
assignments=assignments,
78+
proportions=proportions,
79+
n_labels=n_classes,
80+
)
81+
82+
assignments, proportions, rates = assign_labels(
83+
spikes=spike_record_tensor,
84+
labels=label_tensor,
85+
n_labels=n_classes,
86+
rates=rates,
87+
)
88+
delta += time.perf_counter() - start
89+
spike_record[spike_record_idx] = create_spikes_tensor(
90+
0.03,
91+
sparse,
92+
return_memory_usage=False
93+
)
94+
spike_record_idx += 1
95+
if spike_record_idx == len(spike_record):
96+
spike_record_idx = 0
97+
return round(delta, 1)
98+
99+
100+
def runtime_benchmark():
101+
print(f"Sparse runtime: {run(True)} seconds")
102+
print(f"Dense runtime: {run(False)} seconds")
103+
104+
105+
if args.benchmark_type == 'memory':
106+
memory_benchmark()
107+
else:
108+
runtime_benchmark()

0 commit comments

Comments
 (0)