Skip to content

Commit cd233ef

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: GenAI - Vertex GenAI SDK
Added the Vertex GenAI SDK. Features: * Text generation (from text and images) * Chat (stateful chat session) * Function calling * Tuning * Token counting Usage: Imports: ``` from vertexai.preview.generative_models import GenerativeModel, Image, Content, Part, Tool, FunctionDeclaration, GenerationConfig, HarmCategory, HarmBlockThreshold ``` Basic generation: ``` from vertexai.preview.generative_models import GenerativeModel model = GenerativeModel("gemini-pro") model.generate_content("Why is sky blue?") ``` Using image from local file ``` image = Image.load_from_file("image.jpg") vision_model = GenerativeModel("gemini-pro-vision") vision_model.generate_content(image) ``` Using image from GCS ``` vision_model = GenerativeModel("gemini-pro-vision") vision_model.generate_content(generative_models.Part.from_uri("gs://download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg", mime_type="image/jpeg")) ``` Using text and image: ``` vision_model.generate_content(["What is shown in this image?", image]) ``` Using video ``` vision_model.generate_content([ "What is in the video? ", Part.from_uri("gs://cloud-samples-data/video/animals.mp4", mime_type="video/mp4"), ]) ``` Chat: ``` vision_model = GenerativeModel("gemini-ultra-vision") vision_chat = vision_model.start_chat() print(vision_chat.send_message(["I like this image.", image])) print(vision_chat.send_message("What things do I like?.")) ``` PiperOrigin-RevId: 589962731
1 parent 4158e53 commit cd233ef

File tree

4 files changed

+2238
-0
lines changed

4 files changed

