Skip to content

Commit 536f1d5

Browse files
jsondaicopybara-github
authored andcommitted
feat: Add _ModelBasedMetric base class to vertexai.preview.evaluation.metrics and allow metric spec customization
PiperOrigin-RevId: 644440372
1 parent a0ad286 commit 536f1d5

File tree

5 files changed

+161
-54
lines changed

5 files changed

+161
-54
lines changed

tests/unit/vertexai/test_evaluation.py

+19-8
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
from vertexai.preview.evaluation import _base as eval_base
3232
from vertexai.preview.evaluation import _evaluation
3333
from vertexai.preview.evaluation import utils
34+
from vertexai.preview.evaluation.metrics import (
35+
_pairwise_summarization_quality,
36+
)
3437
import numpy as np
3538
import pandas as pd
3639
import pytest
@@ -111,6 +114,18 @@
111114
)
112115
),
113116
)
117+
_MOCK_SUMMARIZATION_QUALITY_RESULT = (
118+
gapic_evaluation_service_types.EvaluateInstancesResponse(
119+
summarization_quality_result=gapic_evaluation_service_types.SummarizationQualityResult(
120+
score=5, explanation="explanation", confidence=1.0
121+
)
122+
),
123+
gapic_evaluation_service_types.EvaluateInstancesResponse(
124+
summarization_quality_result=gapic_evaluation_service_types.SummarizationQualityResult(
125+
score=4, explanation="explanation", confidence=0.5
126+
)
127+
),
128+
)
114129

