Skip to content

Commit 98ab2f9

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: LLM - Support accelerator_type in tuning
PiperOrigin-RevId: 574768322
1 parent cbe3a0d commit 98ab2f9

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

tests/unit/aiplatform/test_language_models.py

+15
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,11 @@ def reverse_string_2(s):""",
365365
"dag": {"tasks": {}},
366366
"inputDefinitions": {
367367
"parameters": {
368+
"accelerator_type": {
369+
"defaultValue": "",
370+
"isOptional": True,
371+
"parameterType": "STRING",
372+
},
368373
"api_endpoint": {
369374
"defaultValue": "aiplatform.googleapis.com/ui",
370375
"isOptional": True,
@@ -1568,6 +1573,7 @@ def test_tune_text_generation_model(
15681573
enable_early_stopping=enable_early_stopping,
15691574
tensorboard=tensorboard_name,
15701575
),
1576+
accelerator_type="TPU",
15711577
)
15721578
call_kwargs = mock_pipeline_service_create.call_args[1]
15731579
pipeline_arguments = call_kwargs[
@@ -1581,6 +1587,7 @@ def test_tune_text_generation_model(
15811587
assert pipeline_arguments["enable_early_stopping"] == enable_early_stopping
15821588
assert pipeline_arguments["tensorboard_resource_id"] == tensorboard_name
15831589
assert pipeline_arguments["large_model_reference"] == "text-bison@001"
1590+
assert pipeline_arguments["accelerator_type"] == "TPU"
15841591
assert (
15851592
call_kwargs["pipeline_job"].encryption_spec.kms_key_name
15861593
== _TEST_ENCRYPTION_KEY_NAME
@@ -1649,6 +1656,7 @@ def test_tune_text_generation_model_ga(
16491656
enable_early_stopping=enable_early_stopping,
16501657
tensorboard=tensorboard_name,
16511658
),
1659+
accelerator_type="TPU",
16521660
)
16531661
call_kwargs = mock_pipeline_service_create.call_args[1]
16541662
pipeline_arguments = call_kwargs[
@@ -1661,6 +1669,7 @@ def test_tune_text_generation_model_ga(
16611669
assert pipeline_arguments["enable_early_stopping"] == enable_early_stopping
16621670
assert pipeline_arguments["tensorboard_resource_id"] == tensorboard_name
16631671
assert pipeline_arguments["large_model_reference"] == "text-bison@001"
1672+
assert pipeline_arguments["accelerator_type"] == "TPU"
16641673
assert (
16651674
call_kwargs["pipeline_job"].encryption_spec.kms_key_name
16661675
== _TEST_ENCRYPTION_KEY_NAME
@@ -1808,13 +1817,15 @@ def test_tune_chat_model(
18081817
tuning_job_location="europe-west4",
18091818
tuned_model_location="us-central1",
18101819
default_context=default_context,
1820+
accelerator_type="TPU",
18111821
)
18121822
call_kwargs = mock_pipeline_service_create.call_args[1]
18131823
pipeline_arguments = call_kwargs[
18141824
"pipeline_job"
18151825
].runtime_config.parameter_values
18161826
assert pipeline_arguments["large_model_reference"] == "chat-bison@001"
18171827
assert pipeline_arguments["default_context"] == default_context
1828+
assert pipeline_arguments["accelerator_type"] == "TPU"
18181829

18191830
# Testing the tuned model
18201831
tuned_model = tuning_job.get_tuned_model()
@@ -1862,12 +1873,14 @@ def test_tune_code_generation_model(
18621873
training_data=_TEST_TEXT_BISON_TRAINING_DF,
18631874
tuning_job_location="europe-west4",
18641875
tuned_model_location="us-central1",
1876+
accelerator_type="TPU",
18651877
)
18661878
call_kwargs = mock_pipeline_service_create.call_args[1]
18671879
pipeline_arguments = call_kwargs[
18681880
"pipeline_job"
18691881
].runtime_config.parameter_values
18701882
assert pipeline_arguments["large_model_reference"] == "code-bison@001"
1883+
assert pipeline_arguments["accelerator_type"] == "TPU"
18711884

18721885
@pytest.mark.parametrize(
18731886
"job_spec",
@@ -1909,12 +1922,14 @@ def test_tune_code_chat_model(
19091922
training_data=_TEST_TEXT_BISON_TRAINING_DF,
19101923
tuning_job_location="europe-west4",
19111924
tuned_model_location="us-central1",
1925+
accelerator_type="TPU",
19121926
)
19131927
call_kwargs = mock_pipeline_service_create.call_args[1]
19141928
pipeline_arguments = call_kwargs[
19151929
"pipeline_job"
19161930
].runtime_config.parameter_values
19171931
assert pipeline_arguments["large_model_reference"] == "codechat-bison@001"
1932+
assert pipeline_arguments["accelerator_type"] == "TPU"
19181933

19191934
@pytest.mark.usefixtures(
19201935
"get_model_with_tuned_version_label_mock",

vertexai/language_models/_language_models.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""Classes for working with language models."""
1616

