Skip to content

Linting low-hanging fruit across MaxText, MaxText/input_pipeline, MaxText/layers #1457

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions MaxText/benchmark_chunked_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,22 @@
"""


# pylint: disable=ungrouped-imports
import datetime
import os
from typing import Any, Sequence
import datetime

import jax
from absl import app
from jetstream.core import prefix_cache

from jetstream.engine import chunked_prefill
from jetstream.engine import engine_api
from jetstream.engine import prefix_cache

from absl import app

from MaxText import max_utils
from MaxText import maxengine
from MaxText import pyconfig


_WARMUP_ITERS = 2
_BENCHMARK_ITERS = 5

Expand Down Expand Up @@ -81,7 +81,7 @@ def copy_prefix():
return jax.tree.map(lambda x: x.copy(), prefix)

# --- Fill the cache with dummy entries ---
print(f"Filling cache with {cache_num} dummy entries...")
print("Filling cache with", cache_num, "dummy entries...")
for i in range(cache_num):
# Create a unique dummy key, ensuring it's different from key_to_hit
# and has the same length for consistency (though not strictly required by Trie).
Expand All @@ -105,10 +105,10 @@ def copy_prefix():
jax.block_until_ready(load_result.prefix)
del load_result

print(f"Finished filling cache with {cache_num} dummy entries.")
print("Finished filling cache with", cache_num, "dummy entries.")

# --- Add the actual target entry ---
print(f"Adding the target entry with key length {len(key_to_hit)}...")
print("Adding the target entry with key length ", len(key_to_hit), "...", sep="")

value_to_hit = prefix_cache.Value(
prefix=copy_prefix(),
Expand Down Expand Up @@ -171,7 +171,7 @@ def run_chunked_prefill_utility():
prefill_result = run_chunked_prefill_utility()
jax.block_until_ready(prefill_result)
end = datetime.datetime.now()
print(f" Warmup iteration {i+1} time: {end - start}")
print(" Warmup iteration", i + 1, "time:", end - start)

print("\nStarting benchmark...")
total_time = datetime.timedelta()
Expand All @@ -182,10 +182,10 @@ def run_chunked_prefill_utility():
end = datetime.datetime.now()
iter_time = end - start
total_time += iter_time
print(f" Benchmark iteration {i+1} time: {iter_time}")
print(" Benchmark iteration", i + 1, "time:", iter_time)

average_time = total_time / _BENCHMARK_ITERS
print(f"\nAverage time taken for chunked prefill over {_BENCHMARK_ITERS} iterations: {average_time}")
print("\nAverage time taken for chunked prefill over", _BENCHMARK_ITERS, "iterations:", average_time)

# Run prefix caching benchmark
prefill_result = run_chunked_prefill_utility()
Expand Down Expand Up @@ -235,13 +235,13 @@ def run_chunked_prefill_with_prefix_caching(cache_hit_chunk: int, need_save: boo

for cache_hit_chunk in range(len(chunked_tokens_list)):
for need_save in [True, False]:
print(f"\nBenchmark prefix caching {cache_hit_chunk=}, {need_save=}")
print("\nBenchmark prefix caching cache_hit_chunk=", cache_hit_chunk, " need_save=", need_save, sep="")
for i in range(_WARMUP_ITERS):
start = datetime.datetime.now()
prefill_result = run_chunked_prefill_with_prefix_caching(cache_hit_chunk, need_save)
jax.block_until_ready(prefill_result)
end = datetime.datetime.now()
print(f" Warmup iteration {i+1} time: {end - start}")
print(" Warmup iteration", i + 1, "time:", end - start)

total_time = datetime.timedelta()
for i in range(_BENCHMARK_ITERS):
Expand All @@ -251,10 +251,10 @@ def run_chunked_prefill_with_prefix_caching(cache_hit_chunk: int, need_save: boo
end = datetime.datetime.now()
iter_time = end - start
total_time += iter_time
print(f" Benchmark iteration {i+1} time: {iter_time}")
print(" Benchmark iteration", i + 1, "time:", iter_time)

average_time = total_time / _BENCHMARK_ITERS
print(f"\nAverage time taken for prefix caching chunked prefill: {average_time}")
print("\nAverage time taken for prefix caching chunked prefill:", average_time)


if __name__ == "__main__":
Expand Down
19 changes: 14 additions & 5 deletions MaxText/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ def load_state_if_possible(
enable_single_replica_ckpt_restoring: Optional[bool] = False,
dataset_type: Optional[str] = "tfds",
step: int = -1, # -1 means latest
use_ocdbt = True,
use_zarr3 = True,
use_ocdbt=True,
use_zarr3=True,
):
"""Loads TrainState as possible from the inputs.

