Skip to content

Commit 7f1e031

Browse files
matthew29tangcopybara-github
authored andcommitted
feat: Add Prompt class support for configs and Prompt.generate_content wrapper
PiperOrigin-RevId: 665012897
1 parent fd38b49 commit 7f1e031

File tree

1 file changed

+254
-6
lines changed

1 file changed

+254
-6
lines changed

vertexai/generative_models/_prompts.py

+254-6
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,34 @@
1616
#
1717

1818
from google.cloud.aiplatform import base
19+
from google.cloud.aiplatform import initializer as aiplatform_initializer
1920
from vertexai.generative_models import (
2021
Content,
2122
Image,
2223
Part,
24+
GenerativeModel,
25+
GenerationConfig,
26+
SafetySetting,
27+
Tool,
28+
ToolConfig,
2329
)
2430
from vertexai.generative_models._generative_models import (
2531
_to_content,
32+
_validate_generate_content_parameters,
33+
_reconcile_model_name,
34+
_get_resource_name_from_model_name,
35+
ContentsType,
36+
GenerationConfigType,
37+
GenerationResponse,
2638
PartsType,
39+
SafetySettingsType,
2740
)
2841

2942
import re
3043
from typing import (
3144
Any,
3245
Dict,
46+
Iterable,
3347
List,
3448
Optional,
3549
Union,
@@ -55,29 +69,72 @@ class Prompt:
5569
prompt = Prompt(
5670
prompt_data="Hello, {name}! Today is {day}. How are you?",
5771
variables=[{"name": "Alice", "day": "Monday"}]
72+
generation_config=GenerationConfig(
73+
temperature=0.1,
74+
top_p=0.95,
75+
top_k=20,
76+
candidate_count=1,
77+
max_output_tokens=100,
78+
stop_sequences=["\n\n\n"],
79+
),
80+
model_name="gemini-1.0-pro-002",
81+
safety_settings=[SafetySetting(
82+
category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
83+
threshold=SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
84+
method=SafetySetting.HarmBlockMethod.SEVERITY,
85+
)],
86+
system_instruction="Please answer in a short sentence.",
5887
)
5988
6089
# Generate content using the assembled prompt.
61-
model.generate_content(contents=prompt.assemble_contents(**prompt.variables[0]))
90+
prompt.generate_content(
91+
contents=prompt.assemble_contents(**prompt.variables)
92+
)
6293
```
6394
"""
6495

6596
def __init__(
6697
self,
6798
prompt_data: PartsType,
99+
*,
68100
variables: Optional[List[Dict[str, PartsType]]] = None,
101+
generation_config: Optional[GenerationConfig] = None,
102+
model_name: Optional[str] = None,
103+
safety_settings: Optional[SafetySetting] = None,
104+
system_instruction: Optional[PartsType] = None,
105+
tools: Optional[List[Tool]] = None,
106+
tool_config: Optional[ToolConfig] = None,
69107
):
70108
"""Initializes the Prompt with a given prompt, and variables.
71109
72110
Args:
73111
prompt: A PartsType prompt which may be a template with variables or a prompt with no variables.
74112
variables: A list of dictionaries containing the variable names and values.
113+
generation_config: A GenerationConfig object containing parameters for generation.
114+
model_name: Model Garden model resource name.
115+
Alternatively, a tuned model endpoint resource name can be provided.
116+
safety_settings: A SafetySetting object containing safety settings for generation.
117+
system_instruction: A PartsType object representing the system instruction.
118+
tools: A list of Tool objects for function calling.
119+
tool_config: A ToolConfig object for function calling.
75120
"""
76121
self._prompt_data = None
77122
self._variables = None
123+
self._model_name = None
124+
self._generation_config = None
125+
self._safety_settings = None
126+
self._system_instruction = None
127+
self._tools = None
128+
self._tool_config = None
78129

79130
self.prompt_data = prompt_data
80131
self.variables = variables if variables else [{}]
132+
self.model_name = model_name
133+
self.generation_config = generation_config
134+
self.safety_settings = safety_settings
135+
self.system_instruction = system_instruction
136+
self.tools = tools
137+
self.tool_config = tool_config
81138

82139
@property
83140
def prompt_data(self) -> PartsType:
@@ -87,14 +144,38 @@ def prompt_data(self) -> PartsType:
87144
def variables(self) -> Optional[List[Dict[str, PartsType]]]:
88145
return self._variables
89146

147+
@property
148+
def generation_config(self) -> Optional[GenerationConfig]:
149+
return self._generation_config
150+
151+
@property
152+
def model_name(self) -> Optional[str]:
153+
return self._model_name
154+
155+
@property
156+
def safety_settings(self) -> Optional[List[SafetySetting]]:
157+
return self._safety_settings
158+
159+
@property
160+
def system_instruction(self) -> Optional[PartsType]:
161+
return self._system_instruction
162+
163+
@property
164+
def tools(self) -> Optional[List[Tool]]:
165+
return self._tools
166+
167+
@property
168+
def tool_config(self) -> Optional[ToolConfig]:
169+
return self._tool_config
170+
90171
@prompt_data.setter
91172
def prompt_data(self, prompt_data: PartsType) -> None:
92173
"""Overwrites the existing saved local prompt_data.
93174
94175
Args:
95176
prompt_data: A PartsType prompt.
96177
"""
97-
Prompt._validate_prompt_data(prompt_data)
178+
self._validate_parts_type_data(prompt_data)
98179
self._prompt_data = prompt_data
99180

100181
@variables.setter
@@ -114,6 +195,98 @@ def variables(self, variables: List[Dict[str, PartsType]]) -> None:
114195
f"Variables must be a list of dictionaries, not {type(variables)}"
115196
)
116197

198+
@model_name.setter
199+
def model_name(self, model_name: Optional[str]) -> None:
200+
"""Overwrites the existing saved local model_name."""
201+
if model_name:
202+
self._model_name = Prompt._format_model_resource_name(model_name)
203+
else:
204+
self._model_name = None
205+
206+
def _format_model_resource_name(model_name: Optional[str]) -> str:
207+
"""Formats the model resource name."""
208+
project = aiplatform_initializer.global_config.project
209+
location = aiplatform_initializer.global_config.location
210+
model_name = _reconcile_model_name(model_name, project, location)
211+
212+
prediction_resource_name = _get_resource_name_from_model_name(
213+
model_name, project, location
214+
)
215+
return prediction_resource_name
216+
217+
def _validate_configs(
218+
self,
219+
generation_config: Optional[GenerationConfig] = None,
220+
safety_settings: Optional[SafetySetting] = None,
221+
system_instruction: Optional[PartsType] = None,
222+
tools: Optional[List[Tool]] = None,
223+
tool_config: Optional[ToolConfig] = None,
224+
):
225+
generation_config = generation_config or self._generation_config
226+
safety_settings = safety_settings or self._safety_settings
227+
tools = tools or self._tools
228+
tool_config = tool_config or self._tool_config
229+
system_instruction = system_instruction or self._system_instruction
230+
return _validate_generate_content_parameters(
231+
contents="test",
232+
generation_config=generation_config,
233+
safety_settings=safety_settings,
234+
system_instruction=system_instruction,
235+
tools=tools,
236+
tool_config=tool_config,
237+
)
238+
239+
@generation_config.setter
240+
def generation_config(self, generation_config: Optional[GenerationConfig]) -> None:
241+
"""Overwrites the existing saved local generation_config.
242+
243+
Args:
244+
generation_config: A GenerationConfig object containing parameters for generation.
245+
"""
246+
self._validate_configs(generation_config=generation_config)
247+
self._generation_config = generation_config
248+
249+
@safety_settings.setter
250+
def safety_settings(self, safety_settings: Optional[SafetySetting]) -> None:
251+
"""Overwrites the existing saved local safety_settings.
252+
253+
Args:
254+
safety_settings: A SafetySetting object containing safety settings for generation.
255+
"""
256+
self._validate_configs(safety_settings=safety_settings)
257+
self._safety_settings = safety_settings
258+
259+
@system_instruction.setter
260+
def system_instruction(self, system_instruction: Optional[PartsType]) -> None:
261+
"""Overwrites the existing saved local system_instruction.
262+
263+
Args:
264+
system_instruction: A PartsType object representing the system instruction.
265+
"""
266+
if system_instruction:
267+
self._validate_parts_type_data(system_instruction)
268+
self._system_instruction = system_instruction
269+
270+
@tools.setter
271+
def tools(self, tools: Optional[List[Tool]]) -> None:
272+
"""Overwrites the existing saved local tools.
273+
274+
Args:
275+
tools: A list of Tool objects for function calling.
276+
"""
277+
self._validate_configs(tools=tools)
278+
self._tools = tools
279+
280+
@tool_config.setter
281+
def tool_config(self, tool_config: Optional[ToolConfig] = None) -> None:
282+
"""Overwrites the existing saved local tool_config.
283+
284+
Args:
285+
tool_config: A ToolConfig object for function calling.
286+
"""
287+
self._validate_configs(tool_config=tool_config)
288+
self._tool_config = tool_config
289+
117290
def _format_variable_value_to_parts(variables_dict: Dict[str, PartsType]) -> None:
118291
"""Formats the variables values to be List[Part].
119292
@@ -134,7 +307,7 @@ def _format_variable_value_to_parts(variables_dict: Dict[str, PartsType]) -> Non
134307
content = Content._from_gapic(_to_content(value=variables_dict[key]))
135308
variables_dict[key] = content.parts
136309

137-
def _validate_prompt_data(prompt_data: Any) -> None:
310+
def _validate_parts_type_data(self, data: Any) -> None:
138311
"""
139312
Args:
140313
prompt_data: The prompt input to validate
@@ -143,11 +316,11 @@ def _validate_prompt_data(prompt_data: Any) -> None:
143316
TypeError: If prompt_data is not a PartsType Object.
144317
"""
145318
# Disallow Content as prompt_data.
146-
if isinstance(prompt_data, Content):
319+
if isinstance(data, Content):
147320
raise TypeError("Prompt data must be a PartsType object, not Content")
148321

149322
# Rely on type checks in _to_content.
150-
_to_content(value=prompt_data)
323+
_to_content(value=data)
151324

152325
def assemble_contents(self, **variables_dict: PartsType) -> List[Content]:
153326
"""Returns the prompt data, as a List[Content], assembled with variables if applicable.
@@ -176,7 +349,7 @@ def assemble_contents(self, **variables_dict: PartsType) -> List[Content]:
176349
Prompt._format_variable_value_to_parts(variables_dict)
177350

178351
# Step 2) Assemble the prompt.
179-
# prompt_data must have been previously validated using _validate_prompt_data.
352+
# prompt_data must have been previously validated using _validate_parts_type_data.
180353
assembled_prompt = []
181354
assembled_variables_cnt = {}
182355
if isinstance(self.prompt_data, list):
@@ -288,6 +461,81 @@ def _assemble_single_str(
288461

289462
return assembled_data
290463

464+
def generate_content(
465+
self,
466+
contents: ContentsType,
467+
*,
468+
generation_config: Optional[GenerationConfigType] = None,
469+
safety_settings: Optional[SafetySettingsType] = None,
470+
model_name: Optional[str] = None,
471+
tools: Optional[List["Tool"]] = None,
472+
tool_config: Optional["ToolConfig"] = None,
473+
stream: bool = False,
474+
system_instruction: Optional[PartsType] = None,
475+
) -> Union["GenerationResponse", Iterable["GenerationResponse"],]:
476+
"""Generates content using the saved Prompt configs.
477+
478+
Args:
479+
contents: Contents to send to the model.
480+
Supports either a list of Content objects (passing a multi-turn conversation)
481+
or a value that can be converted to a single Content object (passing a single message).
482+
Supports
483+
* str, Image, Part,
484+
* List[Union[str, Image, Part]],
485+
* List[Content]
486+
generation_config: Parameters for the generation.
487+
model_name: Prediction model resource name.
488+
safety_settings: Safety settings as a mapping from HarmCategory to HarmBlockThreshold.
489+
tools: A list of tools (functions) that the model can try calling.
490+
tool_config: Config shared for all tools provided in the request.
491+
stream: Whether to stream the response.
492+
system_instruction: System instruction to pass to the model.
493+
494+
Returns:
495+
A single GenerationResponse object if stream == False
496+
A stream of GenerationResponse objects if stream == True
497+
498+
Usage:
499+
```
500+
prompt = Prompt(
501+
prompt_data="Hello, {name}! Today is {day}. How are you?",
502+
variables={"name": "Alice", "day": "Monday"},
503+
generation_config=GenerationConfig(temperature=0.1,),
504+
system_instruction="Please answer in a short sentence.",
505+
model_name="gemini-1.0-pro-002",
506+
)
507+
508+
prompt.generate_content(
509+
contents=prompt.assemble_contents(**prompt.variables)
510+
)
511+
```
512+
"""
513+
514+
generation_config = generation_config or self.generation_config
515+
safety_settings = safety_settings or self.safety_settings
516+
model_name = model_name or self.model_name
517+
tools = tools or self.tools
518+
tool_config = tool_config or self.tool_config
519+
system_instruction = system_instruction or self.system_instruction
520+
521+
if not model_name:
522+
raise ValueError(
523+
"Model name must be specified to use Prompt.generate_content()"
524+
)
525+
model_name = Prompt._format_model_resource_name(model_name)
526+
527+
model = GenerativeModel(
528+
model_name=model_name, system_instruction=system_instruction
529+
)
530+
return model.generate_content(
531+
contents=contents,
532+
generation_config=generation_config,
533+
safety_settings=safety_settings,
534+
tools=tools,
535+
tool_config=tool_config,
536+
stream=stream,
537+
)
538+
291539
def get_unassembled_prompt_data(self) -> PartsType:
292540
"""Returns the prompt data, without any variables replaced."""
293541
return self.prompt_data

0 commit comments

Comments
 (0)