|
| 1 | +# Copyright 2024 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +from datetime import datetime |
| 15 | +import time |
| 16 | +import os |
| 17 | +import torch |
| 18 | +from datasets import load_dataset, Dataset |
| 19 | +from transformers import ( |
| 20 | + AutoModelForCausalLM, |
| 21 | + AutoTokenizer, |
| 22 | + BitsAndBytesConfig, |
| 23 | + TrainingArguments, |
| 24 | +) |
| 25 | +from peft import LoraConfig, PeftModel, get_peft_model |
| 26 | + |
| 27 | +from trl import SFTTrainer |
| 28 | + |
| 29 | + |
| 30 | + |
| 31 | +# The model that you want to train from the Hugging Face hub |
| 32 | +model_name = os.getenv("BASE_MODEL_NAME", "google/gemma-2b") |
| 33 | + |
| 34 | +# The instruction dataset to use |
| 35 | +dataset_name = "b-mc2/sql-create-context" |
| 36 | + |
| 37 | +# Fine-tuned model name |
| 38 | +new_model = os.getenv("MODEL_NAME", "gemma-2b-sql") |
| 39 | + |
| 40 | +################################################################################ |
| 41 | +# QLoRA parameters |
| 42 | +################################################################################ |
| 43 | + |
| 44 | +# LoRA attention dimension |
| 45 | +lora_r = int(os.getenv("LORA_R", "4")) |
| 46 | + |
| 47 | +# Alpha parameter for LoRA scaling |
| 48 | +lora_alpha = int(os.getenv("LORA_ALPHA", "8")) |
| 49 | + |
| 50 | +# Dropout probability for LoRA layers |
| 51 | +lora_dropout = 0.1 |
| 52 | + |
| 53 | +################################################################################ |
| 54 | +# bitsandbytes parameters |
| 55 | +################################################################################ |
| 56 | + |
| 57 | +# Activate 4-bit precision base model loading |
| 58 | +use_4bit = True |
| 59 | + |
| 60 | +# Compute dtype for 4-bit base models |
| 61 | +bnb_4bit_compute_dtype = "float16" |
| 62 | + |
| 63 | +# Quantization type (fp4 or nf4) |
| 64 | +bnb_4bit_quant_type = "nf4" |
| 65 | + |
| 66 | +# Activate nested quantization for 4-bit base models (double quantization) |
| 67 | +use_nested_quant = False |
| 68 | + |
| 69 | +################################################################################ |
| 70 | +# TrainingArguments parameters |
| 71 | +################################################################################ |
| 72 | + |
| 73 | +# Output directory where the model predictions and checkpoints will be stored |
| 74 | +output_dir = "/data/models/" + new_model |
| 75 | + |
| 76 | +# Number of training epochs |
| 77 | +num_train_epochs = 1 |
| 78 | + |
| 79 | +# Enable fp16/bf16 training (set bf16 to True with an A100) |
| 80 | +fp16 = True |
| 81 | +bf16 = False |
| 82 | + |
| 83 | +# Batch size per GPU for training |
| 84 | +per_device_train_batch_size = int(os.getenv("TRAIN_BATCH_SIZE", "1")) |
| 85 | + |
| 86 | +# Batch size per GPU for evaluation |
| 87 | +per_device_eval_batch_size = int(os.getenv("EVAL_BATCH_SIZE", "2")) |
| 88 | + |
| 89 | +# Number of update steps to accumulate the gradients for |
| 90 | +gradient_accumulation_steps = int(os.getenv("GRADIENT_ACCUMULATION_STEPS", "1")) |
| 91 | + |
| 92 | +# Enable gradient checkpointing |
| 93 | +gradient_checkpointing = True |
| 94 | + |
| 95 | +# Maximum gradient normal (gradient clipping) |
| 96 | +max_grad_norm = 0.3 |
| 97 | + |
| 98 | +# Initial learning rate (AdamW optimizer) |
| 99 | +learning_rate = 2e-4 |
| 100 | + |
| 101 | +# Weight decay to apply to all layers except bias/LayerNorm weights |
| 102 | +weight_decay = 0.001 |
| 103 | + |
| 104 | +# Optimizer to use |
| 105 | +optim = "paged_adamw_32bit" |
| 106 | + |
| 107 | +# Learning rate schedule |
| 108 | +lr_scheduler_type = "cosine" |
| 109 | + |
| 110 | +# Number of training steps (overrides num_train_epochs) |
| 111 | +max_steps = -1 |
| 112 | + |
| 113 | +# Ratio of steps for a linear warmup (from 0 to learning rate) |
| 114 | +warmup_ratio = 0.03 |
| 115 | + |
| 116 | +# Group sequences into batches with same length |
| 117 | +# Saves memory and speeds up training considerably |
| 118 | +group_by_length = True |
| 119 | + |
| 120 | +# Save checkpoint every X updates steps |
| 121 | +save_steps = 0 |
| 122 | + |
| 123 | +# Log every X updates steps |
| 124 | +logging_steps = int(os.getenv("LOGGING_STEPS", "50")) |
| 125 | + |
| 126 | +################################################################################ |
| 127 | +# SFT parameters |
| 128 | +################################################################################ |
| 129 | + |
| 130 | +# Maximum sequence length to use |
| 131 | +max_seq_length = int(os.getenv("MAX_SEQ_LENGTH", "512")) |
| 132 | + |
| 133 | +# Pack multiple short examples in the same input sequence to increase efficiency |
| 134 | +packing = False |
| 135 | + |
| 136 | +# Load the entire model on the GPU 0 |
| 137 | +device_map = {'':torch.cuda.current_device()} |
| 138 | + |
| 139 | +# Set limit to a positive number |
| 140 | +limit = int(os.getenv("DATASET_LIMIT", "5000")) |
| 141 | + |
| 142 | +dataset = load_dataset(dataset_name, split="train") |
| 143 | +if limit != -1: |
| 144 | + dataset = dataset.shuffle(seed=42).select(range(limit)) |
| 145 | + |
| 146 | + |
| 147 | +def transform(data): |
| 148 | + question = data['question'] |
| 149 | + context = data['context'] |
| 150 | + answer = data['answer'] |
| 151 | + template = "Question: {question}\nContext: {context}\nAnswer: {answer}" |
| 152 | + return {'text': template.format(question=question, context=context, answer=answer)} |
| 153 | + |
| 154 | + |
| 155 | +transformed = dataset.map(transform) |
| 156 | + |
| 157 | +# Load tokenizer and model with QLoRA configuration |
| 158 | +compute_dtype = getattr(torch, bnb_4bit_compute_dtype) |
| 159 | + |
| 160 | +bnb_config = BitsAndBytesConfig( |
| 161 | + load_in_4bit=use_4bit, |
| 162 | + bnb_4bit_quant_type=bnb_4bit_quant_type, |
| 163 | + bnb_4bit_compute_dtype=compute_dtype, |
| 164 | + bnb_4bit_use_double_quant=use_nested_quant, |
| 165 | +) |
| 166 | + |
| 167 | +# Check GPU compatibility with bfloat16 |
| 168 | +if compute_dtype == torch.float16 and use_4bit: |
| 169 | + major, _ = torch.cuda.get_device_capability() |
| 170 | + if major >= 8: |
| 171 | + print("=" * 80) |
| 172 | + print("Your GPU supports bfloat16") |
| 173 | + print("=" * 80) |
| 174 | + |
| 175 | +# Load base model |
| 176 | +# model = AutoModelForCausalLM.from_pretrained("google/gemma-7b") |
| 177 | +model = AutoModelForCausalLM.from_pretrained( |
| 178 | + model_name, |
| 179 | + quantization_config=bnb_config, |
| 180 | + device_map=device_map, |
| 181 | + torch_dtype=torch.float16, |
| 182 | +) |
| 183 | +model.config.use_cache = False |
| 184 | +model.config.pretraining_tp = 1 |
| 185 | + |
| 186 | +# Load LLaMA tokenizer |
| 187 | +tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
| 188 | +tokenizer.pad_token = tokenizer.eos_token |
| 189 | +tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training |
| 190 | + |
| 191 | +# Load LoRA configuration |
| 192 | +peft_config = LoraConfig( |
| 193 | + lora_alpha=lora_alpha, |
| 194 | + lora_dropout=lora_dropout, |
| 195 | + r=lora_r, |
| 196 | + bias="none", |
| 197 | + task_type="CAUSAL_LM", |
| 198 | + target_modules=["q_proj", "v_proj"] |
| 199 | +) |
| 200 | + |
| 201 | +# Set training parameters |
| 202 | +training_arguments = TrainingArguments( |
| 203 | + output_dir=output_dir, |
| 204 | + num_train_epochs=num_train_epochs, |
| 205 | + per_device_train_batch_size=per_device_train_batch_size, |
| 206 | + gradient_accumulation_steps=gradient_accumulation_steps, |
| 207 | + optim=optim, |
| 208 | + save_steps=save_steps, |
| 209 | + logging_steps=logging_steps, |
| 210 | + learning_rate=learning_rate, |
| 211 | + weight_decay=weight_decay, |
| 212 | + fp16=fp16, |
| 213 | + bf16=bf16, |
| 214 | + max_grad_norm=max_grad_norm, |
| 215 | + max_steps=max_steps, |
| 216 | + warmup_ratio=warmup_ratio, |
| 217 | + group_by_length=group_by_length, |
| 218 | + lr_scheduler_type=lr_scheduler_type, |
| 219 | +) |
| 220 | + |
| 221 | +trainer = SFTTrainer( |
| 222 | + model=model, |
| 223 | + train_dataset=transformed, |
| 224 | + peft_config=peft_config, |
| 225 | + dataset_text_field="text", |
| 226 | + max_seq_length=max_seq_length, |
| 227 | + tokenizer=tokenizer, |
| 228 | + args=training_arguments, |
| 229 | + packing=packing, |
| 230 | +) |
| 231 | + |
| 232 | + |
| 233 | +trainer.train() |
| 234 | +trainer.model.save_pretrained(new_model) |
| 235 | + |
| 236 | + |
| 237 | +# Reload model in FP16 and merge it with LoRA weights |
| 238 | +base_model = AutoModelForCausalLM.from_pretrained( |
| 239 | + model_name, |
| 240 | + low_cpu_mem_usage=True, |
| 241 | + return_dict=True, |
| 242 | + torch_dtype=torch.float16, |
| 243 | + device_map=device_map, |
| 244 | +) |
| 245 | + |
| 246 | +model = PeftModel.from_pretrained(base_model, new_model) |
| 247 | +model = model.merge_and_unload() |
| 248 | + |
| 249 | +# Save model to disk |
| 250 | +model.save_pretrained(output_dir) |
| 251 | +# Save the tokenizer to disk |
| 252 | +tokenizer.save_pretrained(output_dir) |
0 commit comments