Skip to content

Commit 4e9f925

Browse files
Refactor to use a config object to manage arguments to sparse dense matmul.
PiperOrigin-RevId: 751450733
1 parent a6ce2f2 commit 4e9f925

10 files changed

+386
-417
lines changed

jax_tpu_embedding/sparsecore/examples/shakespeare/jax_sc_shakespeare_jit.py

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import orbax.checkpoint as ocp
4747
import tree
4848

49+
4950
np.set_printoptions(threshold=np.inf)
5051
Nested = embedding.Nested
5152

@@ -376,7 +377,7 @@ def train_step_fn(
376377
mesh: jax.sharding.Mesh,
377378
model: nn.Module,
378379
optimizer,
379-
feature_specs,
380+
config: embedding.SparseDenseMatmulConfig,
380381
train_state: TrainState,
381382
preprocessed_inputs,
382383
emb_variables,
@@ -386,22 +387,15 @@ def train_step_fn(
386387

387388
# Sparse forward pass - embedding lookup.
388389
with jax.named_scope('sc_forward_pass'):
389-
tpu_sparse_dense_matmul = partial(
390-
embedding.tpu_sparse_dense_matmul,
391-
global_device_count=num_global_devices,
392-
feature_specs=feature_specs,
393-
sharding_strategy='MOD',
394-
)
395390
tpu_sparse_dense_matmul = shard_map(
396-
f=tpu_sparse_dense_matmul,
391+
f=embedding.tpu_sparse_dense_matmul,
397392
mesh=mesh,
398-
in_specs=(pd, pe),
393+
in_specs=(pd, pe, None),
399394
out_specs=pd,
400395
check_rep=False,
401396
)
402397
emb_act = tpu_sparse_dense_matmul(
403-
preprocessed_inputs,
404-
emb_variables,
398+
preprocessed_inputs, emb_variables, config
405399
)
406400

407401
# Dense forward + backward pass.
@@ -429,22 +423,15 @@ def train_step_fn(
429423

430424
# Sparse backward pass - embedding update.
431425
with jax.named_scope('sc_backward_pass'):
432-
tpu_sparse_dense_matmul_grad = partial(
433-
embedding.tpu_sparse_dense_matmul_grad,
434-
feature_specs=feature_specs,
435-
sharding_strategy='MOD',
436-
)
437426
tpu_sparse_dense_matmul_grad = shard_map(
438-
f=tpu_sparse_dense_matmul_grad,
427+
f=embedding.tpu_sparse_dense_matmul_grad,
439428
mesh=mesh,
440-
in_specs=(pd, pd, pe),
429+
in_specs=(pd, pd, pe, None),
441430
out_specs=pe,
442431
check_rep=False,
443432
)
444433
emb_variables = tpu_sparse_dense_matmul_grad(
445-
emb_grad,
446-
preprocessed_inputs,
447-
emb_variables,
434+
emb_grad, preprocessed_inputs, emb_variables, config
448435
)
449436

450437
train_state = train_state.replace(
@@ -503,16 +490,17 @@ def train_step_fn(
503490
lambda y: jax.make_array_from_process_local_data(global_sharding, y),
504491
x,
505492
)
493+
config = embedding.SparseDenseMatmulConfig(
494+
global_device_count=num_global_devices,
495+
local_device_count=num_local_devices,
496+
feature_specs=flax.core.freeze(feature_specs),
497+
num_sc_per_device=num_sc_per_device,
498+
sharding_strategy='MOD',
499+
)
506500
preprocessed_inputs, stats = map(
507501
make_global_view,
508502
embedding.preprocess_sparse_dense_matmul_input(
509-
features,
510-
feature_weights,
511-
feature_specs,
512-
local_device_count=global_mesh.local_mesh.size,
513-
global_device_count=global_mesh.size,
514-
num_sc_per_device=num_sc_per_device,
515-
sharding_strategy='MOD',
503+
features, feature_weights, config=config
516504
),
517505
)
518506
fdo_client.record(stats)
@@ -524,7 +512,7 @@ def train_step_fn(
524512
global_mesh,
525513
model,
526514
optimizer,
527-
feature_specs,
515+
config,
528516
train_state,
529517
preprocessed_inputs,
530518
emb_variables,

jax_tpu_embedding/sparsecore/examples/shakespeare/jax_sc_shakespeare_pmap.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
"""An example Shakespeare model that uses the SparseCore embedding API."""
15-
1614
# pylint: disable=g-importing-member
1715
import collections
1816
from functools import partial
@@ -42,6 +40,8 @@
4240
import tree
4341

4442

43+
"""An example Shakespeare model that uses the SparseCore embedding API."""
44+
4545
np.set_printoptions(threshold=np.inf)
4646
Nested = embedding.Nested
4747

@@ -260,10 +260,9 @@ def run_model():
260260
)
261261

262262
def train_step_fn(
263-
global_device_count: int,
264263
model: nn.Module,
265264
optimizer,
266-
feature_specs,
265+
config,
267266
train_state: TrainState,
268267
preprocessed_inputs,
269268
emb_variables: Mapping[str, embedding.EmbeddingVariables],
@@ -276,9 +275,7 @@ def train_step_fn(
276275
with jax.named_scope('sc_forward_pass'):
277276
tpu_sparse_dense_matmul = partial(
278277
embedding.tpu_sparse_dense_matmul,
279-
global_device_count=global_device_count,
280-
feature_specs=feature_specs,
281-
sharding_strategy='MOD',
278+
config=config,
282279
)
283280
emb_act = tpu_sparse_dense_matmul(
284281
preprocessed_inputs,
@@ -312,8 +309,7 @@ def train_step_fn(
312309
with jax.named_scope('sc_backward_pass'):
313310
tpu_sparse_dense_matmul_grad = partial(
314311
embedding.tpu_sparse_dense_matmul_grad,
315-
feature_specs=feature_specs,
316-
sharding_strategy='MOD',
312+
config=config,
317313
)
318314
emb_variables = tpu_sparse_dense_matmul_grad(
319315
emb_grad,
@@ -337,14 +333,16 @@ def train_step_fn(
337333
vlog1('Replicating train_state')
338334
train_state = flax_utils.replicate(train_state, local_devices)
339335
parameter_overview.log_parameter_overview(train_state.params)
336+
337+
config = embedding.SparseDenseMatmulConfig(
338+
feature_specs=feature_specs,
339+
local_device_count=global_mesh.local_mesh.size,
340+
global_device_count=global_mesh.size,
341+
num_sc_per_device=num_sc_per_device,
342+
has_leading_dimension=True,
343+
)
340344
p_train_step_fn = jax.pmap(
341-
partial(
342-
train_step_fn,
343-
num_global_devices,
344-
model,
345-
optimizer,
346-
feature_specs,
347-
),
345+
partial(train_step_fn, model, optimizer, config),
348346
axis_name='batch',
349347
)
350348

@@ -376,16 +374,16 @@ def train_step_fn(
376374
)
377375

378376
# Preprocess the inputs.
379-
preprocessed_inputs, _ = embedding.preprocess_sparse_dense_matmul_input(
380-
features,
381-
feature_weights,
382-
feature_specs,
377+
config = embedding.SparseDenseMatmulConfig(
378+
feature_specs=flax.core.freeze(feature_specs),
383379
local_device_count=global_mesh.local_mesh.size,
384380
global_device_count=global_mesh.size,
385381
num_sc_per_device=num_sc_per_device,
386-
sharding_strategy='MOD',
387382
has_leading_dimension=True,
388383
)
384+
preprocessed_inputs, _ = embedding.preprocess_sparse_dense_matmul_input(
385+
features, feature_weights, config
386+
)
389387

390388
# TODO(patn): This (local_slice)will go away once the input processor is
391389
# updated to only produce local batches.

jax_tpu_embedding/sparsecore/lib/flax/embed.py

Lines changed: 19 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,14 @@ def __post_init__(self):
8888
self.mesh.devices.item(0)
8989
)
9090

91+
self.config = embedding.SparseDenseMatmulConfig(
92+
feature_specs=self.feature_specs,
93+
local_device_count=self.mesh.local_mesh.size,
94+
global_device_count=self.mesh.size,
95+
num_sc_per_device=self.num_sc_per_device,
96+
sharding_strategy=self.table_sharding_strategy,
97+
)
98+
9199
super().__post_init__()
92100

93101
def setup(self):
@@ -143,11 +151,7 @@ def preprocess_inputs(
143151
return embedding.preprocess_sparse_dense_matmul_input(
144152
features,
145153
features_weights,
146-
self.feature_specs,
147-
self.mesh.local_mesh.size,
148-
self.mesh.size,
149-
num_sc_per_device=self.num_sc_per_device,
150-
sharding_strategy=self.table_sharding_strategy,
154+
self.config,
151155
)[0]
152156

153157
def __call__(self, embedding_lookups: EmbeddingLookups) -> Nested[jax.Array]:
@@ -198,34 +202,22 @@ def _emb_lookup(
198202
pt = embedding_layer.embedding_table_partition
199203
pd = embedding_layer.data_partition
200204
return shard_map(
201-
functools.partial(
202-
embedding.tpu_sparse_dense_matmul,
203-
global_device_count=embedding_layer.mesh.size,
204-
feature_specs=embedding_layer.feature_specs,
205-
sharding_strategy=embedding_layer.table_sharding_strategy,
206-
),
205+
embedding.tpu_sparse_dense_matmul,
207206
mesh=embedding_layer.mesh,
208-
in_specs=(pd, pt),
207+
in_specs=(pd, pt, None),
209208
out_specs=pd,
210209
check_rep=False,
211-
)(
212-
embedding_lookups,
213-
emb_table,
214-
)
210+
)(embedding_lookups, emb_table, embedding_layer.config)
215211

216212

217213
def _emb_lookup_fwd(
218214
embedding_layer: SparseCoreEmbed,
219215
embedding_lookups: EmbeddingLookups,
220216
emb_table: Mapping[str, tuple[jax.Array, ...]],
221217
):
222-
return _emb_lookup(
223-
embedding_layer,
224-
embedding_lookups,
225-
emb_table,
226-
), (
227-
embedding_lookups,
228-
emb_table,
218+
return (
219+
_emb_lookup(embedding_layer, embedding_lookups, emb_table),
220+
(embedding_lookups, emb_table),
229221
)
230222

231223

@@ -236,20 +228,12 @@ def _emb_lookup_bwd(embedding_layer, res, gradients):
236228
pt = embedding_layer.embedding_table_partition
237229
pd = embedding_layer.data_partition
238230
emb_table_grads = shard_map(
239-
functools.partial(
240-
embedding.tpu_sparse_dense_matmul_grad,
241-
feature_specs=embedding_layer.feature_specs,
242-
sharding_strategy=embedding_layer.table_sharding_strategy,
243-
),
231+
embedding.tpu_sparse_dense_matmul_grad,
244232
mesh=embedding_layer.mesh,
245-
in_specs=(pd, pd, pt),
233+
in_specs=(pd, pd, pt, None),
246234
out_specs=pt,
247235
check_rep=False,
248-
)(
249-
gradients,
250-
embedding_lookups,
251-
emb_table,
252-
)
236+
)(gradients, embedding_lookups, emb_table, config=embedding_layer.config)
253237

254238
# tpu_sparse_dense_matmul_grad returns a general Mapping (usually a dict).
255239
# It may not be the same type as the embedding table (e.g. FrozenDict).
@@ -258,10 +242,7 @@ def _emb_lookup_bwd(embedding_layer, res, gradients):
258242
jax.tree.structure(emb_table), jax.tree.leaves(emb_table_grads)
259243
)
260244

261-
return (
262-
None,
263-
emb_table_grads,
264-
)
245+
return (None, emb_table_grads)
265246

266247

267248
_emb_lookup.defvjp(_emb_lookup_fwd, _emb_lookup_bwd)

jax_tpu_embedding/sparsecore/lib/flax/tests/autograd_test.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def test_shakespeare_model_loss_convergence(self):
129129
)
130130

131131
# Initialize the model.
132-
def process_inputs(feature_batch):
132+
def process_inputs(feature_batch, config):
133133
features = np.reshape(feature_batch, (-1, 1))
134134
feature_weights = np.ones(features.shape, dtype=np.float32)
135135

@@ -144,15 +144,18 @@ def process_inputs(feature_batch):
144144
*embedding.preprocess_sparse_dense_matmul_input(
145145
features,
146146
feature_weights,
147-
feature_specs,
148-
mesh.local_mesh.size,
149-
mesh.size,
150-
num_sc_per_device=num_sc_per_device,
151-
sharding_strategy='MOD',
147+
config,
152148
)[0]
153149
)
154150

155-
first_model_input = process_inputs(feature_batches[0])
151+
config = embedding.SparseDenseMatmulConfig(
152+
feature_specs=feature_specs,
153+
local_device_count=mesh.local_mesh.size,
154+
global_device_count=mesh.size,
155+
num_sc_per_device=num_sc_per_device,
156+
sharding_strategy='MOD',
157+
)
158+
first_model_input = process_inputs(feature_batches[0], config)
156159
params = model.init(jax.random.key(42), first_model_input)
157160

158161
# Create optimizer.
@@ -204,7 +207,7 @@ def forward_pass(params, embedding_lookups, labels):
204207
# ------------------------------------------------------------------------
205208
# Step 1: SC input processing.
206209
# ------------------------------------------------------------------------
207-
processed_input_tensor = process_inputs(features)
210+
processed_input_tensor = process_inputs(features, config)
208211

209212
# ------------------------------------------------------------------------
210213
# Step 2: run model.

0 commit comments

Comments
 (0)