Skip to content

Commit a9e3b2f

Browse files
committed
Make compute_contact_forces return tuple[jtp.Vector, tuple[Any, ...]]
1 parent 2cac06c commit a9e3b2f

File tree

3 files changed

+8
-7
lines changed

3 files changed

+8
-7
lines changed

src/jaxsim/api/contact.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def collidable_point_dynamics(
149149
# collidable point, and the corresponding material deformation rate.
150150
# Note that the material deformation rate is always returned in the mixed frame
151151
# C[W] = (W_p_C, [W]). This is convenient for integration purpose.
152-
W_f_Ci, CW_ṁ = jax.vmap(soft_contacts.compute_contact_forces)(
152+
W_f_Ci, (CW_ṁ,) = jax.vmap(soft_contacts.compute_contact_forces)(
153153
W_p_Ci, W_ṗ_Ci, data.state.contact.tangential_deformation
154154
)
155155

src/jaxsim/rbda/contacts/soft.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def compute_contact_forces(
158158
position: jtp.Vector,
159159
velocity: jtp.Vector,
160160
tangential_deformation: jtp.Vector,
161-
) -> tuple[jtp.Vector, jtp.Vector]:
161+
) -> tuple[jtp.Vector, tuple[jtp.Vector, None]]:
162162
"""
163163
Compute the contact forces and material deformation rate.
164164
@@ -237,7 +237,7 @@ def with_no_friction():
237237
# Compute lin-ang 6D forces (inertial representation)
238238
W_f = W_Xf_CW @ CW_f
239239

240-
return W_f,
240+
return W_f, (,)
241241

242242
# =========================
243243
# Compute tangential forces
@@ -255,7 +255,7 @@ def with_friction():
255255
active_contact = pz < self.terrain.height(x=px, y=py)
256256

257257
def above_terrain():
258-
return jnp.zeros(6),
258+
return jnp.zeros(6), (,)
259259

260260
def below_terrain():
261261
# Decompose the velocity in normal and tangential components
@@ -311,9 +311,9 @@ def slipping_contact():
311311
W_f = W_Xf_CW @ CW_f
312312

313313
# Return the 6D force in the world frame and the deformation derivative
314-
return W_f,
314+
return W_f, (,)
315315

316-
# (W_f, )
316+
# (W_f, (ṁ,))
317317
return jax.lax.cond(
318318
pred=active_contact,
319319
true_fun=lambda _: below_terrain(),

tests/test_automatic_differentiation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,9 +308,10 @@ def close_over_inputs_and_parameters(
308308
m: jtp.VectorLike,
309309
params: SoftContactsParams,
310310
) -> tuple[jtp.Vector, jtp.Vector]:
311-
return SoftContacts(parameters=params).compute_contact_forces(
311+
W_f_Ci, (CW_ṁ,) = SoftContacts(parameters=params).compute_contact_forces(
312312
position=p, velocity=v, tangential_deformation=m
313313
)
314+
return W_f_Ci, CW_ṁ
314315

315316
# Check derivatives against finite differences.
316317
check_grads(

0 commit comments

Comments
 (0)