Skip to content

Commit 28925e9

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: LLM - Added support for model distillation
PiperOrigin-RevId: 590578502
1 parent cfc5cba commit 28925e9

File tree

3 files changed

+386
-21
lines changed

3 files changed

+386
-21
lines changed

tests/unit/aiplatform/test_language_models.py

+234
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,124 @@ def reverse_string_2(s):""",
758758
"pipelineSpec": json.loads(_TEST_EVAL_PIPELINE_SPEC_JSON),
759759
}
760760
)
761+
_TEST_DISTILLATION_PIPELINE_SPEC = {
762+
"components": {},
763+
"pipelineInfo": {
764+
"description": "Vertex kfp pipeline for distillation.",
765+
"name": "distillation",
766+
},
767+
"root": {
768+
"dag": {"tasks": {}},
769+
"inputDefinitions": {
770+
"parameters": {
771+
"accelerator_type": {
772+
"defaultValue": "GPU",
773+
"isOptional": True,
774+
"parameterType": "STRING",
775+
},
776+
"api_endpoint": {
777+
"defaultValue": "aiplatform.googleapis.com/ui",
778+
"isOptional": True,
779+
"parameterType": "STRING",
780+
},
781+
"dataset_uri": {"parameterType": "STRING"},
782+
"enable_checkpoint_selection": {
783+
"defaultValue": "default",
784+
"isOptional": True,
785+
"parameterType": "STRING",
786+
},
787+
"enable_early_stopping": {
788+
"defaultValue": True,
789+
"isOptional": True,
790+
"parameterType": "BOOLEAN",
791+
},
792+
"encryption_spec_key_name": {
793+
"defaultValue": "",
794+
"isOptional": True,
795+
"parameterType": "STRING",
796+
},
797+
"evaluation_data_uri": {
798+
"defaultValue": "",
799+
"isOptional": True,
800+
"parameterType": "STRING",
801+
},
802+
"evaluation_interval": {
803+
"defaultValue": 100,
804+
"isOptional": True,
805+
"parameterType": "NUMBER_INTEGER",
806+
},
807+
"evaluation_output_root_dir": {
808+
"defaultValue": "",
809+
"isOptional": True,
810+
"parameterType": "STRING",
811+
},
812+
"learning_rate_multiplier": {
813+
"defaultValue": 1,
814+
"isOptional": True,
815+
"parameterType": "NUMBER_DOUBLE",
816+
},
817+
"location": {
818+
"defaultValue": "",
819+
"isOptional": True,
820+
"parameterType": "STRING",
821+
},
822+
"max_context_length": {
823+
"defaultValue": "",
824+
"isOptional": True,
825+
"parameterType": "STRING",
826+
},
827+
"model_display_name": {
828+
"defaultValue": "distilled-student-model",
829+
"isOptional": True,
830+
"parameterType": "STRING",
831+
},
832+
"project": {"parameterType": "STRING"},
833+
"student_model_reference": {
834+
"defaultValue": "text-bison@002",
835+
"isOptional": True,
836+
"parameterType": "STRING",
837+
},
838+
"teacher_model_reference": {
839+
"defaultValue": "text-unicorn@001",
840+
"isOptional": True,
841+
"parameterType": "STRING",
842+
},
843+
"temperature": {
844+
"defaultValue": 0,
845+
"isOptional": True,
846+
"parameterType": "NUMBER_DOUBLE",
847+
},
848+
"tensorboard_resource_id": {
849+
"defaultValue": "",
850+
"isOptional": True,
851+
"parameterType": "STRING",
852+
},
853+
"tpu_training_skip_cmek": {
854+
"defaultValue": False,
855+
"isOptional": True,
856+
"parameterType": "BOOLEAN",
857+
},
858+
"train_steps": {
859+
"defaultValue": 300,
860+
"isOptional": True,
861+
"parameterType": "NUMBER_INTEGER",
862+
},
863+
"version": {
864+
"defaultValue": "latest",
865+
"isOptional": True,
866+
"parameterType": "STRING",
867+
},
868+
}
869+
},
870+
},
871+
"schemaVersion": "2.1.0",
872+
"sdkVersion": "kfp-2.4.0",
873+
}
874+
875+
_TEST_DISTILLATION_PIPELINE_SPEC_JSON = json.dumps(
876+
_TEST_DISTILLATION_PIPELINE_SPEC,
877+
)
878+
761879

762880
# Eval classification spec
763881

@@ -875,6 +993,10 @@ def reverse_string_2(s):""",
875993
}
876994
)
877995

