Skip to content

Commit 41a00f1

Browse files
committed
Extend pytree test
1 parent 5cec2fd commit 41a00f1

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

tests/test_pytree.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from contextlib import redirect_stdout
44

55
import jax
6+
import jax.numpy as jnp
67

78
import jaxsim.api as js
89

@@ -45,3 +46,21 @@ def test_call_jit_compiled_function_passing_different_objects(
4546
f"Compiling {js.contact.estimate_good_soft_contacts_parameters.__name__}"
4647
not in stdout
4748
)
49+
50+
# Define a new JIT-compiled function and check that is not recompiled for
51+
# different model objects having the same pytree structure.
52+
@jax.jit
53+
def my_jit_function(model: js.model.JaxSimModel, data: js.data.JaxSimModelData):
54+
# Return random elements from model and data, just to have something returned.
55+
return (
56+
jnp.sum(model.kin_dyn_parameters.link_parameters.mass),
57+
data.base_position(),
58+
)
59+
60+
data1 = js.data.JaxSimModelData.build(model=model1)
61+
62+
_ = my_jit_function(model=model1, data=data1)
63+
assert my_jit_function._cache_size() == 1
64+
65+
_ = my_jit_function(model=model2, data=data1)
66+
assert my_jit_function._cache_size() == 1

0 commit comments

Comments
 (0)