Skip to content

Refactor to use a config object to manage arguments to sparse dense matmul. #235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import orbax.checkpoint as ocp
import tree


np.set_printoptions(threshold=np.inf)
Nested = embedding.Nested

Expand Down Expand Up @@ -376,7 +377,7 @@ def train_step_fn(
mesh: jax.sharding.Mesh,
model: nn.Module,
optimizer,
feature_specs,
config: embedding.SparseDenseMatmulConfig,
train_state: TrainState,
preprocessed_inputs,
emb_variables,
Expand All @@ -386,22 +387,15 @@ def train_step_fn(

# Sparse forward pass - embedding lookup.
with jax.named_scope('sc_forward_pass'):
tpu_sparse_dense_matmul = partial(
embedding.tpu_sparse_dense_matmul,
global_device_count=num_global_devices,
feature_specs=feature_specs,
sharding_strategy='MOD',
)
tpu_sparse_dense_matmul = shard_map(
f=tpu_sparse_dense_matmul,
f=embedding.tpu_sparse_dense_matmul,
mesh=mesh,
in_specs=(pd, pe),
in_specs=(pd, pe, None),
out_specs=pd,
check_rep=False,
)
emb_act = tpu_sparse_dense_matmul(
preprocessed_inputs,
emb_variables,
preprocessed_inputs, emb_variables, config
)

# Dense forward + backward pass.
Expand Down Expand Up @@ -429,22 +423,15 @@ def train_step_fn(

# Sparse backward pass - embedding update.
with jax.named_scope('sc_backward_pass'):
tpu_sparse_dense_matmul_grad = partial(
embedding.tpu_sparse_dense_matmul_grad,
feature_specs=feature_specs,
sharding_strategy='MOD',
)
tpu_sparse_dense_matmul_grad = shard_map(
f=tpu_sparse_dense_matmul_grad,
f=embedding.tpu_sparse_dense_matmul_grad,
mesh=mesh,
in_specs=(pd, pd, pe),
in_specs=(pd, pd, pe, None),
out_specs=pe,
check_rep=False,
)
emb_variables = tpu_sparse_dense_matmul_grad(
emb_grad,
preprocessed_inputs,
emb_variables,
emb_grad, preprocessed_inputs, emb_variables, config
)

train_state = train_state.replace(
Expand Down Expand Up @@ -503,16 +490,17 @@ def train_step_fn(
lambda y: jax.make_array_from_process_local_data(global_sharding, y),
x,
)
config = embedding.SparseDenseMatmulConfig(
global_device_count=num_global_devices,
local_device_count=num_local_devices,
feature_specs=flax.core.freeze(feature_specs),
num_sc_per_device=num_sc_per_device,
sharding_strategy='MOD',
)
preprocessed_inputs, stats = map(
make_global_view,
embedding.preprocess_sparse_dense_matmul_input(
features,
feature_weights,
feature_specs,
local_device_count=global_mesh.local_mesh.size,
global_device_count=global_mesh.size,
num_sc_per_device=num_sc_per_device,
sharding_strategy='MOD',
features, feature_weights, config=config
),
)
fdo_client.record(stats)
Expand All @@ -524,7 +512,7 @@ def train_step_fn(
global_mesh,
model,
optimizer,
feature_specs,
config,
train_state,
preprocessed_inputs,
emb_variables,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,9 @@ def run_model():
)

def train_step_fn(
global_device_count: int,
model: nn.Module,
optimizer,
feature_specs,
config,
train_state: TrainState,
preprocessed_inputs,
emb_variables: Mapping[str, embedding.EmbeddingVariables],
Expand All @@ -276,9 +275,7 @@ def train_step_fn(
with jax.named_scope('sc_forward_pass'):
tpu_sparse_dense_matmul = partial(
embedding.tpu_sparse_dense_matmul,
global_device_count=global_device_count,
feature_specs=feature_specs,
sharding_strategy='MOD',
config=config,
)
emb_act = tpu_sparse_dense_matmul(
preprocessed_inputs,
Expand Down Expand Up @@ -312,8 +309,7 @@ def train_step_fn(
with jax.named_scope('sc_backward_pass'):
tpu_sparse_dense_matmul_grad = partial(
embedding.tpu_sparse_dense_matmul_grad,
feature_specs=feature_specs,
sharding_strategy='MOD',
config=config,
)
emb_variables = tpu_sparse_dense_matmul_grad(
emb_grad,
Expand All @@ -337,14 +333,16 @@ def train_step_fn(
vlog1('Replicating train_state')
train_state = flax_utils.replicate(train_state, local_devices)
parameter_overview.log_parameter_overview(train_state.params)

config = embedding.SparseDenseMatmulConfig(
feature_specs=feature_specs,
local_device_count=global_mesh.local_mesh.size,
global_device_count=global_mesh.size,
num_sc_per_device=num_sc_per_device,
has_leading_dimension=True,
)
p_train_step_fn = jax.pmap(
partial(
train_step_fn,
num_global_devices,
model,
optimizer,
feature_specs,
),
partial(train_step_fn, model, optimizer, config),
axis_name='batch',
)

Expand Down Expand Up @@ -376,16 +374,16 @@ def train_step_fn(
)

# Preprocess the inputs.
preprocessed_inputs, _ = embedding.preprocess_sparse_dense_matmul_input(
features,
feature_weights,
feature_specs,
config = embedding.SparseDenseMatmulConfig(
feature_specs=flax.core.freeze(feature_specs),
local_device_count=global_mesh.local_mesh.size,
global_device_count=global_mesh.size,
num_sc_per_device=num_sc_per_device,
sharding_strategy='MOD',
has_leading_dimension=True,
)
preprocessed_inputs, _ = embedding.preprocess_sparse_dense_matmul_input(
features, feature_weights, config
)

# TODO(patn): This (local_slice)will go away once the input processor is
# updated to only produce local batches.
Expand Down
57 changes: 19 additions & 38 deletions jax_tpu_embedding/sparsecore/lib/flax/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@ def __post_init__(self):
self.mesh.devices.item(0)
)

self.config = embedding.SparseDenseMatmulConfig(
feature_specs=self.feature_specs,
local_device_count=self.mesh.local_mesh.size,
global_device_count=self.mesh.size,
num_sc_per_device=self.num_sc_per_device,
sharding_strategy=self.table_sharding_strategy,
)

super().__post_init__()

def setup(self):
Expand Down Expand Up @@ -143,11 +151,7 @@ def preprocess_inputs(
return embedding.preprocess_sparse_dense_matmul_input(
features,
features_weights,
self.feature_specs,
self.mesh.local_mesh.size,
self.mesh.size,
num_sc_per_device=self.num_sc_per_device,
sharding_strategy=self.table_sharding_strategy,
self.config,
)[0]

def __call__(self, embedding_lookups: EmbeddingLookups) -> Nested[jax.Array]:
Expand Down Expand Up @@ -198,34 +202,22 @@ def _emb_lookup(
pt = embedding_layer.embedding_table_partition
pd = embedding_layer.data_partition
return shard_map(
functools.partial(
embedding.tpu_sparse_dense_matmul,
global_device_count=embedding_layer.mesh.size,
feature_specs=embedding_layer.feature_specs,
sharding_strategy=embedding_layer.table_sharding_strategy,
),
embedding.tpu_sparse_dense_matmul,
mesh=embedding_layer.mesh,
in_specs=(pd, pt),
in_specs=(pd, pt, None),
out_specs=pd,
check_rep=False,
)(
embedding_lookups,
emb_table,
)
)(embedding_lookups, emb_table, embedding_layer.config)


def _emb_lookup_fwd(
embedding_layer: SparseCoreEmbed,
embedding_lookups: EmbeddingLookups,
emb_table: Mapping[str, tuple[jax.Array, ...]],
):
return _emb_lookup(
embedding_layer,
embedding_lookups,
emb_table,
), (
embedding_lookups,
emb_table,
return (
_emb_lookup(embedding_layer, embedding_lookups, emb_table),
(embedding_lookups, emb_table),
)


Expand All @@ -236,20 +228,12 @@ def _emb_lookup_bwd(embedding_layer, res, gradients):
pt = embedding_layer.embedding_table_partition
pd = embedding_layer.data_partition
emb_table_grads = shard_map(
functools.partial(
embedding.tpu_sparse_dense_matmul_grad,
feature_specs=embedding_layer.feature_specs,
sharding_strategy=embedding_layer.table_sharding_strategy,
),
embedding.tpu_sparse_dense_matmul_grad,
mesh=embedding_layer.mesh,
in_specs=(pd, pd, pt),
in_specs=(pd, pd, pt, None),
out_specs=pt,
check_rep=False,
)(
gradients,
embedding_lookups,
emb_table,
)
)(gradients, embedding_lookups, emb_table, embedding_layer.config)

# tpu_sparse_dense_matmul_grad returns a general Mapping (usually a dict).
# It may not be the same type as the embedding table (e.g. FrozenDict).
Expand All @@ -258,10 +242,7 @@ def _emb_lookup_bwd(embedding_layer, res, gradients):
jax.tree.structure(emb_table), jax.tree.leaves(emb_table_grads)
)

return (
None,
emb_table_grads,
)
return (None, emb_table_grads)


_emb_lookup.defvjp(_emb_lookup_fwd, _emb_lookup_bwd)
19 changes: 11 additions & 8 deletions jax_tpu_embedding/sparsecore/lib/flax/tests/autograd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def test_shakespeare_model_loss_convergence(self):
)

# Initialize the model.
def process_inputs(feature_batch):
def process_inputs(feature_batch, config):
features = np.reshape(feature_batch, (-1, 1))
feature_weights = np.ones(features.shape, dtype=np.float32)

Expand All @@ -144,15 +144,18 @@ def process_inputs(feature_batch):
*embedding.preprocess_sparse_dense_matmul_input(
features,
feature_weights,
feature_specs,
mesh.local_mesh.size,
mesh.size,
num_sc_per_device=num_sc_per_device,
sharding_strategy='MOD',
config,
)[0]
)

first_model_input = process_inputs(feature_batches[0])
config = embedding.SparseDenseMatmulConfig(
feature_specs=feature_specs,
local_device_count=mesh.local_mesh.size,
global_device_count=mesh.size,
num_sc_per_device=num_sc_per_device,
sharding_strategy='MOD',
)
first_model_input = process_inputs(feature_batches[0], config)
params = model.init(jax.random.key(42), first_model_input)

# Create optimizer.
Expand Down Expand Up @@ -204,7 +207,7 @@ def forward_pass(params, embedding_lookups, labels):
# ------------------------------------------------------------------------
# Step 1: SC input processing.
# ------------------------------------------------------------------------
processed_input_tensor = process_inputs(features)
processed_input_tensor = process_inputs(features, config)

# ------------------------------------------------------------------------
# Step 2: run model.
Expand Down
Loading