23
23
from google .cloud .aiplatform import base
24
24
from google .cloud .aiplatform import initializer as aiplatform_initializer
25
25
from google .cloud .aiplatform import utils as aiplatform_utils
26
+ from google .cloud .aiplatform .compat import types as aiplatform_types
26
27
from google .cloud .aiplatform .utils import gcs_utils
27
28
from vertexai ._model_garden import _model_garden_models
28
29
from vertexai .language_models import (
@@ -148,18 +149,24 @@ def tune_model(
148
149
self ,
149
150
training_data : Union [str , "pandas.core.frame.DataFrame" ],
150
151
* ,
151
- train_steps : int = 1000 ,
152
+ train_steps : Optional [ int ] = None ,
152
153
learning_rate : Optional [float ] = None ,
153
154
learning_rate_multiplier : Optional [float ] = None ,
154
155
tuning_job_location : Optional [str ] = None ,
155
156
tuned_model_location : Optional [str ] = None ,
156
157
model_display_name : Optional [str ] = None ,
157
158
tuning_evaluation_spec : Optional ["TuningEvaluationSpec" ] = None ,
158
159
default_context : Optional [str ] = None ,
159
- ):
160
+ ) -> "_LanguageModelTuningJob" :
160
161
"""Tunes a model based on training data.
161
162
162
- This method launches a model tuning job that can take some time.
163
+ This method launches and returns an asynchronous model tuning job.
164
+ Usage:
165
+ ```
166
+ tuning_job = model.tune_model(...)
167
+ ... do some other work
168
+ tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete
169
+ ```
163
170
164
171
Args:
165
172
training_data: A Pandas DataFrame or a URI pointing to data in JSON lines format.
@@ -303,16 +310,68 @@ def _tune_model(
303
310
base_model = self ,
304
311
job = pipeline_job ,
305
312
)
306
- self ._job = job
307
- tuned_model = job .result ()
308
- # The UXR study attendees preferred to tune model in place
309
- self ._endpoint = tuned_model ._endpoint
310
- self ._endpoint_name = tuned_model ._endpoint_name
313
+ return job
311
314
312
315
313
316
class _TunableTextModelMixin (_TunableModelMixin ):
314
317
"""Text model that can be tuned."""
315
318
319
+ def tune_model (
320
+ self ,
321
+ training_data : Union [str , "pandas.core.frame.DataFrame" ],
322
+ * ,
323
+ train_steps : Optional [int ] = None ,
324
+ learning_rate_multiplier : Optional [float ] = None ,
325
+ tuning_job_location : Optional [str ] = None ,
326
+ tuned_model_location : Optional [str ] = None ,
327
+ model_display_name : Optional [str ] = None ,
328
+ tuning_evaluation_spec : Optional ["TuningEvaluationSpec" ] = None ,
329
+ ) -> "_LanguageModelTuningJob" :
330
+ """Tunes a model based on training data.
331
+
332
+ This method launches and returns an asynchronous model tuning job.
333
+ Usage:
334
+ ```
335
+ tuning_job = model.tune_model(...)
336
+ ... do some other work
337
+ tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete
338
+
339
+ Args:
340
+ training_data: A Pandas DataFrame or a URI pointing to data in JSON lines format.
341
+ The dataset schema is model-specific.
342
+ See https://cloud.google.com/vertex-ai/docs/generative-ai/models/tune-models#dataset_format
343
+ train_steps: Number of training batches to tune on (batch size is 8 samples).
344
+ learning_rate_multiplier: Learning rate multiplier to use in tuning.
345
+ tuning_job_location: GCP location where the tuning job should be run.
346
+ Only "europe-west4" and "us-central1" locations are supported for now.
347
+ tuned_model_location: GCP location where the tuned model should be deployed. Only "us-central1" is supported for now.
348
+ model_display_name: Custom display name for the tuned model.
349
+ tuning_evaluation_spec: Specification for the model evaluation during tuning.
350
+
351
+ Returns:
352
+ A `LanguageModelTuningJob` object that represents the tuning job.
353
+ Calling `job.result()` blocks until the tuning is complete and returns a `LanguageModel` object.
354
+
355
+ Raises:
356
+ ValueError: If the "tuning_job_location" value is not supported
357
+ ValueError: If the "tuned_model_location" value is not supported
358
+ RuntimeError: If the model does not support tuning
359
+ """
360
+ # Note: Chat models do not support default_context
361
+ return super ().tune_model (
362
+ training_data = training_data ,
363
+ train_steps = train_steps ,
364
+ learning_rate_multiplier = learning_rate_multiplier ,
365
+ tuning_job_location = tuning_job_location ,
366
+ tuned_model_location = tuned_model_location ,
367
+ model_display_name = model_display_name ,
368
+ tuning_evaluation_spec = tuning_evaluation_spec ,
369
+ )
370
+
371
+
372
+ class _PreviewTunableTextModelMixin (_TunableModelMixin ):
373
+ """Text model that can be tuned."""
374
+
316
375
def tune_model (
317
376
self ,
318
377
training_data : Union [str , "pandas.core.frame.DataFrame" ],
@@ -324,10 +383,20 @@ def tune_model(
324
383
tuned_model_location : Optional [str ] = None ,
325
384
model_display_name : Optional [str ] = None ,
326
385
tuning_evaluation_spec : Optional ["TuningEvaluationSpec" ] = None ,
327
- ):
386
+ ) -> "_LanguageModelTuningJob" :
328
387
"""Tunes a model based on training data.
329
388
330
- This method launches a model tuning job that can take some time.
389
+ This method launches a model tuning job, waits for completion,
390
+ updates the model in-place. This method returns job object for forward
391
+ compatibility.
392
+ In the future (GA), this method will become asynchronous and will stop
393
+ updating the model in-place.
394
+
395
+ Usage:
396
+ ```
397
+ tuning_job = model.tune_model(...) # Blocks until tuning is complete
398
+ tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete
399
+ ```
331
400
332
401
Args:
333
402
training_data: A Pandas DataFrame or a URI pointing to data in JSON lines format.
@@ -353,7 +422,7 @@ def tune_model(
353
422
RuntimeError: If the model does not support tuning
354
423
"""
355
424
# Note: Chat models do not support default_context
356
- return super ().tune_model (
425
+ job = super ().tune_model (
357
426
training_data = training_data ,
358
427
train_steps = train_steps ,
359
428
learning_rate = learning_rate ,
@@ -363,11 +432,74 @@ def tune_model(
363
432
model_display_name = model_display_name ,
364
433
tuning_evaluation_spec = tuning_evaluation_spec ,
365
434
)
435
+ tuned_model = job .get_tuned_model ()
436
+ self ._endpoint = tuned_model ._endpoint
437
+ self ._endpoint_name = tuned_model ._endpoint_name
438
+ return job
366
439
367
440
368
441
class _TunableChatModelMixin (_TunableModelMixin ):
369
442
"""Chat model that can be tuned."""
370
443
444
+ def tune_model (
445
+ self ,
446
+ training_data : Union [str , "pandas.core.frame.DataFrame" ],
447
+ * ,
448
+ train_steps : Optional [int ] = None ,
449
+ learning_rate_multiplier : Optional [float ] = None ,
450
+ tuning_job_location : Optional [str ] = None ,
451
+ tuned_model_location : Optional [str ] = None ,
452
+ model_display_name : Optional [str ] = None ,
453
+ default_context : Optional [str ] = None ,
454
+ ) -> "_LanguageModelTuningJob" :
455
+ """Tunes a model based on training data.
456
+
457
+ This method launches and returns an asynchronous model tuning job.
458
+ Usage:
459
+ ```
460
+ tuning_job = model.tune_model(...)
461
+ ... do some other work
462
+ tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete
463
+ ```
464
+
465
+ Args:
466
+ training_data: A Pandas DataFrame or a URI pointing to data in JSON lines format.
467
+ The dataset schema is model-specific.
468
+ See https://cloud.google.com/vertex-ai/docs/generative-ai/models/tune-models#dataset_format
469
+ train_steps: Number of training batches to tune on (batch size is 8 samples).
470
+ learning_rate: Deprecated. Use learning_rate_multiplier instead.
471
+ Learning rate to use in tuning.
472
+ learning_rate_multiplier: Learning rate multiplier to use in tuning.
473
+ tuning_job_location: GCP location where the tuning job should be run.
474
+ Only "europe-west4" and "us-central1" locations are supported for now.
475
+ tuned_model_location: GCP location where the tuned model should be deployed. Only "us-central1" is supported for now.
476
+ model_display_name: Custom display name for the tuned model.
477
+ default_context: The context to use for all training samples by default.
478
+
479
+ Returns:
480
+ A `LanguageModelTuningJob` object that represents the tuning job.
481
+ Calling `job.result()` blocks until the tuning is complete and returns a `LanguageModel` object.
482
+
483
+ Raises:
484
+ ValueError: If the "tuning_job_location" value is not supported
485
+ ValueError: If the "tuned_model_location" value is not supported
486
+ RuntimeError: If the model does not support tuning
487
+ """
488
+ # Note: Chat models do not support tuning_evaluation_spec
489
+ return super ().tune_model (
490
+ training_data = training_data ,
491
+ train_steps = train_steps ,
492
+ learning_rate_multiplier = learning_rate_multiplier ,
493
+ tuning_job_location = tuning_job_location ,
494
+ tuned_model_location = tuned_model_location ,
495
+ model_display_name = model_display_name ,
496
+ default_context = default_context ,
497
+ )
498
+
499
+
500
+ class _PreviewTunableChatModelMixin (_TunableModelMixin ):
501
+ """Chat model that can be tuned."""
502
+
371
503
def tune_model (
372
504
self ,
373
505
training_data : Union [str , "pandas.core.frame.DataFrame" ],
@@ -379,10 +511,20 @@ def tune_model(
379
511
tuned_model_location : Optional [str ] = None ,
380
512
model_display_name : Optional [str ] = None ,
381
513
default_context : Optional [str ] = None ,
382
- ):
514
+ ) -> "_LanguageModelTuningJob" :
383
515
"""Tunes a model based on training data.
384
516
385
- This method launches a model tuning job that can take some time.
517
+ This method launches a model tuning job, waits for completion,
518
+ updates the model in-place. This method returns job object for forward
519
+ compatibility.
520
+ In the future (GA), this method will become asynchronous and will stop
521
+ updating the model in-place.
522
+
523
+ Usage:
524
+ ```
525
+ tuning_job = model.tune_model(...) # Blocks until tuning is complete
526
+ tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete
527
+ ```
386
528
387
529
Args:
388
530
training_data: A Pandas DataFrame or a URI pointing to data in JSON lines format.
@@ -408,7 +550,7 @@ def tune_model(
408
550
RuntimeError: If the model does not support tuning
409
551
"""
410
552
# Note: Chat models do not support tuning_evaluation_spec
411
- return super ().tune_model (
553
+ job = super ().tune_model (
412
554
training_data = training_data ,
413
555
train_steps = train_steps ,
414
556
learning_rate = learning_rate ,
@@ -418,6 +560,10 @@ def tune_model(
418
560
model_display_name = model_display_name ,
419
561
default_context = default_context ,
420
562
)
563
+ tuned_model = job .get_tuned_model ()
564
+ self ._endpoint = tuned_model ._endpoint
565
+ self ._endpoint_name = tuned_model ._endpoint_name
566
+ return job
421
567
422
568
423
569
@dataclasses .dataclass
@@ -746,7 +892,7 @@ class TextGenerationModel(_TextGenerationModel, _ModelWithBatchPredict):
746
892
747
893
class _PreviewTextGenerationModel (
748
894
_TextGenerationModel ,
749
- _TunableTextModelMixin ,
895
+ _PreviewTunableTextModelMixin ,
750
896
_PreviewModelWithBatchPredict ,
751
897
_evaluatable_language_models ._EvaluatableLanguageModel ,
752
898
):
@@ -1076,7 +1222,7 @@ class ChatModel(_ChatModelBase):
1076
1222
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/chat_generation_1.0.0.yaml"
1077
1223
1078
1224
1079
- class _PreviewChatModel (ChatModel , _TunableChatModelMixin ):
1225
+ class _PreviewChatModel (ChatModel , _PreviewTunableChatModelMixin ):
1080
1226
_LAUNCH_STAGE = _model_garden_models ._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
1081
1227
1082
1228
@@ -1650,11 +1796,12 @@ def __init__(
1650
1796
base_model : _LanguageModel ,
1651
1797
job : aiplatform .PipelineJob ,
1652
1798
):
1799
+ """Internal constructor. Do not call directly."""
1653
1800
self ._base_model = base_model
1654
1801
self ._job = job
1655
1802
self ._model : Optional [_LanguageModel ] = None
1656
1803
1657
- def result (self ) -> "_LanguageModel" :
1804
+ def get_tuned_model (self ) -> "_LanguageModel" :
1658
1805
"""Blocks until the tuning is complete and returns a `LanguageModel` object."""
1659
1806
if self ._model :
1660
1807
return self ._model
@@ -1681,11 +1828,12 @@ def result(self) -> "_LanguageModel":
1681
1828
return self ._model
1682
1829
1683
1830
@property
1684
- def status (self ):
1685
- """Job status"""
1831
+ def _status (self ) -> Optional [ aiplatform_types . pipeline_state . PipelineState ] :
1832
+ """Job status. """
1686
1833
return self ._job .state
1687
1834
1688
- def cancel (self ):
1835
+ def _cancel (self ):
1836
+ """Cancels the job."""
1689
1837
self ._job .cancel ()
1690
1838
1691
1839
0 commit comments