Skip to content

Commit 4fd8f37

Browse files
committed
git checkout main <files that should not of changed>
1 parent c32a090 commit 4fd8f37

File tree

9 files changed

+1454
-2085
lines changed

9 files changed

+1454
-2085
lines changed

MaxText/inference/decode_multi.py

+212
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# [http://www.apache.org/licenses/LICENSE-2.0](http://www.apache.org/licenses/LICENSE-2.0)
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""CLI utility for running inference with interleaved prefill and generate."""
16+
17+
import os
18+
from typing import Sequence, List
19+
20+
import jax
21+
from absl import app
22+
23+
from MaxText import max_utils, maxengine, pyconfig
24+
import uuid
25+
26+
_NUM_STREAMS = 5
27+
# How many streams to prefill initially before starting generation.
28+
_INITIAL_PREFILL_STREAMS = 2 # Example: Start generating after 2 streams are ready
29+
30+
31+
def _validate_config(config):
32+
"""Validate configuration settings."""
33+
assert config.load_full_state_path == "", (
34+
"Decode doesn't operate on full states! Convert to parameter checkpoint first." "Using generate_param_only_checkpoint."
35+
)
36+
assert (
37+
0 < _INITIAL_PREFILL_STREAMS <= _NUM_STREAMS
38+
), f"_INITIAL_PREFILL_STREAMS ({_INITIAL_PREFILL_STREAMS}) must be > 0 and <= _NUM_STREAMS ({_NUM_STREAMS})"
39+
40+
41+
def main(argv: Sequence[str]) -> None:
42+
"""Main function to run interleaved inference."""
43+
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
44+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
45+
46+
config = pyconfig.initialize(argv)
47+
_validate_config(config)
48+
max_utils.print_system_information()
49+
50+
engine = maxengine.MaxEngine(config)
51+
rng = jax.random.PRNGKey(1234)
52+
rng, rng_load_params = jax.random.split(rng)
53+
params = engine.load_params(rng=rng_load_params)
54+
55+
text = config.prompt
56+
metadata = engine.get_tokenizer()
57+
tokenizer_model = engine.build_tokenizer(metadata)
58+
tokens, true_length = tokenizer_model.encode(text, is_bos=True, prefill_lengths=[config.max_prefill_predict_length])
59+
assert true_length <= config.max_prefill_predict_length, "Prompt too long for prefill length"
60+
61+
batch_size = int(config.per_device_batch_size * jax.device_count())
62+
assert 0 < _NUM_STREAMS <= batch_size, f"The number of streams {_NUM_STREAMS} must be > 0 and <= batch size {batch_size}"
63+
64+
# Initialize decode state
65+
rng, rng_init_decode = jax.random.split(rng)
66+
decode_state = engine.init_decode_state(rng=rng_init_decode)
67+
print("Initial decode state initialized.")
68+
69+
# Keep track of results per stream (slot)
70+
streams_results: dict[int, List[int]] = {i: [] for i in range(_NUM_STREAMS)}
71+
streams_active: List[bool] = [False] * _NUM_STREAMS # Track which slots are active
72+
streams_finished: List[bool] = [False] * _NUM_STREAMS # Track finished streams
73+
streams_prefilled_count = 0
74+
streams_inserted_count = 0
75+
76+
# --- Initial Prefill Phase ---
77+
print(f"Starting initial prefill for {_INITIAL_PREFILL_STREAMS} streams...")
78+
prefill_results_to_insert = {} # Store prefill results before inserting
79+
for i in range(_INITIAL_PREFILL_STREAMS):
80+
slot_idx = i
81+
print(f" Prefilling stream for slot {slot_idx}...")
82+
rng, rng_prefill = jax.random.split(rng)
83+
request_id = uuid.uuid4()
84+
prefill_result, first_token = engine.prefill(
85+
params=params,
86+
padded_tokens=tokens,
87+
true_length=true_length,
88+
rng=rng_prefill,
89+
slot=slot_idx,
90+
request_id=request_id,
91+
)
92+
prefill_results_to_insert[slot_idx] = prefill_result
93+
streams_results[slot_idx].append(first_token.get_result_at_slot(0).tokens.item())
94+
streams_prefilled_count += 1
95+
print(f"After prefill stream {slot_idx}")
96+
97+
# --- Insert Initial Prefills ---
98+
print("Inserting initial prefill results...")
99+
for slot_idx, prefill_result in prefill_results_to_insert.items():
100+
request_id = uuid.uuid4()
101+
decode_state = engine.insert(
102+
prefix=prefill_result,
103+
decode_state=decode_state,
104+
slot=slot_idx,
105+
request_id=request_id, # Pass request_id
106+
)
107+
streams_active[slot_idx] = True # Mark stream as active
108+
streams_inserted_count += 1
109+
print(f" Inserted prefill for slot {slot_idx}")
110+
111+
print("Starting interleaved generation loop...")
112+
total_steps = config.max_target_length - config.max_prefill_predict_length
113+
for step in range(total_steps):
114+
print(f"\n--- Step {step + 1} / {total_steps} ---")
115+
116+
# Generate step for all active streams
117+
active_stream_indices = [i for i, active in enumerate(streams_active) if active and not streams_finished[i]]
118+
if active_stream_indices:
119+
print(f" Generating for active slots: {active_stream_indices}")
120+
rng, rng_generate = jax.random.split(rng)
121+
decode_state, sampled_tokens = engine.generate(params, decode_state, rng=rng_generate)
122+
123+
# Store the generated token and check for finished streams
124+
for slot_idx in active_stream_indices:
125+
# Check if the stream finished this step
126+
current_len = config.max_prefill_predict_length + step + 1 # Includes prefill + current step
127+
finished_this_step = False
128+
if current_len >= config.max_target_length:
129+
print(f" Stream in slot {slot_idx} reached max target length.")
130+
streams_finished[slot_idx] = True
131+
streams_active[slot_idx] = False
132+
finished_this_step = True
133+
134+
# Store token if it wasn't already finished before this step or if it finished on this step
135+
if not streams_finished[slot_idx] or finished_this_step:
136+
# Ensure we don't try to access results for a slot that might not exist
137+
if slot_idx < sampled_tokens.data.shape[0]:
138+
token_for_slot = sampled_tokens.get_result_at_slot(slot_idx).tokens.item()
139+
streams_results[slot_idx].append(token_for_slot)
140+
else:
141+
print(f"Warning: Tried to get token for slot {slot_idx}, but batch size seems smaller.")
142+
143+
# Call release_pages if finished this step
144+
if finished_this_step:
145+
print(f" Calling engine to release pages for finished slot {slot_idx}...")
146+
engine.release_pages(slot=slot_idx)
147+
148+
else:
149+
print(" No active streams to generate for.")
150+
151+
# 2. Check if all streams are finished (can exit loop early)
152+
if all(streams_finished):
153+
print("\nAll streams finished generation.")
154+
break
155+
156+
# 3. Prefill and Insert new streams if capacity allows
157+
num_active_not_finished = sum(1 for i in range(_NUM_STREAMS) if streams_active[i] and not streams_finished[i])
158+
available_slots = batch_size - num_active_not_finished
159+
can_prefill_more = streams_prefilled_count < _NUM_STREAMS
160+
161+
if can_prefill_more and available_slots > 0:
162+
try:
163+
next_available_slot = streams_active.index(False)
164+
print(f" Prefilling new stream for slot {next_available_slot}...")
165+
rng, rng_prefill = jax.random.split(rng)
166+
request_id = uuid.uuid4()
167+
prefill_result, first_token = engine.prefill(
168+
params=params,
169+
padded_tokens=tokens,
170+
true_length=true_length,
171+
rng=rng_prefill,
172+
slot=next_available_slot,
173+
request_id=request_id,
174+
)
175+
streams_prefilled_count += 1
176+
177+
# Insert the new prefill
178+
print(f" Inserting new stream into slot {next_available_slot}...")
179+
request_id_insert = uuid.uuid4()
180+
decode_state = engine.insert(
181+
prefix=prefill_result,
182+
decode_state=decode_state,
183+
slot=next_available_slot,
184+
request_id=request_id_insert,
185+
)
186+
streams_active[next_available_slot] = True
187+
streams_inserted_count += 1
188+
streams_results[next_available_slot].append(first_token.get_result_at_slot(0).tokens.item())
189+
190+
except ValueError:
191+
print(" Warning: Available slots detected but couldn't find an inactive one.")
192+
elif can_prefill_more:
193+
print(" Generate step finished, but no available slots to prefill new stream.")
194+
else:
195+
print(" Generate step finished, all streams already prefilled.")
196+
197+
print("\n--- Final Results ---")
198+
for i in range(_NUM_STREAMS):
199+
if streams_results[i]:
200+
output = tokenizer_model.decode(streams_results[i])
201+
print(f"Stream {i}: Input=`{text}` -> Output=`{output}`")
202+
203+
if i == 0: # Check first stream as an example
204+
assert output.startswith(
205+
config.autoregressive_decode_assert
206+
), f"Stream {i} generated text mismatch: `{output}` vs expected start `{config.autoregressive_decode_assert}`"
207+
else:
208+
print(f"Stream {i}: Was not activated.")
209+
210+
211+
if __name__ == "__main__":
212+
app.run(main)

0 commit comments

Comments
 (0)