Skip to content
This repository was archived by the owner on Feb 26, 2025. It is now read-only.

Commit 5f411a8

Browse files
ilkilicKiliç Ilkan Fabrice
and
Kiliç Ilkan Fabrice
authored
having consistent sampling options across ssim methods #44 (#52)
* consistent sampling options across ssim methods * lint fix * lint fix * Updated type hint in _sample_array function to ndarray from Sequence * lint fix * bug fix * bug fix --------- Co-authored-by: Kiliç Ilkan Fabrice <[email protected]>
1 parent 81ce66d commit 5f411a8

File tree

4 files changed

+81
-43
lines changed

4 files changed

+81
-43
lines changed

bluecellulab/circuit/simulation_access.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,11 @@
1515

1616
from __future__ import annotations
1717
from pathlib import Path
18-
from platform import python_version_tuple
1918
from typing import Optional, Protocol
2019

2120
from bluecellulab.circuit.config import SimulationConfig, SonataSimulationConfig
2221
from bluecellulab.exceptions import ExtraDependencyMissingError
2322

24-
if python_version_tuple() < ('3', '9'):
25-
from typing import Sequence
26-
else:
27-
from collections.abc import Sequence
28-
2923
from bluecellulab import BLUEPY_AVAILABLE
3024
if BLUEPY_AVAILABLE:
3125
import bluepy
@@ -38,23 +32,22 @@
3832
from bluecellulab.circuit.config import BluepySimulationConfig
3933

4034

41-
def _sample_array(arr: Sequence, t_step: float, sim_t_step: float) -> Sequence:
35+
def _sample_array(arr: np.ndarray, ratio: float) -> np.ndarray:
4236
"""Sample an array at a given time step.
4337
4438
Args:
4539
arr: Array to sample.
46-
t_step: User specified time step to sample at.
47-
sim_t_step: Time step used in the main simulation.
40+
ratio: The ratio between the time step used for sampling and the time step used in the simulation
41+
(the time step should be a multiple of the simulation time step)
4842
4943
Returns:
5044
Array sampled at the given time step.
5145
"""
52-
ratio = t_step / sim_t_step
53-
if t_step == sim_t_step:
46+
if ratio == 1:
5447
return arr
5548
elif not np.isclose(ratio, round(ratio)):
56-
raise ValueError(
57-
f"Time step {t_step} is not a multiple of the simulation time step {sim_t_step}.")
49+
raise ValueError("The ratio is not close to a whole number. "
50+
"The time step should be a multiple of the simulation time step.")
5851
return arr[::round(ratio)]
5952

6053

@@ -105,15 +98,17 @@ def get_soma_voltage(
10598
.to_numpy()
10699
)
107100
if t_step is not None:
108-
arr = _sample_array(arr, t_step, self._config._soma_report_dt)
101+
ratio = t_step / self._config._soma_report_dt
102+
arr = _sample_array(arr, ratio)
109103
return arr
110104

111105
def get_soma_time_trace(self, t_step: Optional[float] = None) -> np.ndarray:
112106
"""Retrieve the time trace from the main simulation."""
113107
report = self.impl.report('soma')
114108
arr = report.get_gid(report.gids[0]).index.to_numpy()
115109
if t_step is not None:
116-
arr = _sample_array(arr, t_step, self._config._soma_report_dt)
110+
ratio = t_step / self._config._soma_report_dt
111+
arr = _sample_array(arr, ratio)
117112
return arr
118113

