Skip to content

Commit c32a090

Browse files
committed
[MaxText,MaxText/input_pipeline,MaxText/layers] Linting
1 parent dc5988b commit c32a090

27 files changed

+2329
-1706
lines changed

MaxText/benchmark_chunked_prefill.py

+15-16
Original file line numberDiff line numberDiff line change
@@ -32,22 +32,21 @@
3232
"""
3333

3434

35-
# pylint: disable=ungrouped-imports
36-
import datetime
3735
import os
3836
from typing import Any, Sequence
37+
import datetime
3938

4039
import jax
41-
from absl import app
42-
from jetstream.core import prefix_cache
40+
4341
from jetstream.engine import chunked_prefill
4442
from jetstream.engine import engine_api
4543

46-
from MaxText import max_utils
44+
from absl import app
45+
46+
from MaxText import max_utils, prefix_cache
4747
from MaxText import maxengine
4848
from MaxText import pyconfig
4949

50-
5150
_WARMUP_ITERS = 2
5251
_BENCHMARK_ITERS = 5
5352

@@ -81,7 +80,7 @@ def copy_prefix():
8180
return jax.tree.map(lambda x: x.copy(), prefix)
8281

8382
# --- Fill the cache with dummy entries ---
84-
print(f"Filling cache with {cache_num} dummy entries...")
83+
print("Filling cache with", cache_num, "dummy entries...")
8584
for i in range(cache_num):
8685
# Create a unique dummy key, ensuring it's different from key_to_hit
8786
# and has the same length for consistency (though not strictly required by Trie).
@@ -105,10 +104,10 @@ def copy_prefix():
105104
jax.block_until_ready(load_result.prefix)
106105
del load_result
107106

108-
print(f"Finished filling cache with {cache_num} dummy entries.")
107+
print("Finished filling cache with", cache_num, "dummy entries.")
109108

110109
# --- Add the actual target entry ---
111-
print(f"Adding the target entry with key length {len(key_to_hit)}...")
110+
print("Adding the target entry with key length ", len(key_to_hit), "...", sep="")
112111

113112
value_to_hit = prefix_cache.Value(
114113
prefix=copy_prefix(),
@@ -171,7 +170,7 @@ def run_chunked_prefill_utility():
171170
prefill_result = run_chunked_prefill_utility()
172171
jax.block_until_ready(prefill_result)
173172
end = datetime.datetime.now()
174-
print(f" Warmup iteration {i+1} time: {end - start}")
173+
print(" Warmup iteration", i + 1, "time:", end - start)
175174

176175
print("\nStarting benchmark...")
177176
total_time = datetime.timedelta()
@@ -182,10 +181,10 @@ def run_chunked_prefill_utility():
182181
end = datetime.datetime.now()
183182
iter_time = end - start
184183
total_time += iter_time
185-
print(f" Benchmark iteration {i+1} time: {iter_time}")
184+
print(" Benchmark iteration", i + 1, "time:", iter_time)
186185

187186
average_time = total_time / _BENCHMARK_ITERS
188-
print(f"\nAverage time taken for chunked prefill over {_BENCHMARK_ITERS} iterations: {average_time}")
187+
print("\nAverage time taken for chunked prefill over", _BENCHMARK_ITERS, "iterations:", average_time)
189188

190189
# Run prefix caching benchmark
191190
prefill_result = run_chunked_prefill_utility()
@@ -235,13 +234,13 @@ def run_chunked_prefill_with_prefix_caching(cache_hit_chunk: int, need_save: boo
235234

236235
for cache_hit_chunk in range(len(chunked_tokens_list)):
237236
for need_save in [True, False]:
238-
print(f"\nBenchmark prefix caching {cache_hit_chunk=}, {need_save=}")
237+
print("\nBenchmark prefix caching cache_hit_chunk=", cache_hit_chunk, " need_save=", need_save, sep="")
239238
for i in range(_WARMUP_ITERS):
240239
start = datetime.datetime.now()
241240
prefill_result = run_chunked_prefill_with_prefix_caching(cache_hit_chunk, need_save)
242241
jax.block_until_ready(prefill_result)
243242
end = datetime.datetime.now()
244-
print(f" Warmup iteration {i+1} time: {end - start}")
243+
print(" Warmup iteration", i + 1, "time:", end - start)
245244

246245
total_time = datetime.timedelta()
247246
for i in range(_BENCHMARK_ITERS):
@@ -251,10 +250,10 @@ def run_chunked_prefill_with_prefix_caching(cache_hit_chunk: int, need_save: boo
251250
end = datetime.datetime.now()
252251
iter_time = end - start
253252
total_time += iter_time
254-
print(f" Benchmark iteration {i+1} time: {iter_time}")
253+
print(" Benchmark iteration", i + 1, "time:", iter_time)
255254

256255
average_time = total_time / _BENCHMARK_ITERS
257-
print(f"\nAverage time taken for prefix caching chunked prefill: {average_time}")
256+
print("\nAverage time taken for prefix caching chunked prefill:", average_time)
258257

259258

260259
if __name__ == "__main__":

MaxText/checkpointing.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,8 @@ def load_state_if_possible(
184184
enable_single_replica_ckpt_restoring: Optional[bool] = False,
185185
dataset_type: Optional[str] = "tfds",
186186
step: int = -1, # -1 means latest
187-
use_ocdbt = True,
188-
use_zarr3 = True,
187+
use_ocdbt=True,
188+
use_zarr3=True,
189189
):
190190
"""Loads TrainState as possible from the inputs.
191191
@@ -293,7 +293,11 @@ def map_to_pspec(data):
293293