Expand Down Expand Up @@ -293,7 +293,11 @@ def map_to_pspec(data):

if load_parameters_from_path != "":
restored_params = load_params_from_path(
load_parameters_from_path, abstract_unboxed_pre_state.params, checkpoint_storage_concurrent_gb, use_ocdbt=use_ocdbt, use_zarr3=use_zarr3
load_parameters_from_path,
abstract_unboxed_pre_state.params,
checkpoint_storage_concurrent_gb,
use_ocdbt=use_ocdbt,
use_zarr3=use_zarr3,
)
return None, restored_params
elif load_full_state_from_path != "":
Expand Down Expand Up @@ -329,7 +333,9 @@ def setup_checkpoint_logger(config) -> Any | None: # pytype: disable=attribute-
return orbax_cloud_logger


def load_params_from_path(load_parameters_from_path, abstract_unboxed_params, checkpoint_storage_concurrent_gb, use_ocdbt=True, use_zarr3=True):
def load_params_from_path(
load_parameters_from_path, abstract_unboxed_params, checkpoint_storage_concurrent_gb, use_ocdbt=True, use_zarr3=True
):
"""Load decode params from checkpoint at specified path."""
assert load_parameters_from_path, "load_parameters_from_path is not defined."
max_logging.log(f"restoring params from {load_parameters_from_path}")
Expand All @@ -338,7 +344,10 @@ def load_params_from_path(load_parameters_from_path, abstract_unboxed_params, ch
# *_concurrent_gb should be set for large models, the default is 96.
ckptr = ocp.Checkpointer(
ocp.PyTreeCheckpointHandler(
restore_concurrent_gb=checkpoint_storage_concurrent_gb, save_concurrent_gb=checkpoint_storage_concurrent_gb, use_ocdbt=use_ocdbt, use_zarr3=use_zarr3
restore_concurrent_gb=checkpoint_storage_concurrent_gb,
save_concurrent_gb=checkpoint_storage_concurrent_gb,
use_ocdbt=use_ocdbt,
use_zarr3=use_zarr3,
)
)

Expand Down
2 changes: 2 additions & 0 deletions MaxText/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

import jax
import jax.numpy as jnp

from absl import app

from jetstream.engine import engine_api

from MaxText import max_utils
Expand Down
4 changes: 2 additions & 2 deletions MaxText/inference_microbenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,9 +538,9 @@ def run_benchmarks(config):
return results


def main(argv):
def main(config, **kwargs):
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
run_benchmarks(pyconfig.initialize(argv))
return run_benchmarks(pyconfig.initialize(config, **kwargs))


if __name__ == "__main__":
Expand Down
9 changes: 6 additions & 3 deletions MaxText/inference_microbenchmark_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def main():
config = pyconfig.initialize(sys.argv)
base_run_name = config.run_name

with open(config.inference_metadata_file, encoding="utf-8") as json_file:
with open(config.inference_metadata_file, "rt", encoding="utf-8") as json_file:
inference_metadata = json.load(json_file)
print(f"inference_metadata: {inference_metadata}")

Expand Down Expand Up @@ -121,8 +121,11 @@ def main():
}
try:
microbenchmark_results = inference_microbenchmark.main(config, inference_metadata=inference_metadata)
metrics = microbenchmark_results["flattened_results"]
metrics = {k.lower(): v for k, v in metrics.items()}
if microbenchmark_results:
metrics = microbenchmark_results["flattened_results"]
metrics = {k.lower(): v for k, v in metrics.items()}
else:
metrics = {}
dimensions_json["oom"] = "False"
print(
f"Completed run {two_axis_order_product_id} out of: "
Expand Down
Loading