Skip to content

Commit 298690a

Browse files
Consolidate sparse matmul preprocessed inputs into a single NamedTuple
PiperOrigin-RevId: 750724913
1 parent 95cac49 commit 298690a

17 files changed

+298
-463
lines changed

jax_tpu_embedding/sparsecore/examples/shakespeare/jax_sc_shakespeare_jit.py

+8-26
Original file line numberDiff line numberDiff line change
@@ -370,18 +370,15 @@ def run_model():
370370
None,
371371
emb_var_outsharding,
372372
),
373-
donate_argnums=(9),
373+
donate_argnums=(6),
374374
)
375375
def train_step_fn(
376376
mesh: jax.sharding.Mesh,
377377
model: nn.Module,
378378
optimizer,
379379
feature_specs,
380380
train_state: TrainState,
381-
lhs_row_pointers,
382-
lhs_local_embedding_ids,
383-
lhs_local_sample_ids,
384-
lhs_gains,
381+
preprocessed_inputs,
385382
emb_variables,
386383
labels,
387384
) -> tuple[TrainState, TrainMetrics, Nested[jax.Array]]:
@@ -398,15 +395,12 @@ def train_step_fn(
398395
tpu_sparse_dense_matmul = shard_map(
399396
f=tpu_sparse_dense_matmul,
400397
mesh=mesh,
401-
in_specs=(pd, pd, pd, pd, pe),
398+
in_specs=(pd, pe),
402399
out_specs=pd,
403400
check_rep=False,
404401
)
405402
emb_act = tpu_sparse_dense_matmul(
406-
lhs_row_pointers,
407-
lhs_local_embedding_ids,
408-
lhs_local_sample_ids,
409-
lhs_gains,
403+
preprocessed_inputs,
410404
emb_variables,
411405
)
412406

@@ -443,16 +437,13 @@ def train_step_fn(
443437
tpu_sparse_dense_matmul_grad = shard_map(
444438
f=tpu_sparse_dense_matmul_grad,
445439
mesh=mesh,
446-
in_specs=(pd, pd, pd, pd, pd, pe),
440+
in_specs=(pd, pd, pe),
447441
out_specs=pe,
448442
check_rep=False,
449443
)
450444
emb_variables = tpu_sparse_dense_matmul_grad(
451445
emb_grad,
452-
lhs_row_pointers,
453-
lhs_local_embedding_ids,
454-
lhs_local_sample_ids,
455-
lhs_gains,
446+
preprocessed_inputs,
456447
emb_variables,
457448
)
458449

@@ -512,13 +503,7 @@ def train_step_fn(
512503
lambda y: jax.make_array_from_process_local_data(global_sharding, y),
513504
x,
514505
)
515-
(
516-
lhs_row_pointers,
517-
lhs_local_embedding_ids,
518-
lhs_local_sample_ids,
519-
lhs_gains,
520-
stats,
521-
) = map(
506+
preprocessed_inputs, stats = map(
522507
make_global_view,
523508
embedding.preprocess_sparse_dense_matmul_input(
524509
features,
@@ -541,10 +526,7 @@ def train_step_fn(
541526
optimizer,
542527
feature_specs,
543528
train_state,
544-
lhs_row_pointers,
545-
lhs_local_embedding_ids,
546-
lhs_local_sample_ids,
547-
lhs_gains,
529+
preprocessed_inputs,
548530
emb_variables,
549531
labels,
550532
)

jax_tpu_embedding/sparsecore/examples/shakespeare/jax_sc_shakespeare_pmap.py

+14-31
Original file line numberDiff line numberDiff line change
@@ -265,10 +265,7 @@ def train_step_fn(
265265
optimizer,
266266
feature_specs,
267267
train_state: TrainState,
268-
lhs_row_pointers,
269-
lhs_local_embedding_ids,
270-
lhs_local_sample_ids,
271-
lhs_gains,
268+
preprocessed_inputs,
272269
emb_variables: Mapping[str, embedding.EmbeddingVariables],
273270
labels,
274271
) -> tuple[
@@ -284,10 +281,7 @@ def train_step_fn(
284281
sharding_strategy='MOD',
285282
)
286283
emb_act = tpu_sparse_dense_matmul(
287-
lhs_row_pointers,
288-
lhs_local_embedding_ids,
289-
lhs_local_sample_ids,
290-
lhs_gains,
284+
preprocessed_inputs,
291285
emb_variables,
292286
)
293287

@@ -323,10 +317,7 @@ def train_step_fn(
323317
)
324318
emb_variables = tpu_sparse_dense_matmul_grad(
325319
emb_grad,
326-
lhs_row_pointers,
327-
lhs_local_embedding_ids,
328-
lhs_local_sample_ids,
329-
lhs_gains,
320+
preprocessed_inputs,
330321
emb_variables,
331322
)
332323

@@ -385,17 +376,15 @@ def train_step_fn(
385376
)
386377

387378
# Preprocess the inputs.
388-
(lhs_row_pointers, lhs_embedding_ids, lhs_sample_ids, lhs_gains, _) = (
389-
embedding.preprocess_sparse_dense_matmul_input(
390-
features,
391-
feature_weights,
392-
feature_specs,
393-
local_device_count=global_mesh.local_mesh.size,
394-
global_device_count=global_mesh.size,
395-
num_sc_per_device=num_sc_per_device,
396-
sharding_strategy='MOD',
397-
has_leading_dimension=True,
398-
)
379+
preprocessed_inputs, _ = embedding.preprocess_sparse_dense_matmul_input(
380+
features,
381+
feature_weights,
382+
feature_specs,
383+
local_device_count=global_mesh.local_mesh.size,
384+
global_device_count=global_mesh.size,
385+
num_sc_per_device=num_sc_per_device,
386+
sharding_strategy='MOD',
387+
has_leading_dimension=True,
399388
)
400389

401390
# TODO(patn): This (local_slice)will go away once the input processor is
@@ -432,10 +421,7 @@ def train_step_fn(
432421
continue
433422
jaxpr = jax.make_jaxpr(p_train_step_fn)(
434423
train_state,
435-
lhs_row_pointers,
436-
lhs_embedding_ids,
437-
lhs_sample_ids,
438-
lhs_gains,
424+
preprocessed_inputs,
439425
emb_variables,
440426
labels_sharded,
441427
)
@@ -448,10 +434,7 @@ def train_step_fn(
448434

449435
train_state, metrics_update, emb_variables = p_train_step_fn(
450436
train_state,
451-
lhs_row_pointers,
452-
lhs_embedding_ids,
453-
lhs_sample_ids,
454-
lhs_gains,
437+
preprocessed_inputs,
455438
emb_variables,
456439
labels_sharded,
457440
)

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_cc_test.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -763,15 +763,19 @@ def test_multi_process_fdo(self, has_leading_dimension):
763763
allow_id_dropping=False,
764764
)
765765
)
766+
stats = embedding.SparseDenseMatmulInputStats(
767+
max_ids_per_partition=stats["max_ids"],
768+
max_unique_ids_per_partition=stats["max_unique_ids"],
769+
)
766770
fdo_client.record(stats)
767771
fdo_client.publish()
768772
# Duplicated ids on row 0 and 6 are combined.
769773
np.testing.assert_equal(
770-
stats["max_ids"]["one_table_to_rule_them_all"],
774+
stats.max_ids_per_partition["one_table_to_rule_them_all"],
771775
np.array([7, 4, 6, 5, 9, 5, 5, 5], dtype=np.int32),
772776
)
773777
np.testing.assert_equal(
774-
stats["max_unique_ids"]["one_table_to_rule_them_all"],
778+
stats.max_unique_ids_per_partition["one_table_to_rule_them_all"],
775779
np.array([3, 3, 4, 4, 5, 3, 3, 5], dtype=np.int32),
776780
)
777781

jax_tpu_embedding/sparsecore/lib/fdo/BUILD

+6-2
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,18 @@ package(
2222
pytype_strict_library(
2323
name = "fdo_client",
2424
srcs = ["fdo_client.py"],
25-
deps = [pypi_requirement("numpy")],
25+
deps = [
26+
"//jax_tpu_embedding/sparsecore/lib/nn:embedding",
27+
pypi_requirement("numpy"),
28+
],
2629
)
2730

2831
pytype_strict_library(
2932
name = "file_fdo_client",
3033
srcs = ["file_fdo_client.py"],
3134
deps = [
3235
":fdo_client",
36+
"//jax_tpu_embedding/sparsecore/lib/nn:embedding",
3337
pypi_requirement("absl/logging"),
3438
pypi_requirement("jax"),
3539
pypi_requirement("numpy"),
@@ -42,8 +46,8 @@ pytype_strict_contrib_test(
4246
env = {"JAX_PLATFORMS": "cpu"},
4347
deps = [
4448
":file_fdo_client",
49+
"//jax_tpu_embedding/sparsecore/lib/nn:embedding",
4550
pypi_requirement("absl/testing:absltest"),
46-
pypi_requirement("jax"),
4751
pypi_requirement("numpy"),
4852
],
4953
)

jax_tpu_embedding/sparsecore/lib/fdo/fdo_client.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import abc
1717
from collections.abc import Mapping
1818

19+
from jax_tpu_embedding.sparsecore.lib.nn import embedding
1920
import numpy as np
2021

2122

@@ -41,7 +42,7 @@ class FDOClient(abc.ABC):
4142
@abc.abstractmethod
4243
def record(
4344
self,
44-
data: Mapping[str, Mapping[str, np.ndarray]],
45+
data: embedding.SparseDenseMatmulInputStats,
4546
) -> None:
4647
"""Records the raw stats to local memory.
4748

jax_tpu_embedding/sparsecore/lib/fdo/file_fdo_client.py

+14-19
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,14 @@
2424
from absl import logging
2525
import jax
2626
from jax_tpu_embedding.sparsecore.lib.fdo import fdo_client
27+
from jax_tpu_embedding.sparsecore.lib.nn import embedding
2728
import numpy as np
2829

2930

3031
_FILE_NAME = 'fdo_stats'
3132
_FILE_EXTENSION = 'npz'
32-
_MAX_ID_STATS_KEY = '_max_ids'
33-
_MAX_UNIQUE_ID_STATS_KEY = '_max_unique_ids'
33+
_MAX_ID_STATS_SUFFIX = '_max_ids'
34+
_MAX_UNIQUE_ID_STATS_SUFFIX = '_max_unique_ids'
3435

3536

3637
class NPZFileFDOClient(fdo_client.FDOClient):
@@ -57,7 +58,7 @@ def __init__(self, base_dir: str):
5758
self._max_ids_per_partition = collections.defaultdict(np.ndarray)
5859
self._max_unique_ids_per_partition = collections.defaultdict(np.ndarray)
5960

60-
def record(self, data: Mapping[str, Mapping[str, np.ndarray]]) -> None:
61+
def record(self, data: embedding.SparseDenseMatmulInputStats) -> None:
6162
"""Records stats per process.
6263
6364
Accumulates the max ids observed per process per sparsecore per device for
@@ -67,9 +68,7 @@ def record(self, data: Mapping[str, Mapping[str, np.ndarray]]) -> None:
6768
Args:
6869
data: A mapping representing data to be recorded.
6970
"""
70-
if _MAX_ID_STATS_KEY[1:] not in data:
71-
raise ValueError(f'Expected stat ({_MAX_ID_STATS_KEY[1:]}) not found.')
72-
max_ids_per_process = data[_MAX_ID_STATS_KEY[1:]]
71+
max_ids_per_process = data.max_ids_per_partition
7372
for table_name, stats in max_ids_per_process.items():
7473
logging.vlog(
7574
2, 'Recording observed max ids for table: %s -> %s', table_name, stats
@@ -80,11 +79,7 @@ def record(self, data: Mapping[str, Mapping[str, np.ndarray]]) -> None:
8079
self._max_ids_per_partition[table_name] = np.vstack(
8180
(self._max_ids_per_partition[table_name], stats)
8281
)
83-
if _MAX_UNIQUE_ID_STATS_KEY[1:] not in data:
84-
raise ValueError(
85-
f'Expected stats ({_MAX_UNIQUE_ID_STATS_KEY[1:]}) not found.'
86-
)
87-
max_uniques_per_process = data[_MAX_UNIQUE_ID_STATS_KEY[1:]]
82+
max_uniques_per_process = data.max_unique_ids_per_partition
8883
for table_name, stats in max_uniques_per_process.items():
8984
logging.vlog(
9085
2,
@@ -107,7 +102,7 @@ def _generate_file_name(self) -> str:
107102
_FILE_NAME, jax.process_index(), time.time_ns(), _FILE_EXTENSION
108103
)
109104
return os.path.join(self._base_dir, filename)
110-
# LINT.ThenChange(:_get_latest_files_by_process)
105+
# LINT.ThenChange(:_get_latest_files_by_process)
111106

112107
def _get_latest_files_by_process(self, files: list[str]) -> list[str]:
113108
"""Returns the latest file for each process."""
@@ -150,11 +145,11 @@ def publish(self) -> None:
150145
processes.
151146
"""
152147
merged_stats = {
153-
f'{table_name}{_MAX_ID_STATS_KEY}': stats
148+
f'{table_name}{_MAX_ID_STATS_SUFFIX}': stats
154149
for table_name, stats in self._max_ids_per_partition.items()
155150
}
156151
merged_stats.update({
157-
f'{table_name}{_MAX_UNIQUE_ID_STATS_KEY}': stats
152+
f'{table_name}{_MAX_UNIQUE_ID_STATS_SUFFIX}': stats
158153
for table_name, stats in self._max_unique_ids_per_partition.items()
159154
})
160155
self._write_to_file(merged_stats)
@@ -197,16 +192,16 @@ def load(
197192
stats = self._read_from_file(files_glob)
198193
max_id_stats, max_unique_id_stats = {}, {}
199194
for table_name, stats in stats.items():
200-
if table_name.endswith(f'{_MAX_ID_STATS_KEY}'):
201-
max_id_stats[table_name[: -len(_MAX_ID_STATS_KEY)]] = stats
202-
elif table_name.endswith(f'{_MAX_UNIQUE_ID_STATS_KEY}'):
203-
max_unique_id_stats[table_name[: -len(_MAX_UNIQUE_ID_STATS_KEY)]] = (
195+
if table_name.endswith(f'{_MAX_ID_STATS_SUFFIX}'):
196+
max_id_stats[table_name[: -len(_MAX_ID_STATS_SUFFIX)]] = stats
197+
elif table_name.endswith(f'{_MAX_UNIQUE_ID_STATS_SUFFIX}'):
198+
max_unique_id_stats[table_name[: -len(_MAX_UNIQUE_ID_STATS_SUFFIX)]] = (
204199
stats
205200
)
206201
else:
207202
raise ValueError(
208203
f'Unexpected table name and stats key: {table_name}, expected to'
209-
f' end with {_MAX_ID_STATS_KEY} or {_MAX_UNIQUE_ID_STATS_KEY}'
204+
f' end with {_MAX_ID_STATS_SUFFIX} or {_MAX_UNIQUE_ID_STATS_SUFFIX}'
210205
)
211206
self._max_ids_per_partition = max_id_stats
212207
self._max_unique_ids_per_partition = max_unique_id_stats

jax_tpu_embedding/sparsecore/lib/fdo/file_fdo_client_test.py

+15-14
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from absl.testing import absltest
1919
from jax_tpu_embedding.sparsecore.lib.fdo import file_fdo_client
20+
from jax_tpu_embedding.sparsecore.lib.nn import embedding
2021
import numpy as np
2122

2223

@@ -36,26 +37,26 @@ def _assert_stats_equal(self, actual, expected):
3637

3738
def test_record_and_publish_load(self):
3839
fdo_client = file_fdo_client.NPZFileFDOClient(self.base_dir)
39-
max_id_stats = {"tab_one": np.array([10, 20, 30, 40])}
40-
max_unique_stats = {"tab_one": np.array([1, 2, 3, 4])}
41-
fdo_client.record(
42-
{"max_ids": max_id_stats, "max_unique_ids": max_unique_stats}
40+
stats = embedding.SparseDenseMatmulInputStats(
41+
max_ids_per_partition={"tab_one": np.array([10, 20, 30, 40])},
42+
max_unique_ids_per_partition={"tab_one": np.array([1, 2, 3, 4])},
4343
)
44+
fdo_client.record(stats)
4445
fdo_client.publish()
4546
loaded_max_ids, loaded_max_uniques = fdo_client.load()
46-
self._assert_stats_equal(loaded_max_ids, max_id_stats)
47-
self._assert_stats_equal(loaded_max_uniques, max_unique_stats)
47+
self._assert_stats_equal(loaded_max_ids, stats.max_ids_per_partition)
48+
self._assert_stats_equal(
49+
loaded_max_uniques, stats.max_unique_ids_per_partition
50+
)
4851

4952
def test_multiple_record(self):
5053
fdo_client = file_fdo_client.NPZFileFDOClient(self.base_dir)
51-
fdo_client.record({
52-
"max_ids": {"tab_one": np.array([10, 20, 30, 40])},
53-
"max_unique_ids": {"tab_one": np.array([1, 2, 3, 4])},
54-
})
55-
fdo_client.record({
56-
"max_ids": {"tab_one": np.array([10, 20, 30, 40])},
57-
"max_unique_ids": {"tab_one": np.array([1, 2, 3, 4])},
58-
})
54+
stats = embedding.SparseDenseMatmulInputStats(
55+
max_ids_per_partition={"tab_one": np.array([10, 20, 30, 40])},
56+
max_unique_ids_per_partition={"tab_one": np.array([1, 2, 3, 4])},
57+
)
58+
fdo_client.record(stats)
59+
fdo_client.record(stats)
5960
fdo_client.publish()
6061
loaded_max_ids, loaded_max_uniques = fdo_client.load()
6162

0 commit comments

Comments
 (0)