14
14
# See the License for the specific language governing permissions and
15
15
# limitations under the License.
16
16
#
17
+
18
+ import threading
19
+ import time
17
20
from unittest import mock
18
21
19
22
from google .cloud import aiplatform
32
35
from vertexai .preview .evaluation import _evaluation
33
36
from vertexai .preview .evaluation import utils
34
37
from vertexai .preview .evaluation .metrics import (
35
- _summarization_quality ,
38
+ _pairwise_question_answering_quality ,
36
39
)
37
40
from vertexai .preview .evaluation .metrics import (
38
41
_pairwise_summarization_quality ,
39
42
)
43
+ from vertexai .preview .evaluation .metrics import _rouge
40
44
from vertexai .preview .evaluation .metrics import (
41
- _pairwise_question_answering_quality ,
42
- )
43
- from vertexai .preview .evaluation .metrics import (
44
- _rouge ,
45
+ _summarization_quality ,
45
46
)
46
47
import numpy as np
47
48
import pandas as pd
48
49
import pytest
49
50
51
+
50
52
_TEST_PROJECT = "test-project"
51
53
_TEST_LOCATION = "us-central1"
52
54
_TEST_METRICS = (
221
223
)
222
224
223
225
224
- @pytest .fixture
225
- def mock_async_event_loop ():
226
- with mock .patch ("asyncio.get_event_loop" ) as mock_async_event_loop :
227
- yield mock_async_event_loop
228
-
229
-
230
226
@pytest .fixture
231
227
def mock_experiment_tracker ():
232
228
with mock .patch .object (
@@ -267,32 +263,6 @@ def test_create_eval_task(self):
267
263
assert test_eval_task .reference_column_name == test_reference_column_name
268
264
assert test_eval_task .response_column_name == test_response_column_name
269
265
270
- def test_evaluate_saved_response (self , mock_async_event_loop ):
271
- eval_dataset = _TEST_EVAL_DATASET
272
- test_metrics = _TEST_METRICS
273
- mock_summary_metrics = {
274
- "row_count" : 2 ,
275
- "mock_metric/mean" : 0.5 ,
276
- "mock_metric/std" : 0.5 ,
277
- }
278
- mock_metrics_table = pd .DataFrame (
279
- {
280
- "response" : ["test" , "text" ],
281
- "reference" : ["test" , "ref" ],
282
- "mock_metric" : [1.0 , 0.0 ],
283
- }
284
- )
285
- mock_async_event_loop .return_value .run_until_complete .return_value = (
286
- mock_summary_metrics ,
287
- mock_metrics_table ,
288
- )
289
-
290
- test_eval_task = evaluation .EvalTask (dataset = eval_dataset , metrics = test_metrics )
291
- test_result = test_eval_task .evaluate ()
292
-
293
- assert test_result .summary_metrics == mock_summary_metrics
294
- assert test_result .metrics_table .equals (mock_metrics_table )
295
-
296
266
@pytest .mark .parametrize ("api_transport" , ["grpc" , "rest" ])
297
267
def test_compute_automatic_metrics (self , api_transport ):
298
268
aiplatform .init (
@@ -310,7 +280,7 @@ def test_compute_automatic_metrics(self, api_transport):
310
280
test_eval_task = evaluation .EvalTask (dataset = eval_dataset , metrics = test_metrics )
311
281
mock_metric_results = _MOCK_EXACT_MATCH_RESULT
312
282
with mock .patch .object (
313
- target = gapic_evaluation_services .EvaluationServiceAsyncClient ,
283
+ target = gapic_evaluation_services .EvaluationServiceClient ,
314
284
attribute = "evaluate_instances" ,
315
285
side_effect = mock_metric_results ,
316
286
):
@@ -343,7 +313,7 @@ def test_compute_pointwise_metrics(self, api_transport):
343
313
test_eval_task = evaluation .EvalTask (dataset = eval_dataset , metrics = test_metrics )
344
314
mock_metric_results = _MOCK_FLUENCY_RESULT
345
315
with mock .patch .object (
346
- target = gapic_evaluation_services .EvaluationServiceAsyncClient ,
316
+ target = gapic_evaluation_services .EvaluationServiceClient ,
347
317
attribute = "evaluate_instances" ,
348
318
side_effect = mock_metric_results ,
349
319
):
@@ -398,7 +368,7 @@ def test_compute_pointwise_metrics_with_custom_metric_spec(self, api_transport):
398
368
test_eval_task = evaluation .EvalTask (dataset = eval_dataset , metrics = test_metrics )
399
369
mock_metric_results = _MOCK_SUMMARIZATION_QUALITY_RESULT
400
370
with mock .patch .object (
401
- target = gapic_evaluation_services .EvaluationServiceAsyncClient ,
371
+ target = gapic_evaluation_services .EvaluationServiceClient ,
402
372
attribute = "evaluate_instances" ,
403
373
side_effect = mock_metric_results ,
404
374
):
@@ -465,7 +435,7 @@ def test_compute_automatic_metrics_with_custom_metric_spec(self, api_transport):
465
435
]
466
436
test_eval_task = evaluation .EvalTask (dataset = eval_dataset , metrics = test_metrics )
467
437
with mock .patch .object (
468
- target = gapic_evaluation_services .EvaluationServiceAsyncClient ,
438
+ target = gapic_evaluation_services .EvaluationServiceClient ,
469
439
attribute = "evaluate_instances" ,
470
440
side_effect = _MOCK_ROUGE_RESULT ,
471
441
) as mock_evaluate_instances :
@@ -527,7 +497,7 @@ def test_compute_pairwise_metrics_with_model_inference(self, api_transport):
527
497
test_eval_task = evaluation .EvalTask (dataset = eval_dataset , metrics = test_metrics )
528
498
mock_metric_results = _MOCK_PAIRWISE_SUMMARIZATION_QUALITY_RESULT
529
499
with mock .patch .object (
530
- target = gapic_evaluation_services .EvaluationServiceAsyncClient ,
500
+ target = gapic_evaluation_services .EvaluationServiceClient ,
531
501
attribute = "evaluate_instances" ,
532
502
side_effect = mock_metric_results ,
533
503
):
@@ -613,7 +583,7 @@ def test_compute_pairwise_metrics_without_inference(self, api_transport):
613
583
test_eval_task = evaluation .EvalTask (dataset = eval_dataset , metrics = test_metrics )
614
584
mock_metric_results = _MOCK_PAIRWISE_SUMMARIZATION_QUALITY_RESULT
615
585
with mock .patch .object (
616
- target = gapic_evaluation_services .EvaluationServiceAsyncClient ,
586
+ target = gapic_evaluation_services .EvaluationServiceClient ,
617
587
attribute = "evaluate_instances" ,
618
588
side_effect = mock_metric_results ,
619
589
):
@@ -869,9 +839,9 @@ def setup_method(self):
869
839
def teardown_method (self ):
870
840
initializer .global_pool .shutdown (wait = True )
871
841
872
- def test_create_evaluation_service_async_client (self ):
873
- client = utils .create_evaluation_service_async_client ()
874
- assert isinstance (client , utils ._EvaluationServiceAsyncClientWithOverride )
842
+ def test_create_evaluation_service_client (self ):
843
+ client = utils .create_evaluation_service_client ()
844
+ assert isinstance (client , utils ._EvaluationServiceClientWithOverride )
875
845
876
846
def test_load_dataset_from_dataframe (self ):
877
847
data = {"col1" : [1 , 2 ], "col2" : ["a" , "b" ]}
@@ -924,6 +894,57 @@ def test_load_dataset_from_bigquery(self):
924
894
assert isinstance (loaded_df , pd .DataFrame )
925
895
assert loaded_df .equals (_TEST_EVAL_DATASET )
926
896
897
+ def test_initialization (self ):
898
+ limiter = utils .RateLimiter (rate = 2 )
899
+ assert limiter .seconds_per_event == 0.5
900
+
901
+ with pytest .raises (ValueError , match = "Rate must be a positive number" ):
902
+ utils .RateLimiter (- 1 )
903
+ with pytest .raises (ValueError , match = "Rate must be a positive number" ):
904
+ utils .RateLimiter (0 )
905
+
906
+ def test_admit (self ):
907
+ rate_limiter = utils .RateLimiter (rate = 2 )
908
+
909
+ assert rate_limiter ._admit () == 0
910
+
911
+ time .sleep (0.1 )
912
+ delay = rate_limiter ._admit ()
913
+ assert delay == pytest .approx (0.4 , 0.01 )
914
+
915
+ time .sleep (0.5 )
916
+ delay = rate_limiter ._admit ()
917
+ assert delay == 0
918
+
919
+ def test_sleep_and_advance (self ):
920
+ rate_limiter = utils .RateLimiter (rate = 2 )
921
+
922
+ start_time = time .time ()
923
+ rate_limiter .sleep_and_advance ()
924
+ assert (time .time () - start_time ) < 0.1
925
+
926
+ start_time = time .time ()
927
+ rate_limiter .sleep_and_advance ()
928
+ assert (time .time () - start_time ) >= 0.5
929
+
930
+ def test_thread_safety (self ):
931
+ rate_limiter = utils .RateLimiter (rate = 2 )
932
+ start_time = time .time ()
933
+
934
+ def target ():
935
+ rate_limiter .sleep_and_advance ()
936
+
937
+ threads = [threading .Thread (target = target ) for _ in range (10 )]
938
+ for thread in threads :
939
+ thread .start ()
940
+ for thread in threads :
941
+ thread .join ()
942
+
943
+ # Verify that the total minimum time should be 4.5 seconds
944
+ # (9 intervals of 0.5 seconds each).
945
+ total_time = time .time () - start_time
946
+ assert total_time >= 4.5
947
+
927
948
928
949
class TestPromptTemplate :
929
950
def test_init (self ):
0 commit comments