Skip to content

Commit 9fde3b3

Browse files
committed
[MaxText,MaxText/input_pipeline,MaxText/layers] Linting
1 parent dc5988b commit 9fde3b3

16 files changed

+237
-243
lines changed

MaxText/benchmark_chunked_prefill.py

+15-16
Original file line numberDiff line numberDiff line change
@@ -32,22 +32,21 @@
3232
"""
3333

3434

35-
# pylint: disable=ungrouped-imports
36-
import datetime
3735
import os
3836
from typing import Any, Sequence
37+
import datetime
3938

4039
import jax
41-
from absl import app
42-
from jetstream.core import prefix_cache
40+
4341
from jetstream.engine import chunked_prefill
4442
from jetstream.engine import engine_api
4543

46-
from MaxText import max_utils
44+
from absl import app
45+
46+
from MaxText import max_utils, prefix_cache
4747
from MaxText import maxengine
4848
from MaxText import pyconfig
4949

50-
5150
_WARMUP_ITERS = 2
5251
_BENCHMARK_ITERS = 5
5352

@@ -81,7 +80,7 @@ def copy_prefix():
8180
return jax.tree.map(lambda x: x.copy(), prefix)
8281

8382
# --- Fill the cache with dummy entries ---
84-
print(f"Filling cache with {cache_num} dummy entries...")
83+
print("Filling cache with", cache_num, "dummy entries...")
8584
for i in range(cache_num):
8685
# Create a unique dummy key, ensuring it's different from key_to_hit
8786
# and has the same length for consistency (though not strictly required by Trie).
@@ -105,10 +104,10 @@ def copy_prefix():
105104
jax.block_until_ready(load_result.prefix)
106105
del load_result
107106

108-
print(f"Finished filling cache with {cache_num} dummy entries.")
107+
print("Finished filling cache with", cache_num, "dummy entries.")
109108

110109
# --- Add the actual target entry ---
111-
print(f"Adding the target entry with key length {len(key_to_hit)}...")
110+
print("Adding the target entry with key length ", len(key_to_hit), "...", sep="")
112111

113112
value_to_hit = prefix_cache.Value(
114113
prefix=copy_prefix(),
@@ -171,7 +170,7 @@ def run_chunked_prefill_utility():
171170
prefill_result = run_chunked_prefill_utility()
172171
jax.block_until_ready(prefill_result)
173172
end = datetime.datetime.now()
174-
print(f" Warmup iteration {i+1} time: {end - start}")
173+
print(" Warmup iteration", i + 1, "time:", end - start)
175174

176175
print("\nStarting benchmark...")
177176
total_time = datetime.timedelta()
@@ -182,10 +181,10 @@ def run_chunked_prefill_utility():
182181
end = datetime.datetime.now()
183182
iter_time = end - start
184183
total_time += iter_time
185-
print(f" Benchmark iteration {i+1} time: {iter_time}")
184+
print(" Benchmark iteration", i + 1, "time:", iter_time)
186185

187186
average_time = total_time / _BENCHMARK_ITERS
188-
print(f"\nAverage time taken for chunked prefill over {_BENCHMARK_ITERS} iterations: {average_time}")
187+
print("\nAverage time taken for chunked prefill over", _BENCHMARK_ITERS, "iterations:", average_time)
189188

190189
# Run prefix caching benchmark
191190
prefill_result = run_chunked_prefill_utility()
@@ -235,13 +234,13 @@ def run_chunked_prefill_with_prefix_caching(cache_hit_chunk: int, need_save: boo
235234

236235
for cache_hit_chunk in range(len(chunked_tokens_list)):
237236
for need_save in [True, False]:
238-
print(f"\nBenchmark prefix caching {cache_hit_chunk=}, {need_save=}")
237+
print("\nBenchmark prefix caching cache_hit_chunk=", cache_hit_chunk, " need_save=", need_save, sep="")
239238
for i in range(_WARMUP_ITERS):
240239
start = datetime.datetime.now()
241240
prefill_result = run_chunked_prefill_with_prefix_caching(cache_hit_chunk, need_save)
242241
jax.block_until_ready(prefill_result)
243242
end = datetime.datetime.now()
244-
print(f" Warmup iteration {i+1} time: {end - start}")
243+
print(" Warmup iteration", i + 1, "time:", end - start)
245244

246245
total_time = datetime.timedelta()
247246
for i in range(_BENCHMARK_ITERS):
@@ -251,10 +250,10 @@ def run_chunked_prefill_with_prefix_caching(cache_hit_chunk: int, need_save: boo
251250
end = datetime.datetime.now()
252251
iter_time = end - start
253252
total_time += iter_time
254-
print(f" Benchmark iteration {i+1} time: {iter_time}")
253+
print(" Benchmark iteration", i + 1, "time:", iter_time)
255254

256255
average_time = total_time / _BENCHMARK_ITERS
257-
print(f"\nAverage time taken for prefix caching chunked prefill: {average_time}")
256+
print("\nAverage time taken for prefix caching chunked prefill:", average_time)
258257

259258

260259
if __name__ == "__main__":

MaxText/checkpointing.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,8 @@ def load_state_if_possible(
184184
enable_single_replica_ckpt_restoring: Optional[bool] = False,
185185
dataset_type: Optional[str] = "tfds",
186186
step: int = -1, # -1 means latest
187-
use_ocdbt = True,
188-
use_zarr3 = True,
187+
use_ocdbt=True,
188+
use_zarr3=True,
189189
):
190190
"""Loads TrainState as possible from the inputs.
191191
@@ -293,7 +293,11 @@ def map_to_pspec(data):
293293

