File tree 3 files changed +56
-1
lines changed
3 files changed +56
-1
lines changed Original file line number Diff line number Diff line change 33
33
)
34
34
from vertexai .preview .language_models import (
35
35
ChatModel ,
36
+ CodeGenerationModel ,
36
37
InputOutputTextPair ,
37
38
TextGenerationModel ,
38
39
TextGenerationResponse ,
@@ -434,6 +435,26 @@ def test_batch_prediction_for_textembedding(self):
434
435
435
436
assert gapic_job .state == gca_job_state .JobState .JOB_STATE_SUCCEEDED
436
437
438
+ def test_batch_prediction_for_code_generation (self ):
439
+ source_uri = "gs://ucaip-samples-us-central1/model/llm/batch_prediction/code-bison.batch_prediction_prompts.1.jsonl"
440
+ destination_uri_prefix = "gs://ucaip-samples-us-central1/model/llm/batch_prediction/predictions/code-bison@001_"
441
+
442
+ aiplatform .init (project = e2e_base ._PROJECT , location = e2e_base ._LOCATION )
443
+
444
+ model = CodeGenerationModel .from_pretrained ("code-bison@001" )
445
+ job = model .batch_predict (
446
+ dataset = source_uri ,
447
+ destination_uri_prefix = destination_uri_prefix ,
448
+ model_parameters = {"temperature" : 0 },
449
+ )
450
+
451
+ job .wait_for_resource_creation ()
452
+ job .wait ()
453
+ gapic_job = job ._gca_resource
454
+ job .delete ()
455
+
456
+ assert gapic_job .state == gca_job_state .JobState .JOB_STATE_SUCCEEDED
457
+
437
458
def test_code_generation_streaming (self ):
438
459
aiplatform .init (project = e2e_base ._PROJECT , location = e2e_base ._LOCATION )
439
460
Original file line number Diff line number Diff line change @@ -4324,6 +4324,36 @@ def test_batch_prediction(
4324
4324
model_parameters = {"temperature" : 0.1 },
4325
4325
)
4326
4326
4327
+ def test_batch_prediction_for_code_generation (self ):
4328
+ """Tests batch prediction."""
4329
+ with mock .patch .object (
4330
+ target = model_garden_service_client .ModelGardenServiceClient ,
4331
+ attribute = "get_publisher_model" ,
4332
+ return_value = gca_publisher_model .PublisherModel (
4333
+ _CODE_GENERATION_BISON_PUBLISHER_MODEL_DICT
4334
+ ),
4335
+ ):
4336
+ model = preview_language_models .CodeGenerationModel .from_pretrained (
4337
+ "code-bison@001"
4338
+ )
4339
+
4340
+ with mock .patch .object (
4341
+ target = aiplatform .BatchPredictionJob ,
4342
+ attribute = "create" ,
4343
+ ) as mock_create :
4344
+ model .batch_predict (
4345
+ dataset = "gs://test-bucket/test_table.jsonl" ,
4346
+ destination_uri_prefix = "gs://test-bucket/results/" ,
4347
+ model_parameters = {},
4348
+ )
4349
+ mock_create .assert_called_once_with (
4350
+ model_name = f"projects/{ _TEST_PROJECT } /locations/{ _TEST_LOCATION } /publishers/google/models/code-bison@001" ,
4351
+ job_display_name = None ,
4352
+ gcs_source = "gs://test-bucket/test_table.jsonl" ,
4353
+ gcs_destination_prefix = "gs://test-bucket/results/" ,
4354
+ model_parameters = {},
4355
+ )
4356
+
4327
4357
def test_batch_prediction_for_text_embedding (self ):
4328
4358
"""Tests batch prediction."""
4329
4359
aiplatform .init (
Original file line number Diff line number Diff line change @@ -3366,7 +3366,11 @@ def count_tokens(
3366
3366
)
3367
3367
3368
3368
3369
- class CodeGenerationModel (_CodeGenerationModel , _TunableTextModelMixin ):
3369
+ class CodeGenerationModel (
3370
+ _CodeGenerationModel ,
3371
+ _TunableTextModelMixin ,
3372
+ _ModelWithBatchPredict ,
3373
+ ):
3370
3374
pass
3371
3375
3372
3376
You can’t perform that action at this time.
0 commit comments