Skip to content

Commit 3e7bf81

Browse files
jsondaicopybara-github
authored andcommitted
feat: support customizing bring-your-own-response eval use case to use any columns
PiperOrigin-RevId: 686559721
1 parent 7246497 commit 3e7bf81

File tree

3 files changed

+69
-58
lines changed

3 files changed

+69
-58
lines changed

tests/unit/vertexai/test_evaluation.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -1222,11 +1222,12 @@ def test_evaluate_response_column_and_model_not_provided(self):
12221222
metrics=[_TEST_POINTWISE_METRIC],
12231223
)
12241224
with pytest.raises(
1225-
KeyError,
1225+
ValueError,
12261226
match=re.escape(
12271227
(
1228-
"Required column `response` not found in the evaluation dataset."
1229-
" The columns in the evaluation dataset are ['prompt']."
1228+
"Cannot find the `response` column in the evaluation dataset"
1229+
" to fill the metric prompt template for"
1230+
" `test_pointwise_metric` metric."
12301231
)
12311232
),
12321233
):
@@ -1242,12 +1243,12 @@ def test_evaluate_baseline_response_column_and_baseline_model_not_provided(
12421243
metrics=[_TEST_PAIRWISE_METRIC],
12431244
)
12441245
with pytest.raises(
1245-
KeyError,
1246+
ValueError,
12461247
match=re.escape(
12471248
(
1248-
"Required column `baseline_model_response` not found in the"
1249-
" evaluation dataset. The columns in the evaluation dataset are"
1250-
" ['prompt', 'response']."
1249+
"Cannot find the `baseline_model_response` column in the"
1250+
" evaluation dataset to fill the metric prompt template for"
1251+
" `test_pairwise_metric` metric."
12511252
)
12521253
),
12531254
):

vertexai/evaluation/_evaluation.py

+35-32
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,8 @@ def _validate_metric_column_map(
103103
"""Validates the column map for metric prompt template usage."""
104104
for metric in evaluation_run_config.metrics:
105105
if isinstance(
106-
metric, metrics_base._ModelBasedMetric
107-
): # pylint: disable=protected-access
106+
metric, metrics_base._ModelBasedMetric # pylint: disable=protected-access
107+
):
108108
for variable in prompt_template_base.PromptTemplate(
109109
metric.metric_prompt_template
110110
).variables:
@@ -124,6 +124,35 @@ def _validate_metric_column_map(
124124
)
125125

126126