996+
_URL_DATA = {
997+
"https://us-kfp.pkg.dev/ml-pipeline/research/distillation/v1.0.0": _TEST_DISTILLATION_PIPELINE_SPEC_JSON,
998+
}
999+
8781000

8791001
@pytest.fixture
8801002
def mock_pipeline_bucket_exists():
@@ -1225,6 +1347,19 @@ def mock_request_urlopen_eval_classification(
12251347
yield request.param, mock_urlopen
12261348

12271349

1350+
@pytest.fixture
1351+
def mock_urllib_request_urlopen(request: str) -> Tuple[str, mock.MagicMock]:
1352+
url = request.param
1353+
data = _URL_DATA[url]
1354+
with mock.patch.object(urllib_request, "urlopen") as mock_urlopen:
1355+
mock_read_response = mock.MagicMock()
1356+
mock_decode_response = mock.MagicMock()
1357+
mock_decode_response.return_value = data
1358+
mock_read_response.return_value.decode = mock_decode_response
1359+
mock_urlopen.return_value.read = mock_read_response
1360+
yield url, mock_urlopen
1361+
1362+
12281363
@pytest.fixture
12291364
def get_endpoint_mock():
12301365
with mock.patch.object(
@@ -4251,3 +4386,102 @@ def test_model_evaluation_text_classification_base_model_only_summary_metrics(
42514386
)
42524387
assert eval_metrics.confidenceMetrics is None
42534388
assert eval_metrics.auPrc == _TEST_TEXT_CLASSIFICATION_METRICS["auPrc"]
4389+
4390+
@pytest.mark.parametrize(
4391+
"job_spec",
4392+
[
4393+
_TEST_DISTILLATION_PIPELINE_SPEC_JSON,
4394+
],
4395+
)
4396+
@pytest.mark.parametrize(
4397+
"mock_urllib_request_urlopen",
4398+
["https://us-kfp.pkg.dev/ml-pipeline/research/distillation/v1.0.0"],
4399+
indirect=True,
4400+
)
4401+
def test_text_generation_model_distill_from(
4402+
self,
4403+
mock_pipeline_service_create,
4404+
mock_pipeline_job_get,
4405+
mock_pipeline_bucket_exists,
4406+
job_spec,
4407+
mock_load_yaml_and_json,
4408+
mock_gcs_from_string,
4409+
mock_gcs_upload,
4410+
mock_urllib_request_urlopen,
4411+
mock_get_tuned_model,
4412+
):
4413+
"""Tests distilling the text generation model."""
4414+
aiplatform.init(
4415+
project=_TEST_PROJECT,
4416+
location=_TEST_LOCATION,
4417+
encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME,
4418+
)
4419+
with mock.patch.object(
4420+
target=model_garden_service_client.ModelGardenServiceClient,
4421+
attribute="get_publisher_model",
4422+
return_value=gca_publisher_model.PublisherModel(
4423+
_TEXT_BISON_PUBLISHER_MODEL_DICT
4424+
),
4425+
):
4426+
model = preview_language_models.TextGenerationModel.from_pretrained(
4427+
"text-bison@001"
4428+
)
4429+
4430+
dataset_uri = "gs://bucket/distillation.training_data.jsonl"
4431+
evaluation_data_uri = "gs://bucket/eval.jsonl"
4432+
evaluation_interval = 37
4433+
enable_early_stopping = True
4434+
enable_checkpoint_selection = True
4435+
tensorboard_name = (
4436+
f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/tensorboards/123"
4437+
)
4438+
4439+
tuning_job = model.distill_from(
4440+
dataset=dataset_uri,
4441+
teacher_model="text-unicorn@001",
4442+
learning_rate_multiplier=2.0,
4443+
train_steps=10,
4444+
evaluation_spec=preview_language_models.TuningEvaluationSpec(
4445+
evaluation_data=evaluation_data_uri,
4446+
evaluation_interval=evaluation_interval,
4447+
enable_early_stopping=enable_early_stopping,
4448+
enable_checkpoint_selection=enable_checkpoint_selection,
4449+
tensorboard=tensorboard_name,
4450+
),
4451+
accelerator_type="TPU",
4452+
)
4453+
call_kwargs = mock_pipeline_service_create.call_args[1]
4454+
pipeline_arguments = call_kwargs[
4455+
"pipeline_job"
4456+
].runtime_config.parameter_values
4457+
assert pipeline_arguments["teacher_model_reference"] == "text-unicorn@001"
4458+
assert pipeline_arguments["student_model_reference"] == "text-bison@001"
4459+
assert pipeline_arguments["dataset_uri"] == dataset_uri
4460+
assert pipeline_arguments["project"] == _TEST_PROJECT
4461+
assert pipeline_arguments["location"] == _TEST_LOCATION
4462+
assert pipeline_arguments["train_steps"] == 10
4463+
assert pipeline_arguments["learning_rate_multiplier"] == 2.0
4464+
assert pipeline_arguments["evaluation_data_uri"] == evaluation_data_uri
4465+
assert pipeline_arguments["evaluation_interval"] == evaluation_interval
4466+
assert pipeline_arguments["enable_early_stopping"] == enable_early_stopping
4467+
assert (
4468+
pipeline_arguments["enable_checkpoint_selection"]
4469+
== enable_checkpoint_selection
4470+
)
4471+
assert pipeline_arguments["tensorboard_resource_id"] == tensorboard_name
4472+
assert pipeline_arguments["accelerator_type"] == "TPU"
4473+
assert (
4474+
pipeline_arguments["encryption_spec_key_name"]
4475+
== _TEST_ENCRYPTION_KEY_NAME
4476+
)
4477+
assert (
4478+
call_kwargs["pipeline_job"].encryption_spec.kms_key_name
4479+
== _TEST_ENCRYPTION_KEY_NAME
4480+
)
4481+
4482+
# Testing the tuned model
4483+
tuned_model = tuning_job.get_tuned_model()
4484+
assert (
4485+
tuned_model._endpoint_name
4486+
== test_constants.EndpointConstants._TEST_ENDPOINT_NAME
4487+
)
+111
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from typing import Optional, Union
2+
3+
from google.cloud import aiplatform
4+
from google.cloud.aiplatform import initializer as aiplatform_initializer
5+
from vertexai.language_models import _language_models
6+
from vertexai.language_models import _language_models as tuning
7+
8+
9+
class DistillationMixin:
10+
_DISTILLATION_PIPELINE_URI = (
11+
"https://us-kfp.pkg.dev/ml-pipeline/research/distillation/v1.0.0"
12+
)
13+
14+
def distill_from(
15+
self,
16+
*,
17+
dataset: str,
18+
teacher_model: Union[str, _language_models._TextGenerationModel],
19+
train_steps: Optional[int] = None,
20+
learning_rate_multiplier: Optional[float] = None,
21+
evaluation_spec: Optional[tuning.TuningEvaluationSpec] = None,
22+
accelerator_type: Optional[tuning._ACCELERATOR_TYPE_TYPE] = None,
23+
model_display_name: Optional[str] = None,
24+
):
25+
"""Tunes a smaller model with help from another bigger model.
26+
27+
Args:
28+
dataset: A URI pointing to data in JSON lines format.
29+
teacher_model: The teacher model to use for distillation.
30+
train_steps: Number of training batches to use (batch size is 8 samples).
31+
learning_rate_multiplier: Learning rate multiplier to use in tuning.
32+
evaluation_spec: Specification for the model evaluation during tuning.
33+
accelerator_type: Type of accelerator to use. Can be "TPU" or "GPU".
34+
model_display_name: Custom display name for the tuned model.
35+
36+
Returns:
37+
A tuning job for distillation.
38+
39+
Raises:
40+
RuntimeError: If the model does not support distillation.
41+
"""
42+
if "/models/" not in self._endpoint_name:
43+
raise RuntimeError(
44+
f"Model does not support distillation: {self._endpoint_name}"
45+
)
46+
student_short_model_id = self._endpoint_name.split("/")[-1]
47+
48+
if isinstance(teacher_model, str):
49+
teacher_short_model_id = teacher_model
50+
elif isinstance(teacher_model, _language_models._LanguageModel):
51+
if "/models/" not in teacher_model._endpoint_name:
52+
raise RuntimeError(
53+
f"Teacher model does not support distillation: {teacher_model._endpoint_name}"
54+
)
55+
teacher_short_model_id = teacher_model._endpoint_name.split("/")[-1]
56+
else:
57+
raise RuntimeError(f"Unsupported teacher model type: {teacher_model}")
58+
59+
pipeline_arguments = {
60+
"teacher_model_reference": teacher_short_model_id,
61+
"student_model_reference": student_short_model_id,
62+
"dataset_uri": dataset,
63+
"project": aiplatform_initializer.global_config.project,
64+
"location": aiplatform_initializer.global_config.location,
65+
}
66+
if train_steps is not None:
67+
pipeline_arguments["train_steps"] = train_steps
68+
if learning_rate_multiplier is not None:
69+
pipeline_arguments["learning_rate_multiplier"] = learning_rate_multiplier
70+
if evaluation_spec is not None:
71+
pipeline_arguments["evaluation_data_uri"] = evaluation_spec.evaluation_data
72+
pipeline_arguments[
73+
"evaluation_interval"
74+
] = evaluation_spec.evaluation_interval
75+
pipeline_arguments[
76+
"enable_early_stopping"
77+
] = evaluation_spec.enable_early_stopping
78+
pipeline_arguments[
79+
"enable_checkpoint_selection"
80+
] = evaluation_spec.enable_checkpoint_selection
81+
pipeline_arguments["tensorboard_resource_id"] = evaluation_spec.tensorboard
82+
# pipeline_parameter_values["evaluation_output_root_dir"] = ...
83+
if accelerator_type is not None:
84+
pipeline_arguments["accelerator_type"] = accelerator_type
85+
if aiplatform_initializer.global_config.encryption_spec_key_name is not None:
86+
pipeline_arguments[
87+
"encryption_spec_key_name"
88+
] = aiplatform_initializer.global_config.encryption_spec_key_name
89+
if model_display_name is None:
90+
model_display_name = (
91+
f"{student_short_model_id}"
92+
f" distilled from {teacher_short_model_id}"
93+
)
94+
pipeline_arguments["model_display_name"] = model_display_name
95+
# # Not exposing these parameters:
96+
# temperature: Optional[float] = None,
97+
# max_context_length: Optional[int] = None,
98+
# tpu_training_skip_cmek: Optional[bool] = None,
99+
# api_endpoint: Optional[str] = None,
100+
# version: Optional[str] = None,
101+
pipeline_job = aiplatform.PipelineJob(
102+
template_path=self._DISTILLATION_PIPELINE_URI,
103+
display_name=None,
104+
parameter_values=pipeline_arguments,
105+
)
106+
pipeline_job.submit()
107+
tuning_job = tuning._LanguageModelTuningJob(
108+
base_model=self,
109+
job=pipeline_job,
110+
)
111+
return tuning_job

0 commit comments

Comments
 (0)