Skip to content

Commit 6a38ad5

Browse files
authored
JetStream checkpoint converter support for Llama models on MaxText (#840)
* Update pip in JetStream Pytorch and checkpoint Dockerfiles * Add support for llama model conversions from Meta and HF to MaxText; update http server healthcheck
1 parent 8aa4d6a commit 6a38ad5

File tree

3 files changed

+163
-18
lines changed

3 files changed

+163
-18
lines changed

tutorials-and-examples/inference-servers/checkpoints/Dockerfile

+2-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ RUN apt -y update && apt install -y google-cloud-cli
2424

2525
RUN pip install kaggle && \
2626
pip install huggingface_hub[cli] && \
27-
pip install google-jetstream
27+
pip install google-jetstream && \
28+
pip install llama-toolchain
2829

2930
COPY checkpoint_converter.sh /usr/bin/
3031
RUN chmod +x /usr/bin/checkpoint_converter.sh

tutorials-and-examples/inference-servers/checkpoints/checkpoint_converter.sh

+134-17
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#!/bin/bash
22

3+
set -e
34
export KAGGLE_CONFIG_DIR="/kaggle"
45
export HUGGINGFACE_TOKEN_DIR="/huggingface"
56
INFERENCE_SERVER="jetstream-maxtext"
@@ -19,13 +20,15 @@ check_gsbucket() {
1920
BUCKET_NAME=$1
2021
if [ -z $BUCKET_NAME ]; then
2122
echo "BUCKET_NAME is empty, please provide a GSBucket"
23+
exit 1
2224
fi
2325
}
2426

2527
check_model_path() {
2628
MODEL_PATH=$1
2729
if [ -z $MODEL_PATH ]; then
2830
echo "MODEL_PATH is empty, please provide the model path"
31+
exit 1
2932
fi
3033
}
3134

@@ -49,10 +52,15 @@ download_huggingface_checkpoint() {
4952
MODEL_NAME=$2
5053

5154
INPUT_CKPT_DIR_LOCAL=/base/
52-
mkdir /base/
55+
56+
if [ ! -d "/base" ]; then
57+
mkdir /base/
58+
fi
5359
huggingface-cli login --token $(cat ${HUGGINGFACE_TOKEN_DIR}/HUGGINGFACE_TOKEN)
5460
huggingface-cli download ${MODEL_PATH} --local-dir ${INPUT_CKPT_DIR_LOCAL}
5561

62+
echo "Completed downloading model ${MODEL_PATH}"
63+
5664
if [[ $MODEL_NAME == *"llama"* ]]; then
5765
if [[ $MODEL_NAME == "llama-2" ]]; then
5866
TOKENIZER_PATH=/base/tokenizer.model
@@ -64,37 +72,146 @@ download_huggingface_checkpoint() {
6472
fi
6573
elif [[ $MODEL_NAME == *"gemma"* ]]; then
6674
TOKENIZER_PATH=/base/tokenizer.model
75+
if [[ $MODEL_PATH == *"gemma-2b-it-pytorch"* ]]; then
76+
huggingface-cli download google/gemma-2b-pytorch config.json --local-dir ${INPUT_CKPT_DIR_LOCAL}
77+
fi
6778
else
6879
echo -e "Unclear of tokenizer.model for ${MODEL_NAME}. May have to manually upload."
6980
fi
7081
}
7182

83+
download_meta_checkpoint() {
84+
META_URL=$1
85+
MODEL_PATH=$2
86+
echo -e "$META_URL" | llama download --source meta --model-id $MODEL_PATH
87+
}
88+
7289
convert_maxtext_checkpoint() {
7390
BUCKET_NAME=$1
74-
MODEL_NAME=$2
75-
VARIATION_NAME=$3
76-
MODEL_SIZE=$4
77-
MAXTEXT_VERSION=$5
78-
79-
if [ -z $MAXTEXT_VERSION ]; then
80-
MAXTEXT_VERSION=jetstream-v0.2.2
91+
MODEL_PATH=$2
92+
MODEL_NAME=$3
93+
OUTPUT_CKPT_DIR=$4
94+
VERSION=$5
95+
HUGGINGFACE=$6
96+
META_URL=$7
97+
98+
echo -e "\nbucket name=${BUCKET_NAME}"
99+
echo -e "\nmodel path=${MODEL_PATH}"
100+
echo -e "\nmodel name=${MODEL_NAME}"
101+
echo -e "\nversion=${VERSION}"
102+
echo -e "\noutput ckpt dir=${OUTPUT_CKPT_DIR}"
103+
echo -e "\nhuggingface=${HUGGINGFACE}"
104+
echo -e "\nurl=${META_URL}"
105+
106+
if [ -z $VERSION ]; then
107+
VERSION=jetstream-v0.2.2
81108
fi
82109

83110
git clone https://github.com/google/maxtext.git
84111

85112
# checkout stable MaxText commit
86113
cd maxtext
87-
git checkout ${MAXTEXT_VERSION}
114+
git checkout ${VERSION}
88115
python3 -m pip install -r requirements.txt
116+
117+
if [ $VERSION == "jetstream-v0.2.2" || $VERSION == "jetstream-v0.2.1" || $VERSION == "jetstream-v0.2.0" ]; then
118+
pip3 install orbax-checkpoint==0.5.20
119+
else
120+
pip3 install orbax-checkpoint==0.6.0
121+
fi
122+
89123
echo -e "\nCloned MaxText repository and completed installing requirements"
90124

91-
python3 MaxText/convert_gemma_chkpt.py --base_model_path gs://${BUCKET_NAME}/base/${MODEL_NAME}_${VARIATION_NAME}/${VARIATION_NAME} --maxtext_model_path gs://${BUCKET_NAME}/final/scanned/${MODEL_NAME}_${VARIATION_NAME} --model_size ${MODEL_SIZE}
92-
echo -e "\nCompleted conversion of checkpoint to gs://${BUCKET_NAME}/final/scanned/${MODEL_NAME}_${VARIATION_NAME}"
125+
if [[ $MODEL_PATH == *"gemma"* ]]; then
126+
download_kaggle_checkpoint "$BUCKET_NAME" "$MODEL_PATH"
127+
OUTPUT_CKPT_DIR_SCANNED=gs://${BUCKET_NAME}/final/scanned/${MODEL_NAME}_${VARIATION_NAME}
128+
OUTPUT_CKPT_DIR_UNSCANNED=gs://${BUCKET_NAME}/final/unscanned/${MODEL_NAME}_${VARIATION_NAME}
129+
130+
python3 MaxText/convert_gemma_chkpt.py --base_model_path gs://${BUCKET_NAME}/base/${MODEL_NAME}_${VARIATION_NAME}/${VARIATION_NAME} --maxtext_model_path=${OUTPUT_CKPT_DIR_SCANNED} --model_size ${MODEL_SIZE}
131+
echo -e "\nCompleted conversion of checkpoint to gs://${BUCKET_NAME}/final/scanned/${MODEL_NAME}_${VARIATION_NAME}"
132+
133+
MAXTEXT_MODEL_NAME=${MODEL_NAME}-${MODEL_SIZE}
134+
135+
elif [[ $MODEL_PATH == *"Llama"* ]]; then
136+
137+
if [ $HUGGINGFACE == "True" ]; then
138+
echo "Checkpoint weights are from HuggingFace"
139+
download_huggingface_checkpoint "$MODEL_PATH" "$MODEL_NAME"
93140

94-
RUN_NAME=0
141+
else
142+
echo "Checkpoint weights are from Meta, use llama CLI"
143+
144+
if [ -z $META_URL ]; then
145+
echo "META_URL is empty, please provide the Meta url by visiting https://www.llama.com/llama-downloads/ and agreeing to the Terms and Conditions."
146+
exit 1
147+
fi
148+
echo "META_URL: $META_URL"
149+
150+
INPUT_CKPT_DIR_LOCAL=/root/.llama/checkpoints/$MODEL_PATH/
151+
download_meta_checkpoint "$META_URL" "$MODEL_PATH"
152+
fi
95153

96-
python3 MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml force_unroll=true model_name=${MODEL_NAME}-${MODEL_SIZE} async_checkpointing=false run_name=${RUN_NAME} load_parameters_path=gs://${BUCKET_NAME}/final/scanned/${MODEL_NAME}_${VARIATION_NAME}/0/items base_output_directory=gs://${BUCKET_NAME}/final/unscanned/${MODEL_NAME}_${VARIATION_NAME}
97-
echo -e "\nCompleted unscanning checkpoint to gs://${BUCKET_NAME}/final/unscanned/${MODEL_NAME}_${VARIATION_NAME}/${RUN_NAME}/checkpoints/0/items"
154+
echo "Setting model size for $MODEL_PATH"
155+
if [[ $MODEL_NAME == "llama-2" ]]; then
156+
if [[ $MODEL_PATH == *"7B"* ]] || [[ $MODEL_PATH == *"7b"* ]]; then
157+
MODEL_SIZE="llama2-7b"
158+
elif [[ $MODEL_PATH == *"13B"* ]] || [[ $MODEL_PATH == *"13b"* ]]; then
159+
MODEL_SIZE="llama2-13b"
160+
elif [[ $MODEL_PATH == *"70B"* ]] || [[ $MODEL_PATH == *"70b"* ]]; then
161+
MODEL_SIZE="llama2-70b"
162+
elif [[ $MODEL_PATH == *"405B"* ]] || [[ $MODEL_PATH == *"405b"* ]]; then
163+
MODEL_SIZE="llama2-405b"
164+
else
165+
echo -e "\nUnclear llama2 model: $MODEL_PATH"
166+
fi
167+
168+
elif [[ $MODEL_NAME == "llama-3" ]]; then
169+
if [[ $MODEL_PATH == *"8B"* ]] || [[ $MODEL_PATH == *"8b"* ]]; then
170+
MODEL_SIZE="llama3-8b"
171+
elif [[ $MODEL_PATH == *"70B"* ]] || [[ $MODEL_PATH == *"70b"* ]]; then
172+
MODEL_SIZE="llama3-70b"
173+
elif [[ $MODEL_PATH == *"405B"* ]] || [[ $MODEL_PATH == *"405b"* ]]; then
174+
MODEL_SIZE="llama3-405b"
175+
else
176+
echo -e "\nUnclear llama3 model: $MODEL_PATH"
177+
fi
178+
179+
else
180+
echo -e "\nUnclear llama model"
181+
fi
182+
183+
echo "Model size for $MODEL_PATH is $MODEL_SIZE"
184+
185+
OUTPUT_CKPT_DIR_SCANNED=${OUTPUT_CKPT_DIR}/scanned
186+
OUTPUT_CKPT_DIR_UNSCANNED=${OUTPUT_CKPT_DIR}/unscanned
187+
188+
TOKENIZER_PATH=${INPUT_CKPT_DIR_LOCAL}/tokenizer.model
189+
190+
pip3 install torch
191+
echo -e "\ninput dir=${INPUT_CKPT_DIR_LOCAL}"
192+
echo -e "\nmaxtext model path=${OUTPUT_CKPT_DIR_UNSCANNED}"
193+
echo -e "\nmodel path=${MODEL_PATH}"
194+
echo -e "\nmodel size=${MODEL_SIZE}"
195+
196+
cd /maxtext/
197+
python3 MaxText/llama_ckpt_conversion_inference_only.py --base-model-path ${INPUT_CKPT_DIR_LOCAL} --maxtext-model-path ${OUTPUT_CKPT_DIR_UNSCANNED} --model-size ${MODEL_SIZE}
198+
echo -e "\nCompleted conversion of checkpoint to ${OUTPUT_CKPT_DIR_UNSCANNED}/0/items"
199+
200+
gcloud storage cp ${TOKENIZER_PATH} ${OUTPUT_CKPT_DIR_UNSCANNED}
201+
202+
touch commit_success.txt
203+
gcloud storage cp commit_success.txt ${OUTPUT_CKPT_DIR_UNSCANNED}/0/items
204+
205+
else
206+
echo -e "\nUnclear model"
207+
fi
208+
209+
if [[ $MODEL_PATH == *"gemma"* ]]; then
210+
RUN_NAME=0
211+
212+
python3 MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml force_unroll=true model_name=${MAXTEXT_MODEL_NAME} async_checkpointing=false run_name=${RUN_NAME} load_parameters_path=${OUTPUT_CKPT_DIR_SCANNED}/0/items base_output_directory=${OUTPUT_CKPT_DIR_UNSCANNED}
213+
echo -e "\nCompleted unscanning checkpoint to ${OUTPUT_CKPT_DIR_UNSCANNED}/${RUN_NAME}/checkpoints/0/items"
214+
fi
98215
}
99216

100217
convert_pytorch_checkpoint() {
@@ -173,7 +290,7 @@ convert_pytorch_checkpoint() {
173290
}
174291

175292

176-
while getopts 'b:s:m:n:h:t:q:v:i:o:' flag; do
293+
while getopts 'b:s:m:n:h:t:q:v:i:o:u:' flag; do
177294
case "${flag}" in
178295
b) BUCKET_NAME="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;;
179296
s) INFERENCE_SERVER="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;;
@@ -185,6 +302,7 @@ while getopts 'b:s:m:n:h:t:q:v:i:o:' flag; do
185302
v) VERSION="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;;
186303
i) INPUT_DIRECTORY="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;;
187304
o) OUTPUT_DIRECTORY="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;;
305+
u) META_URL="$(echo ${OPTARG} | awk -F'=' '{print $2"="$3"="$4"="$5"="$6}')" ;;
188306
*) print_usage
189307
exit 1 ;;
190308
esac
@@ -197,8 +315,7 @@ case ${INFERENCE_SERVER} in
197315
jetstream-maxtext)
198316
check_gsbucket "$BUCKET_NAME"
199317
check_model_path "$MODEL_PATH"
200-
download_kaggle_checkpoint "$BUCKET_NAME" "$MODEL_PATH"
201-
convert_maxtext_checkpoint "$BUCKET_NAME" "$MODEL_NAME" "$VARIATION_NAME" "$MODEL_SIZE" "$VERSION"
318+
convert_maxtext_checkpoint "$BUCKET_NAME" "$MODEL_PATH" "$MODEL_NAME" "$OUTPUT_DIRECTORY" "$VERSION" "$HUGGINGFACE" "$META_URL"
202319
;;
203320
jetstream-pytorch)
204321
check_model_path "$MODEL_PATH"

