Skip to content

Commit 226ab8b

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: LLM - Made tuning asynchronous when tuning becomes GA
Previously, `tune_model` waited for the tuning is complete, then modified the model in-place. This behavior will change in the future GA (non-preview) classes: In the future, `tune_model` will become asynchronous: It will start tuning job and return a job object immediately without waiting. This will allow the user to do other work while the model is being tuned. This will also allow the user to perform multiple tuning jobs in parallel. Future breaking change: The model will no longer be updated in-place, so the user will need to get the tuned model from the job object. To make the transition easier and avoid breaking changes, the `.tune_model(...)` method will start returning the job object even in preview classes (although it will still wait for the job completion and update the model in-place too). This makes it possible to start writing future-proof code immediately. Usage: ``` tuning_job = model.tune_model(...) # Returns tuning job. In preview: Waits for the tuning job to finish tuned_model = tuning_job.get_tuned_model() # Returns tuned model after waiting for the tuning job to finish. ``` PiperOrigin-RevId: 558554561
1 parent e6d1e95 commit 226ab8b

File tree

3 files changed

+200
-26
lines changed

3 files changed

+200
-26
lines changed

tests/system/aiplatform/test_language_models.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def test_tuning(self, shared_state):
189189
df=training_data, upload_gcs_path=dataset_uri
190190
)
191191

