Skip to content

Commit 8883e59

Browse files
committed
Abstract contact state in api.ode_data.ODEState
1 parent fee82b9 commit 8883e59

File tree

4 files changed

+54
-17
lines changed

4 files changed

+54
-17
lines changed

src/jaxsim/api/contact.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def collidable_point_dynamics(
154154
# Note that the material deformation rate is always returned in the mixed frame
155155
# C[W] = (W_p_C, [W]). This is convenient for integration purpose.
156156
W_f_Ci, CW_ṁ = jax.vmap(soft_contacts.contact_model)(
157-
W_p_Ci, W_ṗ_Ci, data.state.contact_state.tangential_deformation
157+
W_p_Ci, W_ṗ_Ci, data.state.contacts_state.tangential_deformation
158158
)
159159

160160
case _:

src/jaxsim/api/ode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def system_velocity_dynamics(
132132
W_f_Ci = None
133133

134134
# Initialize the derivative of the tangential deformation ṁ ∈ ℝ^{n_c × 3}.
135-
= jnp.zeros_like(data.state.soft_contacts.tangential_deformation).astype(float)
135+
= jnp.zeros_like(data.state.contacts_state.tangential_deformation).astype(float)
136136

137137
if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
138138
# Compute the 6D forces applied to each collidable point and the

src/jaxsim/api/ode_data.py

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from __future__ import annotations
22

3+
import importlib
4+
35
import jax.numpy as jnp
46
import jax_dataclasses
57

68
import jaxsim.api as js
79
import jaxsim.typing as jtp
10+
from jaxsim import logging
811
from jaxsim.api.soft_contacts import SoftContactsState
912
from jaxsim.utils import JaxsimDataclass
1013

@@ -117,11 +120,11 @@ class ODEState(JaxsimDataclass):
117120
118121
Attributes:
119122
physics_model: The state of the physics model.
120-
soft_contacts: The state of the soft-contacts model.
123+
contacts_state: The state of the contacts model.
121124
"""
122125

123126
physics_model: PhysicsModelState
124-
soft_contacts: SoftContactsState
127+
contacts_state: js.contact.ContactsState
125128

126129
@staticmethod
127130
def build_from_jaxsim_model(
@@ -159,6 +162,15 @@ def build_from_jaxsim_model(
159162
`JaxSimModel` and initialized to zero.
160163
"""
161164

165+
# Get the contact model from the `JaxSimModel`
166+
prefix = type(model.contact_model).__name__.split("Contact")[0]
167+
168+
if prefix:
169+
module_name = f"{prefix.lower()}_contacts"
170+
class_name = f"{prefix.capitalize()}ContactsState"
171+
else:
172+
raise ValueError("Unable to determine contact state class prefix.")
173+
162174
return ODEState.build(
163175
model=model,
164176
physics_model_state=PhysicsModelState.build_from_jaxsim_model(
@@ -170,24 +182,30 @@ def build_from_jaxsim_model(
170182
base_linear_velocity=base_linear_velocity,
171183
base_angular_velocity=base_angular_velocity,
172184
),
173-
soft_contacts_state=SoftContactsState.build_from_jaxsim_model(
185+
contacts_state=getattr(
186+
importlib.import_module(f"jaxsim.api.{module_name}"), class_name
187+
).build_from_jaxsim_model(
174188
model=model,
175-
tangential_deformation=tangential_deformation,
189+
**(
190+
dict(tangential_deformation=tangential_deformation)
191+
if tangential_deformation is not None
192+
else dict()
193+
),
176194
),
177195
)
178196

179197
@staticmethod
180198
def build(
181199
physics_model_state: PhysicsModelState | None = None,
182-
soft_contacts_state: SoftContactsState | None = None,
200+
contacts_state: js.contact.ContactsState | None = None,
183201
model: js.model.JaxSimModel | None = None,
184202
) -> ODEState:
185203
"""
186-
Build an `ODEState` from a `PhysicsModelState` and a `SoftContactsState`.
204+
Build an `ODEState` from a `PhysicsModelState` and a `ContactsState`.
187205
188206
Args:
189207
physics_model_state: The state of the physics model.
190-
soft_contacts_state: The state of the soft-contacts model.
208+
contacts_state: The state of the contacts model.
191209
model: The `JaxSimModel` associated with the ODE state.
192210
193211
Returns:
@@ -200,14 +218,33 @@ def build(
200218
else PhysicsModelState.zero(model=model)
201219
)
202220

203-
soft_contacts_state = (
204-
soft_contacts_state
205-
if soft_contacts_state is not None
221+
# Get the contact model from the `JaxSimModel`
222+
try:
223+
prefix = type(model.contact_model).__name__.split("Contact")[0]
224+
except AttributeError:
225+
logging.warning(
226+
"Unable to determine contact state class prefix. Using default soft contacts."
227+
)
228+
prefix = "Soft"
229+
230+
module_name = f"{prefix.lower()}_contacts"
231+
class_name = f"{prefix.capitalize()}ContactsState"
232+
233+
try:
234+
state_cls = getattr(
235+
importlib.import_module(f"jaxsim.api.{module_name}"), class_name
236+
)
237+
except ImportError as e:
238+
raise e
239+
240+
contacts_state = (
241+
contacts_state
242+
if contacts_state is not None
206243
else SoftContactsState.zero(model=model)
207244
)
208245

209246
return ODEState(
210-
physics_model=physics_model_state, soft_contacts=soft_contacts_state
247+
physics_model=physics_model_state, contacts_state=contacts_state
211248
)
212249

213250
@staticmethod
@@ -237,7 +274,7 @@ def valid(self, model: js.model.JaxSimModel) -> bool:
237274
`True` if the ODE state is valid for the given model, `False` otherwise.
238275
"""
239276

240-
return self.physics_model.valid(model=model) and self.soft_contacts.valid(
277+
return self.physics_model.valid(model=model) and self.contacts_state.valid(
241278
model=model
242279
)
243280

tests/test_automatic_differentiation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def test_ad_integration(
342342
s = data.joint_positions(model=model)
343343
W_v_WB = data.base_velocity()
344344
= data.joint_velocities(model=model)
345-
m = data.state.soft_contacts.tangential_deformation
345+
m = data.state.contacts_state.tangential_deformation
346346

347347
# Inputs.
348348
W_f_L = references.link_forces(model=model)
@@ -396,7 +396,7 @@ def step(
396396
base_angular_velocity=W_v_WB[3:6],
397397
joint_velocities=,
398398
),
399-
soft_contacts_state=js.ode_data.SoftContactsState.build(
399+
contacts_state=js.ode_data.SoftContactsState.build(
400400
tangential_deformation=m
401401
),
402402
),
@@ -417,7 +417,7 @@ def step(
417417
xf_s = data_xf.joint_positions(model=model)
418418
xf_W_v_WB = data_xf.base_velocity()
419419
xf_ṡ = data_xf.joint_velocities(model=model)
420-
xf_m = data_xf.state.soft_contacts.tangential_deformation
420+
xf_m = data_xf.state.contacts_state.tangential_deformation
421421

422422
return xf_W_p_B, xf_W_Q_B, xf_s, xf_W_v_WB, xf_ṡ, xf_m
423423

0 commit comments

Comments
 (0)