Skip to content

Commit 6bfa661

Browse files
volatilemolotovArthurKamalovvpolikarpov-akvelon
committed
Skypilot dws kueue (GoogleCloudPlatform#942)
* SkyPilot with DWS and Kueue tutorial * typo * fix topics * fix backend placeholder * add missing terrafrom outputs * update with files needed for finetune and serve * add mount for text classification files for train task * change example environment structure * move serve to l4 * added clarification on experimental section for skypilot task/service definition files * readme update * Minor README fixes * Update README.md * Update README.md * minor updates * minor updates * newlines added * added kueue quota warning * minor whitespace * minor whitespace * main.tf fmt issues * terraform fmt newlines * terraform fmt newlines * Update tutorials-and-examples/skypilot/dws-and-kueue/example_environment.tfvars Co-authored-by: Vasilii Polikarpov <[email protected]> * Update tutorials-and-examples/skypilot/dws-and-kueue/README.md Co-authored-by: Vasilii Polikarpov <[email protected]> * Update tutorials-and-examples/skypilot/dws-and-kueue/README.md Co-authored-by: Vasilii Polikarpov <[email protected]> * remove repeated steps --------- Co-authored-by: ArthurKamalov <[email protected]> Co-authored-by: Vasilii Polikarpov <[email protected]>
1 parent 06a1d11 commit 6bfa661

File tree

17 files changed

+1375
-0
lines changed

17 files changed

+1375
-0
lines changed

infrastructure/outputs.tf

+11
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,17 @@ output "ca_certificate" {
4545

4646
}
4747

48+
output "service_account" {
49+
value = var.create_cluster && var.autopilot_cluster && var.private_cluster ? module.private-gke-autopilot-cluster[0].service_account : (
50+
var.create_cluster && !var.autopilot_cluster && var.private_cluster ? module.private-gke-standard-cluster[0].service_account : (
51+
var.create_cluster && var.autopilot_cluster && !var.private_cluster ? module.public-gke-autopilot-cluster[0].service_account : (
52+
var.create_cluster && !var.autopilot_cluster && !var.private_cluster ? module.public-gke-standard-cluster[0].service_account :
53+
"")))
54+
sensitive = true
55+
depends_on = [module.private-gke-autopilot-cluster, module.private-gke-standard-cluster, module.public-gke-autopilot-cluster, module.public-gke-standard-cluster]
56+
57+
}
58+
4859
output "private_cluster" {
4960
value = var.private_cluster
5061
}

modules/gke-autopilot-private-cluster/outputs.tf

+4
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,8 @@ output "endpoint" {
2222

2323
output "ca_certificate" {
2424
value = module.gke.ca_certificate
25+
}
26+
27+
output "service_account" {
28+
value = module.gke.service_account
2529
}

modules/gke-autopilot-public-cluster/outputs.tf

+4
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,8 @@ output "endpoint" {
2222

2323
output "ca_certificate" {
2424
value = module.gke.ca_certificate
25+
}
26+
27+
output "service_account" {
28+
value = module.gke.service_account
2529
}

modules/gke-standard-private-cluster/outputs.tf

+4
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,8 @@ output "endpoint" {
2222

2323
output "ca_certificate" {
2424
value = module.gke.ca_certificate
25+
}
26+
27+
output "service_account" {
28+
value = module.gke.service_account
2529
}

modules/gke-standard-public-cluster/outputs.tf

+5
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,8 @@ output "endpoint" {
2323
output "ca_certificate" {
2424
value = module.gke.ca_certificate
2525
}
26+
27+
28+
output "service_account" {
29+
value = module.gke.service_account
30+
}

tutorials-and-examples/skypilot/dws-and-kueue/README.md

+429
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# terraform {
16+
# backend "gcs" {
17+
# bucket = "BUCKET_NAME"
18+
# prefix = "terraform/state"
19+
# }
20+
# }
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from datetime import datetime
15+
import time
16+
import os
17+
import torch
18+
from datasets import load_dataset, Dataset
19+
from transformers import (
20+
AutoModelForCausalLM,
21+
AutoTokenizer,
22+
BitsAndBytesConfig,
23+
TrainingArguments,
24+
)
25+
from peft import LoraConfig, PeftModel, get_peft_model
26+
27+
from trl import SFTTrainer
28+
29+
30+
31+
# The model that you want to train from the Hugging Face hub
32+
model_name = os.getenv("BASE_MODEL_NAME", "google/gemma-2b")
33+
34+
# The instruction dataset to use
35+
dataset_name = "b-mc2/sql-create-context"
36+
37+
# Fine-tuned model name
38+
new_model = os.getenv("MODEL_NAME", "gemma-2b-sql")
39+
40+
################################################################################
41+
# QLoRA parameters
42+
################################################################################
43+
44+
# LoRA attention dimension
45+
lora_r = int(os.getenv("LORA_R", "4"))
46+
47+
# Alpha parameter for LoRA scaling
48+
lora_alpha = int(os.getenv("LORA_ALPHA", "8"))
49+
50+
# Dropout probability for LoRA layers
51+
lora_dropout = 0.1
52+
53+
################################################################################
54+
# bitsandbytes parameters
55+
################################################################################
56+
57+
# Activate 4-bit precision base model loading
58+
use_4bit = True
59+
60+
# Compute dtype for 4-bit base models
61+
bnb_4bit_compute_dtype = "float16"
62+
63+
# Quantization type (fp4 or nf4)
64+
bnb_4bit_quant_type = "nf4"
65+
66+
# Activate nested quantization for 4-bit base models (double quantization)
67+
use_nested_quant = False
68+
69+
################################################################################
70+
# TrainingArguments parameters
71+
################################################################################
72+
73+
# Output directory where the model predictions and checkpoints will be stored
74+
output_dir = "/data/models/" + new_model
75+
76+
# Number of training epochs
77+
num_train_epochs = 1
78+
79+
# Enable fp16/bf16 training (set bf16 to True with an A100)
80+
fp16 = True
81+
bf16 = False
82+
83+
# Batch size per GPU for training
84+
per_device_train_batch_size = int(os.getenv("TRAIN_BATCH_SIZE", "1"))
85+
86+
# Batch size per GPU for evaluation
87+
per_device_eval_batch_size = int(os.getenv("EVAL_BATCH_SIZE", "2"))
88+
89+
# Number of update steps to accumulate the gradients for
90+
gradient_accumulation_steps = int(os.getenv("GRADIENT_ACCUMULATION_STEPS", "1"))
91+
92+
# Enable gradient checkpointing
93+
gradient_checkpointing = True
94+
95+
# Maximum gradient normal (gradient clipping)
96+
max_grad_norm = 0.3
97+
98+
# Initial learning rate (AdamW optimizer)
99+
learning_rate = 2e-4
100+
101+
# Weight decay to apply to all layers except bias/LayerNorm weights
102+
weight_decay = 0.001
103+
104+
# Optimizer to use
105+
optim = "paged_adamw_32bit"
106+
107+
# Learning rate schedule
108+
lr_scheduler_type = "cosine"
109+
110+
# Number of training steps (overrides num_train_epochs)
111+
max_steps = -1
112+
113+
# Ratio of steps for a linear warmup (from 0 to learning rate)
114+
warmup_ratio = 0.03
115+
116+
# Group sequences into batches with same length
117+
# Saves memory and speeds up training considerably
118+
group_by_length = True
119+
120+
# Save checkpoint every X updates steps
121+
save_steps = 0
122+
123+
# Log every X updates steps
124+
logging_steps = int(os.getenv("LOGGING_STEPS", "50"))
125+
126+
################################################################################
127+
# SFT parameters
128+
################################################################################
129+
130+
# Maximum sequence length to use
131+
max_seq_length = int(os.getenv("MAX_SEQ_LENGTH", "512"))
132+
133+
# Pack multiple short examples in the same input sequence to increase efficiency
134+
packing = False
135+
136+
# Load the entire model on the GPU 0
137+
device_map = {'':torch.cuda.current_device()}
138+
139+
# Set limit to a positive number
140+
limit = int(os.getenv("DATASET_LIMIT", "5000"))
141+
142+
dataset = load_dataset(dataset_name, split="train")
143+
if limit != -1:
144+
dataset = dataset.shuffle(seed=42).select(range(limit))
145+
146+
147+
def transform(data):
148+
question = data['question']
149+
context = data['context']
150+
answer = data['answer']
151+
template = "Question: {question}\nContext: {context}\nAnswer: {answer}"
152+
return {'text': template.format(question=question, context=context, answer=answer)}
153+
154+
155+
transformed = dataset.map(transform)
156+
157+
# Load tokenizer and model with QLoRA configuration
158+
compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
159+
160+
bnb_config = BitsAndBytesConfig(
161+
load_in_4bit=use_4bit,
162+
bnb_4bit_quant_type=bnb_4bit_quant_type,
163+
bnb_4bit_compute_dtype=compute_dtype,
164+
bnb_4bit_use_double_quant=use_nested_quant,
165+
)
166+
167+
# Check GPU compatibility with bfloat16
168+
if compute_dtype == torch.float16 and use_4bit:
169+
major, _ = torch.cuda.get_device_capability()
170+
if major >= 8:
171+
print("=" * 80)
172+
print("Your GPU supports bfloat16")
173+
print("=" * 80)
174+
175+
# Load base model
176+
# model = AutoModelForCausalLM.from_pretrained("google/gemma-7b")
177+
model = AutoModelForCausalLM.from_pretrained(
178+
model_name,
179+
quantization_config=bnb_config,
180+
device_map=device_map,
181+
torch_dtype=torch.float16,
182+
)
183+
model.config.use_cache = False
184+
model.config.pretraining_tp = 1
185+
186+
# Load LLaMA tokenizer
187+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
188+
tokenizer.pad_token = tokenizer.eos_token
189+
tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training
190+
191+
# Load LoRA configuration
192+
peft_config = LoraConfig(
193+
lora_alpha=lora_alpha,
194+
lora_dropout=lora_dropout,
195+
r=lora_r,
196+
bias="none",
197+
task_type="CAUSAL_LM",
198+
target_modules=["q_proj", "v_proj"]
199+
)
200+
201+
# Set training parameters
202+
training_arguments = TrainingArguments(
203+
output_dir=output_dir,
204+
num_train_epochs=num_train_epochs,
205+
per_device_train_batch_size=per_device_train_batch_size,
206+
gradient_accumulation_steps=gradient_accumulation_steps,
207+
optim=optim,
208+
save_steps=save_steps,
209+
logging_steps=logging_steps,
210+
learning_rate=learning_rate,
211+
weight_decay=weight_decay,
212+
fp16=fp16,
213+
bf16=bf16,
214+
max_grad_norm=max_grad_norm,
215+
max_steps=max_steps,
216+
warmup_ratio=warmup_ratio,
217+
group_by_length=group_by_length,
218+
lr_scheduler_type=lr_scheduler_type,
219+
)
220+
221+
trainer = SFTTrainer(
222+
model=model,
223+
train_dataset=transformed,
224+
peft_config=peft_config,
225+
dataset_text_field="text",
226+
max_seq_length=max_seq_length,
227+
tokenizer=tokenizer,
228+
args=training_arguments,
229+
packing=packing,
230+
)
231+
232+
233+
trainer.train()
234+
trainer.model.save_pretrained(new_model)
235+
236+
237+
# Reload model in FP16 and merge it with LoRA weights
238+
base_model = AutoModelForCausalLM.from_pretrained(
239+
model_name,
240+
low_cpu_mem_usage=True,
241+
return_dict=True,
242+
torch_dtype=torch.float16,
243+
device_map=device_map,
244+
)
245+
246+
model = PeftModel.from_pretrained(base_model, new_model)
247+
model = model.merge_and_unload()
248+
249+
# Save model to disk
250+
model.save_pretrained(output_dir)
251+
# Save the tokenizer to disk
252+
tokenizer.save_pretrained(output_dir)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
project_id = "skypilot_project"
17+
create_cluster = true
18+
cluster_name = "skypilot-test"
19+
cluster_location = "us-central1"
20+
enable_gpu = true
21+
create_service_account = false
22+
create_brand = false
23+
create_gcs_bucket = true
24+
gcs_bucket = "skypilot-model-bucket"
25+
26+
# For Autopilot clusters
27+
autopilot_cluster = true
28+
29+
# For Standard clusters, configure GPU node pools:
30+
#autopilot_cluster = false
31+
32+
# If using Standard cluster please uncomment the
33+
# following gpu_pools block to enable queued_provisioning
34+
# on the node pool
35+
# gpu_pools = [{
36+
# name = "gpu-pool"
37+
# queued_provisioning = true
38+
# machine_type = "g2-standard-24"
39+
# disk_type = "pd-balanced"
40+
# autoscaling = true
41+
# min_count = 0
42+
# max_count = 3
43+
# initial_node_count = 0
44+
# }]

0 commit comments

Comments
 (0)