192-
model.tune_model(
192+
tuning_job = model.tune_model(
193193
training_data=training_data,
194194
train_steps=1,
195195
tuning_job_location="europe-west4",
@@ -211,6 +211,18 @@ def test_tuning(self, shared_state):
211211
)
212212
# Deleting the Endpoint is a little less bad since the LLM SDK will recreate it, but it's not advised for the same reason.
213213

214+
# Testing the new model returned by the `tuning_job.get_tuned_model` method
215+
tuned_model1 = tuning_job.get_tuned_model()
216+
response1 = tuned_model1.predict(
217+
"What is the best recipe for banana bread? Recipe:",
218+
max_output_tokens=128,
219+
temperature=0,
220+
top_p=1,
221+
top_k=5,
222+
)
223+
assert response1.text
224+
225+
# Testing the model updated in-place (Deprecated. Preview only)
214226
response = model.predict(
215227
"What is the best recipe for banana bread? Recipe:",
216228
max_output_tokens=128,

tests/unit/aiplatform/test_language_models.py

+19-5
Original file line numberDiff line numberDiff line change
@@ -1039,13 +1039,13 @@ def mock_get_tuned_model(get_endpoint_mock):
10391039
with mock.patch.object(
10401040
_language_models._TunableModelMixin, "get_tuned_model"
10411041
) as mock_text_generation_model:
1042-
mock_text_generation_model._model_id = (
1042+
mock_text_generation_model.return_value._model_id = (
10431043
test_constants.ModelConstants._TEST_MODEL_RESOURCE_NAME
10441044
)
1045-
mock_text_generation_model._endpoint_name = (
1045+
mock_text_generation_model.return_value._endpoint_name = (
10461046
test_constants.EndpointConstants._TEST_ENDPOINT_NAME
10471047
)
1048-
mock_text_generation_model._endpoint = get_endpoint_mock
1048+
mock_text_generation_model.return_value._endpoint = get_endpoint_mock
10491049
yield mock_text_generation_model
10501050

10511051

@@ -1344,7 +1344,7 @@ def test_tune_text_generation_model(
13441344
enable_early_stopping = True
13451345
tensorboard_name = f"projects/{_TEST_PROJECT}/locations/{tuning_job_location}/tensorboards/123"
13461346

1347-
model.tune_model(
1347+
tuning_job = model.tune_model(
13481348
training_data=_TEST_TEXT_BISON_TRAINING_DF,
13491349
tuning_job_location=tuning_job_location,
13501350
tuned_model_location="us-central1",
@@ -1375,6 +1375,13 @@ def test_tune_text_generation_model(
13751375
== _TEST_ENCRYPTION_KEY_NAME
13761376
)
13771377

1378+
# Testing the tuned model
1379+
tuned_model = tuning_job.get_tuned_model()
1380+
assert (
1381+
tuned_model._endpoint_name
1382+
== test_constants.EndpointConstants._TEST_ENDPOINT_NAME
1383+
)
1384+
13781385
@pytest.mark.parametrize(
13791386
"job_spec",
13801387
[_TEST_PIPELINE_SPEC_JSON],
@@ -1408,7 +1415,7 @@ def test_tune_chat_model(
14081415
model = preview_language_models.ChatModel.from_pretrained("chat-bison@001")
14091416

14101417
default_context = "Default context"
1411-
model.tune_model(
1418+
tuning_job = model.tune_model(
14121419
training_data=_TEST_TEXT_BISON_TRAINING_DF,
14131420
tuning_job_location="europe-west4",
14141421
tuned_model_location="us-central1",
@@ -1421,6 +1428,13 @@ def test_tune_chat_model(
14211428
assert pipeline_arguments["large_model_reference"] == "chat-bison@001"
14221429
assert pipeline_arguments["default_context"] == default_context
14231430

1431+
# Testing the tuned model
1432+
tuned_model = tuning_job.get_tuned_model()
1433+
assert (
1434+
tuned_model._endpoint_name
1435+
== test_constants.EndpointConstants._TEST_ENDPOINT_NAME
1436+
)
1437+
14241438
@pytest.mark.parametrize(
14251439
"job_spec",
14261440
[_TEST_PIPELINE_SPEC_JSON],

vertexai/language_models/_language_models.py

+168-20
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from google.cloud.aiplatform import base
2424
from google.cloud.aiplatform import initializer as aiplatform_initializer
2525
from google.cloud.aiplatform import utils as aiplatform_utils
26+
from google.cloud.aiplatform.compat import types as aiplatform_types
2627
from google.cloud.aiplatform.utils import gcs_utils
2728
from vertexai._model_garden import _model_garden_models
2829
from vertexai.language_models import (
@@ -148,18 +149,24 @@ def tune_model(
148149
self,
149150
training_data: Union[str, "pandas.core.frame.DataFrame"],
150151
*,
151-
train_steps: int = 1000,
152+
train_steps: Optional[int] = None,
152153
learning_rate: Optional[float] = None,
153154
learning_rate_multiplier: Optional[float] = None,
154155
tuning_job_location: Optional[str] = None,
155156
tuned_model_location: Optional[str] = None,
156157
model_display_name: Optional[str] = None,
157158
tuning_evaluation_spec: Optional["TuningEvaluationSpec"] = None,
158159
default_context: Optional[str] = None,
159-
):
160+
) -> "_LanguageModelTuningJob":
160161
"""Tunes a model based on training data.
161162
162-
This method launches a model tuning job that can take some time.
163+
This method launches and returns an asynchronous model tuning job.
164+
Usage:
165+
```
166+
tuning_job = model.tune_model(...)
167+
... do some other work
168+
tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete
169+
```
163170
164171
Args:
165172
training_data: A Pandas DataFrame or a URI pointing to data in JSON lines format.
@@ -303,16 +310,68 @@ def _tune_model(
303310
base_model=self,
304311
job=pipeline_job,
305312
)
306-
self._job = job
307-
tuned_model = job.result()
308-
# The UXR study attendees preferred to tune model in place
309-
self._endpoint = tuned_model._endpoint
310-
self._endpoint_name = tuned_model._endpoint_name
313+
return job
311314

312315

313316
class _TunableTextModelMixin(_TunableModelMixin):
314317
"""Text model that can be tuned."""
315318

319+
def tune_model(
320+
self,
321+
training_data: Union[str, "pandas.core.frame.DataFrame"],
322+
*,
323+
train_steps: Optional[int] = None,
324+
learning_rate_multiplier: Optional[float] = None,
325+
tuning_job_location: Optional[str] = None,
326+
tuned_model_location: Optional[str] = None,
327+
model_display_name: Optional[str] = None,
328+
tuning_evaluation_spec: Optional["TuningEvaluationSpec"] = None,
329+
) -> "_LanguageModelTuningJob":
330+
"""Tunes a model based on training data.
331+
332+
This method launches and returns an asynchronous model tuning job.
333+
Usage:
334+
```
335+
tuning_job = model.tune_model(...)
336+
... do some other work
337+
tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete
338+
339+
Args:
340+
training_data: A Pandas DataFrame or a URI pointing to data in JSON lines format.
341+
The dataset schema is model-specific.
342+
See https://cloud.google.com/vertex-ai/docs/generative-ai/models/tune-models#dataset_format
343+
train_steps: Number of training batches to tune on (batch size is 8 samples).
344+
learning_rate_multiplier: Learning rate multiplier to use in tuning.
345+
tuning_job_location: GCP location where the tuning job should be run.
346+
Only "europe-west4" and "us-central1" locations are supported for now.
347+
tuned_model_location: GCP location where the tuned model should be deployed. Only "us-central1" is supported for now.
348+
model_display_name: Custom display name for the tuned model.
349+
tuning_evaluation_spec: Specification for the model evaluation during tuning.
350+
351+
Returns:
352+
A `LanguageModelTuningJob` object that represents the tuning job.
353+
Calling `job.result()` blocks until the tuning is complete and returns a `LanguageModel` object.
354+
355+
Raises:
356+
ValueError: If the "tuning_job_location" value is not supported
357+
ValueError: If the "tuned_model_location" value is not supported
358+
RuntimeError: If the model does not support tuning
359+
"""
360+
# Note: Chat models do not support default_context
361+
return super().tune_model(
362+
training_data=training_data,
363+
train_steps=train_steps,
364+
learning_rate_multiplier=learning_rate_multiplier,
365+
tuning_job_location=tuning_job_location,
366+
tuned_model_location=tuned_model_location,
367+
model_display_name=model_display_name,
368+
tuning_evaluation_spec=tuning_evaluation_spec,
369+
)
370+
371+
372+
class _PreviewTunableTextModelMixin(_TunableModelMixin):
373+
"""Text model that can be tuned."""
374+
316375
def tune_model(
317376
self,
318377
training_data: Union[str, "pandas.core.frame.DataFrame"],
@@ -324,10 +383,20 @@ def tune_model(
324383
tuned_model_location: Optional[str] = None,
325384
model_display_name: Optional[str] = None,
326385
tuning_evaluation_spec: Optional["TuningEvaluationSpec"] = None,
327-
):
386+
) -> "_LanguageModelTuningJob":
328387
"""Tunes a model based on training data.
329388
330-
This method launches a model tuning job that can take some time.
389+
This method launches a model tuning job, waits for completion,
390+
updates the model in-place. This method returns job object for forward
391+
compatibility.
392+
In the future (GA), this method will become asynchronous and will stop
393+
updating the model in-place.
394+
395+
Usage:
396+
```
397+
tuning_job = model.tune_model(...) # Blocks until tuning is complete
398+
tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete
399+
```
331400
332401
Args:
333402
training_data: A Pandas DataFrame or a URI pointing to data in JSON lines format.
@@ -353,7 +422,7 @@ def tune_model(
353422
RuntimeError: If the model does not support tuning
354423
"""
355424
# Note: Chat models do not support default_context
356-
return super().tune_model(
425+
job = super().tune_model(
357426
training_data=training_data,
358427
train_steps=train_steps,
359428
learning_rate=learning_rate,
@@ -363,11 +432,74 @@ def tune_model(
363432
model_display_name=model_display_name,
364433
tuning_evaluation_spec=tuning_evaluation_spec,
365434
)
435+
tuned_model = job.get_tuned_model()
436+
self._endpoint = tuned_model._endpoint
437+
self._endpoint_name = tuned_model._endpoint_name
438+
return job
366439

367440

368441
class _TunableChatModelMixin(_TunableModelMixin):
369442
"""Chat model that can be tuned."""
370443

444+
def tune_model(
445+
self,
446+
training_data: Union[str, "pandas.core.frame.DataFrame"],
447+
*,
448+
train_steps: Optional[int] = None,
449+
learning_rate_multiplier: Optional[float] = None,
450+
tuning_job_location: Optional[str] = None,
451+
tuned_model_location: Optional[str] = None,
452+
model_display_name: Optional[str] = None,
453+
default_context: Optional[str] = None,
454+
) -> "_LanguageModelTuningJob":
455+
"""Tunes a model based on training data.
456+
457+
This method launches and returns an asynchronous model tuning job.
458+
Usage:
459+
```
460+
tuning_job = model.tune_model(...)
461+
... do some other work
462+
tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete
463+
```
464+
465+
Args:
466+
training_data: A Pandas DataFrame or a URI pointing to data in JSON lines format.
467+
The dataset schema is model-specific.
468+
See https://cloud.google.com/vertex-ai/docs/generative-ai/models/tune-models#dataset_format
469+
train_steps: Number of training batches to tune on (batch size is 8 samples).
470+
learning_rate: Deprecated. Use learning_rate_multiplier instead.
471+
Learning rate to use in tuning.
472+
learning_rate_multiplier: Learning rate multiplier to use in tuning.
473+
tuning_job_location: GCP location where the tuning job should be run.
474+
Only "europe-west4" and "us-central1" locations are supported for now.
475+
tuned_model_location: GCP location where the tuned model should be deployed. Only "us-central1" is supported for now.
476+
model_display_name: Custom display name for the tuned model.
477+
default_context: The context to use for all training samples by default.
478+
479+
Returns:
480+
A `LanguageModelTuningJob` object that represents the tuning job.
481+
Calling `job.result()` blocks until the tuning is complete and returns a `LanguageModel` object.
482+
483+
Raises:
484+
ValueError: If the "tuning_job_location" value is not supported
485+
ValueError: If the "tuned_model_location" value is not supported
486+
RuntimeError: If the model does not support tuning
487+
"""
488+
# Note: Chat models do not support tuning_evaluation_spec
489+
return super().tune_model(
490+
training_data=training_data,
491+
train_steps=train_steps,
492+
learning_rate_multiplier=learning_rate_multiplier,
493+
tuning_job_location=tuning_job_location,
494+
tuned_model_location=tuned_model_location,
495+
model_display_name=model_display_name,
496+
default_context=default_context,
497+
)
498+
499+
500+
class _PreviewTunableChatModelMixin(_TunableModelMixin):
501+
"""Chat model that can be tuned."""
502+
371503
def tune_model(
372504
self,
373505
training_data: Union[str, "pandas.core.frame.DataFrame"],
@@ -379,10 +511,20 @@ def tune_model(
379511
tuned_model_location: Optional[str] = None,
380512
model_display_name: Optional[str] = None,
381513
default_context: Optional[str] = None,
382-
):
514+
) -> "_LanguageModelTuningJob":
383515
"""Tunes a model based on training data.
384516
385-
This method launches a model tuning job that can take some time.
517+
This method launches a model tuning job, waits for completion,
518+
updates the model in-place. This method returns job object for forward
519+
compatibility.
520+
In the future (GA), this method will become asynchronous and will stop
521+
updating the model in-place.
522+
523+
Usage:
524+
```
525+
tuning_job = model.tune_model(...) # Blocks until tuning is complete
526+
tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete
527+
```
386528
387529
Args:
388530
training_data: A Pandas DataFrame or a URI pointing to data in JSON lines format.
@@ -408,7 +550,7 @@ def tune_model(
408550
RuntimeError: If the model does not support tuning
409551
"""
410552
# Note: Chat models do not support tuning_evaluation_spec
411-
return super().tune_model(
553+
job = super().tune_model(
412554
training_data=training_data,
413555
train_steps=train_steps,
414556
learning_rate=learning_rate,
@@ -418,6 +560,10 @@ def tune_model(
418560
model_display_name=model_display_name,
419561
default_context=default_context,
420562
)
563+
tuned_model = job.get_tuned_model()
564+
self._endpoint = tuned_model._endpoint
565+
self._endpoint_name = tuned_model._endpoint_name
566+
return job
421567

422568

423569
@dataclasses.dataclass
@@ -746,7 +892,7 @@ class TextGenerationModel(_TextGenerationModel, _ModelWithBatchPredict):
746892

747893
class _PreviewTextGenerationModel(
748894
_TextGenerationModel,
749-
_TunableTextModelMixin,
895+
_PreviewTunableTextModelMixin,
750896
_PreviewModelWithBatchPredict,
751897
_evaluatable_language_models._EvaluatableLanguageModel,
752898
):
@@ -1076,7 +1222,7 @@ class ChatModel(_ChatModelBase):
10761222
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/chat_generation_1.0.0.yaml"
10771223

10781224

1079-
class _PreviewChatModel(ChatModel, _TunableChatModelMixin):
1225+
class _PreviewChatModel(ChatModel, _PreviewTunableChatModelMixin):
10801226
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
10811227

10821228

@@ -1650,11 +1796,12 @@ def __init__(
16501796
base_model: _LanguageModel,
16511797
job: aiplatform.PipelineJob,
16521798
):
1799+
"""Internal constructor. Do not call directly."""
16531800
self._base_model = base_model
16541801
self._job = job
16551802
self._model: Optional[_LanguageModel] = None
16561803

1657-
def result(self) -> "_LanguageModel":
1804+
def get_tuned_model(self) -> "_LanguageModel":
16581805
"""Blocks until the tuning is complete and returns a `LanguageModel` object."""
16591806
if self._model:
16601807
return self._model
@@ -1681,11 +1828,12 @@ def result(self) -> "_LanguageModel":
16811828
return self._model
16821829

16831830
@property
1684-
def status(self):
1685-
"""Job status"""
1831+
def _status(self) -> Optional[aiplatform_types.pipeline_state.PipelineState]:
1832+
"""Job status."""
16861833
return self._job.state
16871834

1688-
def cancel(self):
1835+
def _cancel(self):
1836+
"""Cancels the job."""
16891837
self._job.cancel()
16901838

16911839

0 commit comments

Comments
 (0)