Skip to content

Commit f64764f

Browse files
Apply suggestions from code review
Co-authored-by: Diego Ferigo <[email protected]>
1 parent 908ca8b commit f64764f

File tree

4 files changed

+37
-11
lines changed

4 files changed

+37
-11
lines changed

src/jaxsim/api/contact.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,10 +365,12 @@ def jacobian(
365365

366366
# Adjust the output representation.
367367
match output_vel_repr:
368+
368369
case VelRepr.Inertial:
369370
O_J_WC = W_J_WC
370371

371372
case VelRepr.Body:
373+
372374
W_H_C = transforms(model=model, data=data)
373375

374376
def body_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:
@@ -381,9 +383,11 @@ def body_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:
381383
O_J_WC = jax.vmap(body_jacobian)(W_H_C, W_J_WC)
382384

383385
case VelRepr.Mixed:
386+
384387
W_H_C = transforms(model=model, data=data)
385388

386389
def mixed_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:
390+
387391
W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3))
388392

389393
CW_X_W = jaxsim.math.Adjoint.from_transform(

src/jaxsim/api/kin_dyn_parameters.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def κb(link_index: jtp.IntLike) -> jtp.Vector:
184184
carry0 = κb, link_index
185185

186186
def scan_body(carry: tuple, i: jtp.Int) -> tuple[tuple, None]:
187+
187188
κb, active_link_index = carry
188189

189190
κb, active_link_index = jax.lax.cond(
@@ -225,12 +226,14 @@ def scan_body(carry: tuple, i: jtp.Int) -> tuple[tuple, None]:
225226
)
226227

227228
def __eq__(self, other: KynDynParameters) -> bool:
229+
228230
if not isinstance(other, KynDynParameters):
229231
return False
230232

231233
return hash(self) == hash(other)
232234

233235
def __hash__(self) -> int:
236+
234237
return hash(
235238
(
236239
hash(self.number_of_links()),
@@ -640,6 +643,7 @@ def build_from_inertial_parameters(
640643
def build_from_flat_parameters(
641644
index: jtp.IntLike, parameters: jtp.VectorLike
642645
) -> LinkParameters:
646+
643647
index = jnp.array(index).squeeze().astype(int)
644648

645649
m = jnp.array(parameters[0]).squeeze().astype(float)
@@ -664,11 +668,7 @@ def flat_parameters(params: LinkParameters) -> jtp.Vector:
664668

665669
return (
666670
jnp.hstack(
667-
[
668-
params.mass,
669-
params.center_of_mass.squeeze(),
670-
params.inertia_elements,
671-
]
671+
[params.mass, params.center_of_mass.squeeze(), params.inertia_elements]
672672
)
673673
.squeeze()
674674
.astype(float)

src/jaxsim/rbda/contacts/common.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class ContactsState(abc.ABC):
1717
def build(cls, **kwargs) -> ContactsState:
1818
"""
1919
Build the contact state object.
20+
2021
Returns:
2122
The contact state object.
2223
"""
@@ -28,6 +29,7 @@ def build(cls, **kwargs) -> ContactsState:
2829
def zero(cls, **kwargs) -> ContactsState:
2930
"""
3031
Build a zero contact state.
32+
3133
Returns:
3234
The zero contact state.
3335
"""
@@ -57,8 +59,7 @@ def build(cls) -> ContactsParams:
5759
The `ContactsParams` instance.
5860
"""
5961

60-
raise NotImplementedError
61-
62+
@abc.abstractmethod
6263
def valid(self, *args, **kwargs) -> bool:
6364
"""
6465
Check if the parameters are valid.
@@ -72,6 +73,7 @@ def valid(self, *args, **kwargs) -> bool:
7273
class ContactModel(abc.ABC):
7374
"""
7475
Abstract class representing a contact model.
76+
7577
Attributes:
7678
parameters: The parameters of the contact model.
7779
terrain: The terrain model.
@@ -86,12 +88,16 @@ def compute_contact_forces(
8688
position: jtp.Vector,
8789
velocity: jtp.Vector,
8890
**kwargs,
89-
) -> tuple[Any, ...]:
91+
) -> tuple[jtp.Vector, tuple[Any, ...]]:
9092
"""
9193
Compute the contact forces.
94+
9295
Args:
93-
position: The position of the collidable point.
94-
velocity: The velocity of the collidable point.
96+
position: The position of the collidable point w.r.t. the world frame.
97+
velocity:
98+
The linear velocity of the collidable point (linear component of the mixed 6D velocity).
99+
95100
Returns:
96-
A tuple containing the contact force and additional information.
101+
A tuple containing as first element the computed 6D contact force applied to the contact point and expressed in the world frame,
102+
and as second element a tuple of optional additional information.
97103
"""

src/jaxsim/rbda/contacts/soft.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class SoftContactsParams(ContactsParams):
3131
)
3232

3333
def __hash__(self) -> int:
34+
3435
from jaxsim.utils.wrappers import HashedNumpyArray
3536

3637
return hash(
@@ -42,6 +43,7 @@ def __hash__(self) -> int:
4243
)
4344

4445
def __eq__(self, other: SoftContactsParams) -> bool:
46+
4547
if not isinstance(other, SoftContactsParams):
4648
return NotImplemented
4749

@@ -126,6 +128,20 @@ def build_default_from_jaxsim_model(
126128

127129
return SoftContactsParams.build(K=K, D=D, mu=μc)
128130

131+
def valid(self) -> bool:
132+
"""
133+
Check if the parameters are valid.
134+
135+
Returns:
136+
`True` if the parameters are valid, `False` otherwise.
137+
"""
138+
139+
return (
140+
jnp.all(self.K >= 0.0)
141+
and jnp.all(self.D >= 0.0)
142+
and jnp.all(self.mu >= 0.0)
143+
)
144+
129145

130146
@jax_dataclasses.pytree_dataclass
131147
class SoftContacts(ContactModel):

0 commit comments

Comments
 (0)