@@ -4526,7 +4526,48 @@ def test_text_embedding(self):
4526
4526
== expected_embedding ["statistics" ]["truncated" ]
4527
4527
)
4528
4528
4529
- def test_text_embedding_preview_count_tokens (self ):
4529
+ def test_text_embedding_count_tokens_ga (self ):
4530
+ """Tests the text embedding model."""
4531
+ aiplatform .init (
4532
+ project = _TEST_PROJECT ,
4533
+ location = _TEST_LOCATION ,
4534
+ )
4535
+ with mock .patch .object (
4536
+ target = model_garden_service_client .ModelGardenServiceClient ,
4537
+ attribute = "get_publisher_model" ,
4538
+ return_value = gca_publisher_model .PublisherModel (
4539
+ _TEXT_EMBEDDING_GECKO_PUBLISHER_MODEL_DICT
4540
+ ),
4541
+ ):
4542
+ model = language_models .TextEmbeddingModel .from_pretrained (
4543
+ "textembedding-gecko@001"
4544
+ )
4545
+
4546
+ gca_count_tokens_response = (
4547
+ gca_prediction_service_v1beta1 .CountTokensResponse (
4548
+ total_tokens = _TEST_COUNT_TOKENS_RESPONSE ["total_tokens" ],
4549
+ total_billable_characters = _TEST_COUNT_TOKENS_RESPONSE [
4550
+ "total_billable_characters"
4551
+ ],
4552
+ )
4553
+ )
4554
+
4555
+ with mock .patch .object (
4556
+ target = prediction_service_client_v1beta1 .PredictionServiceClient ,
4557
+ attribute = "count_tokens" ,
4558
+ return_value = gca_count_tokens_response ,
4559
+ ):
4560
+ response = model .count_tokens (["What is life?" ])
4561
+
4562
+ assert (
4563
+ response .total_tokens == _TEST_COUNT_TOKENS_RESPONSE ["total_tokens" ]
4564
+ )
4565
+ assert (
4566
+ response .total_billable_characters
4567
+ == _TEST_COUNT_TOKENS_RESPONSE ["total_billable_characters" ]
4568
+ )
4569
+
4570
+ def test_text_embedding_count_tokens_preview (self ):
4530
4571
"""Tests the text embedding model."""
4531
4572
aiplatform .init (
4532
4573
project = _TEST_PROJECT ,
0 commit comments