Skip to content

Commit d645ad5

Browse files
authored
checkpoint conversion edits for guide and flag for increased byte i/o (#1029)
* add concurrent_gb flag to support i/o of large models * add checkpoint conversion instructions externally
1 parent 201002a commit d645ad5

File tree

4 files changed

+172
-20
lines changed

4 files changed

+172
-20
lines changed

tutorials-and-examples/inference-servers/checkpoints/README.md

+20-20
Original file line numberDiff line numberDiff line change
@@ -13,28 +13,28 @@ Now you can use it in a [Kubernetes job](../jetstream/maxtext/single-host-infere
1313

1414
## Jetstream + MaxText
1515
```
16-
-b, --bucket_name: [string] The GSBucket name to store checkpoints, without gs://.
17-
-s, --inference_server: [string] The name of the inference server that serves your model. (Optional) (default=jetstream-maxtext)
18-
-m, --model_path: [string] The model path.
19-
-n, --model_name: [string] The model name.
20-
-h, --huggingface: [bool] The model is from Hugging Face. (Optional) (default=False)
21-
-t, --quantize_type: [string] The type of quantization. (Optional)
22-
-q, --quantize_weights: [bool] The checkpoint is to be quantized. (Optional) (default=False)
23-
-i, --input_directory: [string] The input directory, likely a GSBucket path.
24-
-o, --output_directory: [string] The output directory, likely a GSBucket path.
25-
-u, --meta_url: [string] The url from Meta. (Optional)
26-
-v, --version: [string] The version of repository. (Optional) (default=main)
16+
--bucket_name: [string] The GSBucket name to store checkpoints, without gs://.
17+
--inference_server: [string] The name of the inference server that serves your model. (Optional) (default=jetstream-maxtext)
18+
--model_path: [string] The model path.
19+
--model_name: [string] The model name. ex. llama-2, llama-3, gemma.
20+
--huggingface: [bool] The model is from Hugging Face. (Optional) (default=False)
21+
--quantize_type: [string] The type of quantization. (Optional)
22+
--quantize_weights: [bool] The checkpoint is to be quantized. (Optional) (default=False)
23+
--input_directory: [string] The input directory, likely a GSBucket path.
24+
--output_directory: [string] The output directory, likely a GSBucket path.
25+
--meta_url: [string] The url from Meta. (Optional)
26+
--version: [string] The version of repository. (Optional) (default=main)
2727
```
2828

2929
## Jetstream + Pytorch/XLA
3030
```
31-
- -s, --inference_server: [string] The name of the inference server that serves your model.
32-
- -m, --model_path: [string] The model path.
33-
- -n, --model_name: [string] The model name, Model name, ex. llama-2, llama-3, gemma.
34-
- -q, --quantize_weights: [bool] The checkpoint is to be quantized. (Optional) (default=False)
35-
- -t, --quantize_type: [string] The type of quantization. Availabe quantize type: {"int8", "int4"} x {"per_channel", "blockwise"}. (Optional) (default=int8_per_channel)
36-
- -v, --version: [string] The version of repository to override, ex. jetstream-v0.2.2, jetstream-v0.2.3. (Optional) (default=main)
37-
- -i, --input_directory: [string] The input directory, likely a GSBucket path. (Optional)
38-
- -o, --output_directory: [string] The output directory, likely a GSBucket path.
39-
- -h, --huggingface: [bool] The model is from Hugging Face. (Optional) (default=False)
31+
--inference_server: [string] The name of the inference server that serves your model.
32+
--model_path: [string] The model path.
33+
--model_name: [string] The model name. ex. llama-2, llama-3, gemma.
34+
--quantize_weights: [bool] The checkpoint is to be quantized. (Optional) (default=False)
35+
--quantize_type: [string] The type of quantization. Availabe quantize type: {"int8", "int4"} x {"per_channel", "blockwise"}. (Optional) (default=int8_per_channel)
36+
--version: [string] The version of repository to override, ex. jetstream-v0.2.2, jetstream-v0.2.3. (Optional) (default=main)
37+
--input_directory: [string] The input directory, likely a GSBucket path. (Optional)
38+
--output_directory: [string] The output directory, likely a GSBucket path.
39+
--huggingface: [bool] The model is from Hugging Face. (Optional) (default=False)
4040
```

tutorials-and-examples/inference-servers/checkpoints/checkpoint_converter.sh

+5
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ convert_maxtext_checkpoint() {
128128
QUANTIZE_TYPE=$8
129129
QUANTIZE_WEIGHTS=$9
130130

131+
CONCURRENT_GB=96
132+
131133
echo -e "$(date '+%Y-%m-%d %H:%M:%S'): Bucket name=${BUCKET_NAME}"
132134
echo -e "$(date '+%Y-%m-%d %H:%M:%S'): Model path=${MODEL_PATH}"
133135
echo -e "$(date '+%Y-%m-%d %H:%M:%S'): Model name=${MODEL_NAME}"
@@ -210,6 +212,7 @@ convert_maxtext_checkpoint() {
210212
MODEL_SIZE="llama3.1-70b"
211213
elif [[ $MODEL_PATH == *"405B"* ]] || [[ $MODEL_PATH == *"405b"* ]]; then
212214
MODEL_SIZE="llama3.1-405b"
215+
CONCURRENT_GB=500
213216
else
214217
echo -e "Unclear llama3.1 model: $MODEL_PATH"
215218
fi
@@ -237,6 +240,7 @@ convert_maxtext_checkpoint() {
237240
echo -e "$(date '+%Y-%m-%d %H:%M:%S'): Maxtext model path=${OUTPUT_CKPT_DIR_SCANNED}"
238241
echo -e "$(date '+%Y-%m-%d %H:%M:%S'): Model path=${MODEL_PATH}"
239242
echo -e "$(date '+%Y-%m-%d %H:%M:%S'): Model size=${MODEL_SIZE}"
243+
echo -e "$(date '+%Y-%m-%d %H:%M:%S'): Concurrent_gb=${CONCURRENT_GB}"
240244

241245
export JAX_PLATFORMS=cpu
242246
cd /maxtext/
@@ -260,6 +264,7 @@ convert_maxtext_checkpoint() {
260264
run_name=${RUN_NAME} \
261265
model_name=${MODEL_SIZE} \
262266
force_unroll=true \
267+
checkpoint_storage_concurrent_gb=${CONCURRENT_GB} \
263268
weight_dtype=bfloat16 \
264269
opt_type=sgd
265270

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Checkpoint conversion with Llama3.1-405B
2+
3+
This example will walk through converting a Llama3.1-405b from Meta to a MaxText compatible checkpoint for inference.
4+
5+
The Llama3-405B model needs ~2000 GB of client memory to download and run checkpoint conversion. This machine type and boot disk size will ensure enough capacity to facilitate the model download and conversion. The checkpoint conversion for the 405B model supports weights downloaded only from Meta.
6+
7+
## Agree to the Meta Terms and Conditions
8+
Go to https://www.llama.com/llama-downloads/ to acknowledge the terms and conditions. Select `Llama3.1: 405B & 8B` as the model(s) you will request.
9+
10+
Copy the provided Meta URL to use in your manifest file.
11+
12+
## Create GCS Bucket to store checkpoint
13+
```
14+
BUCKET_NAME=<your bucket>
15+
16+
gcloud storage buckets create gs://$BUCKET_NAME
17+
```
18+
19+
## Configure a service account for Storage Object access
20+
**This step can be skipped if already done on the cluster with a different Service Account.**
21+
22+
Configure a Kubernetes service account to act as an IAM service account.
23+
24+
Create an IAM service account for your application:
25+
26+
```
27+
gcloud iam service-accounts create checkpoint-converter
28+
```
29+
30+
Add an IAM policy binding for your IAM service account to manage Cloud Storage:
31+
32+
```
33+
gcloud projects add-iam-policy-binding ${PROJECT} \
34+
--member "serviceAccount:checkpoint-converter@${PROJECT}.iam.gserviceaccount.com" \
35+
--role roles/storage.objectUser
36+
37+
gcloud projects add-iam-policy-binding ${PROJECT} \
38+
--member "serviceAccount:checkpoint-converter@${PROJECT}.iam.gserviceaccount.com" \
39+
--role roles/storage.insightsCollectorService
40+
```
41+
42+
Annotate the Kubernetes service account with the email address of the IAM service account.
43+
44+
```
45+
kubectl annotate serviceaccount default \
46+
iam.gke.io/gcp-service-account=checkpoint-converter@${PROJECT}.iam.gserviceaccount.com
47+
```
48+
49+
## Provision resources to facilitate conversion
50+
Create a node pool with machine type `m1-ultramem-160`. You may need to request m1 quota in your project.
51+
52+
```
53+
CLUSTER=<your cluster>
54+
ZONE=<your zone>
55+
PROJECT=<project>
56+
57+
gcloud container node-pools create m1-pool \
58+
--cluster "${CLUSTER}" \
59+
--zone "${ZONE}" \
60+
--machine-type m1-ultramem-160 \
61+
--num-nodes 1 \
62+
--disk-size 3000 \
63+
--project "${PROJECT}" \
64+
--scopes cloud-platform
65+
```
66+
67+
In `checkpoint-converter.yaml`, replace `BUCKET_NAME` with the GCS Bucket that you created earlier, without gs://.
68+
69+
Parameter descriptions:
70+
71+
```
72+
--bucket_name: [string] The GSBucket name to store checkpoints, without gs://.
73+
--inference_server: [string] The name of the inference server that serves your model. (ex. jetstream-maxtext)
74+
--meta_url: [string] The url from Meta.
75+
--model_name: [string] The model name. (ex. llama-2, llama-3, llama-3.1)
76+
--model_path: [string] The model path. For Llama models, download llama via `pip install llama-stack` and run `llama model list --show-all` for Model Descriptor to use. (ex. Llama3.1-405B-Instruct:bf16-mp16)
77+
--output_directory: [string] The output directory. (ex. gs://bucket_name/maxtext/llama3.1-405b)
78+
--quantize_type: [string] The type of quantization. (ex. int8)
79+
--quantize_weights: [bool] The checkpoint is to be quantized. (ex. True)
80+
```
81+
82+
For a bf16 checkpoint only, remove the flags `--quantize_type` and `--quantize_weights`.
83+
84+
Apply the manifest:
85+
86+
```
87+
kubectl apply -f checkpoint-converter.yaml
88+
```
89+
90+
The checkpoint conversion job takes around 9-10 hours to complete, To monitor the progress of the checkpoint download and conversion, check [GCP Log Explorer](https://console.cloud.google.com/logs/query) and enter the following query:
91+
92+
```
93+
resource.type="k8s_container"
94+
resource.labels.project_id="PROJECT_ID"
95+
resource.labels.location="LOCATION"
96+
resource.labels.cluster_name="CLUSTER_NAME"
97+
resource.labels.namespace_name="default"
98+
resource.labels.pod_name:"checkpoint-converter-"
99+
```
100+
101+
Once completed, you will see a log similar to:
102+
103+
```
104+
# bf16 checkpoint
105+
Completed unscanning checkpoint to gs://output_directory/bf16/unscanned/checkpoints/0/items
106+
107+
# int8 checkpoint
108+
Completed quantizing model llama3.1-405b with int8 to gs://output_directory/int8
109+
```
110+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
apiVersion: batch/v1
16+
kind: Job
17+
metadata:
18+
name: checkpoint-converter
19+
spec:
20+
template:
21+
spec:
22+
restartPolicy: Never
23+
containers:
24+
- name: inference-checkpoint
25+
image: us-docker.pkg.dev/cloud-tpu-images/inference/inference-checkpoint:v0.2.5
26+
imagePullPolicy: Always
27+
args:
28+
- --bucket_name=BUCKET_NAME
29+
- --inference_server=jetstream-maxtext
30+
- --meta_url=META_URL
31+
- --model_name=llama-3.1
32+
- --model_path=Llama3.1-405B-Instruct:bf16-mp16
33+
- --output_directory=gs://BUCKET_NAME/maxtext/llama-3.1-405b
34+
- --quantize_type=int8
35+
- --quantize_weights=True
36+
nodeSelector:
37+
cloud.google.com/gke-nodepool: m1-pool

0 commit comments

Comments
 (0)