Skip to content

Commit 4135810

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add COMET and MetricX to the evaluation SDK
PiperOrigin-RevId: 696878382
1 parent c39334a commit 4135810

File tree

7 files changed

+408
-45
lines changed

7 files changed

+408
-45
lines changed

tests/unit/vertexai/test_evaluation.py

+144-19
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,16 @@
9999
evaluation_steps=_EVALUATION_STEPS,
100100
),
101101
)
102+
_TEST_COMET = pointwise_metric.Comet(
103+
version="COMET_22_SRC_REF",
104+
source_language="en",
105+
target_language="zh",
106+
)
107+
_TEST_METRICX = pointwise_metric.MetricX(
108+
version="METRICX_24_SRC",
109+
source_language="en",
110+
target_language="zh",
111+
)
102112
_TEST_METRICS = (
103113
"exact_match",
104114
"bleu",
@@ -139,6 +149,7 @@
139149
"reference": ["test", "ref"],
140150
"context": ["test", "context"],
141151
"instruction": ["test", "instruction"],
152+
"source": ["test", "source"],
142153
}
143154
)
144155
_TEST_EVAL_DATASET_SINGLE = pd.DataFrame({"prompt": ["test_prompt", "text_prompt"]})
@@ -305,7 +316,7 @@
305316
)
306317
),
307318
)
308-
_MOCK_POINTEWISE_RESULT = (
319+
_MOCK_POINTWISE_RESULT = (
309320
gapic_evaluation_service_types.EvaluateInstancesResponse(
310321
pointwise_metric_result=gapic_evaluation_service_types.PointwiseMetricResult(
311322
score=5, explanation="explanation"
@@ -423,6 +434,29 @@
423434
)
424435
),
425436
)
437+
_EXPECTED_COLUMN_MAPPING = {
438+
"context": "context",
439+
"reference": "reference",
440+
"response": "response",
441+
"instruction": "instruction",
442+
"prompt": "prompt",
443+
"source": "source",
444+
}
445+
_MOCK_MODEL_BASED_TRANSLATION_RESULT = (
446+
# The order of the responses is important.
447+
gapic_evaluation_service_types.EvaluateInstancesResponse(
448+
comet_result=gapic_evaluation_service_types.CometResult(score=0.1)
449+
),
450+
gapic_evaluation_service_types.EvaluateInstancesResponse(
451+
metricx_result=gapic_evaluation_service_types.MetricxResult(score=5)
452+
),
453+
gapic_evaluation_service_types.EvaluateInstancesResponse(
454+
comet_result=gapic_evaluation_service_types.CometResult(score=0.9)
455+
),
456+
gapic_evaluation_service_types.EvaluateInstancesResponse(
457+
metricx_result=gapic_evaluation_service_types.MetricxResult(score=20)
458+
),
459+
)
426460

427461

428462
@pytest.fixture(scope="module")
@@ -465,16 +499,10 @@ def test_create_eval_task(self):
465499
assert test_eval_task.dataset.equals(_TEST_EVAL_DATASET_ALL_INCLUDED)
466500
assert test_eval_task.metrics == _TEST_METRICS
467501
assert test_eval_task.experiment == _TEST_EXPERIMENT
468-
assert test_eval_task._metric_column_mapping == {
469-
"context": "context",
470-
"reference": "reference",
471-
"response": "response",
472-
"instruction": "instruction",
473-
"prompt": "prompt",
474-
}
502+
assert test_eval_task._metric_column_mapping == _EXPECTED_COLUMN_MAPPING
475503

476504
@pytest.mark.parametrize("api_transport", ["grpc", "rest"])
477-
def test_compute_automatic_metrics(self, api_transport):
505+
def test_compute_exact_match_metric(self, api_transport):
478506
aiplatform.init(
479507
project=_TEST_PROJECT,
480508
location=_TEST_LOCATION,
@@ -521,7 +549,7 @@ def test_compute_pointwise_metrics(self, api_transport):
521549
test_eval_task = EvalTask(
522550
dataset=_TEST_EVAL_DATASET_ALL_INCLUDED, metrics=test_metrics
523551
)
524-
mock_metric_results = _MOCK_POINTEWISE_RESULT
552+
mock_metric_results = _MOCK_POINTWISE_RESULT
525553
with mock.patch.object(
526554
target=gapic_evaluation_services.EvaluationServiceClient,
527555
attribute="evaluate_instances",
@@ -543,6 +571,7 @@ def test_compute_pointwise_metrics(self, api_transport):
543571
"reference",
544572
"test_pointwise_metric/score",
545573
"test_pointwise_metric/explanation",
574+
"source",
546575
]
547576
)
548577
assert test_result.metrics_table["response"].equals(
@@ -567,7 +596,7 @@ def test_compute_pointwise_metrics_free_string(self):
567596
metrics=[_TEST_POINTWISE_METRIC_FREE_STRING],
568597
metric_column_mapping={"abc": "prompt"},
569598
)
570-
mock_metric_results = _MOCK_POINTEWISE_RESULT
599+
mock_metric_results = _MOCK_POINTWISE_RESULT
571600
with mock.patch.object(
572601
target=gapic_evaluation_services.EvaluationServiceClient,
573602
attribute="evaluate_instances",
@@ -589,6 +618,7 @@ def test_compute_pointwise_metrics_free_string(self):
589618
"reference",
590619
"test_pointwise_metric_str/score",
591620
"test_pointwise_metric_str/explanation",
621+
"source",
592622
]
593623
)
594624
assert test_result.metrics_table["response"].equals(
@@ -695,6 +725,7 @@ def test_compute_pointwise_metrics_without_model_inference(self, api_transport):
695725
"response",
696726
"summarization_quality/score",
697727
"summarization_quality/explanation",
728+
"source",
698729
]
699730
)
700731
assert list(
@@ -707,6 +738,48 @@ def test_compute_pointwise_metrics_without_model_inference(self, api_transport):
707738
"explanation",
708739
]
709740

