Skip to content

Commit 7f5c3ad

Browse files
committed
updates for overhauled testsuite
1 parent 605adcb commit 7f5c3ad

File tree

7 files changed

+48
-78
lines changed

7 files changed

+48
-78
lines changed

.github/workflows/test_petab_sciml.yml

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ jobs:
2222

2323
strategy:
2424
matrix:
25-
python-version: ["3.11"]
25+
python-version: ["3.12"]
2626

2727
steps:
2828
- name: Set up Python ${{ matrix.python-version }}
@@ -56,13 +56,10 @@ jobs:
5656
&& pip3 install wheel pytest shyaml pytest-cov
5757
5858
# retrieve test models
59-
- name: Download and install PEtab SciML test suite
59+
- name: Download and install PEtab SciML
6060
run: |
61-
git clone --depth 1 --branch main \
62-
https://github.com/sebapersson/petab_sciml.git \
63-
&& export SCIML_TESTSUITE="$(pwd)/petab_sciml" \
64-
&& source venv/bin/activate \
65-
&& python -m pip install -e $SCIML_TESTSUITE/src/python
61+
source ./venv/bin/activate \
62+
&& python -m pip install git+https://github.com/sebapersson/petab_sciml.git@unify_data#subdirectory=src/python \
6663
6764
6865
- name: Install petab

python/sdist/amici/jax/nn.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,22 +30,22 @@ def tanhshrink(x: jnp.ndarray) -> jnp.ndarray:
3030
return x - jnp.tanh(x)
3131

3232

33-
def generate_equinox(ml_model: "MLModel", filename: Path | str): # noqa: F821
33+
def generate_equinox(nn_model: "NNModel", filename: Path | str): # noqa: F821
3434
# TODO: move to top level import and replace forward type definitions
3535
from petab_sciml import Layer
3636

3737
filename = Path(filename)
3838
layer_indent = 12
3939
node_indent = 8
4040

41-
layers = {layer.layer_id: layer for layer in ml_model.layers}
41+
layers = {layer.layer_id: layer for layer in nn_model.layers}
4242

4343
tpl_data = {
44-
"MODEL_ID": ml_model.mlmodel_id,
44+
"MODEL_ID": nn_model.nn_model_id,
4545
"LAYERS": ",\n".join(
4646
[
4747
_generate_layer(layer, layer_indent, ilayer)
48-
for ilayer, layer in enumerate(ml_model.layers)
48+
for ilayer, layer in enumerate(nn_model.layers)
4949
]
5050
)[layer_indent:],
5151
"FORWARD": "\n".join(
@@ -58,19 +58,19 @@ def generate_equinox(ml_model: "MLModel", filename: Path | str): # noqa: F821
5858
Layer(layer_id="dummy", layer_type="Linear"),
5959
).layer_type,
6060
)
61-
for node in ml_model.forward
61+
for node in nn_model.forward
6262
]
6363
)[node_indent:],
64-
"INPUT": ", ".join([f"'{inp.input_id}'" for inp in ml_model.inputs]),
64+
"INPUT": ", ".join([f"'{inp.input_id}'" for inp in nn_model.inputs]),
6565
"OUTPUT": ", ".join(
6666
[
6767
f"'{arg}'"
6868
for arg in next(
69-
node for node in ml_model.forward if node.op == "output"
69+
node for node in nn_model.forward if node.op == "output"
7070
).args
7171
]
7272
),
73-
"N_LAYERS": len(ml_model.layers),
73+
"N_LAYERS": len(nn_model.layers),
7474
}
7575

7676
filename.parent.mkdir(parents=True, exist_ok=True)

python/sdist/amici/jax/ode_export.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -267,11 +267,10 @@ def _generate_jax_code(self) -> None:
267267

268268
def _generate_nn_code(self) -> None:
269269
for net_name, net in self.hybridisation.items():
270-
for model in net["model"]:
271-
generate_equinox(
272-
model,
273-
self.model_path / f"{net_name}.py",
274-
)
270+
generate_equinox(
271+
net["model"],
272+
self.model_path / f"{net_name}.py",
273+
)
275274

