1
1
#! /bin/bash
2
2
3
+ set -e
3
4
export KAGGLE_CONFIG_DIR=" /kaggle"
4
5
export HUGGINGFACE_TOKEN_DIR=" /huggingface"
5
6
INFERENCE_SERVER=" jetstream-maxtext"
@@ -19,13 +20,15 @@ check_gsbucket() {
19
20
BUCKET_NAME=$1
20
21
if [ -z $BUCKET_NAME ]; then
21
22
echo " BUCKET_NAME is empty, please provide a GSBucket"
23
+ exit 1
22
24
fi
23
25
}
24
26
25
27
check_model_path () {
26
28
MODEL_PATH=$1
27
29
if [ -z $MODEL_PATH ]; then
28
30
echo " MODEL_PATH is empty, please provide the model path"
31
+ exit 1
29
32
fi
30
33
}
31
34
@@ -49,10 +52,15 @@ download_huggingface_checkpoint() {
49
52
MODEL_NAME=$2
50
53
51
54
INPUT_CKPT_DIR_LOCAL=/base/
52
- mkdir /base/
55
+
56
+ if [ ! -d " /base" ]; then
57
+ mkdir /base/
58
+ fi
53
59
huggingface-cli login --token $( cat ${HUGGINGFACE_TOKEN_DIR} /HUGGINGFACE_TOKEN)
54
60
huggingface-cli download ${MODEL_PATH} --local-dir ${INPUT_CKPT_DIR_LOCAL}
55
61
62
+ echo " Completed downloading model ${MODEL_PATH} "
63
+
56
64
if [[ $MODEL_NAME == * " llama" * ]]; then
57
65
if [[ $MODEL_NAME == " llama-2" ]]; then
58
66
TOKENIZER_PATH=/base/tokenizer.model
@@ -64,37 +72,146 @@ download_huggingface_checkpoint() {
64
72
fi
65
73
elif [[ $MODEL_NAME == * " gemma" * ]]; then
66
74
TOKENIZER_PATH=/base/tokenizer.model
75
+ if [[ $MODEL_PATH == * " gemma-2b-it-pytorch" * ]]; then
76
+ huggingface-cli download google/gemma-2b-pytorch config.json --local-dir ${INPUT_CKPT_DIR_LOCAL}
77
+ fi
67
78
else
68
79
echo -e " Unclear of tokenizer.model for ${MODEL_NAME} . May have to manually upload."
69
80
fi
70
81
}
71
82
83
+ download_meta_checkpoint () {
84
+ META_URL=$1
85
+ MODEL_PATH=$2
86
+ echo -e " $META_URL " | llama download --source meta --model-id $MODEL_PATH
87
+ }
88
+
72
89
convert_maxtext_checkpoint () {
73
90
BUCKET_NAME=$1
74
- MODEL_NAME=$2
75
- VARIATION_NAME=$3
76
- MODEL_SIZE=$4
77
- MAXTEXT_VERSION=$5
78
-
79
- if [ -z $MAXTEXT_VERSION ]; then
80
- MAXTEXT_VERSION=jetstream-v0.2.2
91
+ MODEL_PATH=$2
92
+ MODEL_NAME=$3
93
+ OUTPUT_CKPT_DIR=$4
94
+ VERSION=$5
95
+ HUGGINGFACE=$6
96
+ META_URL=$7
97
+
98
+ echo -e " \nbucket name=${BUCKET_NAME} "
99
+ echo -e " \nmodel path=${MODEL_PATH} "
100
+ echo -e " \nmodel name=${MODEL_NAME} "
101
+ echo -e " \nversion=${VERSION} "
102
+ echo -e " \noutput ckpt dir=${OUTPUT_CKPT_DIR} "
103
+ echo -e " \nhuggingface=${HUGGINGFACE} "
104
+ echo -e " \nurl=${META_URL} "
105
+
106
+ if [ -z $VERSION ]; then
107
+ VERSION=jetstream-v0.2.2
81
108
fi
82
109
83
110
git clone https://github.com/google/maxtext.git
84
111
85
112
# checkout stable MaxText commit
86
113
cd maxtext
87
- git checkout ${MAXTEXT_VERSION }
114
+ git checkout ${VERSION }
88
115
python3 -m pip install -r requirements.txt
116
+
117
+ if [ $VERSION == " jetstream-v0.2.2" || $VERSION == " jetstream-v0.2.1" || $VERSION == " jetstream-v0.2.0" ]; then
118
+ pip3 install orbax-checkpoint==0.5.20
119
+ else
120
+ pip3 install orbax-checkpoint==0.6.0
121
+ fi
122
+
89
123
echo -e " \nCloned MaxText repository and completed installing requirements"
90
124
91
- python3 MaxText/convert_gemma_chkpt.py --base_model_path gs://${BUCKET_NAME} /base/${MODEL_NAME} _${VARIATION_NAME} /${VARIATION_NAME} --maxtext_model_path gs://${BUCKET_NAME} /final/scanned/${MODEL_NAME} _${VARIATION_NAME} --model_size ${MODEL_SIZE}
92
- echo -e " \nCompleted conversion of checkpoint to gs://${BUCKET_NAME} /final/scanned/${MODEL_NAME} _${VARIATION_NAME} "
125
+ if [[ $MODEL_PATH == * " gemma" * ]]; then
126
+ download_kaggle_checkpoint " $BUCKET_NAME " " $MODEL_PATH "
127
+ OUTPUT_CKPT_DIR_SCANNED=gs://${BUCKET_NAME} /final/scanned/${MODEL_NAME} _${VARIATION_NAME}
128
+ OUTPUT_CKPT_DIR_UNSCANNED=gs://${BUCKET_NAME} /final/unscanned/${MODEL_NAME} _${VARIATION_NAME}
129
+
130
+ python3 MaxText/convert_gemma_chkpt.py --base_model_path gs://${BUCKET_NAME} /base/${MODEL_NAME} _${VARIATION_NAME} /${VARIATION_NAME} --maxtext_model_path=${OUTPUT_CKPT_DIR_SCANNED} --model_size ${MODEL_SIZE}
131
+ echo -e " \nCompleted conversion of checkpoint to gs://${BUCKET_NAME} /final/scanned/${MODEL_NAME} _${VARIATION_NAME} "
132
+
133
+ MAXTEXT_MODEL_NAME=${MODEL_NAME} -${MODEL_SIZE}
134
+
135
+ elif [[ $MODEL_PATH == * " Llama" * ]]; then
136
+
137
+ if [ $HUGGINGFACE == " True" ]; then
138
+ echo " Checkpoint weights are from HuggingFace"
139
+ download_huggingface_checkpoint " $MODEL_PATH " " $MODEL_NAME "
93
140
94
- RUN_NAME=0
141
+ else
142
+ echo " Checkpoint weights are from Meta, use llama CLI"
143
+
144
+ if [ -z $META_URL ]; then
145
+ echo " META_URL is empty, please provide the Meta url by visiting https://www.llama.com/llama-downloads/ and agreeing to the Terms and Conditions."
146
+ exit 1
147
+ fi
148
+ echo " META_URL: $META_URL "
149
+
150
+ INPUT_CKPT_DIR_LOCAL=/root/.llama/checkpoints/$MODEL_PATH /
151
+ download_meta_checkpoint " $META_URL " " $MODEL_PATH "
152
+ fi
95
153
96
- python3 MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml force_unroll=true model_name=${MODEL_NAME} -${MODEL_SIZE} async_checkpointing=false run_name=${RUN_NAME} load_parameters_path=gs://${BUCKET_NAME} /final/scanned/${MODEL_NAME} _${VARIATION_NAME} /0/items base_output_directory=gs://${BUCKET_NAME} /final/unscanned/${MODEL_NAME} _${VARIATION_NAME}
97
- echo -e " \nCompleted unscanning checkpoint to gs://${BUCKET_NAME} /final/unscanned/${MODEL_NAME} _${VARIATION_NAME} /${RUN_NAME} /checkpoints/0/items"
154
+ echo " Setting model size for $MODEL_PATH "
155
+ if [[ $MODEL_NAME == " llama-2" ]]; then
156
+ if [[ $MODEL_PATH == * " 7B" * ]] || [[ $MODEL_PATH == * " 7b" * ]]; then
157
+ MODEL_SIZE=" llama2-7b"
158
+ elif [[ $MODEL_PATH == * " 13B" * ]] || [[ $MODEL_PATH == * " 13b" * ]]; then
159
+ MODEL_SIZE=" llama2-13b"
160
+ elif [[ $MODEL_PATH == * " 70B" * ]] || [[ $MODEL_PATH == * " 70b" * ]]; then
161
+ MODEL_SIZE=" llama2-70b"
162
+ elif [[ $MODEL_PATH == * " 405B" * ]] || [[ $MODEL_PATH == * " 405b" * ]]; then
163
+ MODEL_SIZE=" llama2-405b"
164
+ else
165
+ echo -e " \nUnclear llama2 model: $MODEL_PATH "
166
+ fi
167
+
168
+ elif [[ $MODEL_NAME == " llama-3" ]]; then
169
+ if [[ $MODEL_PATH == * " 8B" * ]] || [[ $MODEL_PATH == * " 8b" * ]]; then
170
+ MODEL_SIZE=" llama3-8b"
171
+ elif [[ $MODEL_PATH == * " 70B" * ]] || [[ $MODEL_PATH == * " 70b" * ]]; then
172
+ MODEL_SIZE=" llama3-70b"
173
+ elif [[ $MODEL_PATH == * " 405B" * ]] || [[ $MODEL_PATH == * " 405b" * ]]; then
174
+ MODEL_SIZE=" llama3-405b"
175
+ else
176
+ echo -e " \nUnclear llama3 model: $MODEL_PATH "
177
+ fi
178
+
179
+ else
180
+ echo -e " \nUnclear llama model"
181
+ fi
182
+
183
+ echo " Model size for $MODEL_PATH is $MODEL_SIZE "
184
+
185
+ OUTPUT_CKPT_DIR_SCANNED=${OUTPUT_CKPT_DIR} /scanned
186
+ OUTPUT_CKPT_DIR_UNSCANNED=${OUTPUT_CKPT_DIR} /unscanned
187
+
188
+ TOKENIZER_PATH=${INPUT_CKPT_DIR_LOCAL} /tokenizer.model
189
+
190
+ pip3 install torch
191
+ echo -e " \ninput dir=${INPUT_CKPT_DIR_LOCAL} "
192
+ echo -e " \nmaxtext model path=${OUTPUT_CKPT_DIR_UNSCANNED} "
193
+ echo -e " \nmodel path=${MODEL_PATH} "
194
+ echo -e " \nmodel size=${MODEL_SIZE} "
195
+
196
+ cd /maxtext/
197
+ python3 MaxText/llama_ckpt_conversion_inference_only.py --base-model-path ${INPUT_CKPT_DIR_LOCAL} --maxtext-model-path ${OUTPUT_CKPT_DIR_UNSCANNED} --model-size ${MODEL_SIZE}
198
+ echo -e " \nCompleted conversion of checkpoint to ${OUTPUT_CKPT_DIR_UNSCANNED} /0/items"
199
+
200
+ gcloud storage cp ${TOKENIZER_PATH} ${OUTPUT_CKPT_DIR_UNSCANNED}
201
+
202
+ touch commit_success.txt
203
+ gcloud storage cp commit_success.txt ${OUTPUT_CKPT_DIR_UNSCANNED} /0/items
204
+
205
+ else
206
+ echo -e " \nUnclear model"
207
+ fi
208
+
209
+ if [[ $MODEL_PATH == * " gemma" * ]]; then
210
+ RUN_NAME=0
211
+
212
+ python3 MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml force_unroll=true model_name=${MAXTEXT_MODEL_NAME} async_checkpointing=false run_name=${RUN_NAME} load_parameters_path=${OUTPUT_CKPT_DIR_SCANNED} /0/items base_output_directory=${OUTPUT_CKPT_DIR_UNSCANNED}
213
+ echo -e " \nCompleted unscanning checkpoint to ${OUTPUT_CKPT_DIR_UNSCANNED} /${RUN_NAME} /checkpoints/0/items"
214
+ fi
98
215
}
99
216
100
217
convert_pytorch_checkpoint () {
@@ -173,7 +290,7 @@ convert_pytorch_checkpoint() {
173
290
}
174
291
175
292
176
- while getopts ' b:s:m:n:h:t:q:v:i:o:' flag; do
293
+ while getopts ' b:s:m:n:h:t:q:v:i:o:u: ' flag; do
177
294
case " ${flag} " in
178
295
b) BUCKET_NAME=" $( echo ${OPTARG} | awk -F' =' ' {print $2}' ) " ;;
179
296
s) INFERENCE_SERVER=" $( echo ${OPTARG} | awk -F' =' ' {print $2}' ) " ;;
@@ -185,6 +302,7 @@ while getopts 'b:s:m:n:h:t:q:v:i:o:' flag; do
185
302
v) VERSION=" $( echo ${OPTARG} | awk -F' =' ' {print $2}' ) " ;;
186
303
i) INPUT_DIRECTORY=" $( echo ${OPTARG} | awk -F' =' ' {print $2}' ) " ;;
187
304
o) OUTPUT_DIRECTORY=" $( echo ${OPTARG} | awk -F' =' ' {print $2}' ) " ;;
305
+ u) META_URL=" $( echo ${OPTARG} | awk -F' =' ' {print $2"="$3"="$4"="$5"="$6}' ) " ;;
188
306
* ) print_usage
189
307
exit 1 ;;
190
308
esac
@@ -197,8 +315,7 @@ case ${INFERENCE_SERVER} in
197
315
jetstream-maxtext)
198
316
check_gsbucket " $BUCKET_NAME "
199
317
check_model_path " $MODEL_PATH "
200
- download_kaggle_checkpoint " $BUCKET_NAME " " $MODEL_PATH "
201
- convert_maxtext_checkpoint " $BUCKET_NAME " " $MODEL_NAME " " $VARIATION_NAME " " $MODEL_SIZE " " $VERSION "
318
+ convert_maxtext_checkpoint " $BUCKET_NAME " " $MODEL_PATH " " $MODEL_NAME " " $OUTPUT_DIRECTORY " " $VERSION " " $HUGGINGFACE " " $META_URL "
202
319
;;
203
320
jetstream-pytorch)
204
321
check_model_path " $MODEL_PATH "
0 commit comments