Skip to content

Commit be4922a

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: GenAI - Support generating JSON Schema from Python function
The `generate_json_schema_from_function` function generates JSON Schema from a Python function so that it can be used for constructing `Schema` or `FunctionDeclaration` objects (after tweaking the schema with the `adapt_json_schema_to_google_tool_schema` function). PiperOrigin-RevId: 617074433
1 parent bdd4817 commit be4922a

File tree

3 files changed

+206
-3
lines changed

3 files changed

+206
-3
lines changed

setup.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,11 @@
111111
"immutabledict",
112112
]
113113

114+
genai_requires = (
115+
"pydantic < 2",
116+
"docstring_parser < 1",
117+
)
118+
114119
full_extra_require = list(
115120
set(
116121
tensorboard_extra_require
@@ -186,7 +191,8 @@
186191
"google-cloud-bigquery >= 1.15.0, < 4.0.0dev",
187192
"google-cloud-resource-manager >= 1.3.3, < 3.0.0dev",
188193
"shapely < 3.0.0dev",
189-
),
194+
)
195+
+ genai_requires,
190196
extras_require={
191197
"endpoint": endpoint_extra_require,
192198
"full": full_extra_require,

tests/unit/vertexai/test_generative_models.py

+50-2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
gapic_content_types,
3333
gapic_tool_types,
3434
)
35+
from vertexai.generative_models import _function_calling_utils
36+
3537

3638
_TEST_PROJECT = "test-project"
3739
_TEST_LOCATION = "us-central1"
@@ -251,12 +253,12 @@ def mock_stream_generate_content(
251253
)
252254

253255

