Skip to content

Commit 34e2131

Browse files
authored
Add quantization support for maxtext models and update conversion scripts (#1006)
* Support quantization for llama3.3-70b and llama3.1-405b on CPUs * add quantization support for maxtext models and update conversion scripts
1 parent 1d7e20d commit 34e2131

File tree

3 files changed

+249
-113
lines changed

3 files changed

+249
-113
lines changed

tutorials-and-examples/inference-servers/checkpoints/Dockerfile

+21-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
115
# Ubuntu:22.04
216
# Use Ubuntu 22.04 from Docker Hub.
317
# https://hub.docker.com/_/ubuntu/tags?page=1&name=22.04
@@ -22,10 +36,16 @@ RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyri
2236
RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list
2337
RUN apt -y update && apt install -y google-cloud-cli
2438

39+
RUN git clone https://github.com/AI-Hypercomputer/maxtext.git && \
40+
cd /maxtext && \
41+
bash setup.sh
42+
2543
RUN pip install kaggle && \
2644
pip install huggingface_hub[cli] && \
2745
pip install google-jetstream && \
28-
pip install llama-toolchain
46+
pip install llama-stack && \
47+
pip install torch && \
48+
pip install grain-nightly==0.0.10
2949

3050
COPY checkpoint_converter.sh /usr/bin/
3151
RUN chmod +x /usr/bin/checkpoint_converter.sh

tutorials-and-examples/inference-servers/checkpoints/README.md

+22-29
Original file line numberDiff line numberDiff line change
@@ -5,43 +5,36 @@ The `checkpoint_entrypoint.sh` script overviews how to convert your inference ch
55
Build the checkpoint conversion Dockerfile
66
```
77
docker build -t inference-checkpoint .
8-
docker tag inference-checkpoint gcr.io/${PROJECT_ID}/inference-checkpoint:latest
9-
docker push gcr.io/${PROJECT_ID}/inference-checkpoint:latest
8+
docker tag inference-checkpoint ${LOCATION}-docker.pkg.dev/${PROJECT_ID}/jetstream/inference-checkpoint:latest
9+
docker push ${LOCATION}-docker.pkg.dev/${PROJECT_ID}/jetstream/inference-checkpoint:latest
1010
```
1111

1212
Now you can use it in a [Kubernetes job](../jetstream/maxtext/single-host-inference/checkpoint-job.yaml) and pass the following arguments
1313

1414
## Jetstream + MaxText
1515
```
16-
- -s=INFERENCE_SERVER
17-
- -b=BUCKET_NAME
18-
- -m=MODEL_PATH
19-
- -v=VERSION (Optional)
16+
-b, --bucket_name: [string] The GSBucket name to store checkpoints, without gs://.
17+
-s, --inference_server: [string] The name of the inference server that serves your model. (Optional) (default=jetstream-maxtext)
18+
-m, --model_path: [string] The model path.
19+
-n, --model_name: [string] The model name.
20+
-h, --huggingface: [bool] The model is from Hugging Face. (Optional) (default=False)
21+
-t, --quantize_type: [string] The type of quantization. (Optional)
22+
-q, --quantize_weights: [bool] The checkpoint is to be quantized. (Optional) (default=False)
23+
-i, --input_directory: [string] The input directory, likely a GSBucket path.
24+
-o, --output_directory: [string] The output directory, likely a GSBucket path.
25+
-u, --meta_url: [string] The url from Meta. (Optional)
26+
-v, --version: [string] The version of repository. (Optional) (default=main)
2027
```
2128

2229
## Jetstream + Pytorch/XLA
2330
```
24-
- -s=INFERENCE_SERVER
25-
- -m=MODEL_PATH
26-
- -n=MODEL_NAME
27-
- -q=QUANTIZE_WEIGHTS (Optional) (default=False)
28-
- -t=QUANTIZE_TYPE (Optional) (default=int8_per_channel)
29-
- -v=VERSION (Optional) (default=jetstream-v0.2.3)
30-
- -i=INPUT_DIRECTORY (Optional)
31-
- -o=OUTPUT_DIRECTORY
32-
- -h=HUGGINGFACE (Optional) (default=False)
33-
```
34-
35-
## Argument descriptions:
36-
```
37-
b) BUCKET_NAME: (str) GSBucket, without gs://
38-
s) INFERENCE_SERVER: (str) Inference server, ex. jetstream-maxtext, jetstream-pytorch
39-
m) MODEL_PATH: (str) Model path, varies depending on inference server and location of base checkpoint
40-
n) MODEL_NAME: (str) Model name, ex. llama-2, llama-3, gemma
41-
h) HUGGINGFACE: (bool) Checkpoint is from HuggingFace.
42-
q) QUANTIZE_WEIGHTS: (str) Whether to quantize weights
43-
t) QUANTIZE_TYPE: (str) Quantization type, QUANTIZE_WEIGHTS must be set to true. Availabe quantize type: {"int8", "int4"} x {"per_channel", "blockwise"},
44-
v) VERSION: (str) Version of inference server to override, ex. jetstream-v0.2.2, jetstream-v0.2.3
45-
i) INPUT_DIRECTORY: (str) Input checkpoint directory, likely a GSBucket path
46-
o) OUTPUT_DIRECTORY: (str) Output checkpoint directory, likely a GSBucket path
31+
- -s, --inference_server: [string] The name of the inference server that serves your model.
32+
- -m, --model_path: [string] The model path.
33+
- -n, --model_name: [string] The model name, Model name, ex. llama-2, llama-3, gemma.
34+
- -q, --quantize_weights: [bool] The checkpoint is to be quantized. (Optional) (default=False)
35+
- -t, --quantize_type: [string] The type of quantization. Availabe quantize type: {"int8", "int4"} x {"per_channel", "blockwise"}. (Optional) (default=int8_per_channel)
36+
- -v, --version: [string] The version of repository to override, ex. jetstream-v0.2.2, jetstream-v0.2.3. (Optional) (default=main)
37+
- -i, --input_directory: [string] The input directory, likely a GSBucket path. (Optional)
38+
- -o, --output_directory: [string] The output directory, likely a GSBucket path.
39+
- -h, --huggingface: [bool] The model is from Hugging Face. (Optional) (default=False)
4740
```

0 commit comments

Comments
 (0)