Skip to content

Commit eaf5d81

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: LLM - Enable tuning eval TensorBoard without evaluation data
PiperOrigin-RevId: 572460871
1 parent bb8388e commit eaf5d81

File tree

2 files changed

+58
-7
lines changed

2 files changed

+58
-7
lines changed

tests/unit/aiplatform/test_language_models.py

+50
Original file line numberDiff line numberDiff line change
@@ -1638,6 +1638,56 @@ def test_tune_text_generation_model_ga(
16381638
== test_constants.EndpointConstants._TEST_ENDPOINT_NAME
16391639
)
16401640

1641+
@pytest.mark.parametrize(
1642+
"job_spec",
1643+
[_TEST_PIPELINE_SPEC_JSON],
1644+
)
1645+
@pytest.mark.parametrize(
1646+
"mock_request_urlopen",
1647+
["https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"],
1648+
indirect=True,
1649+
)
1650+
def test_tune_text_generation_model_evaluation_with_only_tensorboard(
1651+
self,
1652+
mock_pipeline_service_create,
1653+
mock_pipeline_job_get,
1654+
mock_pipeline_bucket_exists,
1655+
job_spec,
1656+
mock_load_yaml_and_json,
1657+
mock_gcs_from_string,
1658+
mock_gcs_upload,
1659+
mock_request_urlopen,
1660+
mock_get_tuned_model,
1661+
):
1662+
"""Tests tuning the text generation model."""
1663+
with mock.patch.object(
1664+
target=model_garden_service_client.ModelGardenServiceClient,
1665+
attribute="get_publisher_model",
1666+
return_value=gca_publisher_model.PublisherModel(
1667+
_TEXT_BISON_PUBLISHER_MODEL_DICT
1668+
),
1669+
):
1670+
model = language_models.TextGenerationModel.from_pretrained(
1671+
"text-bison@001"
1672+
)
1673+
1674+
tuning_job_location = "europe-west4"
1675+
tensorboard_name = f"projects/{_TEST_PROJECT}/locations/{tuning_job_location}/tensorboards/123"
1676+
1677+
model.tune_model(
1678+
training_data=_TEST_TEXT_BISON_TRAINING_DF,
1679+
tuning_job_location=tuning_job_location,
1680+
tuned_model_location="us-central1",
1681+
tuning_evaluation_spec=preview_language_models.TuningEvaluationSpec(
1682+
tensorboard=tensorboard_name,
1683+
),
1684+
)
1685+
call_kwargs = mock_pipeline_service_create.call_args[1]
1686+
pipeline_arguments = call_kwargs[
1687+
"pipeline_job"
1688+
].runtime_config.parameter_values
1689+
assert pipeline_arguments["tensorboard_resource_id"] == tensorboard_name
1690+
16411691
@pytest.mark.parametrize(
16421692
"job_spec",
16431693
[_TEST_PIPELINE_SPEC_JSON],

vertexai/language_models/_language_models.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -214,13 +214,14 @@ def tune_model(
214214
tuning_parameters["learning_rate_multiplier"] = learning_rate_multiplier
215215
eval_spec = tuning_evaluation_spec
216216
if eval_spec is not None:
217-
if isinstance(eval_spec.evaluation_data, str):
218-
if eval_spec.evaluation_data.startswith("gs://"):
219-
tuning_parameters["evaluation_data_uri"] = eval_spec.evaluation_data
217+
if eval_spec.evaluation_data:
218+
if isinstance(eval_spec.evaluation_data, str):
219+
if eval_spec.evaluation_data.startswith("gs://"):
220+
tuning_parameters["evaluation_data_uri"] = eval_spec.evaluation_data
221+
else:
222+
raise ValueError("evaluation_data should be a GCS URI")
220223
else:
221-
raise ValueError("evaluation_data should be a GCS URI")
222-
else:
223-
raise TypeError("evaluation_data should be a URI string")
224+
raise TypeError("evaluation_data should be a URI string")
224225
if eval_spec.evaluation_interval is not None:
225226
tuning_parameters["evaluation_interval"] = eval_spec.evaluation_interval
226227
if eval_spec.enable_early_stopping is not None:
@@ -648,7 +649,7 @@ class TuningEvaluationSpec:
648649

649650
__module__ = "vertexai.language_models"
650651

651-
evaluation_data: str
652+
evaluation_data: Optional[str] = None
652653
evaluation_interval: Optional[int] = None
653654
enable_early_stopping: Optional[bool] = None
654655
tensorboard: Optional[Union[aiplatform.Tensorboard, str]] = None

0 commit comments

Comments
 (0)