1
1
from __future__ import annotations
2
2
3
+ import importlib
4
+
3
5
import jax .numpy as jnp
4
6
import jax_dataclasses
5
7
6
8
import jaxsim .api as js
7
9
import jaxsim .typing as jtp
10
+ from jaxsim import logging
8
11
from jaxsim .api .soft_contacts import SoftContactsState
9
12
from jaxsim .utils import JaxsimDataclass
10
13
@@ -117,11 +120,11 @@ class ODEState(JaxsimDataclass):
117
120
118
121
Attributes:
119
122
physics_model: The state of the physics model.
120
- soft_contacts : The state of the soft- contacts model.
123
+ contacts_state : The state of the contacts model.
121
124
"""
122
125
123
126
physics_model : PhysicsModelState
124
- soft_contacts : SoftContactsState
127
+ contacts_state : js . contact . ContactsState
125
128
126
129
@staticmethod
127
130
def build_from_jaxsim_model (
@@ -159,6 +162,15 @@ def build_from_jaxsim_model(
159
162
`JaxSimModel` and initialized to zero.
160
163
"""
161
164
165
+ # Get the contact model from the `JaxSimModel`
166
+ prefix = type (model .contact_model ).__name__ .split ("Contact" )[0 ]
167
+
168
+ if prefix :
169
+ module_name = f"{ prefix .lower ()} _contacts"
170
+ class_name = f"{ prefix .capitalize ()} ContactsState"
171
+ else :
172
+ raise ValueError ("Unable to determine contact state class prefix." )
173
+
162
174
return ODEState .build (
163
175
model = model ,
164
176
physics_model_state = PhysicsModelState .build_from_jaxsim_model (
@@ -170,24 +182,30 @@ def build_from_jaxsim_model(
170
182
base_linear_velocity = base_linear_velocity ,
171
183
base_angular_velocity = base_angular_velocity ,
172
184
),
173
- soft_contacts_state = SoftContactsState .build_from_jaxsim_model (
185
+ contact_state = getattr (
186
+ importlib .import_module (f"jaxsim.api.{ module_name } " ), class_name
187
+ ).build_from_jaxsim_model (
174
188
model = model ,
175
- tangential_deformation = tangential_deformation ,
189
+ ** (
190
+ dict (tangential_deformation = tangential_deformation )
191
+ if tangential_deformation is not None
192
+ else dict ()
193
+ ),
176
194
),
177
195
)
178
196
179
197
@staticmethod
180
198
def build (
181
199
physics_model_state : PhysicsModelState | None = None ,
182
- soft_contacts_state : SoftContactsState | None = None ,
200
+ contact_state : js . contact . ContactsState | None = None ,
183
201
model : js .model .JaxSimModel | None = None ,
184
202
) -> ODEState :
185
203
"""
186
- Build an `ODEState` from a `PhysicsModelState` and a `SoftContactsState `.
204
+ Build an `ODEState` from a `PhysicsModelState` and a `ContactsState `.
187
205
188
206
Args:
189
207
physics_model_state: The state of the physics model.
190
- soft_contacts_state : The state of the soft- contacts model.
208
+ contact_state : The state of the contacts model.
191
209
model: The `JaxSimModel` associated with the ODE state.
192
210
193
211
Returns:
@@ -200,15 +218,32 @@ def build(
200
218
else PhysicsModelState .zero (model = model )
201
219
)
202
220
203
- soft_contacts_state = (
204
- soft_contacts_state
205
- if soft_contacts_state is not None
221
+ # Get the contact model from the `JaxSimModel`
222
+ try :
223
+ prefix = type (model .contact_model ).__name__ .split ("Contact" )[0 ]
224
+ except AttributeError :
225
+ logging .warning (
226
+ "Unable to determine contact state class prefix. Using default soft contacts."
227
+ )
228
+ prefix = "Soft"
229
+
230
+ module_name = f"{ prefix .lower ()} _contacts"
231
+ class_name = f"{ prefix .capitalize ()} ContactsState"
232
+
233
+ try :
234
+ state_cls = getattr (
235
+ importlib .import_module (f"jaxsim.api.{ module_name } " ), class_name
236
+ )
237
+ except ImportError as e :
238
+ raise e
239
+
240
+ contact_state = (
241
+ contact_state
242
+ if contact_state is not None
206
243
else SoftContactsState .zero (model = model )
207
244
)
208
245
209
- return ODEState (
210
- physics_model = physics_model_state , soft_contacts = soft_contacts_state
211
- )
246
+ return ODEState (physics_model = physics_model_state , contacts_state = contact_state )
212
247
213
248
@staticmethod
214
249
def zero (model : js .model .JaxSimModel ) -> ODEState :
@@ -237,7 +272,7 @@ def valid(self, model: js.model.JaxSimModel) -> bool:
237
272
`True` if the ODE state is valid for the given model, `False` otherwise.
238
273
"""
239
274
240
- return self .physics_model .valid (model = model ) and self .soft_contacts .valid (
275
+ return self .physics_model .valid (model = model ) and self .contacts_state .valid (
241
276
model = model
242
277
)
243
278
0 commit comments