Skip to content

Commit ec67648

Browse files
committed
Update hash and comparison of JaxSimModelData
1 parent 50dc320 commit ec67648

File tree

3 files changed

+20
-12
lines changed

3 files changed

+20
-12
lines changed

src/jaxsim/api/data.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,14 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
4545

4646
def __hash__(self) -> int:
4747

48+
from jaxsim.utils.wrappers import HashedNumpyArray
49+
4850
return hash(
4951
(
5052
hash(self.state),
51-
hash(tuple(self.gravity.flatten().tolist())),
53+
HashedNumpyArray.hash_of_array(self.gravity),
5254
hash(self.soft_contacts_params),
53-
hash(jnp.atleast_1d(self.time_ns).flatten().tolist()),
55+
hash(tuple(self.time_ns.flatten().tolist())),
5456
)
5557
)
5658

src/jaxsim/api/ode_data.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -283,12 +283,16 @@ class PhysicsModelState(JaxsimDataclass):
283283

284284
def __hash__(self) -> int:
285285

286+
from jaxsim.utils.wrappers import HashedNumpyArray
287+
286288
return hash(
287289
(
288-
hash(tuple(jnp.atleast_1d(self.joint_positions.flatten().tolist()))),
289-
hash(tuple(jnp.atleast_1d(self.joint_velocities.flatten().tolist()))),
290-
hash(tuple(self.base_position.flatten().tolist())),
291-
hash(tuple(self.base_quaternion.flatten().tolist())),
290+
HashedNumpyArray.hash_of_array(self.joint_positions),
291+
HashedNumpyArray.hash_of_array(self.joint_velocities),
292+
HashedNumpyArray.hash_of_array(self.base_position),
293+
HashedNumpyArray.hash_of_array(self.base_quaternion),
294+
HashedNumpyArray.hash_of_array(self.base_linear_velocity),
295+
HashedNumpyArray.hash_of_array(self.base_angular_velocity),
292296
)
293297
)
294298

@@ -613,9 +617,9 @@ class SoftContactsState(JaxsimDataclass):
613617

614618
def __hash__(self) -> int:
615619

616-
return hash(
617-
tuple(jnp.atleast_1d(self.tangential_deformation.flatten()).tolist())
618-
)
620+
from jaxsim.utils.wrappers import HashedNumpyArray
621+
622+
return HashedNumpyArray.hash_of_array(self.tangential_deformation)
619623

620624
def __eq__(self, other: SoftContactsState) -> bool:
621625

src/jaxsim/rbda/soft_contacts.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,13 @@ class SoftContactsParams(JaxsimDataclass):
3131

3232
def __hash__(self) -> int:
3333

34+
from jaxsim.utils.wrappers import HashedNumpyArray
35+
3436
return hash(
3537
(
36-
hash(tuple(jnp.atleast_1d(self.K).flatten().tolist())),
37-
hash(tuple(jnp.atleast_1d(self.D).flatten().tolist())),
38-
hash(tuple(jnp.atleast_1d(self.mu).flatten().tolist())),
38+
HashedNumpyArray.hash_of_array(self.K),
39+
HashedNumpyArray.hash_of_array(self.D),
40+
HashedNumpyArray.hash_of_array(self.mu),
3941
)
4042
)
4143

0 commit comments

Comments
 (0)