Skip to content

Commit dac2da2

Browse files
Allow construction SDMM stats object from dicts.
PiperOrigin-RevId: 751446115
1 parent 298690a commit dac2da2

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_cc_test.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -763,10 +763,7 @@ def test_multi_process_fdo(self, has_leading_dimension):
763763
allow_id_dropping=False,
764764
)
765765
)
766-
stats = embedding.SparseDenseMatmulInputStats(
767-
max_ids_per_partition=stats["max_ids"],
768-
max_unique_ids_per_partition=stats["max_unique_ids"],
769-
)
766+
stats = embedding.SparseDenseMatmulInputStats.from_dict(stats)
770767
fdo_client.record(stats)
771768
fdo_client.publish()
772769
# Duplicated ids on row 0 and 6 are combined.

jax_tpu_embedding/sparsecore/lib/nn/embedding.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,15 @@ class SparseDenseMatmulInputStats:
6868
max_ids_per_partition: Mapping[str, np.ndarray]
6969
max_unique_ids_per_partition: Mapping[str, np.ndarray]
7070

71+
@classmethod
72+
def from_dict(
73+
cls, stats: Mapping[str, Mapping[str, np.ndarray]]
74+
) -> "SparseDenseMatmulInputStats":
75+
return cls(
76+
max_ids_per_partition=stats["max_ids"],
77+
max_unique_ids_per_partition=stats["max_unique_ids"],
78+
)
79+
7180

7281
# TODO: b/346873239 - Add more checks for the feature specs to ensure all the
7382
# fields are valid.
@@ -371,10 +380,7 @@ def preprocess_sparse_dense_matmul_input(
371380

372381
return SparseDenseMatmulInput(
373382
*preprocessed_inputs
374-
), SparseDenseMatmulInputStats(
375-
max_ids_per_partition=stats["max_ids"],
376-
max_unique_ids_per_partition=stats["max_unique_ids"],
377-
)
383+
), SparseDenseMatmulInputStats.from_dict(stats)
378384

379385

380386
def _get_activation_for_feature(

0 commit comments

Comments
 (0)