16
16
#
17
17
18
18
from google .cloud .aiplatform import base
19
+ from google .cloud .aiplatform import initializer as aiplatform_initializer
19
20
from vertexai .generative_models import (
20
21
Content ,
21
22
Image ,
22
23
Part ,
24
+ GenerativeModel ,
25
+ GenerationConfig ,
26
+ SafetySetting ,
27
+ Tool ,
28
+ ToolConfig ,
23
29
)
24
30
from vertexai .generative_models ._generative_models import (
25
31
_to_content ,
32
+ _validate_generate_content_parameters ,
33
+ _reconcile_model_name ,
34
+ _get_resource_name_from_model_name ,
35
+ ContentsType ,
36
+ GenerationConfigType ,
37
+ GenerationResponse ,
26
38
PartsType ,
39
+ SafetySettingsType ,
27
40
)
28
41
29
42
import re
30
43
from typing import (
31
44
Any ,
32
45
Dict ,
46
+ Iterable ,
33
47
List ,
34
48
Optional ,
35
49
Union ,
@@ -55,29 +69,72 @@ class Prompt:
55
69
prompt = Prompt(
56
70
prompt_data="Hello, {name}! Today is {day}. How are you?",
57
71
variables=[{"name": "Alice", "day": "Monday"}]
72
+ generation_config=GenerationConfig(
73
+ temperature=0.1,
74
+ top_p=0.95,
75
+ top_k=20,
76
+ candidate_count=1,
77
+ max_output_tokens=100,
78
+ stop_sequences=["\n \n \n "],
79
+ ),
80
+ model_name="gemini-1.0-pro-002",
81
+ safety_settings=[SafetySetting(
82
+ category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
83
+ threshold=SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
84
+ method=SafetySetting.HarmBlockMethod.SEVERITY,
85
+ )],
86
+ system_instruction="Please answer in a short sentence.",
58
87
)
59
88
60
89
# Generate content using the assembled prompt.
61
- model.generate_content(contents=prompt.assemble_contents(**prompt.variables[0]))
90
+ prompt.generate_content(
91
+ contents=prompt.assemble_contents(**prompt.variables)
92
+ )
62
93
```
63
94
"""
64
95
65
96
def __init__ (
66
97
self ,
67
98
prompt_data : PartsType ,
99
+ * ,
68
100
variables : Optional [List [Dict [str , PartsType ]]] = None ,
101
+ generation_config : Optional [GenerationConfig ] = None ,
102
+ model_name : Optional [str ] = None ,
103
+ safety_settings : Optional [SafetySetting ] = None ,
104
+ system_instruction : Optional [PartsType ] = None ,
105
+ tools : Optional [List [Tool ]] = None ,
106
+ tool_config : Optional [ToolConfig ] = None ,
69
107
):
70
108
"""Initializes the Prompt with a given prompt, and variables.
71
109
72
110
Args:
73
111
prompt: A PartsType prompt which may be a template with variables or a prompt with no variables.
74
112
variables: A list of dictionaries containing the variable names and values.
113
+ generation_config: A GenerationConfig object containing parameters for generation.
114
+ model_name: Model Garden model resource name.
115
+ Alternatively, a tuned model endpoint resource name can be provided.
116
+ safety_settings: A SafetySetting object containing safety settings for generation.
117
+ system_instruction: A PartsType object representing the system instruction.
118
+ tools: A list of Tool objects for function calling.
119
+ tool_config: A ToolConfig object for function calling.
75
120
"""
76
121
self ._prompt_data = None
77
122
self ._variables = None
123
+ self ._model_name = None
124
+ self ._generation_config = None
125
+ self ._safety_settings = None
126
+ self ._system_instruction = None
127
+ self ._tools = None
128
+ self ._tool_config = None
78
129
79
130
self .prompt_data = prompt_data
80
131
self .variables = variables if variables else [{}]
132
+ self .model_name = model_name
133
+ self .generation_config = generation_config
134
+ self .safety_settings = safety_settings
135
+ self .system_instruction = system_instruction
136
+ self .tools = tools
137
+ self .tool_config = tool_config
81
138
82
139
@property
83
140
def prompt_data (self ) -> PartsType :
@@ -87,14 +144,38 @@ def prompt_data(self) -> PartsType:
87
144
def variables (self ) -> Optional [List [Dict [str , PartsType ]]]:
88
145
return self ._variables
89
146
147
+ @property
148
+ def generation_config (self ) -> Optional [GenerationConfig ]:
149
+ return self ._generation_config
150
+
151
+ @property
152
+ def model_name (self ) -> Optional [str ]:
153
+ return self ._model_name
154
+
155
+ @property
156
+ def safety_settings (self ) -> Optional [List [SafetySetting ]]:
157
+ return self ._safety_settings
158
+
159
+ @property
160
+ def system_instruction (self ) -> Optional [PartsType ]:
161
+ return self ._system_instruction
162
+
163
+ @property
164
+ def tools (self ) -> Optional [List [Tool ]]:
165
+ return self ._tools
166
+
167
+ @property
168
+ def tool_config (self ) -> Optional [ToolConfig ]:
169
+ return self ._tool_config
170
+
90
171
@prompt_data .setter
91
172
def prompt_data (self , prompt_data : PartsType ) -> None :
92
173
"""Overwrites the existing saved local prompt_data.
93
174
94
175
Args:
95
176
prompt_data: A PartsType prompt.
96
177
"""
97
- Prompt . _validate_prompt_data (prompt_data )
178
+ self . _validate_parts_type_data (prompt_data )
98
179
self ._prompt_data = prompt_data
99
180
100
181
@variables .setter
@@ -114,6 +195,98 @@ def variables(self, variables: List[Dict[str, PartsType]]) -> None:
114
195
f"Variables must be a list of dictionaries, not { type (variables )} "
115
196
)
116
197
198
+ @model_name .setter
199
+ def model_name (self , model_name : Optional [str ]) -> None :
200
+ """Overwrites the existing saved local model_name."""
201
+ if model_name :
202
+ self ._model_name = Prompt ._format_model_resource_name (model_name )
203
+ else :
204
+ self ._model_name = None
205
+
206
+ def _format_model_resource_name (model_name : Optional [str ]) -> str :
207
+ """Formats the model resource name."""
208
+ project = aiplatform_initializer .global_config .project
209
+ location = aiplatform_initializer .global_config .location
210
+ model_name = _reconcile_model_name (model_name , project , location )
211
+
212
+ prediction_resource_name = _get_resource_name_from_model_name (
213
+ model_name , project , location
214
+ )
215
+ return prediction_resource_name
216
+
217
+ def _validate_configs (
218
+ self ,
219
+ generation_config : Optional [GenerationConfig ] = None ,
220
+ safety_settings : Optional [SafetySetting ] = None ,
221
+ system_instruction : Optional [PartsType ] = None ,
222
+ tools : Optional [List [Tool ]] = None ,
223
+ tool_config : Optional [ToolConfig ] = None ,
224
+ ):
225
+ generation_config = generation_config or self ._generation_config
226
+ safety_settings = safety_settings or self ._safety_settings
227
+ tools = tools or self ._tools
228
+ tool_config = tool_config or self ._tool_config
229
+ system_instruction = system_instruction or self ._system_instruction
230
+ return _validate_generate_content_parameters (
231
+ contents = "test" ,
232
+ generation_config = generation_config ,
233
+ safety_settings = safety_settings ,
234
+ system_instruction = system_instruction ,
235
+ tools = tools ,
236
+ tool_config = tool_config ,
237
+ )
238
+
239
+ @generation_config .setter
240
+ def generation_config (self , generation_config : Optional [GenerationConfig ]) -> None :
241
+ """Overwrites the existing saved local generation_config.
242
+
243
+ Args:
244
+ generation_config: A GenerationConfig object containing parameters for generation.
245
+ """
246
+ self ._validate_configs (generation_config = generation_config )
247
+ self ._generation_config = generation_config
248
+
249
+ @safety_settings .setter
250
+ def safety_settings (self , safety_settings : Optional [SafetySetting ]) -> None :
251
+ """Overwrites the existing saved local safety_settings.
252
+
253
+ Args:
254
+ safety_settings: A SafetySetting object containing safety settings for generation.
255
+ """
256
+ self ._validate_configs (safety_settings = safety_settings )
257
+ self ._safety_settings = safety_settings
258
+
259
+ @system_instruction .setter
260
+ def system_instruction (self , system_instruction : Optional [PartsType ]) -> None :
261
+ """Overwrites the existing saved local system_instruction.
262
+
263
+ Args:
264
+ system_instruction: A PartsType object representing the system instruction.
265
+ """
266
+ if system_instruction :
267
+ self ._validate_parts_type_data (system_instruction )
268
+ self ._system_instruction = system_instruction
269
+
270
+ @tools .setter
271
+ def tools (self , tools : Optional [List [Tool ]]) -> None :
272
+ """Overwrites the existing saved local tools.
273
+
274
+ Args:
275
+ tools: A list of Tool objects for function calling.
276
+ """
277
+ self ._validate_configs (tools = tools )
278
+ self ._tools = tools
279
+
280
+ @tool_config .setter
281
+ def tool_config (self , tool_config : Optional [ToolConfig ] = None ) -> None :
282
+ """Overwrites the existing saved local tool_config.
283
+
284
+ Args:
285
+ tool_config: A ToolConfig object for function calling.
286
+ """
287
+ self ._validate_configs (tool_config = tool_config )
288
+ self ._tool_config = tool_config
289
+
117
290
def _format_variable_value_to_parts (variables_dict : Dict [str , PartsType ]) -> None :
118
291
"""Formats the variables values to be List[Part].
119
292
@@ -134,7 +307,7 @@ def _format_variable_value_to_parts(variables_dict: Dict[str, PartsType]) -> Non
134
307
content = Content ._from_gapic (_to_content (value = variables_dict [key ]))
135
308
variables_dict [key ] = content .parts
136
309
137
- def _validate_prompt_data ( prompt_data : Any ) -> None :
310
+ def _validate_parts_type_data ( self , data : Any ) -> None :
138
311
"""
139
312
Args:
140
313
prompt_data: The prompt input to validate
@@ -143,11 +316,11 @@ def _validate_prompt_data(prompt_data: Any) -> None:
143
316
TypeError: If prompt_data is not a PartsType Object.
144
317
"""
145
318
# Disallow Content as prompt_data.
146
- if isinstance (prompt_data , Content ):
319
+ if isinstance (data , Content ):
147
320
raise TypeError ("Prompt data must be a PartsType object, not Content" )
148
321
149
322
# Rely on type checks in _to_content.
150
- _to_content (value = prompt_data )
323
+ _to_content (value = data )
151
324
152
325
def assemble_contents (self , ** variables_dict : PartsType ) -> List [Content ]:
153
326
"""Returns the prompt data, as a List[Content], assembled with variables if applicable.
@@ -176,7 +349,7 @@ def assemble_contents(self, **variables_dict: PartsType) -> List[Content]:
176
349
Prompt ._format_variable_value_to_parts (variables_dict )
177
350
178
351
# Step 2) Assemble the prompt.
179
- # prompt_data must have been previously validated using _validate_prompt_data .
352
+ # prompt_data must have been previously validated using _validate_parts_type_data .
180
353
assembled_prompt = []
181
354
assembled_variables_cnt = {}
182
355
if isinstance (self .prompt_data , list ):
@@ -288,6 +461,81 @@ def _assemble_single_str(
288
461
289
462
return assembled_data
290
463
464
+ def generate_content (
465
+ self ,
466
+ contents : ContentsType ,
467
+ * ,
468
+ generation_config : Optional [GenerationConfigType ] = None ,
469
+ safety_settings : Optional [SafetySettingsType ] = None ,
470
+ model_name : Optional [str ] = None ,
471
+ tools : Optional [List ["Tool" ]] = None ,
472
+ tool_config : Optional ["ToolConfig" ] = None ,
473
+ stream : bool = False ,
474
+ system_instruction : Optional [PartsType ] = None ,
475
+ ) -> Union ["GenerationResponse" , Iterable ["GenerationResponse" ],]:
476
+ """Generates content using the saved Prompt configs.
477
+
478
+ Args:
479
+ contents: Contents to send to the model.
480
+ Supports either a list of Content objects (passing a multi-turn conversation)
481
+ or a value that can be converted to a single Content object (passing a single message).
482
+ Supports
483
+ * str, Image, Part,
484
+ * List[Union[str, Image, Part]],
485
+ * List[Content]
486
+ generation_config: Parameters for the generation.
487
+ model_name: Prediction model resource name.
488
+ safety_settings: Safety settings as a mapping from HarmCategory to HarmBlockThreshold.
489
+ tools: A list of tools (functions) that the model can try calling.
490
+ tool_config: Config shared for all tools provided in the request.
491
+ stream: Whether to stream the response.
492
+ system_instruction: System instruction to pass to the model.
493
+
494
+ Returns:
495
+ A single GenerationResponse object if stream == False
496
+ A stream of GenerationResponse objects if stream == True
497
+
498
+ Usage:
499
+ ```
500
+ prompt = Prompt(
501
+ prompt_data="Hello, {name}! Today is {day}. How are you?",
502
+ variables={"name": "Alice", "day": "Monday"},
503
+ generation_config=GenerationConfig(temperature=0.1,),
504
+ system_instruction="Please answer in a short sentence.",
505
+ model_name="gemini-1.0-pro-002",
506
+ )
507
+
508
+ prompt.generate_content(
509
+ contents=prompt.assemble_contents(**prompt.variables)
510
+ )
511
+ ```
512
+ """
513
+
514
+ generation_config = generation_config or self .generation_config
515
+ safety_settings = safety_settings or self .safety_settings
516
+ model_name = model_name or self .model_name
517
+ tools = tools or self .tools
518
+ tool_config = tool_config or self .tool_config
519
+ system_instruction = system_instruction or self .system_instruction
520
+
521
+ if not model_name :
522
+ raise ValueError (
523
+ "Model name must be specified to use Prompt.generate_content()"
524
+ )
525
+ model_name = Prompt ._format_model_resource_name (model_name )
526
+
527
+ model = GenerativeModel (
528
+ model_name = model_name , system_instruction = system_instruction
529
+ )
530
+ return model .generate_content (
531
+ contents = contents ,
532
+ generation_config = generation_config ,
533
+ safety_settings = safety_settings ,
534
+ tools = tools ,
535
+ tool_config = tool_config ,
536
+ stream = stream ,
537
+ )
538
+
291
539
def get_unassembled_prompt_data (self ) -> PartsType :
292
540
"""Returns the prompt data, without any variables replaced."""
293
541
return self .prompt_data
0 commit comments