294294
if load_parameters_from_path != "":
295295
restored_params = load_params_from_path(
296-
load_parameters_from_path, abstract_unboxed_pre_state.params, checkpoint_storage_concurrent_gb, use_ocdbt=use_ocdbt, use_zarr3=use_zarr3
296+
load_parameters_from_path,
297+
abstract_unboxed_pre_state.params,
298+
checkpoint_storage_concurrent_gb,
299+
use_ocdbt=use_ocdbt,
300+
use_zarr3=use_zarr3,
297301
)
298302
return None, restored_params
299303
elif load_full_state_from_path != "":
@@ -329,7 +333,9 @@ def setup_checkpoint_logger(config) -> Any | None: # pytype: disable=attribute-
329333
return orbax_cloud_logger
330334

331335

332-
def load_params_from_path(load_parameters_from_path, abstract_unboxed_params, checkpoint_storage_concurrent_gb, use_ocdbt=True, use_zarr3=True):
336+
def load_params_from_path(
337+
load_parameters_from_path, abstract_unboxed_params, checkpoint_storage_concurrent_gb, use_ocdbt=True, use_zarr3=True
338+
):
333339
"""Load decode params from checkpoint at specified path."""
334340
assert load_parameters_from_path, "load_parameters_from_path is not defined."
335341
max_logging.log(f"restoring params from {load_parameters_from_path}")
@@ -338,7 +344,10 @@ def load_params_from_path(load_parameters_from_path, abstract_unboxed_params, ch
338344
# *_concurrent_gb should be set for large models, the default is 96.
339345
ckptr = ocp.Checkpointer(
340346
ocp.PyTreeCheckpointHandler(
341-
restore_concurrent_gb=checkpoint_storage_concurrent_gb, save_concurrent_gb=checkpoint_storage_concurrent_gb, use_ocdbt=use_ocdbt, use_zarr3=use_zarr3
347+
restore_concurrent_gb=checkpoint_storage_concurrent_gb,
348+
save_concurrent_gb=checkpoint_storage_concurrent_gb,
349+
use_ocdbt=use_ocdbt,
350+
use_zarr3=use_zarr3,
342351
)
343352
)
344353

MaxText/configs/base.yml

+7-8
Original file line numberDiff line numberDiff line change
@@ -323,8 +323,8 @@ logical_axis_rules: [
323323
['cache_kv', []],
324324
['cache_sequence', []],
325325
['exp', 'expert'],
326-
['paged_kv_heads', ['tensor']],
327-
['num_pages', []],
326+
['paged_kv_heads', []],
327+
['num_pages', ['tensor']],
328328
['tokens_per_page', []],
329329
['paged_kv_head_dim_size', []],
330330
]
@@ -654,15 +654,14 @@ sa_v_layout: "HEAD_DIM_MINOR"
654654
### Determine if we want to use load balance for context parallelism
655655
context_parallel_load_balance: True
656656

657+
#######################
657658
### Paged Attention ###
659+
#######################
658660
# These settings take effect only when `attention=paged`.
659661
# They should be adjusted based on the available HBM and model config.
660-
# Note: one page group corresponds to one request/slot
661-
pagedattn_num_pages: 64 # total number of pages to allocate
662-
pagedattn_tokens_per_page: 32 # number of tokens each page can hold
663-
pagedattn_pages_per_compute_block: 4 # number of pages processed together in pallas kernels
664-
pagedattn_max_pages_per_group: -1 # defaults to number of pages needed to reach max_target_length
665-
662+
pagedattn_num_pages: 64
663+
pagedattn_tokens_per_page: 32
664+
pagedattn_pages_per_compute_block: 8
666665

667666
# Chunked Prefill Parameters
668667
prefill_chunk_size: 256

MaxText/decode.py

+2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919

2020
import jax
2121
import jax.numpy as jnp
22+
2223
from absl import app
24+
2325
from jetstream.engine import engine_api
2426

2527
from MaxText import max_utils

0 commit comments

Comments
 (0)