Skip to content

Commit 126d10c

Browse files
Frances Hubis Thomacopybara-github
Frances Hubis Thoma
authored andcommitted
feat: Add a module-level function to create a Gemini template config for single-turn Gemini examples without having to explicitly construct the Gemini example.
PiperOrigin-RevId: 745187824
1 parent f389020 commit 126d10c

File tree

2 files changed

+152
-0
lines changed

2 files changed

+152
-0
lines changed

google/cloud/aiplatform/preview/datasets.py

+89
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,95 @@ def _generate_target_table_id(dataset_id: str):
191191
return f"{dataset_id}.{_DEFAULT_BQ_TABLE_PREFIX}_{str(uuid.uuid4())}"
192192

193193

194+
def construct_single_turn_template(
195+
*,
196+
prompt: str = None,
197+
response: Optional[str] = None,
198+
system_instruction: Optional[str] = None,
199+
model: Optional[str] = None,
200+
cached_content: Optional[str] = None,
201+
tools: Optional[List[generative_models.Tool]] = None,
202+
tool_config: Optional[generative_models.ToolConfig] = None,
203+
safety_settings: Optional[List[generative_models.SafetySetting]] = None,
204+
generation_config: Optional[generative_models.GenerationConfig] = None,
205+
field_mapping: List[Dict[str, str]] = None,
206+
) -> "GeminiTemplateConfig":
207+
"""Constructs a GeminiTemplateConfig object for single-turn cases.
208+
209+
Example:
210+
template_config = dataset.construct_single_turn_template(
211+
prompt = "Which flower is this {flower_image} ?",
212+
response="This is a {label}.",
213+
system_instruction="You are a botanical classifier."
214+
)
215+
216+
Args:
217+
218+
prompt (str):
219+
Required. User input.
220+
response (str):
221+
Optional. Model response to user input.
222+
system_instruction (str):
223+
Optional. System instructions for the model.
224+
model (str):
225+
Optional. The model to use for the GeminiExample.
226+
cached_content (str):
227+
Optional. The cached content to use for the GeminiExample.
228+
tools (List[Tool]):
229+
Optional. The tools to use for the GeminiExample.
230+
tool_config (ToolConfig):
231+
Optional. The tool config to use for the GeminiExample.
232+
safety_settings (List[SafetySetting]):
233+
Optional. The safety settings to use for the GeminiExample.
234+
generation_config (GenerationConfig):
235+
Optional. The generation config to use for the GeminiExample.
236+
field_mapping (List[Dict[str, str]]):
237+
Optional. Mapping of placeholders to dataset columns.
238+
239+
Returns:
240+
A GeminiTemplateConfig object.
241+
"""
242+
contents = []
243+
contents.append(
244+
generative_models.Content(
245+
role="user",
246+
parts=[
247+
generative_models.Part.from_text(prompt),
248+
],
249+
)
250+
)
251+
if response:
252+
contents.append(
253+
generative_models.Content(
254+
role="model",
255+
parts=[
256+
generative_models.Part.from_text(response),
257+
],
258+
)
259+
)
260+
if system_instruction:
261+
system_instruction = generative_models.Content(
262+
parts=[
263+
generative_models.Part.from_text(system_instruction),
264+
],
265+
)
266+
267+
# Set up GeminiExample.
268+
gemini_example = GeminiExample(
269+
model=model,
270+
contents=contents,
271+
system_instruction=system_instruction,
272+
cached_content=cached_content,
273+
tools=tools,
274+
tool_config=tool_config,
275+
safety_settings=safety_settings,
276+
generation_config=generation_config,
277+
)
278+
return GeminiTemplateConfig(
279+
gemini_example=gemini_example, field_mapping=field_mapping
280+
)
281+
282+
194283
class GeminiExample:
195284
"""A class representing a Gemini example."""
196285

tests/unit/aiplatform/test_multimodal_datasets.py

+63
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,69 @@ def blob_side_effect(name, mock_blob, bucket):
304304
yield mock_storage_client_bucket, mock_bucket, mock_blob
305305

306306

307+
def test_construct_single_turn_template():
308+
tools = [
309+
generative_models.Tool(
310+
function_declarations=[
311+
generative_models.FunctionDeclaration(name="function", parameters={})
312+
],
313+
)
314+
]
315+
tool_config = generative_models.ToolConfig(
316+
function_calling_config=generative_models.ToolConfig.FunctionCallingConfig(
317+
mode=generative_models.ToolConfig.FunctionCallingConfig.Mode.ANY,
318+
allowed_function_names=["get_current_weather"],
319+
)
320+
)
321+
safety_settings = [
322+
generative_models.SafetySetting(
323+
category=generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
324+
threshold=generative_models.HarmBlockThreshold.BLOCK_NONE,
325+
)
326+
]
327+
generation_config = generative_models.GenerationConfig(max_output_tokens=100)
328+
field_mapping = [{"input": "prompt", "output": "response"}]
329+
template_config = ummd.construct_single_turn_template(
330+
prompt="prompt",
331+
response="response",
332+
system_instruction="system_instruction",
333+
model="gemini-1.5-flash-002",
334+
cached_content="cached_content",
335+
tools=tools,
336+
tool_config=tool_config,
337+
safety_settings=safety_settings,
338+
generation_config=generation_config,
339+
field_mapping=field_mapping,
340+
)
341+
expected_gemini_example = ummd.GeminiExample(
342+
model="gemini-1.5-flash-002",
343+
contents=[
344+
ummd.GeminiExample.Content(
345+
role="user", parts=[generative_models.Part.from_text("prompt")]
346+
),
347+
ummd.GeminiExample.Content(
348+
role="model",
349+
parts=[generative_models.Part.from_text("response")],
350+
),
351+
],
352+
system_instruction=generative_models.Content(
353+
parts=[
354+
generative_models.Part.from_text("system_instruction"),
355+
]
356+
),
357+
cached_content="cached_content",
358+
tools=tools,
359+
tool_config=tool_config,
360+
safety_settings=safety_settings,
361+
generation_config=generation_config,
362+
)
363+
expected_gemini_template_config = ummd.GeminiTemplateConfig(
364+
gemini_example=expected_gemini_example,
365+
field_mapping=[{"input": "prompt", "output": "response"}],
366+
)
367+
assert str(template_config) == str(expected_gemini_template_config)
368+
369+
307370
@pytest.mark.usefixtures("google_auth_mock")
308371
class TestMultimodalDataset:
309372
"""Tests for the MultimodalDataset class."""

0 commit comments

Comments
 (0)