Skip to content

Commit 8680cf5

Browse files
authored
Refactor SBML event processing, defer computing state updates (#2742)
To support `use_values_from_trigger_time=true`, `deltax` may have to be computed from `x_old` instead of `x`. At the stage of processing the SBML model, there is no concept of `x_old`. Thus, the state updates should only be computed inside the `DEModel`. Also independently of `use_values_from_trigger_time`, this seems to be the more appropriate place to compute the bolus. Therefore, store only the event assignment in `Event`, and compute the bolus on demand.
1 parent 2432657 commit 8680cf5

File tree

3 files changed

+68
-49
lines changed

3 files changed

+68
-49
lines changed

python/sdist/amici/de_model.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -534,11 +534,14 @@ def get_rate(symbol: sp.Symbol):
534534
)
535535

536536
for event in self.events():
537-
if event._state_update is None:
537+
state_update = event.get_state_update(
538+
x=self.sym("x"), x_old=self.sym("x")
539+
)
540+
if state_update is None:
538541
continue
539542

540-
for i_state in range(len(event._state_update)):
541-
if rate_ofs := event._state_update[i_state].find(rate_of_func):
543+
for i_state in range(len(state_update)):
544+
if rate_ofs := state_update[i_state].find(rate_of_func):
542545
raise SBMLException(
543546
"AMICI does currently not support rateOf(.) inside event state updates."
544547
)
@@ -1612,10 +1615,15 @@ def _compute_equation(self, name: str) -> None:
16121615
# would cause problems when writing the function file later
16131616
event_eqs = []
16141617
for event in self._events:
1615-
if event._state_update is None:
1618+
# TODO https://github.com/AMICI-dev/AMICI/issues/2719
1619+
# with use_values_from_trigger_time=True: x_old != x
1620+
state_update = event.get_state_update(
1621+
x=self.sym("x"), x_old=self.sym("x")
1622+
)
1623+
if state_update is None:
16161624
event_eqs.append(sp.zeros(self.num_states_solver(), 1))
16171625
else:
1618-
event_eqs.append(event._state_update)
1626+
event_eqs.append(state_update)
16191627

16201628
self._eqs[name] = event_eqs
16211629

@@ -1718,8 +1726,8 @@ def _compute_equation(self, name: str) -> None:
17181726
self.sym("stau").T,
17191727
)
17201728

1721-
# only add deltax part if there is state update
1722-
if event._state_update is not None:
1729+
# only add deltax part if there is a state update
1730+
if event._assignments is not None:
17231731
# partial derivative for the parameters
17241732
tmp_eq += self.eq("ddeltaxdp")[ie]
17251733

@@ -2259,7 +2267,7 @@ def _get_unique_root(
22592267
identifier=sp.Symbol(root_symstr),
22602268
name=root_symstr,
22612269
value=root_found,
2262-
state_update=None,
2270+
assignments=None,
22632271
)
22642272
)
22652273
return roots[-1].get_id()

