45
45
artifact as gca_artifact ,
46
46
prediction_service as gca_prediction_service ,
47
47
context as gca_context ,
48
- endpoint as gca_endpoint ,
48
+ endpoint_v1 as gca_endpoint ,
49
49
pipeline_job as gca_pipeline_job ,
50
50
pipeline_state as gca_pipeline_state ,
51
51
deployed_model_ref_v1 ,
@@ -1030,6 +1030,11 @@ def get_endpoint_mock():
1030
1030
get_endpoint_mock .return_value = gca_endpoint .Endpoint (
1031
1031
display_name = "test-display-name" ,
1032
1032
name = test_constants .EndpointConstants ._TEST_ENDPOINT_NAME ,
1033
+ deployed_models = [
1034
+ gca_endpoint .DeployedModel (
1035
+ model = test_constants .ModelConstants ._TEST_MODEL_RESOURCE_NAME
1036
+ ),
1037
+ ],
1033
1038
)
1034
1039
yield get_endpoint_mock
1035
1040
@@ -2420,7 +2425,10 @@ def test_text_embedding_ga(self):
2420
2425
assert len (vector ) == _TEXT_EMBEDDING_VECTOR_LENGTH
2421
2426
assert vector == _TEST_TEXT_EMBEDDING_PREDICTION ["embeddings" ]["values" ]
2422
2427
2423
- def test_batch_prediction (self ):
2428
+ def test_batch_prediction (
2429
+ self ,
2430
+ get_endpoint_mock ,
2431
+ ):
2424
2432
"""Tests batch prediction."""
2425
2433
aiplatform .init (
2426
2434
project = _TEST_PROJECT ,
@@ -2447,7 +2455,29 @@ def test_batch_prediction(self):
2447
2455
model_parameters = {"temperature" : 0.1 },
2448
2456
)
2449
2457
mock_create .assert_called_once_with (
2450
- model_name = "publishers/google/models/text-bison@001" ,
2458
+ model_name = f"projects/{ _TEST_PROJECT } /locations/{ _TEST_LOCATION } /publishers/google/models/text-bison@001" ,
2459
+ job_display_name = None ,
2460
+ gcs_source = "gs://test-bucket/test_table.jsonl" ,
2461
+ gcs_destination_prefix = "gs://test-bucket/results/" ,
2462
+ model_parameters = {"temperature" : 0.1 },
2463
+ )
2464
+
2465
+ # Testing tuned model batch prediction
2466
+ tuned_model = language_models .TextGenerationModel (
2467
+ model_id = model ._model_id ,
2468
+ endpoint_name = test_constants .EndpointConstants ._TEST_ENDPOINT_NAME ,
2469
+ )
2470
+ with mock .patch .object (
2471
+ target = aiplatform .BatchPredictionJob ,
2472
+ attribute = "create" ,
2473
+ ) as mock_create :
2474
+ tuned_model .batch_predict (
2475
+ dataset = "gs://test-bucket/test_table.jsonl" ,
2476
+ destination_uri_prefix = "gs://test-bucket/results/" ,
2477
+ model_parameters = {"temperature" : 0.1 },
2478
+ )
2479
+ mock_create .assert_called_once_with (
2480
+ model_name = test_constants .ModelConstants ._TEST_MODEL_RESOURCE_NAME ,
2451
2481
job_display_name = None ,
2452
2482
gcs_source = "gs://test-bucket/test_table.jsonl" ,
2453
2483
gcs_destination_prefix = "gs://test-bucket/results/" ,
@@ -2481,7 +2511,7 @@ def test_batch_prediction_for_text_embedding(self):
2481
2511
model_parameters = {},
2482
2512
)
2483
2513
mock_create .assert_called_once_with (
2484
- model_name = " publishers/google/models/textembedding-gecko@001" ,
2514
+ model_name = f"projects/ { _TEST_PROJECT } /locations/ { _TEST_LOCATION } / publishers/google/models/textembedding-gecko@001" ,
2485
2515
job_display_name = None ,
2486
2516
gcs_source = "gs://test-bucket/test_table.jsonl" ,
2487
2517
gcs_destination_prefix = "gs://test-bucket/results/" ,
0 commit comments