Skip to content

Commit e5daae9

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: LLM - Added support for the max_context_length tuning parameter
PiperOrigin-RevId: 615362109
1 parent 613ce69 commit e5daae9

File tree

3 files changed

+26
-1
lines changed

3 files changed

+26
-1
lines changed

tests/unit/aiplatform/test_language_models.py

+7
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,11 @@ def reverse_string_2(s):""",
606606
"parameterType": "NUMBER_DOUBLE",
607607
},
608608
"location": {"parameterType": "STRING"},
609+
"max_context_length": {
610+
"defaultValue": "",
611+
"isOptional": True,
612+
"parameterType": "STRING",
613+
},
609614
"model_display_name": {"parameterType": "STRING"},
610615
"project": {"parameterType": "STRING"},
611616
"tensorboard_resource_id": {
@@ -2271,6 +2276,7 @@ def test_tune_text_generation_model_ga(
22712276
tensorboard=tensorboard_name,
22722277
),
22732278
accelerator_type="TPU",
2279+
max_context_length="32k",
22742280
)
22752281
call_kwargs = mock_pipeline_service_create.call_args[1]
22762282
pipeline_arguments = call_kwargs[
@@ -2288,6 +2294,7 @@ def test_tune_text_generation_model_ga(
22882294
assert pipeline_arguments["tensorboard_resource_id"] == tensorboard_name
22892295
assert pipeline_arguments["large_model_reference"] == "text-bison@001"
22902296
assert pipeline_arguments["accelerator_type"] == "TPU"
2297+
assert pipeline_arguments["max_context_length"] == "32k"
22912298
assert (
22922299
call_kwargs["pipeline_job"].encryption_spec.kms_key_name
22932300
== _TEST_ENCRYPTION_KEY_NAME

vertexai/language_models/_distillation.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def distill_from(
2121
evaluation_spec: Optional[tuning.TuningEvaluationSpec] = None,
2222
accelerator_type: Optional[tuning._ACCELERATOR_TYPE_TYPE] = None,
2323
model_display_name: Optional[str] = None,
24+
max_context_length: Optional[str] = None,
2425
):
2526
"""Tunes a smaller model with help from another bigger model.
2627
@@ -32,6 +33,8 @@ def distill_from(
3233
evaluation_spec: Specification for the model evaluation during tuning.
3334
accelerator_type: Type of accelerator to use. Can be "TPU" or "GPU".
3435
model_display_name: Custom display name for the tuned model.
36+
max_context_length: The max context length used for tuning.
37+
Can be either '8k' or '32k'
3538
3639
Returns:
3740
A tuning job for distillation.
@@ -86,6 +89,8 @@ def distill_from(
8689
pipeline_arguments[
8790
"encryption_spec_key_name"
8891
] = aiplatform_initializer.global_config.encryption_spec_key_name
92+
if max_context_length is not None:
93+
pipeline_arguments["max_context_length"] = max_context_length
8994
if model_display_name is None:
9095
model_display_name = (
9196
f"{student_short_model_id}"
@@ -94,7 +99,6 @@ def distill_from(
9499
pipeline_arguments["model_display_name"] = model_display_name
95100
# # Not exposing these parameters:
96101
# temperature: Optional[float] = None,
97-
# max_context_length: Optional[int] = None,
98102
# tpu_training_skip_cmek: Optional[bool] = None,
99103
# api_endpoint: Optional[str] = None,
100104
# version: Optional[str] = None,

vertexai/language_models/_language_models.py

+14
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ def tune_model(
228228
tuning_evaluation_spec: Optional["TuningEvaluationSpec"] = None,
229229
default_context: Optional[str] = None,
230230
accelerator_type: Optional[_ACCELERATOR_TYPE_TYPE] = None,
231+
max_context_length: Optional[str] = None,
231232
) -> "_LanguageModelTuningJob":
232233
"""Tunes a model based on training data.
233234
@@ -253,6 +254,8 @@ def tune_model(
253254
tuning_evaluation_spec: Specification for the model evaluation during tuning.
254255
default_context: The context to use for all training samples by default.
255256
accelerator_type: Type of accelerator to use. Can be "TPU" or "GPU".
257+
max_context_length: The max context length used for tuning.
258+
Can be either '8k' or '32k'
256259
257260
Returns:
258261
A `LanguageModelTuningJob` object that represents the tuning job.
@@ -313,6 +316,9 @@ def tune_model(
313316
)
314317
tuning_parameters["accelerator_type"] = accelerator_type
315318

319+
if max_context_length:
320+
tuning_parameters["max_context_length"] = max_context_length
321+
316322
return self._tune_model(
317323
training_data=training_data,
318324
tuning_parameters=tuning_parameters,
@@ -600,6 +606,7 @@ def tune_model(
600606
model_display_name: Optional[str] = None,
601607
tuning_evaluation_spec: Optional["TuningEvaluationSpec"] = None,
602608
accelerator_type: Optional[_ACCELERATOR_TYPE_TYPE] = None,
609+
max_context_length: Optional[str] = None,
603610
) -> "_LanguageModelTuningJob":
604611
"""Tunes a model based on training data.
605612
@@ -621,6 +628,8 @@ def tune_model(
621628
model_display_name: Custom display name for the tuned model.
622629
tuning_evaluation_spec: Specification for the model evaluation during tuning.
623630
accelerator_type: Type of accelerator to use. Can be "TPU" or "GPU".
631+
max_context_length: The max context length used for tuning.
632+
Can be either '8k' or '32k'
624633
625634
Returns:
626635
A `LanguageModelTuningJob` object that represents the tuning job.
@@ -641,6 +650,7 @@ def tune_model(
641650
model_display_name=model_display_name,
642651
tuning_evaluation_spec=tuning_evaluation_spec,
643652
accelerator_type=accelerator_type,
653+
max_context_length=max_context_length,
644654
)
645655

646656

@@ -659,6 +669,7 @@ def tune_model(
659669
model_display_name: Optional[str] = None,
660670
tuning_evaluation_spec: Optional["TuningEvaluationSpec"] = None,
661671
accelerator_type: Optional[_ACCELERATOR_TYPE_TYPE] = None,
672+
max_context_length: Optional[str] = None,
662673
) -> "_LanguageModelTuningJob":
663674
"""Tunes a model based on training data.
664675
@@ -687,6 +698,8 @@ def tune_model(
687698
model_display_name: Custom display name for the tuned model.
688699
tuning_evaluation_spec: Specification for the model evaluation during tuning.
689700
accelerator_type: Type of accelerator to use. Can be "TPU" or "GPU".
701+
max_context_length: The max context length used for tuning.
702+
Can be either '8k' or '32k'
690703
691704
Returns:
692705
A `LanguageModelTuningJob` object that represents the tuning job.
@@ -708,6 +721,7 @@ def tune_model(
708721
model_display_name=model_display_name,
709722
tuning_evaluation_spec=tuning_evaluation_spec,
710723
accelerator_type=accelerator_type,
724+
max_context_length=max_context_length,
711725
)
712726
tuned_model = job.get_tuned_model()
713727
self._endpoint = tuned_model._endpoint

0 commit comments

Comments
 (0)