diff --git a/jax_tpu_embedding/sparsecore/lib/nn/embedding.py b/jax_tpu_embedding/sparsecore/lib/nn/embedding.py index 417401c..65e7565 100644 --- a/jax_tpu_embedding/sparsecore/lib/nn/embedding.py +++ b/jax_tpu_embedding/sparsecore/lib/nn/embedding.py @@ -334,8 +334,8 @@ def preprocess_sparse_dense_matmul_input( the max_ids_per_partition or max_unique_ids_per_partition limits. Returns: - A tuple of four dictionaries mapping the stacked table names to the - preprocessed inputs for the corresponding table. The four dictionaries are + A tuple of five dictionaries mapping the stacked table names to the + preprocessed inputs for the corresponding table. The five dictionaries are lhs_row_pointers, lhs_embedding_ids, lhs_sample_ids, lhs_gains and stats. """ tree.assert_same_structure(features, feature_specs)