294294
if load_parameters_from_path != "":
295295
restored_params = load_params_from_path(
296-
load_parameters_from_path, abstract_unboxed_pre_state.params, checkpoint_storage_concurrent_gb, use_ocdbt=use_ocdbt, use_zarr3=use_zarr3
296+
load_parameters_from_path,
297+
abstract_unboxed_pre_state.params,
298+
checkpoint_storage_concurrent_gb,
299+
use_ocdbt=use_ocdbt,
300+
use_zarr3=use_zarr3,
297301
)
298302
return None, restored_params
299303
elif load_full_state_from_path != "":
@@ -329,7 +333,9 @@ def setup_checkpoint_logger(config) -> Any | None: # pytype: disable=attribute-
329333
return orbax_cloud_logger
330334

331335

332-
def load_params_from_path(load_parameters_from_path, abstract_unboxed_params, checkpoint_storage_concurrent_gb, use_ocdbt=True, use_zarr3=True):
336+
def load_params_from_path(
337+
load_parameters_from_path, abstract_unboxed_params, checkpoint_storage_concurrent_gb, use_ocdbt=True, use_zarr3=True
338+
):
333339
"""Load decode params from checkpoint at specified path."""
334340
assert load_parameters_from_path, "load_parameters_from_path is not defined."
335341
max_logging.log(f"restoring params from {load_parameters_from_path}")
@@ -338,7 +344,10 @@ def load_params_from_path(load_parameters_from_path, abstract_unboxed_params, ch
338344
# *_concurrent_gb should be set for large models, the default is 96.
339345
ckptr = ocp.Checkpointer(
340346
ocp.PyTreeCheckpointHandler(
341-
restore_concurrent_gb=checkpoint_storage_concurrent_gb, save_concurrent_gb=checkpoint_storage_concurrent_gb, use_ocdbt=use_ocdbt, use_zarr3=use_zarr3
347+
restore_concurrent_gb=checkpoint_storage_concurrent_gb,
348+
save_concurrent_gb=checkpoint_storage_concurrent_gb,
349+
use_ocdbt=use_ocdbt,
350+
use_zarr3=use_zarr3,
342351
)
343352
)
344353

MaxText/decode.py

+2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919

2020
import jax
2121
import jax.numpy as jnp
22+
2223
from absl import app
24+
2325
from jetstream.engine import engine_api
2426

2527
from MaxText import max_utils

MaxText/inference_microbenchmark.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -538,9 +538,9 @@ def run_benchmarks(config):
538538
return results
539539

540540

541-
def main(argv):
541+
def main(config, **kwargs):
542542
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
543-
run_benchmarks(pyconfig.initialize(argv))
543+
return run_benchmarks(pyconfig.initialize(config, **kwargs))
544544

545545

546546
if __name__ == "__main__":

MaxText/inference_microbenchmark_sweep.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def main():
4848
config = pyconfig.initialize(sys.argv)
4949
base_run_name = config.run_name
5050

51-
with open(config.inference_metadata_file, encoding="utf-8") as json_file:
51+
with open(config.inference_metadata_file, "rt", encoding="utf-8") as json_file:
5252
inference_metadata = json.load(json_file)
5353
print(f"inference_metadata: {inference_metadata}")
5454

@@ -121,8 +121,11 @@ def main():
121121
}
122122
try:
123123
microbenchmark_results = inference_microbenchmark.main(config, inference_metadata=inference_metadata)
124-
metrics = microbenchmark_results["flattened_results"]
125-
metrics = {k.lower(): v for k, v in metrics.items()}
124+
if microbenchmark_results:
125+
metrics = microbenchmark_results["flattened_results"]
126+
metrics = {k.lower(): v for k, v in metrics.items()}
127+
else:
128+
metrics = {}
126129
dimensions_json["oom"] = "False"
127130
print(
128131
f"Completed run {two_axis_order_product_id} out of: "

0 commit comments

Comments
 (0)