741+
@pytest.mark.parametrize("api_transport", ["grpc", "rest"])
742+
def test_compute_model_based_translation_metrics_without_model_inference(
743+
self, api_transport
744+
):
745+
aiplatform.init(
746+
project=_TEST_PROJECT,
747+
location=_TEST_LOCATION,
748+
api_transport=api_transport,
749+
)
750+
test_metrics = [_TEST_COMET, _TEST_METRICX]
751+
test_eval_task = EvalTask(
752+
dataset=_TEST_EVAL_DATASET_ALL_INCLUDED, metrics=test_metrics
753+
)
754+
755+
mock_metric_results = _MOCK_MODEL_BASED_TRANSLATION_RESULT
756+
with mock.patch.object(
757+
target=gapic_evaluation_services.EvaluationServiceClient,
758+
attribute="evaluate_instances",
759+
side_effect=mock_metric_results,
760+
):
761+
test_result = test_eval_task.evaluate()
762+
763+
assert test_result.summary_metrics["row_count"] == 2
764+
assert test_result.summary_metrics["comet/mean"] == 0.5
765+
assert test_result.summary_metrics["metricx/mean"] == 12.5
766+
assert test_result.summary_metrics["comet/std"] == pytest.approx(0.5, 0.6)
767+
assert test_result.summary_metrics["metricx/std"] == pytest.approx(10, 11)
768+
assert set(test_result.metrics_table.columns.values) == set(
769+
[
770+
"context",
771+
"instruction",
772+
"reference",
773+
"prompt",
774+
"response",
775+
"source",
776+
"comet/score",
777+
"metricx/score",
778+
]
779+
)
780+
assert list(test_result.metrics_table["comet/score"].values) == [0.1, 0.9]
781+
assert list(test_result.metrics_table["metricx/score"].values) == [5, 20]
782+
710783
@pytest.mark.parametrize("api_transport", ["grpc", "rest"])
711784
def test_compute_automatic_metrics_with_custom_metric_spec(self, api_transport):
712785
aiplatform.init(
@@ -940,6 +1013,7 @@ def test_compute_pairwise_metrics_without_model_inference(self, api_transport):
9401013
"instruction",
9411014
"pairwise_summarization_quality/pairwise_choice",
9421015
"pairwise_summarization_quality/explanation",
1016+
"source",
9431017
]
9441018
)
9451019
assert list(
@@ -1281,7 +1355,7 @@ def test_evaluate_response_column_and_model_not_provided(self):
12811355
):
12821356
test_eval_task.evaluate()
12831357