276275
def set_paths(self, output_dir: str | Path | None = None) -> None:
277276
"""

python/sdist/amici/jax/petab.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from collections.abc import Sized, Iterable
77
from pathlib import Path
88
from collections.abc import Callable
9+
import logging
910

1011

1112
import diffrax
@@ -23,6 +24,7 @@
2324
ParameterMappingForCondition,
2425
create_parameter_mapping,
2526
)
27+
from amici.logging import get_logger
2628
from amici.jax.model import JAXModel, ReturnValue
2729

2830
DEFAULT_CONTROLLER_SETTINGS = {
@@ -39,6 +41,8 @@
3941
petab.LOG10: 2,
4042
}
4143

44+
logger = get_logger(__name__, logging.WARNING)
45+
4246

4347
def jax_unscale(
4448
parameter: jnp.float_,
@@ -512,28 +516,35 @@ def _get_nominal_parameter_values(
512516
}
513517
for net_id, nn in model.nns.items()
514518
}
519+
# load nn parameters from file
520+
par_arrays = {
521+
array_id: h5py.File(file_spec["location"], "r")
522+
for array_id, file_spec in self._petab_problem.extensions_config[
523+
"array_files"
524+
].items()
525+
# TODO: FIXME (https://github.com/sebapersson/petab_sciml_testsuite/issues/1)
526+
}
527+
515528
# extract nominal values from petab problem
516529
for pname, row in self._petab_problem.parameter_df.iterrows():
517530
if (net := pname.split(".")[0]) in model.nns:
518531
to_set = []
519532
nn = model_pars[net]
520-
scalar = True
521533
try:
522534
value = float(row[petab.NOMINAL_VALUE])
523535
except ValueError:
524-
value = h5py.File(row[petab.NOMINAL_VALUE], "r")
536+
value = par_arrays[row[petab.NOMINAL_VALUE]]
525537
scalar = False
526-
527538
if len(pname.split(".")) > 1:
528-
layer = nn[pname.split(".")[1]]
539+
layer_name = pname.split(".")[1]
540+
layer = nn[layer_name]
529541
if len(pname.split(".")) > 2:
530-
to_set.append(
531-
(pname.split(".")[1], pname.split(".")[2])
532-
)
542+
attribute_name = pname.split(".")[2]
543+
to_set.append((layer_name, attribute_name))
533544
else:
534545
to_set.extend(
535546
[
536-
(pname.split(".")[1], attribute)
547+
(layer_name, attribute)
537548
for attribute in layer.keys()
538549
]
539550
)
@@ -549,15 +560,20 @@ def _get_nominal_parameter_values(
549560
for layer, attribute in to_set:
550561
if scalar:
551562
nn[layer][attribute] = value * jnp.ones_like(
552-
nn[layer][attribute]
563+
model.nns[net].layers[layer][attribute]
553564
)
554565
else:
555-
nn[layer][attribute] = value[layer][attribute]
566+
nn[layer][attribute] = jnp.array(
567+
value[layer][attribute]
568+
)
556569

557570
# set values in model
558571
for net_id in model_pars:
559572
for layer_id in model_pars[net_id]:
560573
for attribute in model_pars[net_id][layer_id]:
574+
logger.debug(
575+
f"Setting {attribute} of layer {layer_id} in network {net_id} to {model_pars[net_id][layer_id][attribute]}"
576+
)
561577
model = eqx.tree_at(
562578
lambda model: getattr(
563579
model.nns[net_id].layers[layer_id], attribute

python/sdist/amici/petab/petab_import.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def import_petab_problem(
148148
logger.info(f"Compiling model {model_name} to {model_output_dir}.")
149149

150150
if "neural_nets" in petab_problem.extensions_config: # TODO: fixme
151-
from petab_sciml import PetabScimlStandard
151+
from petab_sciml.standard import NNModelStandard
152152

153153
config = petab_problem.extensions_config
154154
# TODO: only accept YAML format for now
@@ -169,9 +169,9 @@ def import_petab_problem(
169169
)
170170
hybridization = {
171171
net_id: {
172-
"model": PetabScimlStandard.load_data(
172+
"model": NNModelStandard.load_data(
173173
Path() / net_config["location"]
174-
).models,
174+
),
175175
"input_vars": [
176176
input_mapping[petab_id]
177177
for petab_id, model_id in petab_problem.mapping_df.loc[

python/sdist/pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ test = [
7272
"scipy",
7373
"pooch",
7474
"beartype",
75-
""
7675
]
7776
vis = [
7877
"matplotlib",

tests/sciml/test_sciml.py

Lines changed: 3 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import h5py
2323
from contextlib import contextmanager
2424

25-
from petab_sciml import PetabScimlStandard
25+
from petab_sciml import NNModelStandard
2626

2727

2828
@contextmanager
@@ -78,7 +78,7 @@ def test_net(test):
7878
net_file = cases_dir / test.replace("_alt", "") / solutions["net_file"]
7979
else:
8080
net_file = test_dir / solutions["net_file"]
81-
ml_models = PetabScimlStandard.load_data(net_file)
81+
ml_models = NNModelStandard.load_data(net_file)
8282

8383
nets = {}
8484
outdir = Path(__file__).parent / "models" / test
@@ -197,49 +197,8 @@ def test_ude(test):
197197
compile_=True,
198198
jax=True,
199199
)
200-
# non_numeric = pd.to_numeric(petab_problem.parameter_df[petab.NOMINAL_VALUE], errors='coerce').isna()
201-
# par_files = petab_problem.parameter_df.loc[non_numeric, petab.NOMINAL_VALUE].unique()
202-
# par_values = {
203-
# par_file: h5py.File(par_file, "r")
204-
# for par_file in par_files
205-
# }
206-
# for par_id, row in petab_problem.parameter_df.iterrows():
207-
# if not non_numeric[par_id]:
208-
# continue
209-
# petab_problem.parameter_df.loc[par_id, petab.NOMINAL_VALUE] = \
210-
# (par_values[row[petab.NOMINAL_VALUE]],)
211-
# petab_problem.parameter_df.loc[np.logical_not(non_numeric), petab.NOMINAL_VALUE] = pd.to_numeric(
212-
# petab_problem.parameter_df.loc[np.logical_not(non_numeric), petab.NOMINAL_VALUE]
213-
# )
214200

215201
jax_problem = JAXProblem(jax_model, petab_problem)
216-
# for net, net_config in petab_problem.extensions_config.items(): # TODO: FIXME (https://github.com/sebapersson/petab_sciml_testsuite/issues/1)
217-
# pars = h5py.File(
218-
# net_config["net1_ps_file"]['location'], "r" # TODO: check format and actually use propoer petab nominal parameter infrastructure
219-
# )
220-
# for layer_name, layer in jax_problem.model.nns[net].layers.items():
221-
# for attribute in dir(layer):
222-
# if not isinstance(
223-
# getattr(layer, attribute), jax.numpy.ndarray
224-
# ):
225-
# continue
226-
# value = jnp.array(pars[layer_name][attribute])
227-
#
228-
# if (
229-
# isinstance(layer, eqx.nn.ConvTranspose)
230-
# and attribute == "weight"
231-
# ):
232-
# # see FAQ in https://docs.kidger.site/equinox/api/nn/conv/#equinox.nn.ConvTranspose
233-
# value = jnp.flip(
234-
# value, axis=tuple(range(2, value.ndim))
235-
# ).swapaxes(0, 1)
236-
# jax_problem = eqx.tree_at(
237-
# lambda x: getattr(
238-
# x.model.nns[net].layers[layer_name], attribute
239-
# ),
240-
# jax_problem,
241-
# value,
242-
# )
243202

244203
# llh
245204
if test in (
@@ -281,7 +240,7 @@ def test_ude(test):
281240
controller=diffrax.PIDController(atol=1e-14, rtol=1e-14),
282241
max_steps=2**16,
283242
)
284-
for component, file in solutions["grad_llh_files"].items():
243+
for component, file in solutions["grad_files"].items():
285244
actual_dict = {}
286245
if component == "mech":
287246
expected = pd.read_csv(test_dir / file, sep="\t").set_index(

0 commit comments

Comments
 (0)