|
| 1 | +# Copyright 2025 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# [http://www.apache.org/licenses/LICENSE-2.0](http://www.apache.org/licenses/LICENSE-2.0) |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""CLI utility for running inference with interleaved prefill and generate.""" |
| 16 | + |
| 17 | +import os |
| 18 | +from typing import Sequence, List |
| 19 | + |
| 20 | +import jax |
| 21 | +from absl import app |
| 22 | + |
| 23 | +from MaxText import max_utils, maxengine, pyconfig |
| 24 | +import uuid |
| 25 | + |
| 26 | +_NUM_STREAMS = 5 |
| 27 | +# How many streams to prefill initially before starting generation. |
| 28 | +_INITIAL_PREFILL_STREAMS = 2 # Example: Start generating after 2 streams are ready |
| 29 | + |
| 30 | + |
| 31 | +def _validate_config(config): |
| 32 | + """Validate configuration settings.""" |
| 33 | + assert config.load_full_state_path == "", ( |
| 34 | + "Decode doesn't operate on full states! Convert to parameter checkpoint first." "Using generate_param_only_checkpoint." |
| 35 | + ) |
| 36 | + assert ( |
| 37 | + 0 < _INITIAL_PREFILL_STREAMS <= _NUM_STREAMS |
| 38 | + ), f"_INITIAL_PREFILL_STREAMS ({_INITIAL_PREFILL_STREAMS}) must be > 0 and <= _NUM_STREAMS ({_NUM_STREAMS})" |
| 39 | + |
| 40 | + |
| 41 | +def main(argv: Sequence[str]) -> None: |
| 42 | + """Main function to run interleaved inference.""" |
| 43 | + jax.config.update("jax_default_prng_impl", "unsafe_rbg") |
| 44 | + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" |
| 45 | + |
| 46 | + config = pyconfig.initialize(argv) |
| 47 | + _validate_config(config) |
| 48 | + max_utils.print_system_information() |
| 49 | + |
| 50 | + engine = maxengine.MaxEngine(config) |
| 51 | + rng = jax.random.PRNGKey(1234) |
| 52 | + rng, rng_load_params = jax.random.split(rng) |
| 53 | + params = engine.load_params(rng=rng_load_params) |
| 54 | + |
| 55 | + text = config.prompt |
| 56 | + metadata = engine.get_tokenizer() |
| 57 | + tokenizer_model = engine.build_tokenizer(metadata) |
| 58 | + tokens, true_length = tokenizer_model.encode(text, is_bos=True, prefill_lengths=[config.max_prefill_predict_length]) |
| 59 | + assert true_length <= config.max_prefill_predict_length, "Prompt too long for prefill length" |
| 60 | + |
| 61 | + batch_size = int(config.per_device_batch_size * jax.device_count()) |
| 62 | + assert 0 < _NUM_STREAMS <= batch_size, f"The number of streams {_NUM_STREAMS} must be > 0 and <= batch size {batch_size}" |
| 63 | + |
| 64 | + # Initialize decode state |
| 65 | + rng, rng_init_decode = jax.random.split(rng) |
| 66 | + decode_state = engine.init_decode_state(rng=rng_init_decode) |
| 67 | + print("Initial decode state initialized.") |
| 68 | + |
| 69 | + # Keep track of results per stream (slot) |
| 70 | + streams_results: dict[int, List[int]] = {i: [] for i in range(_NUM_STREAMS)} |
| 71 | + streams_active: List[bool] = [False] * _NUM_STREAMS # Track which slots are active |
| 72 | + streams_finished: List[bool] = [False] * _NUM_STREAMS # Track finished streams |
| 73 | + streams_prefilled_count = 0 |
| 74 | + streams_inserted_count = 0 |
| 75 | + |
| 76 | + # --- Initial Prefill Phase --- |
| 77 | + print(f"Starting initial prefill for {_INITIAL_PREFILL_STREAMS} streams...") |
| 78 | + prefill_results_to_insert = {} # Store prefill results before inserting |
| 79 | + for i in range(_INITIAL_PREFILL_STREAMS): |
| 80 | + slot_idx = i |
| 81 | + print(f" Prefilling stream for slot {slot_idx}...") |
| 82 | + rng, rng_prefill = jax.random.split(rng) |
| 83 | + request_id = uuid.uuid4() |
| 84 | + prefill_result, first_token = engine.prefill( |
| 85 | + params=params, |
| 86 | + padded_tokens=tokens, |
| 87 | + true_length=true_length, |
| 88 | + rng=rng_prefill, |
| 89 | + slot=slot_idx, |
| 90 | + request_id=request_id, |
| 91 | + ) |
| 92 | + prefill_results_to_insert[slot_idx] = prefill_result |
| 93 | + streams_results[slot_idx].append(first_token.get_result_at_slot(0).tokens.item()) |
| 94 | + streams_prefilled_count += 1 |
| 95 | + print(f"After prefill stream {slot_idx}") |
| 96 | + |
| 97 | + # --- Insert Initial Prefills --- |
| 98 | + print("Inserting initial prefill results...") |
| 99 | + for slot_idx, prefill_result in prefill_results_to_insert.items(): |
| 100 | + request_id = uuid.uuid4() |
| 101 | + decode_state = engine.insert( |
| 102 | + prefix=prefill_result, |
| 103 | + decode_state=decode_state, |
| 104 | + slot=slot_idx, |
| 105 | + request_id=request_id, # Pass request_id |
| 106 | + ) |
| 107 | + streams_active[slot_idx] = True # Mark stream as active |
| 108 | + streams_inserted_count += 1 |
| 109 | + print(f" Inserted prefill for slot {slot_idx}") |
| 110 | + |
| 111 | + print("Starting interleaved generation loop...") |
| 112 | + total_steps = config.max_target_length - config.max_prefill_predict_length |
| 113 | + for step in range(total_steps): |
| 114 | + print(f"\n--- Step {step + 1} / {total_steps} ---") |
| 115 | + |
| 116 | + # Generate step for all active streams |
| 117 | + active_stream_indices = [i for i, active in enumerate(streams_active) if active and not streams_finished[i]] |
| 118 | + if active_stream_indices: |
| 119 | + print(f" Generating for active slots: {active_stream_indices}") |
| 120 | + rng, rng_generate = jax.random.split(rng) |
| 121 | + decode_state, sampled_tokens = engine.generate(params, decode_state, rng=rng_generate) |
| 122 | + |
| 123 | + # Store the generated token and check for finished streams |
| 124 | + for slot_idx in active_stream_indices: |
| 125 | + # Check if the stream finished this step |
| 126 | + current_len = config.max_prefill_predict_length + step + 1 # Includes prefill + current step |
| 127 | + finished_this_step = False |
| 128 | + if current_len >= config.max_target_length: |
| 129 | + print(f" Stream in slot {slot_idx} reached max target length.") |
| 130 | + streams_finished[slot_idx] = True |
| 131 | + streams_active[slot_idx] = False |
| 132 | + finished_this_step = True |
| 133 | + |
| 134 | + # Store token if it wasn't already finished before this step or if it finished on this step |
| 135 | + if not streams_finished[slot_idx] or finished_this_step: |
| 136 | + # Ensure we don't try to access results for a slot that might not exist |
| 137 | + if slot_idx < sampled_tokens.data.shape[0]: |
| 138 | + token_for_slot = sampled_tokens.get_result_at_slot(slot_idx).tokens.item() |
| 139 | + streams_results[slot_idx].append(token_for_slot) |
| 140 | + else: |
| 141 | + print(f"Warning: Tried to get token for slot {slot_idx}, but batch size seems smaller.") |
| 142 | + |
| 143 | + # Call release_pages if finished this step |
| 144 | + if finished_this_step: |
| 145 | + print(f" Calling engine to release pages for finished slot {slot_idx}...") |
| 146 | + engine.release_pages(slot=slot_idx) |
| 147 | + |
| 148 | + else: |
| 149 | + print(" No active streams to generate for.") |
| 150 | + |
| 151 | + # 2. Check if all streams are finished (can exit loop early) |
| 152 | + if all(streams_finished): |
| 153 | + print("\nAll streams finished generation.") |
| 154 | + break |
| 155 | + |
| 156 | + # 3. Prefill and Insert new streams if capacity allows |
| 157 | + num_active_not_finished = sum(1 for i in range(_NUM_STREAMS) if streams_active[i] and not streams_finished[i]) |
| 158 | + available_slots = batch_size - num_active_not_finished |
| 159 | + can_prefill_more = streams_prefilled_count < _NUM_STREAMS |
| 160 | + |
| 161 | + if can_prefill_more and available_slots > 0: |
| 162 | + try: |
| 163 | + next_available_slot = streams_active.index(False) |
| 164 | + print(f" Prefilling new stream for slot {next_available_slot}...") |
| 165 | + rng, rng_prefill = jax.random.split(rng) |
| 166 | + request_id = uuid.uuid4() |
| 167 | + prefill_result, first_token = engine.prefill( |
| 168 | + params=params, |
| 169 | + padded_tokens=tokens, |
| 170 | + true_length=true_length, |
| 171 | + rng=rng_prefill, |
| 172 | + slot=next_available_slot, |
| 173 | + request_id=request_id, |
| 174 | + ) |
| 175 | + streams_prefilled_count += 1 |
| 176 | + |
| 177 | + # Insert the new prefill |
| 178 | + print(f" Inserting new stream into slot {next_available_slot}...") |
| 179 | + request_id_insert = uuid.uuid4() |
| 180 | + decode_state = engine.insert( |
| 181 | + prefix=prefill_result, |
| 182 | + decode_state=decode_state, |
| 183 | + slot=next_available_slot, |
| 184 | + request_id=request_id_insert, |
| 185 | + ) |
| 186 | + streams_active[next_available_slot] = True |
| 187 | + streams_inserted_count += 1 |
| 188 | + streams_results[next_available_slot].append(first_token.get_result_at_slot(0).tokens.item()) |
| 189 | + |
| 190 | + except ValueError: |
| 191 | + print(" Warning: Available slots detected but couldn't find an inactive one.") |
| 192 | + elif can_prefill_more: |
| 193 | + print(" Generate step finished, but no available slots to prefill new stream.") |
| 194 | + else: |
| 195 | + print(" Generate step finished, all streams already prefilled.") |
| 196 | + |
| 197 | + print("\n--- Final Results ---") |
| 198 | + for i in range(_NUM_STREAMS): |
| 199 | + if streams_results[i]: |
| 200 | + output = tokenizer_model.decode(streams_results[i]) |
| 201 | + print(f"Stream {i}: Input=`{text}` -> Output=`{output}`") |
| 202 | + |
| 203 | + if i == 0: # Check first stream as an example |
| 204 | + assert output.startswith( |
| 205 | + config.autoregressive_decode_assert |
| 206 | + ), f"Stream {i} generated text mismatch: `{output}` vs expected start `{config.autoregressive_decode_assert}`" |
| 207 | + else: |
| 208 | + print(f"Stream {i}: Was not activated.") |
| 209 | + |
| 210 | + |
| 211 | +if __name__ == "__main__": |
| 212 | + app.run(main) |
0 commit comments