python/sdist/amici/de_model_components.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -708,7 +708,7 @@ def __init__(
708708
identifier: sp.Symbol,
709709
name: str,
710710
value: sp.Expr,
711-
state_update: sp.Expr | None,
711+
assignments: dict[sp.Symbol, sp.Expr] | None = None,
712712
initial_value: bool | None = True,
713713
priority: sp.Basic | None = None,
714714
):
@@ -724,17 +724,16 @@ def __init__(
724724
:param value:
725725
formula for the root / trigger function
726726
727-
:param state_update:
728-
formula for the bolus function (None for Heaviside functions,
729-
zero vector for events without bolus)
727+
:param assignments:
728+
Dictionary of event assignments: state symbol -> new value.
730729
731730
:param initial_value:
732731
initial boolean value of the trigger function at t0. If set to
733732
`False`, events may trigger at ``t==t0``, otherwise not.
734733
"""
735734
super().__init__(identifier, name, value)
736735
# add the Event specific components
737-
self._state_update = state_update
736+
self._assignments = assignments if assignments is not None else {}
738737
self._initial_value = initial_value
739738

740739
if priority is not None and not priority.is_Number:
@@ -751,6 +750,35 @@ def __init__(
751750
# the trigger can't be solved for `t`
752751
self._t_root = []
753752

753+
def get_state_update(
754+
self, x: sp.Matrix, x_old: sp.Matrix
755+
) -> sp.Matrix | None:
756+
"""
757+
Get the state update (bolus) expression for the event assignment.
758+
759+
:param x: The current state vector.
760+
:param x_old: The previous state vector.
761+
:return: State-update matrix or ``None`` if no state update is defined.
762+
"""
763+
if len(self._assignments) == 0:
764+
return None
765+
766+
x_to_x_old = dict(zip(x, x_old))
767+
768+
def get_bolus(x_i: sp.Symbol) -> sp.Expr:
769+
"""
770+
Get the bolus expression for a state variable.
771+
772+
:param x_i: state variable symbol
773+
:return: bolus expression
774+
"""
775+
if (assignment := self._assignments.get(x_i)) is not None:
776+
return assignment.subs(x_to_x_old) - x_i
777+
else:
778+
return sp.Float(0.0)
779+
780+
return sp.Matrix([get_bolus(x_i) for x_i in x])
781+
754782
def get_initial_value(self) -> bool:
755783
"""
756784
Return the initial value for the root function.

python/sdist/amici/sbml_import.py

Lines changed: 19 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -761,7 +761,7 @@ def _build_ode_model(
761761
args += ["value"]
762762

763763
if symbol_name == SymbolId.EVENT:
764-
args += ["state_update", "initial_value", "priority"]
764+
args += ["assignments", "initial_value", "priority"]
765765
elif symbol_name == SymbolId.OBSERVABLE:
766766
args += ["transformation"]
767767
elif symbol_name == SymbolId.EVENT_OBSERVABLE:
@@ -1769,13 +1769,6 @@ def _process_events(self) -> None:
17691769
"""Process SBML events."""
17701770
events = self.sbml.getListOfEvents()
17711771

1772-
def get_empty_bolus_value() -> sp.Float:
1773-
"""
1774-
Used in the event update vector for species that are not affected
1775-
by the event.
1776-
"""
1777-
return sp.Symbol("AMICI_EMTPY_BOLUS")
1778-
17791772
# Used to update species concentrations when an event affects a
17801773
# compartment.
17811774
concentration_species_by_compartment = {
@@ -1811,7 +1804,7 @@ def get_empty_bolus_value() -> sp.Float:
18111804
trigger = _parse_event_trigger(trigger_sym)
18121805

18131806
# parse the boluses / event assignments
1814-
bolus = [get_empty_bolus_value() for _ in state_vector]
1807+
assignment_exprs = {}
18151808
event_assignments = event.getListOfEventAssignments()
18161809
compartment_event_assignments: set[tuple[sp.Symbol, sp.Expr]] = (
18171810
set()
@@ -1826,8 +1819,8 @@ def get_empty_bolus_value() -> sp.Float:
18261819
formula = self._sympify(event_assignment)
18271820
try:
18281821
# Try to find the species in the state vector.
1829-
index = state_vector.index(variable_sym)
1830-
bolus[index] = formula
1822+
_ = state_vector.index(variable_sym)
1823+
assignment_exprs[variable_sym] = formula
18311824
except ValueError:
18321825
raise SBMLException(
18331826
"Could not process event assignment for "
@@ -1864,30 +1857,17 @@ def get_empty_bolus_value() -> sp.Float:
18641857
]:
18651858
# If the species was not affected by an event assignment,
18661859
# then the old value should be updated.
1867-
if (
1868-
bolus[state_vector.index(species_sym)]
1869-
== get_empty_bolus_value()
1870-
):
1860+
if species_sym not in assignment_exprs:
18711861
species_value = species_sym
18721862
# else the species was affected by an event assignment,
18731863
# hence the updated value should be updated further.
18741864
else:
1875-
species_value = bolus[state_vector.index(species_sym)]
1865+
species_value = assignment_exprs[species_sym]
18761866
# New species value is old amount / new volume.
1877-
bolus[state_vector.index(species_sym)] = (
1867+
assignment_exprs[species_sym] = (
18781868
species_value * compartment_sym / formula
18791869
)
18801870

1881-
# Subtract the current species value from each species with an
1882-
# update, as the bolus will be added on to the current species
1883-
# value during simulation.
1884-
for index in range(len(bolus)):
1885-
if bolus[index] != get_empty_bolus_value():
1886-
bolus[index] -= state_vector[index]
1887-
bolus[index] = bolus[index].subs(
1888-
get_empty_bolus_value(), sp.Float(0.0)
1889-
)
1890-
18911871
initial_value = (
18921872
trigger_sbml.getInitialValue()
18931873
if trigger_sbml is not None
@@ -1917,7 +1897,7 @@ def get_empty_bolus_value() -> sp.Float:
19171897
self.symbols[SymbolId.EVENT][event_sym] = {
19181898
"name": event_id,
19191899
"value": trigger,
1920-
"state_update": sp.MutableDenseMatrix(bolus),
1900+
"assignments": assignment_exprs,
19211901
"initial_value": initial_value,
19221902
"use_values_from_trigger_time": use_trig_val,
19231903
"priority": self._sympify(event.getPriority()),
@@ -1966,10 +1946,10 @@ def try_solve_t(expr: sp.Expr) -> list:
19661946
# if all assignments are absolute (not referring to other non-constant
19671947
# model entities), we are fine.
19681948
if all(
1969-
update.is_zero or (update + variable).is_Number
1949+
assignment.is_Number
19701950
for event in self.symbols[SymbolId.EVENT].values()
1971-
for variable, update in zip(state_vector, event["state_update"])
1972-
if not update.is_zero
1951+
for assignment in event["assignments"].values()
1952+
if event["assignments"] is not None
19731953
):
19741954
return
19751955

@@ -2796,13 +2776,16 @@ def _replace_in_all_expressions(
27962776
for element in self.symbols[symbol].values():
27972777
element["value"] = smart_subs(element["value"], old, new)
27982778

2799-
# replace in event state updates (boluses)
2779+
# replace in event assignments
28002780
if self.symbols.get(SymbolId.EVENT, False):
28012781
for event in self.symbols[SymbolId.EVENT].values():
2802-
for index in range(len(event["state_update"])):
2803-
event["state_update"][index] = smart_subs(
2804-
event["state_update"][index], old, new
2805-
)
2782+
if event["assignments"] is not None:
2783+
event["assignments"] = {
2784+
smart_subs(target, old, new): smart_subs(
2785+
expr, old, new
2786+
)
2787+
for target, expr in event["assignments"].items()
2788+
}
28062789

28072790
for state in {
28082791
**self.symbols[SymbolId.SPECIES],

0 commit comments

Comments
 (0)