Skip to content

Commit ccc5c85

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: deepcopy error from baseline_model in pairwiseMetric
PiperOrigin-RevId: 693800802
1 parent f1da73b commit ccc5c85

File tree

2 files changed

+66
-5
lines changed

2 files changed

+66
-5
lines changed

tests/unit/vertexai/test_evaluation.py

+48
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,20 @@
317317
)
318318
),
319319
)
320+
_MOCK_PAIRWISE_RESULT = (
321+
gapic_evaluation_service_types.EvaluateInstancesResponse(
322+
pairwise_metric_result=gapic_evaluation_service_types.PairwiseMetricResult(
323+
pairwise_choice=gapic_evaluation_service_types.PairwiseChoice.BASELINE,
324+
explanation="explanation",
325+
)
326+
),
327+
gapic_evaluation_service_types.EvaluateInstancesResponse(
328+
pairwise_metric_result=gapic_evaluation_service_types.PairwiseMetricResult(
329+
pairwise_choice=gapic_evaluation_service_types.PairwiseChoice.BASELINE,
330+
explanation="explanation",
331+
)
332+
),
333+
)
320334
_MOCK_SUMMARIZATION_QUALITY_RESULT = (
321335
gapic_evaluation_service_types.EvaluateInstancesResponse(
322336
pointwise_metric_result=gapic_evaluation_service_types.PointwiseMetricResult(
@@ -1216,6 +1230,40 @@ def test_evaluate_baseline_response_column_and_baseline_model_provided(self):
12161230
test_eval_task.evaluate(model=mock.MagicMock())
12171231
_TEST_PAIRWISE_METRIC._baseline_model = None
12181232

1233+
def test_evaluate_baseline_model_provided_but_no_baseline_response_column(self):
1234+
mock_baseline_model = mock.create_autospec(
1235+
generative_models.GenerativeModel, instance=True
1236+
)
1237+
mock_baseline_model.generate_content.return_value = (
1238+
_MOCK_MODEL_INFERENCE_RESPONSE
1239+
)
1240+
mock_baseline_model._model_name = "publishers/google/model/gemini-pro"
1241+
_TEST_PAIRWISE_METRIC._baseline_model = mock_baseline_model
1242+
1243+
mock_candidate_model = mock.create_autospec(
1244+
generative_models.GenerativeModel, instance=True
1245+
)
1246+
mock_candidate_model.generate_content.return_value = (
1247+
_MOCK_MODEL_INFERENCE_RESPONSE
1248+
)
1249+
mock_candidate_model._model_name = "publishers/google/model/gemini-1.0-pro"
1250+
mock_metric_results = _MOCK_PAIRWISE_RESULT
1251+
eval_dataset = _TEST_EVAL_DATASET_WITHOUT_RESPONSE.copy(deep=True)
1252+
test_eval_task = EvalTask(
1253+
dataset=eval_dataset,
1254+
metrics=[_TEST_PAIRWISE_METRIC],
1255+
)
1256+
with mock.patch.object(
1257+
target=gapic_evaluation_services.EvaluationServiceClient,
1258+
attribute="evaluate_instances",
1259+
side_effect=mock_metric_results,
1260+
):
1261+
test_result = test_eval_task.evaluate(
1262+
model=mock_candidate_model,
1263+
)
1264+
_TEST_PAIRWISE_METRIC._baseline_model = None
1265+
assert test_result.summary_metrics["row_count"] == 2
1266+
12191267
def test_evaluate_response_column_and_model_not_provided(self):
12201268
test_eval_task = EvalTask(
12211269
dataset=_TEST_EVAL_DATASET_SINGLE,

vertexai/evaluation/_evaluation.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -856,15 +856,28 @@ def evaluate(
856856
"""
857857
_validate_metrics(metrics)
858858
metrics = _convert_metric_prompt_template_example(metrics)
859-
859+
copied_metrics = []
860+
for metric in metrics:
861+
if isinstance(metric, pairwise_metric.PairwiseMetric):
862+
copied_metrics.append(
863+
pairwise_metric.PairwiseMetric(
864+
metric=metric.metric_name,
865+
metric_prompt_template=metric.metric_prompt_template,
866+
baseline_model=metric.baseline_model,
867+
)
868+
)
869+
else:
870+
copied_metrics.append(copy.deepcopy(metric))
860871
evaluation_run_config = evaluation_base.EvaluationRunConfig(
861872
dataset=dataset.copy(deep=True),
862-
metrics=copy.deepcopy(metrics),
873+
metrics=copied_metrics,
863874
metric_column_mapping=copy.deepcopy(metric_column_mapping),
864875
client=utils.create_evaluation_service_client(),
865-
evaluation_service_qps=evaluation_service_qps
866-
if evaluation_service_qps
867-
else constants.QuotaLimit.EVAL_SERVICE_QPS,
876+
evaluation_service_qps=(
877+
evaluation_service_qps
878+
if evaluation_service_qps
879+
else constants.QuotaLimit.EVAL_SERVICE_QPS
880+
),
868881
retry_timeout=retry_timeout,
869882
)
870883

0 commit comments

Comments
 (0)