115130
_MOCK_PAIRWISE_SUMMARIZATION_QUALITY_RESULT = (
116131
gapic_evaluation_service_types.EvaluateInstancesResponse(
@@ -331,8 +346,7 @@ def test_compute_pairwise_metrics_with_model_inference(self, api_transport):
331346
)
332347
mock_candidate_model._model_name = "publishers/google/model/gemini-pro"
333348
test_metrics = [
334-
evaluation.PairwiseMetric(
335-
metric="pairwise_summarization_quality",
349+
_pairwise_summarization_quality.PairwiseSummarizationQuality(
336350
baseline_model=mock_baseline_model,
337351
use_reference=False,
338352
)
@@ -418,8 +432,7 @@ def test_compute_pairwise_metrics_without_inference(self, api_transport):
418432
}
419433
)
420434
test_metrics = [
421-
evaluation.PairwiseMetric(
422-
metric="summarization_quality",
435+
_pairwise_summarization_quality.PairwiseSummarizationQuality(
423436
baseline_model=None,
424437
use_reference=True,
425438
)
@@ -608,12 +621,10 @@ def test_evaluate_pairwise_metrics_with_multiple_baseline_models(self):
608621
)
609622
mock_candidate_model._model_name = "publishers/google/model/gemini-1.0-ultra"
610623
test_metrics = [
611-
evaluation.PairwiseMetric(
612-
metric="pairwise_summarization_quality",
624+
_pairwise_summarization_quality.PairwiseSummarizationQuality(
613625
baseline_model=mock_baseline_model_1,
614626
),
615-
evaluation.PairwiseMetric(
616-
metric="pairwise_summarization_quality",
627+
_pairwise_summarization_quality.PairwiseSummarizationQuality(
617628
baseline_model=mock_baseline_model_2,
618629
),
619630
]

vertexai/preview/evaluation/_eval_tasks.py

+1
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def __init__(
210210
],
211211
metrics_base.CustomMetric,
212212
metrics_base.PairwiseMetric,
213+
metrics_base._ModelBasedMetric,
213214
]
214215
],
215216
experiment: Optional[str] = None,

vertexai/preview/evaluation/_evaluation.py

+45-25
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,22 @@
9191

9292

9393
def _replace_metric_bundle_with_metrics(
94-
metrics: List[Union[str, metrics_base.CustomMetric, metrics_base.PairwiseMetric]],
95-
) -> List[Union[str, metrics_base.CustomMetric, metrics_base.PairwiseMetric]]:
94+
metrics: List[
95+
Union[
96+
str,
97+
metrics_base.CustomMetric,
98+
metrics_base.PairwiseMetric,
99+
metrics_base._ModelBasedMetric,
100+
]
101+
],
102+
) -> List[
103+
Union[
104+
str,
105+
metrics_base.CustomMetric,
106+
metrics_base.PairwiseMetric,
107+
metrics_base._ModelBasedMetric,
108+
]
109+
]:
96110
"""Replaces metric bundles with corresponding metrics.
97111
98112
Args:
@@ -147,9 +161,17 @@ def _compute_custom_metrics(
147161

148162

149163
def _separate_custom_metrics(
150-
metrics: List[Union[str, metrics_base.CustomMetric, metrics_base.PairwiseMetric]],
164+
metrics: List[
165+
Union[
166+
str,
167+
metrics_base.CustomMetric,
168+
metrics_base.PairwiseMetric,
169+
metrics_base._ModelBasedMetric,
170+
]
171+
],
151172
) -> Tuple[
152-
List[Union[str, metrics_base.PairwiseMetric]], List[metrics_base.CustomMetric]
173+
List[Union[str, metrics_base.PairwiseMetric, metrics_base._ModelBasedMetric]],
174+
List[metrics_base.CustomMetric],
153175
]:
154176
"""Separates the metrics list into API and custom metrics."""
155177
custom_metrics = []
@@ -180,17 +202,12 @@ def _compute_summary_metrics(
180202
for metric in evaluation_run_config.metrics:
181203
try:
182204
if isinstance(metric, metrics_base.PairwiseMetric):
183-
summary_metrics[
184-
f"{metric.pairwise_metric_name}/candidate_model_win_rate"
185-
] = (
186-
metrics_table[f"{metric.pairwise_metric_name}/pairwise_choice"]
205+
summary_metrics[f"{metric.metric_name}/candidate_model_win_rate"] = (
206+
metrics_table[f"{metric.metric_name}/pairwise_choice"]
187207
== "CANDIDATE"
188208
).mean()
189-
summary_metrics[
190-
f"{metric.pairwise_metric_name}/baseline_model_win_rate"
191-
] = (
192-
metrics_table[f"{metric.pairwise_metric_name}/pairwise_choice"]
193-
== "BASELINE"
209+
summary_metrics[f"{metric.metric_name}/baseline_model_win_rate"] = (
210+
metrics_table[f"{metric.metric_name}/pairwise_choice"] == "BASELINE"
194211
).mean()
195212
else:
196213
# TODO(b/325078638): implement additional aggregate methods.
@@ -303,11 +320,11 @@ def _generate_response_from_gemini_model(
303320
model=model,
304321
)
305322
)
306-
respones = [task.result() for task in tasks]
323+
responses = [task.result() for task in tasks]
307324
if is_baseline_model:
308-
evaluation_run_config.dataset = df.assign(baseline_model_response=respones)
325+
evaluation_run_config.dataset = df.assign(baseline_model_response=responses)
309326
else:
310-
evaluation_run_config.dataset = df.assign(response=respones)
327+
evaluation_run_config.dataset = df.assign(response=responses)
311328

312329
_LOGGER.info(
313330
f"All {evaluation_run_config.dataset.shape[0]} responses are successfully"
@@ -358,11 +375,11 @@ def _generate_response_from_custom_model_fn(
358375
except (ValueError, IndexError) as e:
359376
_LOGGER.warning(f"Failed to generate response from model function: {e}")
360377

361-
respones = [task.result() for task in tasks]
378+
responses = [task.result() for task in tasks]
362379
if is_baseline_model:
363-
evaluation_run_config.dataset = df.assign(baseline_model_response=respones)
380+
evaluation_run_config.dataset = df.assign(baseline_model_response=responses)
364381
else:
365-
evaluation_run_config.dataset = df.assign(response=respones)
382+
evaluation_run_config.dataset = df.assign(response=responses)
366383

367384
_LOGGER.info(
368385
f"All {evaluation_run_config.dataset.shape[0]} responses are successfully"
@@ -582,11 +599,7 @@ async def _compute_metrics(
582599
retry_timeout=evaluation_run_config.retry_timeout,
583600
)
584601
)
585-
if isinstance(metric, metrics_base.PairwiseMetric):
586-
metric_name = metric.pairwise_metric_name
587-
else:
588-
metric_name = metric
589-
tasks_by_metric[metric_name].append(task)
602+
tasks_by_metric[str(metric)].append(task)
590603

591604
api_request_count = len(api_metrics) * len(evaluation_run_config.dataset)
592605
_LOGGER.info(
@@ -608,7 +621,14 @@ async def _compute_metrics(
608621

609622
def evaluate(
610623
dataset: "pd.DataFrame",
611-
metrics: List[Union[str, metrics_base.CustomMetric, metrics_base.PairwiseMetric]],
624+
metrics: List[
625+
Union[
626+
str,
627+
metrics_base.CustomMetric,
628+
metrics_base.PairwiseMetric,
629+
metrics_base._ModelBasedMetric,
630+
]
631+
],
612632
*,
613633
model: Optional[
614634
Union[generative_models.GenerativeModel, Callable[[str], str]]

vertexai/preview/evaluation/metrics/_base.py

+74-3
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,20 @@
1616
#
1717

1818
from typing import Any, Callable, Dict, Literal, Optional, Union
19+
import warnings
20+
1921
from vertexai import generative_models
2022
from vertexai.preview.evaluation import constants
2123

2224

25+
_DEPRECATION_WARNING_MESSAGE = (
26+
"After google-cloud-aiplatform>1.60.0, using metric name `summarization_quality`"
27+
"and `question_answering_quality` will result in an error. "
28+
"Please use `pairwise_summarization_quality` and "
29+
"`pairwise_question_answering_quality` instead."
30+
)
31+
32+
2333
class PairwiseMetric:
2434
"""The Side-by-side(SxS) Pairwise Metric.
2535
@@ -64,7 +74,7 @@ class PairwiseMetric:
6474
candidate_model = GenerativeModel("gemini-1.5-pro")
6575
6676
pairwise_summarization_quality = PairwiseMetric(
67-
metric = "summarization_quality",
77+
metric = "pairwise_summarization_quality",
6878
baseline_model=baseline_model,
6979
)
7080
@@ -109,16 +119,19 @@ def __init__(
109119
# TODO(b/311221071): Remove the legacy metric names for GA.
110120
if metric in ("summarization_quality", "question_answering_quality"):
111121
metric = f"pairwise_{metric}"
122+
warnings.warn(
123+
_DEPRECATION_WARNING_MESSAGE, DeprecationWarning, stacklevel=2
124+
)
112125
self._metric = metric
113126
self._baseline_model = baseline_model
114127
self._use_reference = use_reference
115128
self._version = version
116129

117130
def __str__(self):
118-
return self.pairwise_metric_name
131+
return self.metric_name
119132

120133
@property
121-
def pairwise_metric_name(self) -> str:
134+
def metric_name(self) -> str:
122135
return self._metric
123136

124137
@property
@@ -136,6 +149,64 @@ def version(self) -> int:
136149
return self._version
137150

138151

152+
class _ModelBasedMetric:
153+
"""The Model-based Metric.
154+
155+
A model-based evaluation metric that evaluate a generative model's response
156+
on the given evaluation task.
157+
158+
For more details on when to use model-based metrics, see
159+
[Evaluation methods and metrics](https://cloud.google.com/vertex-ai/generative-ai/docs/models/determine-eval).
160+
"""
161+
162+
def __init__(
163+
self,
164+
*,
165+
metric: Literal[
166+
constants.Metric.COHERENCE,
167+
constants.Metric.FLUENCY,
168+
constants.Metric.SAFETY,
169+
constants.Metric.GROUNDEDNESS,
170+
constants.Metric.FULFILLMENT,
171+
constants.Metric.SUMMARIZATION_QUALITY,
172+
constants.Metric.SUMMARIZATION_HELPFULNESS,
173+
constants.Metric.SUMMARIZATION_VERBOSITY,
174+
constants.Metric.QUESTION_ANSWERING_QUALITY,
175+
constants.Metric.QUESTION_ANSWERING_RELEVANCE,
176+
constants.Metric.QUESTION_ANSWERING_HELPFULNESS,
177+
constants.Metric.QUESTION_ANSWERING_CORRECTNESS,
178+
],
179+
use_reference: bool = False,
180+
version: Optional[int] = None,
181+
):
182+
"""Initializes the model-based evaluation metric.
183+
184+
Args:
185+
metric: The model-based evaluation metric name.
186+
use_reference: Whether to use reference to compute the metric. If
187+
specified, the reference column is required in the dataset.
188+
version: The metric version to use for evaluation.
189+
"""
190+
self._metric = metric
191+
self._use_reference = use_reference
192+
self._version = version
193+
194+
def __str__(self):
195+
return self.metric_name
196+
197+
@property
198+
def metric_name(self) -> str:
199+
return self._metric
200+
201+
@property
202+
def use_reference(self) -> bool:
203+
return self._use_reference
204+
205+
@property
206+
def version(self) -> int:
207+
return self._version
208+
209+
139210
class CustomMetric:
140211
"""The custom evaluation metric.
141212

0 commit comments

Comments
 (0)