32
32
"""
33
33
34
34
35
- # pylint: disable=ungrouped-imports
36
- import datetime
37
35
import os
38
36
from typing import Any , Sequence
37
+ import datetime
39
38
40
39
import jax
41
- from absl import app
42
- from jetstream .core import prefix_cache
40
+
43
41
from jetstream .engine import chunked_prefill
44
42
from jetstream .engine import engine_api
45
43
46
- from MaxText import max_utils
44
+ from absl import app
45
+
46
+ from MaxText import max_utils , prefix_cache
47
47
from MaxText import maxengine
48
48
from MaxText import pyconfig
49
49
50
-
51
50
_WARMUP_ITERS = 2
52
51
_BENCHMARK_ITERS = 5
53
52
@@ -81,7 +80,7 @@ def copy_prefix():
81
80
return jax .tree .map (lambda x : x .copy (), prefix )
82
81
83
82
# --- Fill the cache with dummy entries ---
84
- print (f "Filling cache with { cache_num } dummy entries..." )
83
+ print ("Filling cache with" , cache_num , " dummy entries..." )
85
84
for i in range (cache_num ):
86
85
# Create a unique dummy key, ensuring it's different from key_to_hit
87
86
# and has the same length for consistency (though not strictly required by Trie).
@@ -105,10 +104,10 @@ def copy_prefix():
105
104
jax .block_until_ready (load_result .prefix )
106
105
del load_result
107
106
108
- print (f "Finished filling cache with { cache_num } dummy entries." )
107
+ print ("Finished filling cache with" , cache_num , " dummy entries." )
109
108
110
109
# --- Add the actual target entry ---
111
- print (f "Adding the target entry with key length { len (key_to_hit )} ..." )
110
+ print ("Adding the target entry with key length " , len (key_to_hit ), " ..." , sep = " " )
112
111
113
112
value_to_hit = prefix_cache .Value (
114
113
prefix = copy_prefix (),
@@ -171,7 +170,7 @@ def run_chunked_prefill_utility():
171
170
prefill_result = run_chunked_prefill_utility ()
172
171
jax .block_until_ready (prefill_result )
173
172
end = datetime .datetime .now ()
174
- print (f " Warmup iteration { i + 1 } time: { end - start } " )
173
+ print (" Warmup iteration" , i + 1 , " time:" , end - start )
175
174
176
175
print ("\n Starting benchmark..." )
177
176
total_time = datetime .timedelta ()
@@ -182,10 +181,10 @@ def run_chunked_prefill_utility():
182
181
end = datetime .datetime .now ()
183
182
iter_time = end - start
184
183
total_time += iter_time
185
- print (f " Benchmark iteration { i + 1 } time: { iter_time } " )
184
+ print (" Benchmark iteration" , i + 1 , " time:" , iter_time )
186
185
187
186
average_time = total_time / _BENCHMARK_ITERS
188
- print (f "\n Average time taken for chunked prefill over { _BENCHMARK_ITERS } iterations: { average_time } " )
187
+ print ("\n Average time taken for chunked prefill over" , _BENCHMARK_ITERS , " iterations:" , average_time )
189
188
190
189
# Run prefix caching benchmark
191
190
prefill_result = run_chunked_prefill_utility ()
@@ -235,13 +234,13 @@ def run_chunked_prefill_with_prefix_caching(cache_hit_chunk: int, need_save: boo
235
234
236
235
for cache_hit_chunk in range (len (chunked_tokens_list )):
237
236
for need_save in [True , False ]:
238
- print (f "\n Benchmark prefix caching { cache_hit_chunk = } , { need_save = } " )
237
+ print ("\n Benchmark prefix caching cache_hit_chunk=" , cache_hit_chunk , " need_save=" , need_save , sep = " " )
239
238
for i in range (_WARMUP_ITERS ):
240
239
start = datetime .datetime .now ()
241
240
prefill_result = run_chunked_prefill_with_prefix_caching (cache_hit_chunk , need_save )
242
241
jax .block_until_ready (prefill_result )
243
242
end = datetime .datetime .now ()
244
- print (f " Warmup iteration { i + 1 } time: { end - start } " )
243
+ print (" Warmup iteration" , i + 1 , " time:" , end - start )
245
244
246
245
total_time = datetime .timedelta ()
247
246
for i in range (_BENCHMARK_ITERS ):
@@ -251,10 +250,10 @@ def run_chunked_prefill_with_prefix_caching(cache_hit_chunk: int, need_save: boo
251
250
end = datetime .datetime .now ()
252
251
iter_time = end - start
253
252
total_time += iter_time
254
- print (f " Benchmark iteration { i + 1 } time: { iter_time } " )
253
+ print (" Benchmark iteration" , i + 1 , " time:" , iter_time )
255
254
256
255
average_time = total_time / _BENCHMARK_ITERS
257
- print (f "\n Average time taken for prefix caching chunked prefill: { average_time } " )
256
+ print ("\n Average time taken for prefix caching chunked prefill:" , average_time )
258
257
259
258
260
259
if __name__ == "__main__" :
0 commit comments