127+
def _validate_dataset_for_automatic_metrics(
128+
evaluation_run_config: evaluation_base.EvaluationRunConfig,
129+
):
130+
"""Validates the required columns exist in the dataset for automatic metrics."""
131+
if set(evaluation_run_config.metrics).intersection(
132+
set(constants.Metric.AUTOMATIC_METRIC_LIST)
133+
):
134+
if (
135+
constants.Dataset.REFERENCE_COLUMN
136+
not in evaluation_run_config.metric_column_mapping
137+
):
138+
evaluation_run_config.metric_column_mapping[
139+
constants.Dataset.REFERENCE_COLUMN
140+
] = constants.Dataset.REFERENCE_COLUMN
141+
evaluation_run_config.validate_dataset_column(
142+
constants.Dataset.REFERENCE_COLUMN
143+
)
144+
if (
145+
constants.Dataset.MODEL_RESPONSE_COLUMN
146+
not in evaluation_run_config.metric_column_mapping
147+
):
148+
evaluation_run_config.metric_column_mapping[
149+
constants.Dataset.MODEL_RESPONSE_COLUMN
150+
] = constants.Dataset.MODEL_RESPONSE_COLUMN
151+
evaluation_run_config.validate_dataset_column(
152+
constants.Dataset.MODEL_RESPONSE_COLUMN
153+
)
154+
155+
127156
def _compute_custom_metrics(
128157
row_dict: Dict[str, Any],
129158
custom_metrics: List[metrics_base.CustomMetric],
@@ -392,8 +421,8 @@ def _run_model_inference(
392421
is_baseline_model = (
393422
response_column_name == constants.Dataset.BASELINE_MODEL_RESPONSE_COLUMN
394423
)
395-
if response_column_name not in evaluation_run_config.metric_column_mapping:
396-
if model:
424+
if model:
425+
if response_column_name not in evaluation_run_config.metric_column_mapping:
397426
if constants.Dataset.PROMPT_COLUMN in evaluation_run_config.dataset.columns:
398427
t1 = time.perf_counter()
399428
if isinstance(model, generative_models.GenerativeModel):
@@ -423,8 +452,7 @@ def _run_model_inference(
423452
" the model. Mappings in `metric_column_mapping` do not"
424453
" apply for model inference and are used for evaluation only."
425454
)
426-
else:
427-
if model:
455+
else:
428456
raise ValueError(
429457
"The `model` parameter or `baseline_model` in pairwise metric is"
430458
" specified, but the evaluation `dataset` contains model response"
@@ -840,20 +868,6 @@ def evaluate(
840868
retry_timeout=retry_timeout,
841869
)
842870

843-
if set(evaluation_run_config.metrics).intersection(
844-
set(constants.Metric.AUTOMATIC_METRIC_LIST)
845-
):
846-
if (
847-
constants.Dataset.REFERENCE_COLUMN
848-
not in evaluation_run_config.metric_column_mapping
849-
):
850-
evaluation_run_config.metric_column_mapping[
851-
constants.Dataset.REFERENCE_COLUMN
852-
] = constants.Dataset.REFERENCE_COLUMN
853-
evaluation_run_config.validate_dataset_column(
854-
constants.Dataset.REFERENCE_COLUMN
855-
)
856-
857871
if prompt_template:
858872
_assemble_prompt_for_dataset(evaluation_run_config, prompt_template)
859873

@@ -862,12 +876,7 @@ def evaluate(
862876
evaluation_run_config=evaluation_run_config,
863877
response_column_name=constants.Dataset.MODEL_RESPONSE_COLUMN,
864878
)
865-
evaluation_run_config.validate_dataset_column(
866-
metric_column_mapping.get(
867-
constants.Dataset.MODEL_RESPONSE_COLUMN,
868-
constants.Dataset.MODEL_RESPONSE_COLUMN,
869-
)
870-
)
879+
_validate_dataset_for_automatic_metrics(evaluation_run_config)
871880

872881
pairwise_metric_exists = any(
873882
isinstance(metric, pairwise_metric.PairwiseMetric)
@@ -880,12 +889,6 @@ def evaluate(
880889
evaluation_run_config=evaluation_run_config,
881890
response_column_name=constants.Dataset.BASELINE_MODEL_RESPONSE_COLUMN,
882891
)
883-
evaluation_run_config.validate_dataset_column(
884-
metric_column_mapping.get(
885-
constants.Dataset.BASELINE_MODEL_RESPONSE_COLUMN,
886-
constants.Dataset.BASELINE_MODEL_RESPONSE_COLUMN,
887-
)
888-
)
889892

890893
_validate_metric_column_map(evaluation_run_config)
891894
t1 = time.perf_counter()

vertexai/evaluation/eval_task.py

+26-19
Original file line numberDiff line numberDiff line change
@@ -71,18 +71,24 @@ class EvalTask:
7171
* baseline_model_response_column_name: "baseline_model_response"
7272
7373
Requirement for different use cases:
74-
* Bring-your-own-response: A `response` column is required. Response
75-
column name can be customized by providing `response_column_name`
76-
parameter. If a pairwise metric is used and a baseline model is
77-
not provided, a `baseline_model_response` column is required.
78-
Baseline model response column name can be customized by providing
79-
`baseline_model_response_column_name` parameter. If the `response`
80-
column or `baseline_model_response` column is present while the
74+
* Bring-your-own-response (BYOR): You already have the data that you
75+
want to evaluate stored in the dataset. Response column name can be
76+
customized by providing `response_column_name` parameter, or in the
77+
`metric_column_mapping`. For BYOR pairwise evaluation, the baseline
78+
model response column name can be customized by providing
79+
`baseline_model_response_column_name` parameter, or
80+
in the `metric_column_mapping`. If the `response` column or
81+
`baseline_model_response` column is present while the
8182
corresponding model is specified, an error will be raised.
82-
* Perform model inference without a prompt template: A `prompt` column
83-
in the evaluation dataset representing the input prompt to the
84-
model is required and is used directly as input to the model.
85-
* Perform model inference with a prompt template: Evaluation dataset
83+
84+
* Perform model inference without a prompt template: You have a dataset
85+
containing the input prompts to the model and want to perform model
86+
inference before evaluation. A column named `prompt` is required
87+
in the evaluation dataset and is used directly as input to the model.
88+
89+
* Perform model inference with a prompt template: You have a dataset
90+
containing the input variables to the prompt template and want to
91+
assemble the prompts for model inference. Evaluation dataset
8692
must contain column names corresponding to the variable names in
8793
the prompt template. For example, if prompt template is
8894
"Instruction: {instruction}, context: {context}", the dataset must
@@ -371,18 +377,19 @@ def evaluate(
371377
372378
Args:
373379
model: A GenerativeModel instance or a custom model function to generate
374-
responses to evaluate. If not provided, the evaluation is computed with
375-
the `response` column in the `dataset`.
380+
responses to evaluate. If not provided, the evaluation can be performed
381+
in the bring-your-own-response (BYOR) mode.
376382
prompt_template: The prompt template to use for the evaluation. If not
377383
set, the prompt template that was used to create the EvalTask will be
378384
used.
379385
experiment_run_name: The name of the experiment run to log the evaluation
380386
to if an experiment is set for this EvalTask. If not provided, a random
381387
unique experiment run name is used.
382388
response_column_name: The column name of model response in the dataset. If
383-
provided, this will override the `response_column_name` of the `EvalTask`.
389+
provided, this will override the `metric_column_mapping` of the `EvalTask`.
384390
baseline_model_response_column_name: The column name of baseline model
385-
response in the dataset for pairwise metrics.
391+
response in the dataset for pairwise metrics. If provided, this will
392+
override the `metric_column_mapping` of the `EvalTask`
386393
evaluation_service_qps: The custom QPS limit for the evaluation service.
387394
retry_timeout: How long to keep retrying the evaluation requests for
388395
the whole evaluation dataset, in seconds.
@@ -400,11 +407,11 @@ def evaluate(
400407
"`vertexai.init(experiment='experiment_name')`for logging this"
401408
" evaluation run."
402409
)
403-
self._verify_response_column_name(
410+
self._verify_and_set_response_column_name(
404411
response_column_name=response_column_name,
405412
metric_column_mapping_key=constants.Dataset.MODEL_RESPONSE_COLUMN,
406413
)
407-
self._verify_response_column_name(
414+
self._verify_and_set_response_column_name(
408415
response_column_name=baseline_model_response_column_name,
409416
metric_column_mapping_key=constants.Dataset.BASELINE_MODEL_RESPONSE_COLUMN,
410417
)
@@ -503,10 +510,10 @@ def _log_eval_experiment_param(
503510
except (ValueError, TypeError) as e:
504511
_LOGGER.warning(f"Experiment metadata logging failed: {str(e)}")
505512

506-
def _verify_response_column_name(
513+
def _verify_and_set_response_column_name(
507514
self, response_column_name: str, metric_column_mapping_key: str
508515
) -> None:
509-
"""Verifies if model response column name or baseline model response column name is valid."""
516+
"""Verifies and sets the model response column names."""
510517
if response_column_name:
511518
if response_column_name in self._dataset.columns:
512519
self._metric_column_mapping[

0 commit comments

Comments
 (0)