Skip to content

Commit acf8a80

Browse files
jsondaicopybara-github
authored andcommitted
feat: Add notebook helper functions to eval SDK to display and visualize evaluation results in an IPython environment
PiperOrigin-RevId: 732297732
1 parent 4c8c277 commit acf8a80

File tree

1 file changed

+251
-0
lines changed

1 file changed

+251
-0
lines changed

vertexai/evaluation/notebook_utils.py

+251
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2025 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
"""Python functions which run only within a Jupyter or Colab notebook."""
18+
19+
import random
20+
import string
21+
import sys
22+
from typing import List, Optional, Tuple
23+
24+
from vertexai.preview.evaluation import _base
25+
from vertexai.preview.evaluation import constants
26+
27+
try:
28+
import pandas as pd
29+
except ImportError:
30+
pandas = None
31+
32+
_MARKDOWN_H2 = "##"
33+
_MARKDOWN_H3 = "###"
34+
_DEFAULT_COLUMNS_TO_DISPLAY = [
35+
constants.Dataset.MODEL_RESPONSE_COLUMN,
36+
constants.Dataset.BASELINE_MODEL_RESPONSE_COLUMN,
37+
constants.Dataset.PROMPT_COLUMN,
38+
constants.MetricResult.ROW_COUNT_KEY,
39+
]
40+
_DEFAULT_RADAR_RANGE = (0, 5)
41+
42+
43+
def _get_ipython_shell_name() -> str:
44+
if "IPython" in sys.modules:
45+
from IPython import get_ipython
46+
47+
return get_ipython().__class__.__name__
48+
return ""
49+
50+
51+
def is_ipython_available() -> bool:
52+
return _get_ipython_shell_name()
53+
54+
55+
def _filter_df(
56+
df: pd.DataFrame, substrings: Optional[List[str]] = None
57+
) -> pd.DataFrame:
58+
"""Filters a DataFrame to include only columns containing the given substrings."""
59+
if substrings is None:
60+
return df
61+
62+
return df.copy().filter(
63+
[
64+
column_name
65+
for column_name in df.columns
66+
if any(substring in column_name for substring in substrings)
67+
]
68+
)
69+
70+
71+
def display_eval_result(
72+
*,
73+
eval_result: _base.EvalResult,
74+
title: Optional[str] = None,
75+
metrics: Optional[List[str]] = None,
76+
) -> None:
77+
"""Displays evaluation results in a notebook using IPython.display.
78+
79+
Args:
80+
eval_result: An object containing evaluation results with
81+
`summary_metrics` and `metrics_table` attributes.
82+
title: A string title to display above the results.
83+
metrics: A list of metric name substrings to filter displayed columns. If
84+
provided, only metrics whose names contain any of these strings will be
85+
displayed.
86+
"""
87+
if not is_ipython_available():
88+
return
89+
# pylint: disable=g-import-not-at-top, g-importing-member
90+
from IPython.display import display
91+
from IPython.display import Markdown
92+
93+
summary_metrics, metrics_table = (
94+
eval_result.summary_metrics,
95+
eval_result.metrics_table,
96+
)
97+
98+
summary_metrics_df = pd.DataFrame.from_dict(summary_metrics, orient="index").T
99+
100+
if metrics:
101+
columns_to_keep = metrics + _DEFAULT_COLUMNS_TO_DISPLAY
102+
summary_metrics_df = _filter_df(summary_metrics_df, columns_to_keep)
103+
metrics_table = _filter_df(metrics_table, columns_to_keep)
104+
105+
# Display the title in Markdown.
106+
if title:
107+
display(Markdown(f"{_MARKDOWN_H2} {title}"))
108+
109+
# Display the summary metrics.
110+
display(Markdown(f"{_MARKDOWN_H3} Summary Metrics"))
111+
display(summary_metrics_df)
112+
113+
# Display the metrics table.
114+
display(Markdown(f"{_MARKDOWN_H3} Row-based Metrics"))
115+
display(metrics_table)
116+
117+
118+
def display_explanations(
119+
*,
120+
eval_result: _base.EvalResult,
121+
num: int = 1,
122+
metrics: Optional[List[str]] = None,
123+
) -> None:
124+
"""Displays the explanations in a notebook using IPython.display.
125+
126+
Args:
127+
eval_result: An object containing evaluation results. It is expected to
128+
have attributes `summary_metrics` and `metrics_table`.
129+
num: The number of row samples to display. Defaults to 1. If the number of
130+
rows is less than `num`, all rows will be displayed.
131+
metrics: A list of metric name substrings to filter displayed columns. If
132+
provided, only metrics whose names contain any of these strings will be
133+
displayed.
134+
"""
135+
if not is_ipython_available():
136+
return
137+
# pylint: disable=g-import-not-at-top, g-importing-member
138+
from IPython.display import display
139+
from IPython.display import HTML
140+
141+
style = "white-space: pre-wrap; width: 1500px; overflow-x: auto;"
142+
metrics_table = eval_result.metrics_table
143+
144+
if num < 1:
145+
raise ValueError("Num must be greater than 0.")
146+
num = min(num, len(metrics_table))
147+
148+
df = metrics_table.sample(n=num)
149+
150+
if metrics:
151+
columns_to_keep = metrics + _DEFAULT_COLUMNS_TO_DISPLAY
152+
df = _filter_df(df, columns_to_keep)
153+
154+
for _, row in df.iterrows():
155+
for col in df.columns:
156+
display(HTML(f"<div style='{style}'><h4>{col}:</h4>{row[col]}</div>"))
157+
display(HTML("<hr>"))
158+
159+
160+
def display_radar_plot(
161+
eval_results_with_title: List[Tuple[str, _base.EvalResult]],
162+
metrics: List[str],
163+
radar_range: Tuple[float, float] = _DEFAULT_RADAR_RANGE,
164+
) -> None:
165+
"""Plots a radar plot comparing evaluation results.
166+
167+
Args:
168+
eval_results_with_title: List of (title, eval_result) tuples.
169+
metrics: A list of metrics whose mean values will be plotted.
170+
radar_range: Range of the radar plot axes.
171+
"""
172+
# pylint: disable=g-import-not-at-top
173+
try:
174+
import plotly.graph_objects as go
175+
except ImportError as exc:
176+
raise ImportError(
177+
'`plotly` is not installed. Please install using "!pip install plotly"'
178+
) from exc
179+
180+
fig = go.Figure()
181+
for title, eval_result in eval_results_with_title:
182+
summary_metrics = eval_result.summary_metrics
183+
if metrics:
184+
summary_metrics = {
185+
key.replace("/mean", ""): summary_metrics[key]
186+
for key in summary_metrics
187+
if any(selected_metric + "/mean" in key for selected_metric in metrics)
188+
}
189+
fig.add_trace(
190+
go.Scatterpolar(
191+
r=list(summary_metrics.values()),
192+
theta=list(summary_metrics.keys()),
193+
fill="toself",
194+
name=title,
195+
)
196+
)
197+
fig.update_layout(
198+
polar=dict(radialaxis=dict(visible=True, range=radar_range)),
199+
showlegend=True,
200+
)
201+
fig.show()
202+
203+
204+
def display_bar_plot(
205+
eval_results_with_title: List[Tuple[str, _base.EvalResult]],
206+
metrics: List[str],
207+
) -> None:
208+
"""Plots a bar plot comparing evaluation results.
209+
210+
Args:
211+
eval_results_with_title: List of (title, eval_result) tuples.
212+
metrics: A list of metrics whose mean values will be plotted.
213+
"""
214+
215+
# pylint: disable=g-import-not-at-top
216+
try:
217+
import plotly.graph_objects as go
218+
except ImportError as exc:
219+
raise ImportError(
220+
'`plotly` is not installed. Please install using "!pip install plotly"'
221+
) from exc
222+
223+
data = []
224+
225+
for title, eval_result in eval_results_with_title:
226+
summary_metrics = eval_result.summary_metrics
227+
mean_summary_metrics = [f"{metric}/mean" for metric in metrics]
228+
updated_summary_metrics = []
229+
if metrics:
230+
for k, v in summary_metrics.items():
231+
if k in mean_summary_metrics:
232+
updated_summary_metrics.append((k, v))
233+
summary_metrics = dict(updated_summary_metrics)
234+
235+
data.append(
236+
go.Bar(
237+
x=list(summary_metrics.keys()),
238+
y=list(summary_metrics.values()),
239+
name=title,
240+
)
241+
)
242+
243+
fig = go.Figure(data=data)
244+
245+
fig.update_layout(barmode="group", showlegend=True)
246+
fig.show()
247+
248+
249+
def generate_uuid(length: int = 8) -> str:
250+
"""Generates a uuid of a specified length (default=8)."""
251+
return "".join(random.choices(string.ascii_lowercase + string.digits, k=length))

0 commit comments

Comments
 (0)