+2238
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2023 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
# pylint: disable=protected-access,bad-continuation
19+
import pytest
20+
from typing import Iterable, MutableSequence, Optional
21+
from unittest import mock
22+
23+
import vertexai
24+
from google.cloud.aiplatform import initializer
25+
from vertexai.preview import generative_models
26+
from vertexai.generative_models._generative_models import (
27+
prediction_service,
28+
gapic_prediction_service_types,
29+
gapic_content_types,
30+
)
31+
32+
_TEST_PROJECT = "test-project"
33+
_TEST_LOCATION = "us-central1"
34+
35+
36+
_RESPONSE_TEXT_PART_STRUCT = {
37+
"text": "The sky appears blue due to a phenomenon called Rayleigh scattering."
38+
}
39+
40+
_RESPONSE_FUNCTION_CALL_PART_STRUCT = {
41+
"function_call": {
42+
"name": "get_current_weather",
43+
"args": {
44+
"fields": {
45+
"key": "location",
46+
"value": {"string_value": "Boston"},
47+
}
48+
},
49+
}
50+
}
51+
52+
_RESPONSE_AFTER_FUNCTION_CALL_PART_STRUCT = {
53+
"text": "The weather in Boston is super nice!"
54+
}
55+
56+
_RESPONSE_SAFETY_RATINGS_STRUCT = [
57+
{"category": "HARM_CATEGORY_HARASSMENT", "probability": "NEGLIGIBLE"},
58+
{"category": "HARM_CATEGORY_HATE_SPEECH", "probability": "NEGLIGIBLE"},
59+
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "probability": "NEGLIGIBLE"},
60+
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "probability": "NEGLIGIBLE"},
61+
]
62+
63+
_RESPONSE_CITATION_STRUCT = {
64+
"start_index": 528,
65+
"end_index": 656,
66+
"uri": "https://www.quora.com/What-makes-the-sky-blue-during-the-day",
67+
}
68+
69+
70+
_REQUEST_TOOL_STRUCT = {
71+
"function_declarations": [
72+
{
73+
"name": "get_current_weather",
74+
"description": "Get the current weather in a given location",
75+
"parameters": {
76+
"type": "object",
77+
"properties": {
78+
"location": {
79+
"type": "string",
80+
"description": "The city and state, e.g. San Francisco, CA",
81+
},
82+
"unit": {
83+
"type": "string",
84+
"enum": [
85+
"celsius",
86+
"fahrenheit",
87+
],
88+
},
89+
},
90+
"required": ["location"],
91+
},
92+
}
93+
]
94+
}
95+
96+
_REQUEST_FUNCTION_PARAMETER_SCHEMA_STRUCT = {
97+
"type": "object",
98+
"properties": {
99+
"location": {
100+
"type": "string",
101+
"description": "The city and state, e.g. San Francisco, CA",
102+
},
103+
"unit": {
104+
"type": "string",
105+
"enum": [
106+
"celsius",
107+
"fahrenheit",
108+
],
109+
},
110+
},
111+
"required": ["location"],
112+
}
113+
114+
115+
def mock_stream_generate_content(
116+
self,
117+
request: gapic_prediction_service_types.GenerateContentRequest,
118+
*,
119+
model: Optional[str] = None,
120+
contents: Optional[MutableSequence[gapic_content_types.Content]] = None,
121+
) -> Iterable[gapic_prediction_service_types.GenerateContentResponse]:
122+
is_continued_chat = len(request.contents) > 1
123+
has_tools = bool(request.tools)
124+
125+
if has_tools:
126+
has_function_response = any(
127+
"function_response" in content.parts[0] for content in request.contents
128+
)
129+
needs_function_call = not has_function_response
130+
if needs_function_call:
131+
response_part_struct = _RESPONSE_FUNCTION_CALL_PART_STRUCT
132+
else:
133+
response_part_struct = _RESPONSE_AFTER_FUNCTION_CALL_PART_STRUCT
134+
elif is_continued_chat:
135+
response_part_struct = {"text": "Other planets may have different sky color."}
136+
else:
137+
response_part_struct = _RESPONSE_TEXT_PART_STRUCT
138+
139+
response = gapic_prediction_service_types.GenerateContentResponse(
140+
candidates=[
141+
gapic_content_types.Candidate(
142+
index=0,
143+
content=gapic_content_types.Content(
144+
# Model currently does not identify itself
145+
# role="model",
146+
parts=[
147+
gapic_content_types.Part(response_part_struct),
148+
],
149+
),
150+
finish_reason=gapic_content_types.Candidate.FinishReason.STOP,
151+
safety_ratings=[
152+
gapic_content_types.SafetyRating(rating)
153+
for rating in _RESPONSE_SAFETY_RATINGS_STRUCT
154+
],
155+
citation_metadata=gapic_content_types.CitationMetadata(
156+
citations=[
157+
gapic_content_types.Citation(_RESPONSE_CITATION_STRUCT),
158+
]
159+
),
160+
),
161+
],
162+
)
163+
yield response
164+
165+
166+
@pytest.mark.usefixtures("google_auth_mock")
167+
class TestGenerativeModels:
168+
"""Unit tests for the generative models."""
169+
170+
def setup_method(self):
171+
vertexai.init(
172+
project=_TEST_PROJECT,
173+
location=_TEST_LOCATION,
174+
)
175+
176+
def teardown_method(self):
177+
initializer.global_pool.shutdown(wait=True)
178+
179+
@mock.patch.object(
180+
target=prediction_service.PredictionServiceClient,
181+
attribute="stream_generate_content",
182+
new=mock_stream_generate_content,
183+
)
184+
def test_generate_content(self):
185+
model = generative_models.GenerativeModel("gemini-pro")
186+
response = model.generate_content("Why is sky blue?")
187+
assert response.text
188+
189+
response2 = model.generate_content(
190+
"Why is sky blue?",
191+
generation_config=generative_models.GenerationConfig(
192+
temperature=0.2,
193+
top_p=0.9,
194+
top_k=20,
195+
candidate_count=1,
196+
max_output_tokens=200,
197+
stop_sequences=["\n\n\n"],
198+
),
199+
)
200+
assert response2.text
201+
202+
@mock.patch.object(
203+
target=prediction_service.PredictionServiceClient,
204+
attribute="stream_generate_content",
205+
new=mock_stream_generate_content,
206+
)
207+
def test_generate_content_streaming(self):
208+
model = generative_models.GenerativeModel("gemini-pro")
209+
stream = model.generate_content("Why is sky blue?", stream=True)
210+
for chunk in stream:
211+
assert chunk.text
212+
213+
@mock.patch.object(
214+
target=prediction_service.PredictionServiceClient,
215+
attribute="stream_generate_content",
216+
new=mock_stream_generate_content,
217+
)
218+
def test_chat_send_message(self):
219+
model = generative_models.GenerativeModel("gemini-pro")
220+
chat = model.start_chat()
221+
response1 = chat.send_message("Why is sky blue?")
222+
assert response1.text
223+
response2 = chat.send_message("Is sky blue on other planets?")
224+
assert response2.text
225+
226+
@mock.patch.object(
227+
target=prediction_service.PredictionServiceClient,
228+
attribute="stream_generate_content",
229+
new=mock_stream_generate_content,
230+
)
231+
def test_chat_function_calling(self):
232+
get_current_weather_func = generative_models.FunctionDeclaration(
233+
name="get_current_weather",
234+
description="Get the current weather in a given location",
235+
parameters=_REQUEST_FUNCTION_PARAMETER_SCHEMA_STRUCT,
236+
)
237+
weather_tool = generative_models.Tool(
238+
function_declarations=[get_current_weather_func],
239+
)
240+
241+
model = generative_models.GenerativeModel(
242+
"gemini-pro",
243+
# Specifying the tools once to avoid specifying them in every request
244+
tools=[weather_tool],
245+
)
246+
chat = model.start_chat()
247+
248+
response1 = chat.send_message("What is the weather like in Boston?")
249+
assert (
250+
response1.candidates[0].content.parts[0].function_call.name
251+
== "get_current_weather"
252+
)
253+
response2 = chat.send_message(
254+
generative_models.Part.from_function_response(
255+
name="get_current_weather",
256+
response={
257+
"content": {"weather_there": "super nice"},
258+
},
259+
),
260+
)
261+
assert response2.text == "The weather in Boston is super nice!"

vertexai/generative_models/README.md

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Vertex GenAI Python SDK
2+
3+
> [!IMPORTANT]
4+
> Thanks for your interest in the Vertex AI SDKs! **You can start using this SDK and its samples on December 13, 2023.** Until then, check out our [blog post](https://blog.google/technology/ai/google-gemini-ai/) to learn more about Google's Gemini multimodal model.
5+
6+
Vertex GenAI Python SDK enables developers to use Google's state-of-the-art generative AI models (like Gemini) to build AI-powered features and applications.
7+
8+
*More details and information coming soon!*
9+
10+
## License
11+
12+
The contents of this repository are licensed under the
13+
[Apache License, version 2.0](http://www.apache.org/licenses/LICENSE-2.0).

0 commit comments

Comments
 (0)