119114
def get_spikes(self) -> dict[CellId, np.ndarray]:
@@ -140,14 +135,16 @@ def get_soma_voltage(
140135
report = self.impl.reports["soma"].filter(cell_id.id, t_start, t_end)
141136
arr = report.report[cell_id.population_name][cell_id.id].values
142137
if t_step is not None:
143-
arr = _sample_array(arr, t_step, self.impl.dt)
138+
ratio = t_step / self.impl.dt
139+
arr = _sample_array(arr, ratio)
144140
return arr
145141

146142
def get_soma_time_trace(self, t_step: Optional[float] = None) -> np.ndarray:
147143
report = self.impl.reports["soma"]
148144
arr = report.filter().report.index.values
149145
if t_step is not None:
150-
arr = _sample_array(arr, t_step, self.impl.dt)
146+
ratio = t_step / self.impl.dt
147+
arr = _sample_array(arr, ratio)
151148
return arr
152149

153150
def get_spikes(self) -> dict[CellId, np.ndarray]:

bluecellulab/ssim.py

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from bluecellulab.circuit.config import SimulationConfig
3939
from bluecellulab.circuit.format import determine_circuit_format, CircuitFormat
4040
from bluecellulab.circuit.node_id import create_cell_id, create_cell_ids
41-
from bluecellulab.circuit.simulation_access import BluepySimulationAccess, SimulationAccess, SonataSimulationAccess
41+
from bluecellulab.circuit.simulation_access import BluepySimulationAccess, SimulationAccess, SonataSimulationAccess, _sample_array
4242
from bluecellulab.stimuli import Pattern
4343
import bluecellulab.stimuli as stimuli
4444
from bluecellulab.exceptions import BluecellulabError
@@ -656,9 +656,18 @@ def get_mainsim_voltage_trace(
656656
cell_id = create_cell_id(cell_id)
657657
return self.simulation_access.get_soma_voltage(cell_id, t_start, t_stop, t_step)
658658

659-
def get_mainsim_time_trace(self) -> np.ndarray:
660-
"""Get the time trace from the main simulation."""
661-
return self.simulation_access.get_soma_time_trace()
659+
def get_mainsim_time_trace(self, t_step=None) -> np.ndarray:
660+
"""Get the time trace from the main simulation.
661+
662+
Parameters
663+
-----------
664+
t_step: time step (should be a multiple of report time step T;
665+
equals T by default)
666+
667+
Returns:
668+
One dimentional np.ndarray to represent the times.
669+
"""
670+
return self.simulation_access.get_soma_time_trace(t_step)
662671

663672
def get_time(self) -> np.ndarray:
664673
"""Get the time vector for the recordings, contains negative times.
@@ -668,17 +677,56 @@ def get_time(self) -> np.ndarray:
668677
first_key = next(iter(self.cells))
669678
return self.cells[first_key].get_time()
670679

671-
def get_time_trace(self) -> np.ndarray:
672-
"""Get the time vector for the recordings, negative times removed."""
680+
def get_time_trace(self, t_step=None) -> np.ndarray:
681+
"""Get the time vector for the recordings, negative times removed.
682+
683+
Parameters
684+
-----------
685+
t_step: time step (should be a multiple of report time step T;
686+
equals T by default)
687+
688+
Returns:
689+
One dimentional np.ndarray to represent the times.
690+
"""
673691
time = self.get_time()
674-
return time[np.where(time >= 0.0)]
692+
time = time[np.where(time >= 0.0)]
675693

676-
def get_voltage_trace(self, cell_id: int | tuple[str, int]) -> np.ndarray:
677-
"""Get the voltage vector for the cell_id, negative times removed."""
694+
if t_step is not None:
695+
ratio = t_step / self.dt
696+
time = _sample_array(time, ratio)
697+
return time
698+
699+
def get_voltage_trace(
700+
self, cell_id: int | tuple[str, int], t_start=None, t_stop=None, t_step=None
701+
) -> np.ndarray:
702+
"""Get the voltage vector for the cell_id, negative times removed.
703+
704+
Parameters
705+
-----------
706+
cell_id: cell id of interest.
707+
t_start, t_stop: time range of interest,
708+
report time range is used by default.
709+
t_step: time step (should be a multiple of report time step T;
710+
equals T by default)
711+
712+
Returns:
713+
One dimentional np.ndarray to represent the voltages.
714+
"""
678715
cell_id = create_cell_id(cell_id)
679716
time = self.get_time()
680717
voltage = self.cells[cell_id].get_soma_voltage()
681-
return voltage[np.where(time >= 0.0)]
718+
719+
if t_start is None or t_start < 0:
720+
t_start = 0
721+
if t_stop is None:
722+
t_stop = np.inf
723+
724+
voltage = voltage[np.where((time >= t_start) & (time <= t_stop))]
725+
726+
if t_step is not None:
727+
ratio = t_step / self.dt
728+
voltage = _sample_array(voltage, ratio)
729+
return voltage
682730

683731
def delete(self):
684732
"""Delete ssim and all of its attributes.

tests/test_circuit/test_simulation_access.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,20 @@ def test_sample_array():
2525
arr = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
2626
t_step = 0.1
2727
sim_t_step = 0.1
28-
assert np.array_equal(_sample_array(arr, t_step, sim_t_step), arr)
28+
assert np.array_equal(_sample_array(arr, t_step / sim_t_step), arr)
2929
t_step = 0.2
3030
sim_t_step = 0.1
31-
assert np.array_equal(_sample_array(arr, t_step, sim_t_step), [1, 3, 5, 7, 9])
31+
assert np.array_equal(_sample_array(arr, t_step / sim_t_step), [1, 3, 5, 7, 9])
3232
t_step = 0.3
3333
sim_t_step = 0.1
34-
assert np.array_equal(_sample_array(arr, t_step, sim_t_step), [1, 4, 7, 10])
34+
assert np.array_equal(_sample_array(arr, t_step / sim_t_step), [1, 4, 7, 10])
3535
t_step = 0.4
3636
sim_t_step = 0.1
37-
assert np.array_equal(_sample_array(arr, t_step, sim_t_step), [1, 5, 9])
37+
assert np.array_equal(_sample_array(arr, t_step / sim_t_step), [1, 5, 9])
3838
t_step = 0.1
3939
sim_t_step = 0.2
4040
with pytest.raises(ValueError):
41-
_sample_array(arr, t_step, sim_t_step)
41+
_sample_array(arr, t_step / sim_t_step)
4242

4343

4444
class TestSonataSimulationAccess:

tests/test_ssim_sonata.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Testing SSim with SONATA simulations."""
22

33
from pathlib import Path
4-
from bluecellulab.circuit.simulation_access import _sample_array
54

65
import numpy as np
76
import pytest
@@ -34,11 +33,8 @@ def test_sim_quick_scx_sonata(input_type):
3433
sim.run(t_stop)
3534

3635
# Get the voltage trace
37-
time = sim.get_time_trace()
38-
voltage = sim.get_voltage_trace(cell_id)
39-
# Sampling
40-
voltage = _sample_array(voltage, 1.0, 0.025)
41-
time = _sample_array(time, 1.0, 0.025)
36+
time = sim.get_time_trace(1)
37+
voltage = sim.get_voltage_trace(cell_id, 0, t_stop, 1)
4238
voltage = voltage[:len(voltage) - 1] # remove last point, mainsim produces 1 less
4339
time = time[:len(time) - 1] # remove last point, mainsim produces 1 less
4440
mainsim_voltage = sim.get_mainsim_voltage_trace(cell_id)
@@ -69,11 +65,8 @@ def test_sim_quick_scx_sonata_multicircuit(input_type):
6965
t_stop = 20.0
7066
sim.run(t_stop)
7167
for cell_id in cell_ids:
72-
voltage = sim.get_voltage_trace(cell_id)
73-
# Sampling
74-
voltage = _sample_array(voltage, 1.0, 0.025)
75-
# remove the last one
76-
voltage = voltage[:len(voltage) - 1]
68+
voltage = sim.get_voltage_trace(cell_id, 0, t_stop, 1)
69+
voltage = voltage[:len(voltage) - 1] # remove last point, mainsim produces 1 less
7770
mainsim_voltage = sim.get_mainsim_voltage_trace(cell_id)
7871
voltage_diff = voltage - mainsim_voltage
7972
rms_error = np.sqrt(np.mean(voltage_diff ** 2))

0 commit comments

Comments
 (0)