File tree Expand file tree Collapse file tree 1 file changed +19
-0
lines changed Expand file tree Collapse file tree 1 file changed +19
-0
lines changed Original file line number Diff line number Diff line change 3
3
from contextlib import redirect_stdout
4
4
5
5
import jax
6
+ import jax .numpy as jnp
6
7
7
8
import jaxsim .api as js
8
9
@@ -45,3 +46,21 @@ def test_call_jit_compiled_function_passing_different_objects(
45
46
f"Compiling { js .contact .estimate_good_soft_contacts_parameters .__name__ } "
46
47
not in stdout
47
48
)
49
+
50
+ # Define a new JIT-compiled function and check that is not recompiled for
51
+ # different model objects having the same pytree structure.
52
+ @jax .jit
53
+ def my_jit_function (model : js .model .JaxSimModel , data : js .data .JaxSimModelData ):
54
+ # Return random elements from model and data, just to have something returned.
55
+ return (
56
+ jnp .sum (model .kin_dyn_parameters .link_parameters .mass ),
57
+ data .base_position (),
58
+ )
59
+
60
+ data1 = js .data .JaxSimModelData .build (model = model1 )
61
+
62
+ _ = my_jit_function (model = model1 , data = data1 )
63
+ assert my_jit_function ._cache_size () == 1
64
+
65
+ _ = my_jit_function (model = model2 , data = data1 )
66
+ assert my_jit_function ._cache_size () == 1
You can’t perform that action at this time.
0 commit comments