diff --git a/jax_tpu_embedding/sparsecore/examples/shakespeare/jax_sc_shakespeare_jit.py b/jax_tpu_embedding/sparsecore/examples/shakespeare/jax_sc_shakespeare_jit.py index 8a9e238..aebe528 100644 --- a/jax_tpu_embedding/sparsecore/examples/shakespeare/jax_sc_shakespeare_jit.py +++ b/jax_tpu_embedding/sparsecore/examples/shakespeare/jax_sc_shakespeare_jit.py @@ -46,6 +46,7 @@ import orbax.checkpoint as ocp import tree + np.set_printoptions(threshold=np.inf) Nested = embedding.Nested @@ -376,7 +377,7 @@ def train_step_fn( mesh: jax.sharding.Mesh, model: nn.Module, optimizer, - feature_specs, + config: embedding.SparseDenseMatmulConfig, train_state: TrainState, preprocessed_inputs, emb_variables, @@ -386,22 +387,15 @@ def train_step_fn( # Sparse forward pass - embedding lookup. with jax.named_scope('sc_forward_pass'): - tpu_sparse_dense_matmul = partial( - embedding.tpu_sparse_dense_matmul, - global_device_count=num_global_devices, - feature_specs=feature_specs, - sharding_strategy='MOD', - ) tpu_sparse_dense_matmul = shard_map( - f=tpu_sparse_dense_matmul, + f=embedding.tpu_sparse_dense_matmul, mesh=mesh, - in_specs=(pd, pe), + in_specs=(pd, pe, None), out_specs=pd, check_rep=False, ) emb_act = tpu_sparse_dense_matmul( - preprocessed_inputs, - emb_variables, + preprocessed_inputs, emb_variables, config ) # Dense forward + backward pass. @@ -429,22 +423,15 @@ def train_step_fn( # Sparse backward pass - embedding update. with jax.named_scope('sc_backward_pass'): - tpu_sparse_dense_matmul_grad = partial( - embedding.tpu_sparse_dense_matmul_grad, - feature_specs=feature_specs, - sharding_strategy='MOD', - ) tpu_sparse_dense_matmul_grad = shard_map( - f=tpu_sparse_dense_matmul_grad, + f=embedding.tpu_sparse_dense_matmul_grad, mesh=mesh, - in_specs=(pd, pd, pe), + in_specs=(pd, pd, pe, None), out_specs=pe, check_rep=False, ) emb_variables = tpu_sparse_dense_matmul_grad( - emb_grad, - preprocessed_inputs, - emb_variables, + emb_grad, preprocessed_inputs, emb_variables, config ) train_state = train_state.replace( @@ -503,16 +490,17 @@ def train_step_fn( lambda y: jax.make_array_from_process_local_data(global_sharding, y), x, ) + config = embedding.SparseDenseMatmulConfig( + global_device_count=num_global_devices, + local_device_count=num_local_devices, + feature_specs=flax.core.freeze(feature_specs), + num_sc_per_device=num_sc_per_device, + sharding_strategy='MOD', + ) preprocessed_inputs, stats = map( make_global_view, embedding.preprocess_sparse_dense_matmul_input( - features, - feature_weights, - feature_specs, - local_device_count=global_mesh.local_mesh.size, - global_device_count=global_mesh.size, - num_sc_per_device=num_sc_per_device, - sharding_strategy='MOD', + features, feature_weights, config=config ), ) fdo_client.record(stats) @@ -524,7 +512,7 @@ def train_step_fn( global_mesh, model, optimizer, - feature_specs, + config, train_state, preprocessed_inputs, emb_variables, diff --git a/jax_tpu_embedding/sparsecore/examples/shakespeare/jax_sc_shakespeare_pmap.py b/jax_tpu_embedding/sparsecore/examples/shakespeare/jax_sc_shakespeare_pmap.py index 3d0212b..917f891 100644 --- a/jax_tpu_embedding/sparsecore/examples/shakespeare/jax_sc_shakespeare_pmap.py +++ b/jax_tpu_embedding/sparsecore/examples/shakespeare/jax_sc_shakespeare_pmap.py @@ -260,10 +260,9 @@ def run_model(): ) def train_step_fn( - global_device_count: int, model: nn.Module, optimizer, - feature_specs, + config, train_state: TrainState, preprocessed_inputs, emb_variables: Mapping[str, embedding.EmbeddingVariables], @@ -276,9 +275,7 @@ def train_step_fn( with jax.named_scope('sc_forward_pass'): tpu_sparse_dense_matmul = partial( embedding.tpu_sparse_dense_matmul, - global_device_count=global_device_count, - feature_specs=feature_specs, - sharding_strategy='MOD', + config=config, ) emb_act = tpu_sparse_dense_matmul( preprocessed_inputs, @@ -312,8 +309,7 @@ def train_step_fn( with jax.named_scope('sc_backward_pass'): tpu_sparse_dense_matmul_grad = partial( embedding.tpu_sparse_dense_matmul_grad, - feature_specs=feature_specs, - sharding_strategy='MOD', + config=config, ) emb_variables = tpu_sparse_dense_matmul_grad( emb_grad, @@ -337,14 +333,16 @@ def train_step_fn( vlog1('Replicating train_state') train_state = flax_utils.replicate(train_state, local_devices) parameter_overview.log_parameter_overview(train_state.params) + + config = embedding.SparseDenseMatmulConfig( + feature_specs=feature_specs, + local_device_count=global_mesh.local_mesh.size, + global_device_count=global_mesh.size, + num_sc_per_device=num_sc_per_device, + has_leading_dimension=True, + ) p_train_step_fn = jax.pmap( - partial( - train_step_fn, - num_global_devices, - model, - optimizer, - feature_specs, - ), + partial(train_step_fn, model, optimizer, config), axis_name='batch', ) @@ -376,16 +374,16 @@ def train_step_fn( ) # Preprocess the inputs. - preprocessed_inputs, _ = embedding.preprocess_sparse_dense_matmul_input( - features, - feature_weights, - feature_specs, + config = embedding.SparseDenseMatmulConfig( + feature_specs=flax.core.freeze(feature_specs), local_device_count=global_mesh.local_mesh.size, global_device_count=global_mesh.size, num_sc_per_device=num_sc_per_device, - sharding_strategy='MOD', has_leading_dimension=True, ) + preprocessed_inputs, _ = embedding.preprocess_sparse_dense_matmul_input( + features, feature_weights, config + ) # TODO(patn): This (local_slice)will go away once the input processor is # updated to only produce local batches. diff --git a/jax_tpu_embedding/sparsecore/lib/flax/embed.py b/jax_tpu_embedding/sparsecore/lib/flax/embed.py index 2916381..5c896b1 100644 --- a/jax_tpu_embedding/sparsecore/lib/flax/embed.py +++ b/jax_tpu_embedding/sparsecore/lib/flax/embed.py @@ -88,6 +88,14 @@ def __post_init__(self): self.mesh.devices.item(0) ) + self.config = embedding.SparseDenseMatmulConfig( + feature_specs=self.feature_specs, + local_device_count=self.mesh.local_mesh.size, + global_device_count=self.mesh.size, + num_sc_per_device=self.num_sc_per_device, + sharding_strategy=self.table_sharding_strategy, + ) + super().__post_init__() def setup(self): @@ -143,11 +151,7 @@ def preprocess_inputs( return embedding.preprocess_sparse_dense_matmul_input( features, features_weights, - self.feature_specs, - self.mesh.local_mesh.size, - self.mesh.size, - num_sc_per_device=self.num_sc_per_device, - sharding_strategy=self.table_sharding_strategy, + self.config, )[0] def __call__(self, embedding_lookups: EmbeddingLookups) -> Nested[jax.Array]: @@ -198,20 +202,12 @@ def _emb_lookup( pt = embedding_layer.embedding_table_partition pd = embedding_layer.data_partition return shard_map( - functools.partial( - embedding.tpu_sparse_dense_matmul, - global_device_count=embedding_layer.mesh.size, - feature_specs=embedding_layer.feature_specs, - sharding_strategy=embedding_layer.table_sharding_strategy, - ), + embedding.tpu_sparse_dense_matmul, mesh=embedding_layer.mesh, - in_specs=(pd, pt), + in_specs=(pd, pt, None), out_specs=pd, check_rep=False, - )( - embedding_lookups, - emb_table, - ) + )(embedding_lookups, emb_table, embedding_layer.config) def _emb_lookup_fwd( @@ -219,13 +215,9 @@ def _emb_lookup_fwd( embedding_lookups: EmbeddingLookups, emb_table: Mapping[str, tuple[jax.Array, ...]], ): - return _emb_lookup( - embedding_layer, - embedding_lookups, - emb_table, - ), ( - embedding_lookups, - emb_table, + return ( + _emb_lookup(embedding_layer, embedding_lookups, emb_table), + (embedding_lookups, emb_table), ) @@ -236,20 +228,12 @@ def _emb_lookup_bwd(embedding_layer, res, gradients): pt = embedding_layer.embedding_table_partition pd = embedding_layer.data_partition emb_table_grads = shard_map( - functools.partial( - embedding.tpu_sparse_dense_matmul_grad, - feature_specs=embedding_layer.feature_specs, - sharding_strategy=embedding_layer.table_sharding_strategy, - ), + embedding.tpu_sparse_dense_matmul_grad, mesh=embedding_layer.mesh, - in_specs=(pd, pd, pt), + in_specs=(pd, pd, pt, None), out_specs=pt, check_rep=False, - )( - gradients, - embedding_lookups, - emb_table, - ) + )(gradients, embedding_lookups, emb_table, embedding_layer.config) # tpu_sparse_dense_matmul_grad returns a general Mapping (usually a dict). # It may not be the same type as the embedding table (e.g. FrozenDict). @@ -258,10 +242,7 @@ def _emb_lookup_bwd(embedding_layer, res, gradients): jax.tree.structure(emb_table), jax.tree.leaves(emb_table_grads) ) - return ( - None, - emb_table_grads, - ) + return (None, emb_table_grads) _emb_lookup.defvjp(_emb_lookup_fwd, _emb_lookup_bwd) diff --git a/jax_tpu_embedding/sparsecore/lib/flax/tests/autograd_test.py b/jax_tpu_embedding/sparsecore/lib/flax/tests/autograd_test.py index b9d7c4d..8fe4c0c 100644 --- a/jax_tpu_embedding/sparsecore/lib/flax/tests/autograd_test.py +++ b/jax_tpu_embedding/sparsecore/lib/flax/tests/autograd_test.py @@ -129,7 +129,7 @@ def test_shakespeare_model_loss_convergence(self): ) # Initialize the model. - def process_inputs(feature_batch): + def process_inputs(feature_batch, config): features = np.reshape(feature_batch, (-1, 1)) feature_weights = np.ones(features.shape, dtype=np.float32) @@ -144,15 +144,18 @@ def process_inputs(feature_batch): *embedding.preprocess_sparse_dense_matmul_input( features, feature_weights, - feature_specs, - mesh.local_mesh.size, - mesh.size, - num_sc_per_device=num_sc_per_device, - sharding_strategy='MOD', + config, )[0] ) - first_model_input = process_inputs(feature_batches[0]) + config = embedding.SparseDenseMatmulConfig( + feature_specs=feature_specs, + local_device_count=mesh.local_mesh.size, + global_device_count=mesh.size, + num_sc_per_device=num_sc_per_device, + sharding_strategy='MOD', + ) + first_model_input = process_inputs(feature_batches[0], config) params = model.init(jax.random.key(42), first_model_input) # Create optimizer. @@ -204,7 +207,7 @@ def forward_pass(params, embedding_lookups, labels): # ------------------------------------------------------------------------ # Step 1: SC input processing. # ------------------------------------------------------------------------ - processed_input_tensor = process_inputs(features) + processed_input_tensor = process_inputs(features, config) # ------------------------------------------------------------------------ # Step 2: run model. diff --git a/jax_tpu_embedding/sparsecore/lib/nn/embedding.py b/jax_tpu_embedding/sparsecore/lib/nn/embedding.py index 9f2519d..af0e0c6 100644 --- a/jax_tpu_embedding/sparsecore/lib/nn/embedding.py +++ b/jax_tpu_embedding/sparsecore/lib/nn/embedding.py @@ -14,10 +14,12 @@ """List of functions for embedding lookup.""" import collections +import dataclasses import functools from typing import List, Mapping, NamedTuple, Sequence, Tuple, TypeAlias, TypeVar, Union from absl import logging +from flax import linen as nn from flax import struct import jax from jax.experimental import shard_map @@ -78,6 +80,54 @@ def from_dict( ) +@dataclasses.dataclass(frozen=True) +class SparseDenseMatmulConfig: + """The config of preprocessing sparse dense matmul input. + + feature_specs: + The feature specs. This needs to have the same structure as features and + features_weights (e.g., if one of them is a mapping then all of them are). + local_device_count: + The number of local devices (chips). Typically `mesh.local_mesh.size`. + global_device_count: + The number of global devices (chips). Typically `mesh.size`. + num_sc_per_device: + The number of sparse cores per device. + optimizer_label: + The label for the optimizer computation. + static_buffer_size_multiplier: + If larger than 0, this is the multiplier that is used to determine the size + of the static buffers (lhs_embedding_ids, lhs_sample_ids and lhs_gains). The + size of the buffer returned is static_buffer_size_multiplier x batch_size. + If less than or equal to 0, the size of the buffer is determined based off + of the max_ids_per_partition limits. + sharding_strategy: + The sharding strategy (e.g., MOD) + has_leading_dimension: + If set to True, then the first dimension of the preprocessed input will be + the number of local devices. This is useful when using the preprocessed + input in jax.pmap. If set to False, then the first dimension of the + preprocessed input will be the number of local devices * the static buffer + size. This is useful when using the preprocessed input in jax.jit. In + conclusion, Set it to True if using jax.pmap and set it to False if using + jax.jit. + allow_id_dropping: + If set to True, then ids will be dropped if they exceed the + max_ids_per_partition or max_unique_ids_per_partition limits. the + max_ids_per_partition or max_unique_ids_per_partition limits. + """ + + feature_specs: nn.FrozenDict[str, embedding_spec.FeatureSpec] + local_device_count: int + global_device_count: int + num_sc_per_device: int + sharding_strategy: str = "MOD" + optimizer_label: str = "" + static_buffer_size_multiplier: int = 0 + has_leading_dimension: bool = False + allow_id_dropping: bool = False + + # TODO: b/346873239 - Add more checks for the feature specs to ensure all the # fields are valid. def _verify_feature_specs( @@ -314,14 +364,7 @@ def sharding_strategy_to_int(sharding_strategy: str) -> int: def preprocess_sparse_dense_matmul_input( features: Nested[ArrayLike], features_weights: Nested[ArrayLike], - feature_specs: Nested[embedding_spec.FeatureSpec], - local_device_count: int, - global_device_count: int, - num_sc_per_device: int, - static_buffer_size_multiplier: int = 0, - sharding_strategy: str = "MOD", - has_leading_dimension: bool = False, - allow_id_dropping: bool = False, + config: SparseDenseMatmulConfig, ) -> tuple[SparseDenseMatmulInput, SparseDenseMatmulInputStats]: """Preprocesses the input for sparse dense matmul. @@ -332,49 +375,27 @@ def preprocess_sparse_dense_matmul_input( arrays with dtype object (in the ragged tensor case). features_weights: The input feature weights. The structure must be identical to the features. - feature_specs: The feature specs. This needs to have the same structure as - features and features_weights (e.g., if one of them is a mapping then all - of them are). - local_device_count: The number of local devices (chips). Typically - `mesh.local_mesh.size`. - global_device_count: The number of global devices (chips). Typically - `mesh.size`. - num_sc_per_device: The number of sparse cores per device. - static_buffer_size_multiplier: If larger than 0, this is the multiplier that - is used to determine the size of the static buffers (lhs_embedding_ids, - lhs_sample_ids and lhs_gains). The size of the buffer returned is - static_buffer_size_multiplier x batch_size. If less than or equal to 0, - the size of the buffer is determined based off of the - max_ids_per_partition limits. - sharding_strategy: The sharding strategy (e.g., MOD) - has_leading_dimension: If set to True, then the first dimension of the - output will be the number of local devices. This is useful when using the - output in jax.pmap. If set to False, then the first dimension of the - output will be the number of local devices * the static buffer size. This - is useful when using the output in jax.jit. In conclusion, Set it to True - if using jax.pmap and set it to False if using jax.jit. - allow_id_dropping: If set to True, then ids will be dropped if they exceed - the max_ids_per_partition or max_unique_ids_per_partition limits. + config: The config of sparse dense matmul. Returns: A tuple of PreprocessSparseDenseMatmulInput and PreprocessSparseDenseMatmulInputStats. """ - tree.assert_same_structure(features, feature_specs) - tree.assert_same_structure(features_weights, feature_specs) + tree.assert_same_structure(features, config.feature_specs) + tree.assert_same_structure(features_weights, config.feature_specs) *preprocessed_inputs, stats = ( input_preprocessing_cc.PreprocessSparseDenseMatmulInput( tree.flatten(features), tree.flatten(features_weights), - tree.flatten(feature_specs), - local_device_count, - global_device_count, - num_sc_per_device, - sharding_strategy_to_int(sharding_strategy), - has_leading_dimension, - static_buffer_size_multiplier, - allow_id_dropping=allow_id_dropping, + tree.flatten(config.feature_specs), + config.local_device_count, + config.global_device_count, + config.num_sc_per_device, + sharding_strategy_to_int(config.sharding_strategy), + config.has_leading_dimension, + config.static_buffer_size_multiplier, + allow_id_dropping=config.allow_id_dropping, ) ) @@ -434,9 +455,7 @@ def _unstack_embedding_activations( def tpu_sparse_dense_matmul( preprocessed_inputs: SparseDenseMatmulInput, embedding_variables: Mapping[str, EmbeddingVariables], - feature_specs: Nested[embedding_spec.FeatureSpec], - global_device_count: int, - sharding_strategy: str = "MOD", + config: SparseDenseMatmulConfig, ) -> Nested[jax.Array]: """Computes the sparse dense matmul. @@ -445,26 +464,21 @@ def tpu_sparse_dense_matmul( Example invocation: - sparse_matmul = functools.partial( - embedding.tpu_sparse_dense_matmul, - global_device_count=mesh.size, - feature_specs=feature_specs, - sharding_strategy="MOD", - ) + config = embedding.SparseDenseMatmulConfig(...) sparse_matmul = shard_map.shard_map( - sparse_matmul, + embedding.tpu_sparse_dense_matmul, mesh=mesh, in_specs=( P(mesh.axis_names[0]), P(mesh.axis_names[0]), + None, ), out_specs=P(mesh.axis_names[0]), check_rep=False, ) - sparse_matmul = jax.jit(sparse_matmul) + sparse_matmul = jax.jit(sparse_matmul, static_argnums=(2,)) activations = sparse_matmul( - preprocessed_inputs=preprocessed_inputs, - embedding_variables, + preprocessed_inputs, embedding_variables, config ) Args: @@ -472,10 +486,7 @@ def tpu_sparse_dense_matmul( embedding_variables: A tuple of embedding tables and slot variables. The first one is always the embedding table, the following ones are slot variables. The tree structure must be identical to the lhs_row_pointers. - feature_specs: The input features for the current process. - global_device_count: The number of global devices (chips). Typically - `mesh.size`. - sharding_strategy: The sharding strategy (e.g., MOD) + config: The config of sparse dense matmul. Returns: The activations structure with the same structure as feature_specs. @@ -491,10 +502,10 @@ def tpu_sparse_dense_matmul( assert lhs_row_pointers.keys() == embedding_variables.keys() - stacked_table_specs = get_stacked_table_specs(feature_specs) + stacked_table_specs = get_stacked_table_specs(config.feature_specs) assert lhs_row_pointers.keys() == stacked_table_specs.keys() - sharding_strategy = _sharding_strategy_to_enum(sharding_strategy) + sharding_strategy = _sharding_strategy_to_enum(config.sharding_strategy) activations = {} for stacked_table_name in stacked_table_specs: @@ -512,7 +523,7 @@ def tpu_sparse_dense_matmul( gain, embedding_variable[0], # [0] is the embedding table device_batch_size=stacked_table.total_sample_count - // global_device_count, + // config.global_device_count, max_ids_per_partition=stacked_table.max_ids_per_partition, max_unique_ids_per_partition=stacked_table.max_unique_ids_per_partition, sharding_strategy=sharding_strategy, @@ -520,7 +531,7 @@ def tpu_sparse_dense_matmul( ) return _unstack_embedding_activations( - activations, feature_specs, global_device_count + activations, config.feature_specs, config.global_device_count ) @@ -579,37 +590,30 @@ def tpu_sparse_dense_matmul_grad( activation_gradients: Nested[jax.Array], preprocessed_inputs: SparseDenseMatmulInput, embedding_variables: Mapping[str, EmbeddingVariables], - feature_specs: Nested[embedding_spec.FeatureSpec], - sharding_strategy: str = "MOD", - label: str = "", + config: SparseDenseMatmulConfig, step: int | None = None, ) -> Mapping[str, EmbeddingVariables]: """Computes the updated embedding variables based on the activation gradients. Example invocation with jit + shard_map: - grad_update = functools.partial( - embedding.tpu_sparse_dense_matmul_grad, - feature_specs=feature_specs, - sharding_strategy="MOD", - ) + config = embedding.SparseDenseMatmulConfig(...) grad_update = shard_map.shard_map( - grad_update, + embedding.tpu_sparse_dense_matmul_grad, mesh=mesh, in_specs=( P(mesh.axis_names[0]), P(mesh.axis_names[0]), P(mesh.axis_names[0]), + None, ), out_specs=P(mesh.axis_names[0]), check_rep=False, ) - grad_update = jax.jit(grad_update) + grad_update = jax.jit(grad_update, static_argnums=(3,)) updated_embedding_variables = grad_update( - activations_grad, - preprocessed_inputs=preprocessed_inputs, - embedding_variables, + activations_grad, preprocessed_inputs, embedding_variables, config ) Args: @@ -618,9 +622,7 @@ def tpu_sparse_dense_matmul_grad( embedding_variables: A tuple of embedding tables and slot variables. The first one is always the embedding table, the following ones are slot variables. The tree structure must be identical to the lhs_row_pointers. - feature_specs: The input features for the current process. - sharding_strategy: The sharding strategy (e.g., MOD) - label: The label for the optimizer computation. + config: The config of sparse dense matmul. step: The current step number. Returns: @@ -634,15 +636,17 @@ def tpu_sparse_dense_matmul_grad( assert lhs_row_pointers.keys() == embedding_variables.keys() # Activations match the feature specs structure - tree.assert_same_structure(feature_specs, activation_gradients) + tree.assert_same_structure(config.feature_specs, activation_gradients) - stacked_table_specs = get_stacked_table_specs(feature_specs) + stacked_table_specs = get_stacked_table_specs(config.feature_specs) assert lhs_row_pointers.keys() == stacked_table_specs.keys() - gradients = _stack_embedding_gradients(activation_gradients, feature_specs) + gradients = _stack_embedding_gradients( + activation_gradients, config.feature_specs + ) assert lhs_row_pointers.keys() == gradients.keys() - sharding_strategy = _sharding_strategy_to_enum(sharding_strategy) + sharding_strategy = _sharding_strategy_to_enum(config.sharding_strategy) updated_embedding_variables = {} for stacked_table_name in stacked_table_specs: @@ -659,7 +663,7 @@ def tpu_sparse_dense_matmul_grad( symbol_name = "{}-{}{}".format( stack_table_spec.optimizer.short_name(), stack_table_spec.stack_name, - label, + config.optimizer_label, ) optimizer_primitive = stack_table_spec.optimizer.get_optimizer_primitive() diff --git a/jax_tpu_embedding/sparsecore/lib/nn/tests/BUILD b/jax_tpu_embedding/sparsecore/lib/nn/tests/BUILD index 6daf29a..e135e45 100644 --- a/jax_tpu_embedding/sparsecore/lib/nn/tests/BUILD +++ b/jax_tpu_embedding/sparsecore/lib/nn/tests/BUILD @@ -102,6 +102,7 @@ pytype_strict_contrib_test( pypi_requirement("absl/testing:absltest"), pypi_requirement("absl/testing:parameterized"), pypi_requirement("einops"), + pypi_requirement("flax:core"), pypi_requirement("jax"), pypi_requirement("numpy"), pypi_requirement("tree"), diff --git a/jax_tpu_embedding/sparsecore/lib/nn/tests/preprocess_sparse_dense_matmul_input_test.py b/jax_tpu_embedding/sparsecore/lib/nn/tests/preprocess_sparse_dense_matmul_input_test.py index 2f57997..73fcf9e 100644 --- a/jax_tpu_embedding/sparsecore/lib/nn/tests/preprocess_sparse_dense_matmul_input_test.py +++ b/jax_tpu_embedding/sparsecore/lib/nn/tests/preprocess_sparse_dense_matmul_input_test.py @@ -172,13 +172,7 @@ def setUp(self): def test_preprocess_static_buffer_size_multiplier(self): multiplier = 32 - preprocessed_input, _ = embedding.preprocess_sparse_dense_matmul_input( - features={ - "feature_b": self.feature_b_input, - }, - features_weights={ - "feature_b": self.input_weights_b, - }, + sparse_dense_matmul_config = embedding.SparseDenseMatmulConfig( feature_specs={ "feature_b": self.feature_spec_b, }, @@ -188,6 +182,15 @@ def test_preprocess_static_buffer_size_multiplier(self): num_sc_per_device=4, sharding_strategy="MOD", ) + preprocessed_input, _ = embedding.preprocess_sparse_dense_matmul_input( + features={ + "feature_b": self.feature_b_input, + }, + features_weights={ + "feature_b": self.input_weights_b, + }, + config=sparse_dense_matmul_config, + ) self.assertLen(preprocessed_input.lhs_row_pointers, 1) self.assertLen(preprocessed_input.lhs_embedding_ids, 1) self.assertLen(preprocessed_input.lhs_sample_ids, 1) @@ -208,13 +211,7 @@ def test_preprocess_static_buffer_size_multiplier(self): ) def test_preprocess_for_single_feature_single_device(self): - preprocessed_input, _ = embedding.preprocess_sparse_dense_matmul_input( - features={ - "feature_b": self.feature_b_input, - }, - features_weights={ - "feature_b": self.input_weights_b, - }, + sparse_dense_matmul_config = embedding.SparseDenseMatmulConfig( feature_specs={ "feature_b": self.feature_spec_b, }, @@ -223,6 +220,15 @@ def test_preprocess_for_single_feature_single_device(self): num_sc_per_device=4, sharding_strategy="MOD", ) + preprocessed_input, _ = embedding.preprocess_sparse_dense_matmul_input( + features={ + "feature_b": self.feature_b_input, + }, + features_weights={ + "feature_b": self.input_weights_b, + }, + config=sparse_dense_matmul_config, + ) self.assertLen(preprocessed_input.lhs_row_pointers, 1) self.assertLen(preprocessed_input.lhs_embedding_ids, 1) self.assertLen(preprocessed_input.lhs_sample_ids, 1) @@ -263,6 +269,16 @@ def test_preprocess_for_single_feature_single_device(self): ) def test_preprocess_sparse_dense_matmul_input_for_two_features(self): + sparse_dense_matmul_config = embedding.SparseDenseMatmulConfig( + feature_specs={ + "feature_b": self.feature_spec_b, + "feature_a": self.feature_spec_a, + }, + local_device_count=2, + global_device_count=2, + num_sc_per_device=4, + sharding_strategy="MOD", + ) preprocessed_input, _ = embedding.preprocess_sparse_dense_matmul_input( features={ "feature_b": self.feature_b_input, @@ -272,14 +288,7 @@ def test_preprocess_sparse_dense_matmul_input_for_two_features(self): "feature_a": self.input_weights_a, "feature_b": self.input_weights_b, }, - feature_specs={ - "feature_b": self.feature_spec_b, - "feature_a": self.feature_spec_a, - }, - local_device_count=2, - global_device_count=2, - num_sc_per_device=4, - sharding_strategy="MOD", + config=sparse_dense_matmul_config, ) self.assertLen(preprocessed_input.lhs_row_pointers, 2) self.assertLen(preprocessed_input.lhs_embedding_ids, 2) @@ -379,15 +388,7 @@ def test_preprocess_sparse_dense_matmul_input_for_two_features(self): def test_preprocess_sparse_dense_matmul_input_for_two_features_with_leading_dim( self, ): - preprocessed_input, _ = embedding.preprocess_sparse_dense_matmul_input( - features={ - "feature_b": self.feature_b_input, - "feature_a": self.feature_a_input, - }, - features_weights={ - "feature_a": self.input_weights_a, - "feature_b": self.input_weights_b, - }, + sparse_dense_matmul_config = embedding.SparseDenseMatmulConfig( feature_specs={ "feature_b": self.feature_spec_b, "feature_a": self.feature_spec_a, @@ -398,6 +399,17 @@ def test_preprocess_sparse_dense_matmul_input_for_two_features_with_leading_dim( sharding_strategy="MOD", has_leading_dimension=True, ) + preprocessed_input, _ = embedding.preprocess_sparse_dense_matmul_input( + features={ + "feature_b": self.feature_b_input, + "feature_a": self.feature_a_input, + }, + features_weights={ + "feature_a": self.input_weights_a, + "feature_b": self.input_weights_b, + }, + config=sparse_dense_matmul_config, + ) self.assertLen(preprocessed_input.lhs_row_pointers, 2) self.assertLen(preprocessed_input.lhs_embedding_ids, 2) diff --git a/jax_tpu_embedding/sparsecore/lib/nn/tests/tpu_sparse_dense_matmul_grad_test.py b/jax_tpu_embedding/sparsecore/lib/nn/tests/tpu_sparse_dense_matmul_grad_test.py index 69f9ef4..ee81c99 100644 --- a/jax_tpu_embedding/sparsecore/lib/nn/tests/tpu_sparse_dense_matmul_grad_test.py +++ b/jax_tpu_embedding/sparsecore/lib/nn/tests/tpu_sparse_dense_matmul_grad_test.py @@ -226,6 +226,13 @@ def test_sparse_dense_matmul_one_chip_unsharded(self): num_sc_per_device=4, global_device_count=len(devices), ) + config = embedding.SparseDenseMatmulConfig( + feature_specs=feature_specs, + local_device_count=1, + global_device_count=1, + num_sc_per_device=4, + sharding_strategy="MOD", + ) preprocessed_inputs, _ = ( embedding.preprocess_sparse_dense_matmul_input( { @@ -238,11 +245,7 @@ def test_sparse_dense_matmul_one_chip_unsharded(self): "feature_spec_b": self.input_weights_table_b, "feature_spec_c": self.input_weights_table_c, }, - feature_specs, - local_device_count=1, - global_device_count=1, - num_sc_per_device=4, - sharding_strategy="MOD", + config=config, ) ) @@ -356,8 +359,7 @@ def test_sparse_dense_matmul_one_chip_unsharded(self): ) sharded_grad_update = functools.partial( embedding.tpu_sparse_dense_matmul_grad, - feature_specs=feature_specs, - sharding_strategy="MOD", + config=config, ) sharded_grad_update = jax.jit(sharded_grad_update) grad_update = sharded_grad_update( @@ -490,6 +492,13 @@ def test_tpu_sparse_dense_matmul_grad_sharded_two_tables(self): global_device_count=len(devices), num_sc_per_device=num_sc_per_device, ) + config = embedding.SparseDenseMatmulConfig( + feature_specs=feature_specs, + local_device_count=num_devices, + global_device_count=num_devices, + num_sc_per_device=4, + sharding_strategy="MOD", + ) # Add another table. preprocessed_inputs, _ = ( embedding.preprocess_sparse_dense_matmul_input( @@ -503,11 +512,7 @@ def test_tpu_sparse_dense_matmul_grad_sharded_two_tables(self): "feature_spec_b": self.input_weights_table_b, "feature_spec_c": self.input_weights_table_c, }, - feature_specs, - local_device_count=num_devices, - global_device_count=num_devices, - num_sc_per_device=4, - sharding_strategy="MOD", + config=config, ) ) table_dim_a = table_stacking._next_largest_multiple(_DIM_A, 8) @@ -620,8 +625,7 @@ def test_tpu_sparse_dense_matmul_grad_sharded_two_tables(self): ) sharded_grad_update = functools.partial( embedding.tpu_sparse_dense_matmul_grad, - feature_specs=feature_specs, - sharding_strategy="MOD", + config=config, ) sharded_grad_update = shard_map.shard_map( sharded_grad_update, @@ -764,6 +768,13 @@ def test_tpu_sparse_dense_matmul_grad_sharded_two_tables_stacked(self): global_device_count=len(devices), num_sc_per_device=num_sc_per_device, ) + config = embedding.SparseDenseMatmulConfig( + feature_specs=feature_specs, + local_device_count=num_devices, + global_device_count=num_devices, + num_sc_per_device=4, + sharding_strategy="MOD", + ) preprocessed_inputs, _ = ( embedding.preprocess_sparse_dense_matmul_input( { @@ -776,11 +787,7 @@ def test_tpu_sparse_dense_matmul_grad_sharded_two_tables_stacked(self): "feature_spec_b": self.input_weights_table_b, "feature_spec_c": self.input_weights_table_c, }, - feature_specs, - local_device_count=num_devices, - global_device_count=num_devices, - num_sc_per_device=4, - sharding_strategy="MOD", + config=config, ) ) padded_vocab_a = 64 @@ -870,8 +877,7 @@ def test_tpu_sparse_dense_matmul_grad_sharded_two_tables_stacked(self): ) sharded_grad_update = functools.partial( embedding.tpu_sparse_dense_matmul_grad, - feature_specs=feature_specs, - sharding_strategy="MOD", + config=config, ) sharded_grad_update = shard_map.shard_map( sharded_grad_update, diff --git a/jax_tpu_embedding/sparsecore/lib/nn/tests/tpu_sparse_dense_matmul_test.py b/jax_tpu_embedding/sparsecore/lib/nn/tests/tpu_sparse_dense_matmul_test.py index 9bef7c0..ae92f22 100644 --- a/jax_tpu_embedding/sparsecore/lib/nn/tests/tpu_sparse_dense_matmul_test.py +++ b/jax_tpu_embedding/sparsecore/lib/nn/tests/tpu_sparse_dense_matmul_test.py @@ -13,10 +13,12 @@ # limitations under the License. import collections import functools +import logging from absl.testing import absltest from absl.testing import parameterized import einops +import flax import jax from jax.experimental import shard_map import jax.numpy as jnp @@ -138,21 +140,22 @@ def test_static_buffer_size_was_too_small(self): global_device_count=1, num_sc_per_device=num_sc_per_device, ) - preprocessed_inputs, _ = ( - embedding.preprocess_sparse_dense_matmul_input( - { - "feature": long_feature, - }, - { - "feature": long_weights, - }, - feature_specs, - local_device_count=1, - global_device_count=1, - static_buffer_size_multiplier=8, - num_sc_per_device=4, - sharding_strategy="MOD", - ) + config = embedding.SparseDenseMatmulConfig( + feature_specs=flax.core.freeze(feature_specs), + local_device_count=1, + global_device_count=1, + static_buffer_size_multiplier=8, + num_sc_per_device=4, + sharding_strategy="MOD", + ) + preprocessed_inputs, _ = embedding.preprocess_sparse_dense_matmul_input( + { + "feature": long_feature, + }, + { + "feature": long_weights, + }, + config=config, ) self.assertNotEmpty(preprocessed_inputs.lhs_row_pointers) self.assertNotEmpty(preprocessed_inputs.lhs_embedding_ids) @@ -188,9 +191,7 @@ def test_static_buffer_size_was_too_small(self): ]) tpu_sparse_dense_matmul = functools.partial( embedding.tpu_sparse_dense_matmul, - global_device_count=1, - feature_specs=tuple(tree.flatten(feature_specs)), - sharding_strategy="MOD", + config=config, ) sparse_matmul = jax.jit(tpu_sparse_dense_matmul) activations = sparse_matmul( @@ -368,23 +369,24 @@ def test_sparse_dense_matmul_two_chips_sharded(self, using_pmap): global_device_count=len(devices), num_sc_per_device=num_sc_per_device, ) - preprocessed_inputs, _ = ( - embedding.preprocess_sparse_dense_matmul_input( - { - "feature_spec_a": self.input_tensor, - "feature_spec_aa": self.input_tensor, - }, - { - "feature_spec_a": self.input_weights, - "feature_spec_aa": self.input_weights, - }, - feature_specs, - local_device_count=2, - global_device_count=2, - num_sc_per_device=num_sc_per_device, - sharding_strategy="MOD", - has_leading_dimension=using_pmap, - ) + config = embedding.SparseDenseMatmulConfig( + feature_specs=flax.core.freeze(feature_specs), + local_device_count=2, + global_device_count=2, + num_sc_per_device=num_sc_per_device, + sharding_strategy="MOD", + has_leading_dimension=using_pmap, + ) + preprocessed_inputs, _ = embedding.preprocess_sparse_dense_matmul_input( + { + "feature_spec_a": self.input_tensor, + "feature_spec_aa": self.input_tensor, + }, + { + "feature_spec_a": self.input_weights, + "feature_spec_aa": self.input_weights, + }, + config=config, ) embedding_variables = {} if using_pmap: @@ -401,13 +403,11 @@ def test_sparse_dense_matmul_two_chips_sharded(self, using_pmap): activations = jax.pmap( embedding.tpu_sparse_dense_matmul, - static_broadcasted_argnums=[2, 3, 4], + static_broadcasted_argnums=[2], )( preprocessed_inputs, embedding_variables, - tuple(tree.flatten(feature_specs)), - mesh.size, - "MOD", + config, ) else: embedding_variables["table_a"] = tuple([ @@ -422,9 +422,7 @@ def test_sparse_dense_matmul_two_chips_sharded(self, using_pmap): ]) sharded_matmul = functools.partial( embedding.tpu_sparse_dense_matmul, - feature_specs=tuple(tree.flatten(feature_specs)), - global_device_count=mesh.size, - sharding_strategy="MOD", + config=config, ) sharded_matmul = shard_map.shard_map( @@ -467,8 +465,12 @@ def test_sparse_dense_matmul_two_chips_sharded(self, using_pmap): expected_emb_activations = expected_emb_activations.reshape( len(devices), 16 // len(devices), 6 ) - np.testing.assert_equal(activations[0], expected_emb_activations) - np.testing.assert_equal(activations[1], expected_emb_activations) + np.testing.assert_equal( + activations["feature_spec_a"], expected_emb_activations + ) + np.testing.assert_equal( + activations["feature_spec_aa"], expected_emb_activations + ) @parameterized.parameters(False, True) def test_sparse_dense_matmul_two_chips_sharded_stacked(self, using_pmap): @@ -484,23 +486,24 @@ def test_sparse_dense_matmul_two_chips_sharded_stacked(self, using_pmap): global_device_count=len(devices), num_sc_per_device=num_sc_per_device, ) - preprocessed_inputs, _ = ( - embedding.preprocess_sparse_dense_matmul_input( - { - "feature_spec_a": self.input_tensor, - "feature_spec_aa": self.input_tensor, - }, - { - "feature_spec_a": self.input_weights, - "feature_spec_aa": self.input_weights, - }, - feature_specs, - local_device_count=2, - global_device_count=2, - num_sc_per_device=num_sc_per_device, - sharding_strategy="MOD", - has_leading_dimension=using_pmap, - ) + config = embedding.SparseDenseMatmulConfig( + feature_specs=flax.core.freeze(feature_specs), + local_device_count=2, + global_device_count=2, + num_sc_per_device=num_sc_per_device, + sharding_strategy="MOD", + has_leading_dimension=using_pmap, + ) + preprocessed_inputs, _ = embedding.preprocess_sparse_dense_matmul_input( + { + "feature_spec_a": self.input_tensor, + "feature_spec_aa": self.input_tensor, + }, + { + "feature_spec_a": self.input_weights, + "feature_spec_aa": self.input_weights, + }, + config=config, ) embedding_variables = {} if using_pmap: @@ -516,13 +519,11 @@ def test_sparse_dense_matmul_two_chips_sharded_stacked(self, using_pmap): ]) activations = jax.pmap( embedding.tpu_sparse_dense_matmul, - static_broadcasted_argnums=[2, 3, 4], + static_broadcasted_argnums=[2], )( preprocessed_inputs, embedding_variables, - tuple(tree.flatten(feature_specs)), - mesh.size, - "MOD", + config, ) else: embedding_variables["table_a_table_aa"] = tuple([ @@ -537,9 +538,7 @@ def test_sparse_dense_matmul_two_chips_sharded_stacked(self, using_pmap): ]) sharded_matmul = functools.partial( embedding.tpu_sparse_dense_matmul, - feature_specs=tuple(tree.flatten(feature_specs)), - global_device_count=mesh.size, - sharding_strategy="MOD", + config=config, ) sharded_matmul = shard_map.shard_map( @@ -607,8 +606,12 @@ def test_sparse_dense_matmul_two_chips_sharded_stacked(self, using_pmap): len(devices), 16 // len(devices), 6 ) self.assertLen(activations, 2) - np.testing.assert_equal(activations[0], expected_emb_activations_a) - np.testing.assert_equal(activations[1], expected_emb_activations_aa) + np.testing.assert_equal( + activations["feature_spec_a"], expected_emb_activations_a + ) + np.testing.assert_equal( + activations["feature_spec_aa"], expected_emb_activations_aa + ) @parameterized.parameters(False, True) def test_sparse_dense_matmul_single_chip(self, using_pmap): @@ -625,23 +628,24 @@ def test_sparse_dense_matmul_single_chip(self, using_pmap): global_device_count=1, num_sc_per_device=num_sc_per_device, ) - preprocessed_inputs, _ = ( - embedding.preprocess_sparse_dense_matmul_input( - { - "feature_spec_a": self.input_tensor, - "feature_spec_b": self.input_tensor_table_b, - }, - { - "feature_spec_a": self.input_weights, - "feature_spec_b": self.input_weights_table_b, - }, - feature_specs, - local_device_count=1, - global_device_count=1, - num_sc_per_device=4, - sharding_strategy="MOD", - has_leading_dimension=using_pmap, - ) + config = embedding.SparseDenseMatmulConfig( + feature_specs=flax.core.freeze(feature_specs), + local_device_count=1, + global_device_count=1, + num_sc_per_device=4, + sharding_strategy="MOD", + has_leading_dimension=using_pmap, + ) + preprocessed_inputs, _ = embedding.preprocess_sparse_dense_matmul_input( + { + "feature_spec_a": self.input_tensor, + "feature_spec_b": self.input_tensor_table_b, + }, + { + "feature_spec_a": self.input_weights, + "feature_spec_b": self.input_weights_table_b, + }, + config=config, ) embedding_variables = {} @@ -658,13 +662,11 @@ def test_sparse_dense_matmul_single_chip(self, using_pmap): ]) activations = jax.pmap( embedding.tpu_sparse_dense_matmul, - static_broadcasted_argnums=[2, 3, 4], + static_broadcasted_argnums=[2], )( preprocessed_inputs, embedding_variables, - tuple(tree.flatten(feature_specs)), - mesh.size, - "MOD", + config, ) else: embedding_variables["table_a"] = tuple([ @@ -678,14 +680,12 @@ def test_sparse_dense_matmul_single_chip(self, using_pmap): ) ]) sparse_matmul = jax.jit( - embedding.tpu_sparse_dense_matmul, static_argnums=(2, 3, 4) + embedding.tpu_sparse_dense_matmul, static_argnums=(2,) ) activations = sparse_matmul( preprocessed_inputs, embedding_variables, - tuple(tree.flatten(feature_specs)), - mesh.size, - "MOD", + config, ) expected_emb_activations = np.array( [ @@ -710,7 +710,9 @@ def test_sparse_dense_matmul_single_chip(self, using_pmap): ) if using_pmap: expected_emb_activations = expected_emb_activations.reshape(1, 16, 6) - np.testing.assert_equal(activations[0], expected_emb_activations) + np.testing.assert_equal( + activations["feature_spec_a"], expected_emb_activations + ) @parameterized.parameters(False, True) def test_sparse_dense_matmul_two_tables(self, using_pmap): @@ -726,24 +728,25 @@ def test_sparse_dense_matmul_two_tables(self, using_pmap): global_device_count=len(devices), num_sc_per_device=num_sc_per_device, ) + config = embedding.SparseDenseMatmulConfig( + feature_specs=flax.core.freeze(feature_specs), + local_device_count=2, + global_device_count=2, + num_sc_per_device=num_sc_per_device, + sharding_strategy="MOD", + has_leading_dimension=using_pmap, + ) # Add another table. - preprocessed_inputs, _ = ( - embedding.preprocess_sparse_dense_matmul_input( - { - "feature_spec_a": self.input_tensor, - "feature_spec_b": self.input_tensor_table_b, - }, - { - "feature_spec_a": self.input_weights, - "feature_spec_b": self.input_weights_table_b, - }, - feature_specs, - local_device_count=2, - global_device_count=2, - num_sc_per_device=num_sc_per_device, - sharding_strategy="MOD", - has_leading_dimension=using_pmap, - ) + preprocessed_inputs, _ = embedding.preprocess_sparse_dense_matmul_input( + { + "feature_spec_a": self.input_tensor, + "feature_spec_b": self.input_tensor_table_b, + }, + { + "feature_spec_a": self.input_weights, + "feature_spec_b": self.input_weights_table_b, + }, + config=config, ) embedding_variables = {} if using_pmap: @@ -759,13 +762,11 @@ def test_sparse_dense_matmul_two_tables(self, using_pmap): ]) activations = jax.pmap( embedding.tpu_sparse_dense_matmul, - static_broadcasted_argnums=(2, 3, 4), + static_broadcasted_argnums=(2,), )( preprocessed_inputs, embedding_variables, - tuple(tree.flatten(feature_specs)), - mesh.size, - "MOD", + config, ) else: embedding_variables["table_a"] = tuple([ @@ -780,9 +781,7 @@ def test_sparse_dense_matmul_two_tables(self, using_pmap): ]) sharded_matmul = functools.partial( embedding.tpu_sparse_dense_matmul, - feature_specs=tuple(tree.flatten(feature_specs)), - global_device_count=mesh.size, - sharding_strategy="MOD", + config=config, ) sharded_matmul = shard_map.shard_map( sharded_matmul, @@ -1121,12 +1120,14 @@ def test_sparse_dense_matmul_two_tables(self, using_pmap): expected_emb_activations["table_b"] = expected_emb_activations[ "table_b" ].reshape(2, 8, 16) + if isinstance(activations, flax.linen.FrozenDict): + activations = activations.unfreeze() np.testing.assert_equal( activations, - ( - expected_emb_activations["table_a"], - expected_emb_activations["table_b"], - ), + { + "feature_spec_a": expected_emb_activations["table_a"], + "feature_spec_b": expected_emb_activations["table_b"], + }, ) @parameterized.parameters(False, True) @@ -1243,25 +1244,27 @@ def test_sparse_dense_matmul_four_chips_complex_stacked(self, using_pmap): ], dtype=object, ) - preprocessed_inputs, _ = ( - embedding.preprocess_sparse_dense_matmul_input( - { - "country": input_tensor, - "language": input_tensor, - "related_item": input_tensor, - }, - { - "country": input_weights, - "language": input_weights, - "related_item": input_weights, - }, - feature_specs, - local_device_count=mesh.local_mesh.size, - global_device_count=mesh.size, - num_sc_per_device=num_sc_per_device, - sharding_strategy="MOD", - has_leading_dimension=using_pmap, - ) + config = embedding.SparseDenseMatmulConfig( + feature_specs=flax.core.freeze(feature_specs), + local_device_count=mesh.local_mesh.size, + global_device_count=mesh.size, + num_sc_per_device=num_sc_per_device, + sharding_strategy="MOD", + has_leading_dimension=using_pmap, + static_buffer_size_multiplier=8, + ) + preprocessed_inputs, _ = embedding.preprocess_sparse_dense_matmul_input( + { + "country": input_tensor, + "language": input_tensor, + "related_item": input_tensor, + }, + { + "country": input_weights, + "language": input_weights, + "related_item": input_weights, + }, + config=config, ) embedding_variables = {} if using_pmap: @@ -1282,13 +1285,11 @@ def test_sparse_dense_matmul_four_chips_complex_stacked(self, using_pmap): ) activations = jax.pmap( embedding.tpu_sparse_dense_matmul, - static_broadcasted_argnums=[2, 3, 4], + static_broadcasted_argnums=[2], )( preprocessed_inputs, embedding_variables, - tuple(tree.flatten(feature_specs)), - mesh.size, - "MOD", + config, ) else: embedding_variables["country_table_language_table_related_item_table"] = ( @@ -1308,9 +1309,7 @@ def test_sparse_dense_matmul_four_chips_complex_stacked(self, using_pmap): ) sharded_matmul = functools.partial( embedding.tpu_sparse_dense_matmul, - feature_specs=tuple(tree.flatten(feature_specs)), - global_device_count=mesh.size, - sharding_strategy="MOD", + config=config, ) sharded_matmul = shard_map.shard_map( @@ -1324,16 +1323,13 @@ def test_sparse_dense_matmul_four_chips_complex_stacked(self, using_pmap): check_rep=False, ) sharded_matmul = jax.jit(sharded_matmul) - activations = sharded_matmul( - preprocessed_inputs, - embedding_variables, - ) + activations = sharded_matmul(preprocessed_inputs, embedding_variables) expected_act_country = np.ones((4, 4, 16), np.float32) expected_act_country[0][3, :] = 86.0 if not using_pmap: expected_act_country = expected_act_country.reshape(16, 16) np.testing.assert_equal( - activations[0], # country + activations["country"], expected_act_country, "country activations do not match", ) @@ -1342,7 +1338,7 @@ def test_sparse_dense_matmul_four_chips_complex_stacked(self, using_pmap): if not using_pmap: expected_act_language = expected_act_language.reshape(16, 16) np.testing.assert_equal( - activations[1], # language + activations["language"], expected_act_language, "language activations do not match", ) @@ -1351,7 +1347,7 @@ def test_sparse_dense_matmul_four_chips_complex_stacked(self, using_pmap): if not using_pmap: expected_act_related_item = expected_act_related_item.reshape(16, 16) np.testing.assert_equal( - activations[2], # related_item + activations["related_item"], expected_act_related_item, "related_item activations do not match", ) diff --git a/jax_tpu_embedding/sparsecore/tests/jax_sc_shakespeare_tests.py b/jax_tpu_embedding/sparsecore/tests/jax_sc_shakespeare_tests.py index 5a48e1a..671bf34 100644 --- a/jax_tpu_embedding/sparsecore/tests/jax_sc_shakespeare_tests.py +++ b/jax_tpu_embedding/sparsecore/tests/jax_sc_shakespeare_tests.py @@ -150,12 +150,17 @@ def test_shakespeare_model_loss_convergence(self): global_device_count=mesh.size, num_sc_per_device=num_sc_per_device, ) - sharded_matmul = functools.partial( - embedding.tpu_sparse_dense_matmul, - global_device_count=mesh.size, + config = embedding.SparseDenseMatmulConfig( feature_specs=feature_specs, + global_device_count=mesh.size, + num_sc_per_device=num_sc_per_device, + local_device_count=mesh.local_mesh.size, sharding_strategy='MOD', ) + sharded_matmul = functools.partial( + embedding.tpu_sparse_dense_matmul, + config=config, + ) sparse_matmul = shard_map.shard_map( sharded_matmul, mesh=mesh, @@ -170,8 +175,7 @@ def test_shakespeare_model_loss_convergence(self): sharded_grad_update = functools.partial( embedding.tpu_sparse_dense_matmul_grad, - feature_specs=feature_specs, - sharding_strategy='MOD', + config=config, ) sparse_grad_update = shard_map.shard_map( sharded_grad_update, @@ -200,16 +204,8 @@ def test_shakespeare_model_loss_convergence(self): feature_structure, [feature_weights] ) - preprocessed_inputs, _ = ( - embedding.preprocess_sparse_dense_matmul_input( - features, - feature_weights, - feature_specs, - local_device_count=mesh.local_mesh.size, - global_device_count=mesh.size, - num_sc_per_device=num_sc_per_device, - sharding_strategy='MOD', - ) + preprocessed_inputs, _ = embedding.preprocess_sparse_dense_matmul_input( + features, feature_weights, config ) # -------------------------------------------------------------------------- @@ -242,9 +238,7 @@ def test_shakespeare_model_loss_convergence(self): ) } embedding_variables = sparse_grad_update( - gradient_updates, - preprocessed_inputs, - embedding_variables, + gradient_updates, preprocessed_inputs, embedding_variables ) if step % 10 == 0: diff --git a/jax_tpu_embedding/sparsecore/tests/jax_spmd_tc_with_sc_tests.py b/jax_tpu_embedding/sparsecore/tests/jax_spmd_tc_with_sc_tests.py index c6cf2e4..c31f599 100644 --- a/jax_tpu_embedding/sparsecore/tests/jax_spmd_tc_with_sc_tests.py +++ b/jax_tpu_embedding/sparsecore/tests/jax_spmd_tc_with_sc_tests.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Callable +import dataclasses import functools from typing import Any @@ -19,6 +20,7 @@ from absl import logging from absl.testing import absltest import einops +import flax import flax.linen as nn import jax from jax.experimental import shard_map @@ -207,24 +209,28 @@ def loss_fn(params, emb_acts, labels): global_device_count=self.mesh.size, num_sc_per_device=self.num_sc_per_device, ) - sharded_matmul = functools.partial( - embedding.tpu_sparse_dense_matmul, + self.sdmm_config = embedding.SparseDenseMatmulConfig( + feature_specs=self.shakespeare_feature, global_device_count=self.mesh.size, - feature_specs=(self.shakespeare_feature,), + local_device_count=self.mesh.local_mesh.size, + num_sc_per_device=self.num_sc_per_device, sharding_strategy='MOD', ) + sharded_matmul = functools.partial( + embedding.tpu_sparse_dense_matmul, + config=self.sdmm_config, + ) self.sparse_matmul = shard_map.shard_map( sharded_matmul, mesh=self.mesh, - in_specs=(self.pd,) + (P(self.pd, None),), + in_specs=(self.pd, P(self.pd, None)), out_specs=self.pd, check_rep=False, ) sharded_grad_update = functools.partial( embedding.tpu_sparse_dense_matmul_grad, - feature_specs=(self.shakespeare_feature,), - sharding_strategy='MOD', + config=self.sdmm_config, ) self.sparse_grad_update = shard_map.shard_map( sharded_grad_update, @@ -249,7 +255,7 @@ def train_step( # SC forward pass activations = self.sparse_matmul(preprocessed_inputs, embedding_variables) activations = jnp.reshape( - activations[0], + activations, ( _BATCH_SIZE.value, _SEQ_LEN.value, @@ -266,7 +272,7 @@ def train_step( # SC backward pass gradient_updates = jnp.reshape(grads[1], (-1, _EMBEDDING_SIZE.value)) new_embedding_variables = self.sparse_grad_update( - (gradient_updates,), # Should be same structure as features. + gradient_updates, preprocessed_inputs, embedding_variables, ) @@ -277,23 +283,13 @@ def train_step( features = np.reshape(features, (-1, 1)) # SC input processing - preprocessed_inputs, _ = ( - embedding.preprocess_sparse_dense_matmul_input( - {self.shakespeare_feature.name: features}, - { - self.shakespeare_feature.name: np.ones_like( - features, dtype=jnp.float32 - ) - }, - {self.shakespeare_feature.name: self.shakespeare_feature}, - local_device_count=self.mesh.local_mesh.size, - global_device_count=self.mesh.size, - num_sc_per_device=self.num_sc_per_device, - sharding_strategy='MOD', - ) + preprocessed_inputs, _ = embedding.preprocess_sparse_dense_matmul_input( + features, + np.ones_like(features, dtype=jnp.float32), + config=self.sdmm_config, ) self.params, self.opt_state, loss_val, self.embedding_variables = jax.jit( - train_step + train_step, donate_argnums=(3,) )( self.params, self.opt_state, @@ -307,7 +303,6 @@ def train_step( losses.append(loss_val) step += 1 - self.assertLess(losses[-1], 0.001)