254-
def get_current_weather(location: str, unit: str = "centigrade"):
256+
def get_current_weather(location: str, unit: Optional[str] = "centigrade"):
255257
"""Gets weather in the specified location.
256258
257259
Args:
258260
location: The location for which to get the weather.
259-
unit: Optional. Temperature unit. Can be Centigrade or Fahrenheit. Defaults to Centigrade.
261+
unit: Temperature unit. Can be Centigrade or Fahrenheit. Default: Centigrade.
260262
261263
Returns:
262264
The weather information as a dict.
@@ -535,3 +537,49 @@ def test_generate_content_grounding_vertex_ai_search_retriever(self):
535537
"Why is sky blue?", tools=[google_search_retriever_tool]
536538
)
537539
assert response.text
540+
541+
542+
EXPECTED_SCHEMA_FOR_GET_CURRENT_WEATHER = {
543+
"title": "get_current_weather",
544+
"type": "object",
545+
"description": "Gets weather in the specified location.",
546+
"properties": {
547+
"location": {
548+
"title": "Location",
549+
"type": "string",
550+
"description": "The location for which to get the weather.",
551+
},
552+
"unit": {
553+
"title": "Unit",
554+
"type": "string",
555+
"description": "Temperature unit. Can be Centigrade or Fahrenheit. Default: Centigrade.",
556+
"default": "centigrade",
557+
"nullable": True,
558+
},
559+
},
560+
"required": ["location"],
561+
}
562+
563+
564+
class TestFunctionCallingUtils:
565+
def test_generate_json_schema_for_callable(self):
566+
test_cases = [
567+
(get_current_weather, EXPECTED_SCHEMA_FOR_GET_CURRENT_WEATHER),
568+
]
569+
for function, expected_schema in test_cases:
570+
schema = _function_calling_utils.generate_json_schema_from_function(
571+
function
572+
)
573+
function_name = schema["title"]
574+
function_description = schema["description"]
575+
assert schema == expected_schema
576+
577+
fixed_schema = (
578+
_function_calling_utils.adapt_json_schema_to_google_tool_schema(schema)
579+
)
580+
function_declaration = generative_models.FunctionDeclaration(
581+
name=function_name,
582+
description=function_description,
583+
parameters=fixed_schema,
584+
)
585+
assert function_declaration
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
"""Shared utilities for working with function schemas."""
16+
17+
import inspect
18+
import typing
19+
from typing import Any, Callable, Dict
20+
import warnings
21+
22+
from google.cloud.aiplatform_v1beta1 import types as aiplatform_types
23+
24+
Struct = Dict[str, Any]
25+
26+
27+
def _generate_json_schema_from_function_using_pydantic(
28+
func: Callable,
29+
) -> Struct:
30+
"""Generates JSON Schema for a callable object.
31+
32+
The `func` function needs to follow specific rules.
33+
All parameters must be names explicitly (`*args` and `**kwargs` are not supported).
34+
35+
Args:
36+
func: Function for which to generate schema
37+
38+
Returns:
39+
The JSON Schema for the function as a dict.
40+
"""
41+
import pydantic
42+
43+
try:
44+
import docstring_parser # pylint: disable=g-import-not-at-top
45+
except ImportError:
46+
warnings.warn("Unable to import docstring_parser")
47+
docstring_parser = None
48+
49+
function_description = func.__doc__
50+
51+
# Parse parameter descriptions from the docstring.
52+
# Also parse the function descripton in a better way.
53+
parameter_descriptions = {}
54+
if docstring_parser:
55+
parsed_docstring = docstring_parser.parse(func.__doc__)
56+
function_description = (
57+
parsed_docstring.long_description or parsed_docstring.short_description
58+
)
59+
for meta in parsed_docstring.meta:
60+
if isinstance(meta, docstring_parser.DocstringParam):
61+
parameter_descriptions[meta.arg_name] = meta.description
62+
63+
defaults = dict(inspect.signature(func).parameters)
64+
fields_dict = {
65+
name: (
66+
# 1. We infer the argument type here: use Any rather than None so
67+
# it will not try to auto-infer the type based on the default value.
68+
(
69+
param.annotation if param.annotation != inspect.Parameter.empty
70+
else Any
71+
),
72+
pydantic.Field(
73+
# 2. We do not support default values for now.
74+
default=(
75+
param.default if param.default != inspect.Parameter.empty
76+
# ! Need to use pydantic.Undefined instead of None
77+
else pydantic.fields.Undefined
78+
),
79+
# 3. We support user-provided descriptions.
80+
description=parameter_descriptions.get(name, None),
81+
)
82+
)
83+
for name, param in defaults.items()
84+
# We do not support *args or **kwargs
85+
if param.kind in (
86+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
87+
inspect.Parameter.KEYWORD_ONLY,
88+
inspect.Parameter.POSITIONAL_ONLY,
89+
)
90+
}
91+
function_schema = pydantic.create_model(func.__name__, **fields_dict).schema()
92+
93+
function_schema["title"] = func.__name__
94+
function_schema["description"] = function_description
95+
# Postprocessing
96+
for name, property_schema in function_schema.get("properties", {}).items():
97+
annotation = defaults[name].annotation
98+
# 5. Nullable fields:
99+
# * https://github.com/pydantic/pydantic/issues/1270
100+
# * https://stackoverflow.com/a/58841311
101+
# * https://github.com/pydantic/pydantic/discussions/4872
102+
if (
103+
typing.get_origin(annotation) is typing.Union
104+
and type(None) in typing.get_args(annotation)
105+
):
106+
# for "typing.Optional" arguments, function_arg might be a
107+
# dictionary like
108+
#
109+
# {'anyOf': [{'type': 'integer'}, {'type': 'null'}]
110+
for schema in property_schema.pop("anyOf", []):
111+
schema_type = schema.get("type")
112+
if schema_type and schema_type != "null":
113+
property_schema["type"] = schema_type
114+
break
115+
property_schema["nullable"] = True
116+
# 6. Annotate required fields.
117+
function_schema["required"] = [
118+
k for k in defaults if (
119+
defaults[k].default == inspect.Parameter.empty
120+
and defaults[k].kind in (
121+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
122+
inspect.Parameter.KEYWORD_ONLY,
123+
inspect.Parameter.POSITIONAL_ONLY,
124+
)
125+
)
126+
]
127+
return function_schema
128+
129+
130+
def adapt_json_schema_to_google_tool_schema(schema: Struct) -> Struct:
131+
"""Adapts JSON schema to Google tool schema."""
132+
fixed_schema = dict(schema)
133+
# `$schema` is one of the basic/most common fields of the real JSON Schema.
134+
# But Google's Schema proto does not support it.
135+
# Common attributes that we remove:
136+
# $schema, additionalProperties
137+
for key in list(fixed_schema):
138+
if not hasattr(aiplatform_types.Schema, key) and not hasattr(
139+
aiplatform_types.Schema, key + "_"
140+
):
141+
fixed_schema.pop(key, None)
142+
property_schemas = fixed_schema.get("properties")
143+
if property_schemas:
144+
for k, v in property_schemas.items():
145+
property_schemas[k] = adapt_json_schema_to_google_tool_schema(v)
146+
return fixed_schema
147+
148+
149+
generate_json_schema_from_function = _generate_json_schema_from_function_using_pydantic

0 commit comments

Comments
 (0)