1284-
def test_evaluate_baseline_response_column_and_baseline_model_not_provided(
1358+
def test_evaluate_baseline_model_response_column_not_provided(
12851359
self,
12861360
):
12871361
test_eval_dataset = _TEST_EVAL_DATASET_SINGLE.copy(deep=True)
@@ -1302,6 +1376,63 @@ def test_evaluate_baseline_response_column_and_baseline_model_not_provided(
13021376
):
13031377
test_eval_task.evaluate()
13041378

1379+
def test_evaluate_response_column_not_provided(
1380+
self,
1381+
):
1382+
test_eval_dataset = _TEST_EVAL_DATASET_SINGLE
1383+
test_eval_task = EvalTask(
1384+
dataset=test_eval_dataset,
1385+
metrics=["exact_match"],
1386+
)
1387+
with pytest.raises(
1388+
KeyError,
1389+
match=re.escape(
1390+
(
1391+
"Required column `response` not found in the evaluation "
1392+
"dataset. The columns in the evaluation dataset are ['prompt']"
1393+
)
1394+
),
1395+
):
1396+
test_eval_task.evaluate()
1397+
1398+
def test_evaluate_reference_column_not_provided(
1399+
self,
1400+
):
1401+
test_eval_dataset = pd.DataFrame({"response": ["test", "text"]})
1402+
test_eval_task = EvalTask(
1403+
dataset=test_eval_dataset,
1404+
metrics=["exact_match"],
1405+
)
1406+
with pytest.raises(
1407+
KeyError,
1408+
match=re.escape(
1409+
(
1410+
"Required column `reference` not found in the evaluation "
1411+
"dataset. The columns in the evaluation dataset are ['response']"
1412+
)
1413+
),
1414+
):
1415+
test_eval_task.evaluate()
1416+
1417+
def test_evaluate_reference_or_source_column_not_provided(
1418+
self,
1419+
):
1420+
test_eval_dataset = pd.DataFrame({"response": ["test", "text"]})
1421+
test_eval_task = EvalTask(
1422+
dataset=test_eval_dataset,
1423+
metrics=[_TEST_COMET, _TEST_METRICX],
1424+
)
1425+
with pytest.raises(
1426+
KeyError,
1427+
match=re.escape(
1428+
(
1429+
"Required column `source` not found in the evaluation "
1430+
"dataset. The columns in the evaluation dataset are ['response']"
1431+
)
1432+
),
1433+
):
1434+
test_eval_task.evaluate()
1435+
13051436
def test_evaluate_invalid_prompt_template_variables(self):
13061437
test_eval_task = EvalTask(
13071438
dataset=_TEST_EVAL_DATASET_SINGLE,
@@ -1530,13 +1661,7 @@ def test_initialize_metric_column_mapping(self):
15301661
metric_column_mapping=metric_column_mapping,
15311662
dataset=_TEST_EVAL_DATASET_ALL_INCLUDED,
15321663
)
1533-
assert converted_metric_column_mapping == {
1534-
"prompt": "prompt",
1535-
"response": "response",
1536-
"reference": "reference",
1537-
"context": "context",
1538-
"instruction": "instruction",
1539-
}
1664+
assert converted_metric_column_mapping == _EXPECTED_COLUMN_MAPPING
15401665

15411666

15421667
class TestPromptTemplate:

vertexai/evaluation/_evaluation.py

+71-22
Original file line numberDiff line numberDiff line change
@@ -124,33 +124,73 @@ def _validate_metric_column_map(
124124
)
125125

126126

127-
def _validate_dataset_for_automatic_metrics(
127+
def _validate_dataset(
128128
evaluation_run_config: evaluation_base.EvaluationRunConfig,
129-
):
130-
"""Validates the required columns exist in the dataset for automatic metrics."""
129+
) -> None:
130+
"""Validates the required columns exists in the dataset."""
131+
_validate_response_column_required(evaluation_run_config)
132+
_validate_reference_column_required(evaluation_run_config)
133+
_validate_reference_or_source_column_required(evaluation_run_config)
134+
135+
136+
def _validate_response_column_required(
137+
evaluation_run_config: evaluation_base.EvaluationRunConfig,
138+
) -> None:
139+
"""Validates the response column exists in the dataset."""
140+
for metric in evaluation_run_config.metrics:
141+
if metric in constants.Metric.AUTOMATIC_METRIC_LIST or isinstance(
142+
metric, metrics_base._TranslationMetric # pylint: disable=protected-access
143+
):
144+
_validate_column_provided(
145+
evaluation_run_config,
146+
constants.Dataset.MODEL_RESPONSE_COLUMN,
147+
)
148+
149+
150+
def _validate_reference_column_required(
151+
evaluation_run_config: evaluation_base.EvaluationRunConfig,
152+
) -> None:
153+
"""Validates the reference column exists in the dataset."""
131154
if set(evaluation_run_config.metrics).intersection(
132155
set(constants.Metric.AUTOMATIC_METRIC_LIST)
133156
):
134-
if (
135-
constants.Dataset.REFERENCE_COLUMN
136-
not in evaluation_run_config.metric_column_mapping
137-
):
138-
evaluation_run_config.metric_column_mapping[
139-
constants.Dataset.REFERENCE_COLUMN
140-
] = constants.Dataset.REFERENCE_COLUMN
141-
evaluation_run_config.validate_dataset_column(
142-
constants.Dataset.REFERENCE_COLUMN
157+
_validate_column_provided(
158+
evaluation_run_config,
159+
constants.Dataset.REFERENCE_COLUMN,
143160
)
144-
if (
145-
constants.Dataset.MODEL_RESPONSE_COLUMN
146-
not in evaluation_run_config.metric_column_mapping
161+
162+
163+
def _validate_column_provided(
164+
evaluation_run_config: evaluation_base.EvaluationRunConfig,
165+
column_name: str,
166+
) -> None:
167+
"""Validates the required column exist in the dataset."""
168+
if column_name not in evaluation_run_config.metric_column_mapping:
169+
evaluation_run_config.metric_column_mapping[column_name] = column_name
170+
evaluation_run_config.validate_dataset_column(column_name)
171+
172+
173+
def _validate_reference_or_source_column_required(
174+
evaluation_run_config: evaluation_base.EvaluationRunConfig,
175+
) -> None:
176+
"""Validates one of reference or source columns exist in the dataset."""
177+
for metric in evaluation_run_config.metrics:
178+
if isinstance(
179+
metric, metrics_base._TranslationMetric # pylint: disable=protected-access
147180
):
148-
evaluation_run_config.metric_column_mapping[
149-
constants.Dataset.MODEL_RESPONSE_COLUMN
150-
] = constants.Dataset.MODEL_RESPONSE_COLUMN
151-
evaluation_run_config.validate_dataset_column(
152-
constants.Dataset.MODEL_RESPONSE_COLUMN
153-
)
181+
# Validate the reference column.
182+
# This is optional if source column is provided.
183+
try:
184+
_validate_column_provided(
185+
evaluation_run_config,
186+
constants.Dataset.REFERENCE_COLUMN,
187+
)
188+
except KeyError:
189+
# Reference column is optional. Checking for source column.
190+
_validate_column_provided(
191+
evaluation_run_config,
192+
constants.Dataset.SOURCE_COLUMN,
193+
)
154194

155195

156196
def _compute_custom_metrics(
@@ -639,6 +679,15 @@ def _parse_metric_results_to_dataframe(
639679
metrics_table,
640680
constants.MetricResult.SCORE_KEY,
641681
)
682+
elif isinstance(
683+
metric, metrics_base._TranslationMetric # pylint: disable=protected-access
684+
):
685+
_set_metric_table(
686+
str(metric),
687+
metric_results,
688+
metrics_table,
689+
constants.MetricResult.SCORE_KEY,
690+
)
642691
else:
643692
_LOGGER.warning(
644693
f"Metric name: {str(metric)} is not supported when parsing"
@@ -889,7 +938,7 @@ def evaluate(
889938
evaluation_run_config=evaluation_run_config,
890939
response_column_name=constants.Dataset.MODEL_RESPONSE_COLUMN,
891940
)
892-
_validate_dataset_for_automatic_metrics(evaluation_run_config)
941+
_validate_dataset(evaluation_run_config)
893942

894943
pairwise_metric_exists = any(
895944
isinstance(metric, pairwise_metric.PairwiseMetric)

0 commit comments

Comments
 (0)