|
1 | 1 | import torch
|
| 2 | +import time |
| 3 | +import argparse |
| 4 | + |
| 5 | +from bindsnet.evaluation import all_activity, assign_labels, proportion_weighting |
| 6 | + |
| 7 | + |
| 8 | +parser = argparse.ArgumentParser() |
| 9 | +parser.add_argument("--benchmark_type", choices=['memory', 'runtime'], default='memory') |
| 10 | +args = parser.parse_args() |
2 | 11 |
|
3 | 12 |
|
4 | 13 | assert torch.cuda.is_available(), 'Benchmark works only on cuda'
|
5 |
| -device = torch.device("cuda") |
| 14 | +device = torch.device("cpu") |
| 15 | +shape = (500, 500, 500) |
6 | 16 |
|
7 | 17 |
|
8 |
| -def create_spikes_tensor(percent_of_true_values, sparse): |
| 18 | +def create_spikes_tensor(percent_of_true_values, sparse, return_memory_usage=True): |
9 | 19 | spikes_tensor = torch.bernoulli(
|
10 |
| - torch.full((500, 500, 500), percent_of_true_values, device=device) |
| 20 | + torch.full(shape, percent_of_true_values, device=device) |
11 | 21 | ).bool()
|
12 | 22 | if sparse:
|
13 | 23 | spikes_tensor = spikes_tensor.to_sparse()
|
14 | 24 |
|
15 |
| - torch.cuda.reset_peak_memory_stats(device=device) |
16 |
| - return round(torch.cuda.max_memory_allocated(device=device) / (1024 ** 2)) |
17 |
| - |
18 |
| - |
19 |
| -print('======================= ====================== ====================== ====================') |
20 |
| -print('Sparse (megabytes used) Dense (megabytes used) Ratio (Sparse/Dense) % % of non zero values') |
21 |
| -print('======================= ====================== ====================== ====================') |
22 |
| -percent_of_true_values = 0.005 |
23 |
| -while percent_of_true_values < 0.1: |
24 |
| - result = {} |
25 |
| - for sparse in [True, False]: |
26 |
| - result[sparse] = create_spikes_tensor(percent_of_true_values, sparse) |
27 |
| - percent = round((result[True] / result[False]) * 100) |
28 |
| - |
29 |
| - row = [ |
30 |
| - str(result[True]).ljust(23), |
31 |
| - str(result[False]).ljust(22), |
32 |
| - str(percent).ljust(22), |
33 |
| - str(round(percent_of_true_values * 100, 1)).ljust(20), |
34 |
| - ] |
35 |
| - print(' '.join(row)) |
36 |
| - percent_of_true_values += 0.005 |
37 |
| - |
38 |
| -print('======================= ====================== ====================== ====================') |
| 25 | + if return_memory_usage: |
| 26 | + torch.cuda.reset_peak_memory_stats(device=device) |
| 27 | + return round(torch.cuda.max_memory_allocated(device=device) / (1024 ** 2)) |
| 28 | + else: |
| 29 | + return spikes_tensor |
| 30 | + |
| 31 | + |
| 32 | +def memory_benchmark(): |
| 33 | + print('======================= ====================== ====================== ====================') |
| 34 | + print('Sparse (megabytes used) Dense (megabytes used) Ratio (Sparse/Dense) % % of non zero values') |
| 35 | + print('======================= ====================== ====================== ====================') |
| 36 | + percent_of_true_values = 0.005 |
| 37 | + while percent_of_true_values < 0.1: |
| 38 | + result = {} |
| 39 | + for sparse in [True, False]: |
| 40 | + result[sparse] = create_spikes_tensor(percent_of_true_values, sparse) |
| 41 | + percent = round((result[True] / result[False]) * 100) |
| 42 | + |
| 43 | + row = [ |
| 44 | + str(result[True]).ljust(23), |
| 45 | + str(result[False]).ljust(22), |
| 46 | + str(percent).ljust(22), |
| 47 | + str(round(percent_of_true_values * 100, 1)).ljust(20), |
| 48 | + ] |
| 49 | + print(' '.join(row)) |
| 50 | + percent_of_true_values += 0.005 |
| 51 | + |
| 52 | + print('======================= ====================== ====================== ====================') |
| 53 | + |
| 54 | + |
| 55 | +def run(sparse): |
| 56 | + n_classes = 10 |
| 57 | + proportions = torch.zeros((500, n_classes), device=device) |
| 58 | + rates = torch.zeros((500, n_classes), device=device) |
| 59 | + assignments = -torch.ones(500, device=device) |
| 60 | + spike_record = [] |
| 61 | + for _ in range(5): |
| 62 | + tmp = torch.zeros(shape, device=device) |
| 63 | + spike_record.append(tmp.to_sparse() if sparse else tmp) |
| 64 | + |
| 65 | + spike_record_idx = 0 |
| 66 | + |
| 67 | + delta = 0 |
| 68 | + for _ in range(10): |
| 69 | + start = time.perf_counter() |
| 70 | + label_tensor = torch.randint(0, n_classes, (n_classes,), device=device) |
| 71 | + spike_record_tensor = torch.cat(spike_record, dim=0) |
| 72 | + all_activity( |
| 73 | + spikes=spike_record_tensor, assignments=assignments, n_labels=n_classes |
| 74 | + ) |
| 75 | + proportion_weighting( |
| 76 | + spikes=spike_record_tensor, |
| 77 | + assignments=assignments, |
| 78 | + proportions=proportions, |
| 79 | + n_labels=n_classes, |
| 80 | + ) |
| 81 | + |
| 82 | + assignments, proportions, rates = assign_labels( |
| 83 | + spikes=spike_record_tensor, |
| 84 | + labels=label_tensor, |
| 85 | + n_labels=n_classes, |
| 86 | + rates=rates, |
| 87 | + ) |
| 88 | + delta += time.perf_counter() - start |
| 89 | + spike_record[spike_record_idx] = create_spikes_tensor( |
| 90 | + 0.03, |
| 91 | + sparse, |
| 92 | + return_memory_usage=False |
| 93 | + ) |
| 94 | + spike_record_idx += 1 |
| 95 | + if spike_record_idx == len(spike_record): |
| 96 | + spike_record_idx = 0 |
| 97 | + return round(delta, 1) |
| 98 | + |
| 99 | + |
| 100 | +def runtime_benchmark(): |
| 101 | + print(f"Sparse runtime: {run(True)} seconds") |
| 102 | + print(f"Dense runtime: {run(False)} seconds") |
| 103 | + |
| 104 | + |
| 105 | +if args.benchmark_type == 'memory': |
| 106 | + memory_benchmark() |
| 107 | +else: |
| 108 | + runtime_benchmark() |
0 commit comments