1717
import dataclasses
18-
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Sequence, Union
18+
from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Sequence, Union
1919
import warnings
2020

2121
from google.cloud import aiplatform
@@ -42,6 +42,9 @@
4242
# Endpoint label/metadata key to preserve the base model ID information
4343
_TUNING_BASE_MODEL_ID_LABEL_KEY = "google-vertex-llm-tuning-base-model-id"
4444

45+
_ACCELERATOR_TYPES = ["TPU", "GPU"]
46+
_ACCELERATOR_TYPE_TYPE = Literal["TPU", "GPU"]
47+
4548

4649
def _get_model_id_from_tuning_model_id(tuning_model_id: str) -> str:
4750
"""Gets the base model ID for the model ID labels used the tuned models.
@@ -166,6 +169,7 @@ def tune_model(
166169
model_display_name: Optional[str] = None,
167170
tuning_evaluation_spec: Optional["TuningEvaluationSpec"] = None,
168171
default_context: Optional[str] = None,
172+
accelerator_type: Optional[_ACCELERATOR_TYPE_TYPE] = None,
169173
) -> "_LanguageModelTuningJob":
170174
"""Tunes a model based on training data.
171175
@@ -191,6 +195,7 @@ def tune_model(
191195
model_display_name: Custom display name for the tuned model.
192196
tuning_evaluation_spec: Specification for the model evaluation during tuning.
193197
default_context: The context to use for all training samples by default.
198+
accelerator_type: Type of accelerator to use. Can be "TPU" or "GPU".
194199
195200
Returns:
196201
A `LanguageModelTuningJob` object that represents the tuning job.
@@ -252,6 +257,14 @@ def tune_model(
252257
if default_context:
253258
tuning_parameters["default_context"] = default_context
254259

260+
if accelerator_type:
261+
if accelerator_type not in _ACCELERATOR_TYPES:
262+
raise ValueError(
263+
f"Unsupported accelerator type: {accelerator_type}."
264+
f" Supported types: {_ACCELERATOR_TYPES}"
265+
)
266+
tuning_parameters["accelerator_type"] = accelerator_type
267+
255268
return self._tune_model(
256269
training_data=training_data,
257270
tuning_parameters=tuning_parameters,
@@ -336,6 +349,7 @@ def tune_model(
336349
tuned_model_location: Optional[str] = None,
337350
model_display_name: Optional[str] = None,
338351
tuning_evaluation_spec: Optional["TuningEvaluationSpec"] = None,
352+
accelerator_type: Optional[_ACCELERATOR_TYPE_TYPE] = None,
339353
) -> "_LanguageModelTuningJob":
340354
"""Tunes a model based on training data.
341355
@@ -357,6 +371,7 @@ def tune_model(
357371
tuned_model_location: GCP location where the tuned model should be deployed. Only "us-central1" is supported for now.
358372
model_display_name: Custom display name for the tuned model.
359373
tuning_evaluation_spec: Specification for the model evaluation during tuning.
374+
accelerator_type: Type of accelerator to use. Can be "TPU" or "GPU".
360375
361376
Returns:
362377
A `LanguageModelTuningJob` object that represents the tuning job.
@@ -376,6 +391,7 @@ def tune_model(
376391
tuned_model_location=tuned_model_location,
377392
model_display_name=model_display_name,
378393
tuning_evaluation_spec=tuning_evaluation_spec,
394+
accelerator_type=accelerator_type,
379395
)
380396

381397

@@ -393,6 +409,7 @@ def tune_model(
393409
tuned_model_location: Optional[str] = None,
394410
model_display_name: Optional[str] = None,
395411
tuning_evaluation_spec: Optional["TuningEvaluationSpec"] = None,
412+
accelerator_type: Optional[_ACCELERATOR_TYPE_TYPE] = None,
396413
) -> "_LanguageModelTuningJob":
397414
"""Tunes a model based on training data.
398415
@@ -421,6 +438,7 @@ def tune_model(
421438
tuned_model_location: GCP location where the tuned model should be deployed. Only "us-central1" is supported for now.
422439
model_display_name: Custom display name for the tuned model.
423440
tuning_evaluation_spec: Specification for the model evaluation during tuning.
441+
accelerator_type: Type of accelerator to use. Can be "TPU" or "GPU".
424442
425443
Returns:
426444
A `LanguageModelTuningJob` object that represents the tuning job.
@@ -441,6 +459,7 @@ def tune_model(
441459
tuned_model_location=tuned_model_location,
442460
model_display_name=model_display_name,
443461
tuning_evaluation_spec=tuning_evaluation_spec,
462+
accelerator_type=accelerator_type,
444463
)
445464
tuned_model = job.get_tuned_model()
446465
self._endpoint = tuned_model._endpoint
@@ -461,6 +480,7 @@ def tune_model(
461480
tuned_model_location: Optional[str] = None,
462481
model_display_name: Optional[str] = None,
463482
default_context: Optional[str] = None,
483+
accelerator_type: Optional[_ACCELERATOR_TYPE_TYPE] = None,
464484
) -> "_LanguageModelTuningJob":
465485
"""Tunes a model based on training data.
466486
@@ -485,6 +505,7 @@ def tune_model(
485505
tuned_model_location: GCP location where the tuned model should be deployed. Only "us-central1" is supported for now.
486506
model_display_name: Custom display name for the tuned model.
487507
default_context: The context to use for all training samples by default.
508+
accelerator_type: Type of accelerator to use. Can be "TPU" or "GPU".
488509
489510
Returns:
490511
A `LanguageModelTuningJob` object that represents the tuning job.
@@ -504,6 +525,7 @@ def tune_model(
504525
tuned_model_location=tuned_model_location,
505526
model_display_name=model_display_name,
506527
default_context=default_context,
528+
accelerator_type=accelerator_type,
507529
)
508530

509531

@@ -521,6 +543,7 @@ def tune_model(
521543
tuned_model_location: Optional[str] = None,
522544
model_display_name: Optional[str] = None,
523545
default_context: Optional[str] = None,
546+
accelerator_type: Optional[_ACCELERATOR_TYPE_TYPE] = None,
524547
) -> "_LanguageModelTuningJob":
525548
"""Tunes a model based on training data.
526549
@@ -549,6 +572,7 @@ def tune_model(
549572
tuned_model_location: GCP location where the tuned model should be deployed. Only "us-central1" is supported for now.
550573
model_display_name: Custom display name for the tuned model.
551574
default_context: The context to use for all training samples by default.
575+
accelerator_type: Type of accelerator to use. Can be "TPU" or "GPU".
552576
553577
Returns:
554578
A `LanguageModelTuningJob` object that represents the tuning job.
@@ -569,6 +593,7 @@ def tune_model(
569593
tuned_model_location=tuned_model_location,
570594
model_display_name=model_display_name,
571595
default_context=default_context,
596+
accelerator_type=accelerator_type,
572597
)
573598
tuned_model = job.get_tuned_model()
574599
self._endpoint = tuned_model._endpoint

0 commit comments

Comments
 (0)