Skip to content

Commit 1bdc235

Browse files
matthew29tangcopybara-github
authored andcommitted
feat: Add Prompt class for multimodal prompt templating
PiperOrigin-RevId: 661455642
1 parent ebbd1bf commit 1bdc235

File tree

2 files changed

+453
-0
lines changed

2 files changed

+453
-0
lines changed

tests/unit/vertexai/test_prompts.py

+152
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2024 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
"""Unit tests for generative model prompts."""
18+
# pylint: disable=protected-access,bad-continuation
19+
20+
from vertexai.generative_models._prompts import Prompt
21+
from vertexai.generative_models import Content, Part
22+
23+
import pytest
24+
25+
from typing import Any, List
26+
27+
28+
def is_list_of_type(obj: Any, T: Any) -> bool:
29+
return isinstance(obj, list) and all(isinstance(s, T) for s in obj)
30+
31+
32+
def assert_prompt_contents_equal(
33+
prompt_contents: List[Content],
34+
expected_prompt_contents: List[Content],
35+
) -> None:
36+
assert len(prompt_contents) == len(expected_prompt_contents)
37+
for i in range(len(prompt_contents)):
38+
assert prompt_contents[i].role == expected_prompt_contents[i].role
39+
assert len(prompt_contents[i].parts) == len(expected_prompt_contents[i].parts)
40+
for j in range(len(prompt_contents[i].parts)):
41+
assert (
42+
prompt_contents[i].parts[j]._raw_part.text
43+
== expected_prompt_contents[i].parts[j]._raw_part.text
44+
)
45+
46+
47+
@pytest.mark.usefixtures("google_auth_mock")
48+
class TestPrompt:
49+
"""Unit tests for generative model prompts."""
50+
51+
def test_string_prompt_constructor_string_variables(self):
52+
# Create string prompt with string only variable values
53+
prompt = Prompt(
54+
prompt_data="Rate the movie {movie1}",
55+
variables=[
56+
{
57+
"movie1": "The Avengers",
58+
}
59+
],
60+
)
61+
# String prompt data should remain as string before compilation
62+
assert prompt.prompt_data == "Rate the movie {movie1}"
63+
# Variables values should be converted to List[Part]
64+
assert is_list_of_type(prompt.variables[0]["movie1"], Part)
65+
66+
def test_string_prompt_constructor_part_variables(self):
67+
# Create string prompt with List[Part] variable values
68+
prompt = Prompt(
69+
prompt_data="Rate the movie {movie1}",
70+
variables=[
71+
{
72+
"movie1": [Part.from_text("The Avengers")],
73+
}
74+
],
75+
)
76+
# Variables values should be converted to List[Part]
77+
assert is_list_of_type(prompt.variables[0]["movie1"], Part)
78+
79+
def test_string_prompt_constructor_invalid_variables(self):
80+
# String prompt variables must be PartsType
81+
with pytest.raises(TypeError):
82+
Prompt(
83+
prompt_data="Rate the movie {movie1}",
84+
variables=[
85+
{
86+
"movie1": 12345,
87+
}
88+
],
89+
)
90+
91+
def test_string_prompt_assemble_contents(self):
92+
prompt = Prompt(
93+
prompt_data="Which movie is better, {movie1} or {movie2}?",
94+
variables=[
95+
{
96+
"movie1": "The Avengers",
97+
"movie2": "Frozen",
98+
}
99+
],
100+
)
101+
assembled_prompt_content = prompt.assemble_contents(**prompt.variables[0])
102+
expected_content = [
103+
Content(
104+
parts=[
105+
Part.from_text("Which movie is better, The Avengers or Frozen?"),
106+
],
107+
role="user",
108+
)
109+
]
110+
assert_prompt_contents_equal(assembled_prompt_content, expected_content)
111+
112+
def test_string_prompt_partial_assemble_contents(self):
113+
prompt = Prompt(
114+
prompt_data="Which movie is better, {movie1} or {movie2}?",
115+
variables=[
116+
{
117+
"movie1": "The Avengers",
118+
}
119+
],
120+
)
121+
122+
# Check partially assembled prompt content
123+
assembled1_prompt_content = prompt.assemble_contents(**prompt.variables[0])
124+
expected1_content = [
125+
Content(
126+
parts=[
127+
Part.from_text("Which movie is better, The Avengers or {movie2}?"),
128+
],
129+
role="user",
130+
)
131+
]
132+
assert_prompt_contents_equal(assembled1_prompt_content, expected1_content)
133+
134+
# Check fully assembled prompt
135+
assembled2_prompt_content = prompt.assemble_contents(
136+
movie1="Inception", movie2="Frozen"
137+
)
138+
expected2_content = [
139+
Content(
140+
parts=[
141+
Part.from_text("Which movie is better, Inception or Frozen?"),
142+
],
143+
role="user",
144+
)
145+
]
146+
assert_prompt_contents_equal(assembled2_prompt_content, expected2_content)
147+
148+
def test_string_prompt_assemble_unused_variables(self):
149+
# Variables must present in prompt_data if specified
150+
prompt = Prompt(prompt_data="Rate the movie {movie1}")
151+
with pytest.raises(ValueError):
152+
prompt.assemble_contents(day="Tuesday")

0 commit comments

Comments
 (0)