tutorials-and-examples/inference-servers/jetstream/http-server/http_server.py

+27
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,33 @@ def root():
4848
)
4949
return response
5050

51+
@app.get("/healthcheck")
52+
async def healthcheck():
53+
try:
54+
request = jetstream_pb2.HealthCheckRequest()
55+
56+
options = [("grpc.keepalive_timeout_ms", 10000)]
57+
async with grpc.aio.insecure_channel("127.0.0.1:9000", options=options) as channel:
58+
stub = jetstream_pb2_grpc.OrchestratorStub(channel)
59+
response = stub.HealthCheck(request)
60+
response = await response
61+
62+
if response.is_live == False:
63+
raise fastapi.HTTPException(status_code=500, detail="Healthcheck failed, is_live = False")
64+
65+
is_live = {"is_live": response.is_live}
66+
response = {"response": is_live}
67+
68+
response = fastapi.Response(
69+
content=json.dumps(response, indent=4), media_type="application/json"
70+
)
71+
return response
72+
73+
except Exception as e:
74+
logging.exception("Exception in healthcheck")
75+
logging.exception(e)
76+
raise fastapi.HTTPException(status_code=500, detail="Healthcheck failed")
77+
5178

5279
@app.post("/generate", status_code=200)
5380
async def generate(request: GenerateRequest):

0 commit comments

Comments
 (0)