Skip to content

Commit d7f1f0f

Browse files
sararobcopybara-github
authored andcommitted
feat: add generateContent method
PiperOrigin-RevId: 584390719
1 parent 4c8f6d2 commit d7f1f0f

File tree

6 files changed

+792
-0
lines changed

6 files changed

+792
-0
lines changed

src/index.ts

+264
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
/**
2+
* @license
3+
* Copyright 2023 Google LLC
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* https://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
/* tslint:disable */
19+
import {GoogleAuth} from 'google-auth-library';
20+
21+
import {emptyGenerator, processStream} from './process_stream';
22+
import {Content, GenerateContentParams, GenerateContentRequest, GenerateContentResult, GenerationConfig, Part, SafetySetting} from './types/content';
23+
import {postRequest} from './util';
24+
25+
// TODO: update this when model names are available
26+
// const SUPPORTED_MODELS: Array<string> = ['text-bison@001'];
27+
28+
/**
29+
* Base class for authenticating to Vertex, creates the preview namespace.
30+
*/
31+
export class VertexAI {
32+
public preview: VertexAI_Internal;
33+
34+
constructor(
35+
project: string,
36+
location: string,
37+
apiKey: string, // TODO: remove when we switch to Vertex endpoint
38+
apiEndpoint?: string,
39+
) {
40+
this.preview =
41+
new VertexAI_Internal(project, location, apiKey, apiEndpoint);
42+
}
43+
}
44+
45+
/**
46+
* VertexAI class implementation
47+
*/
48+
export class VertexAI_Internal {
49+
protected googleAuth: GoogleAuth = new GoogleAuth(
50+
{scopes: 'https://www.googleapis.com/auth/cloud-platform'});
51+
private tokenInternal?: string;
52+
53+
/**
54+
* API client for authenticating to Vertex
55+
* @param project The Google Cloud project to use for the request
56+
* @param location The Google Cloud project location to use for the
57+
* request
58+
* @param apiEndpoint The base Vertex AI endpoint to use for the request. If
59+
* not provided, the default regionalized endpoint (i.e.
60+
* us-central1-aiplatform.googleapis.com) will be used.
61+
*/
62+
constructor(
63+
protected readonly project: string,
64+
protected readonly location: string,
65+
protected readonly apiKey:
66+
string, // TODO: remove when we switch to Vertex endpoint
67+
protected readonly apiEndpoint?: string,
68+
) {}
69+
70+
/**
71+
* Gets an authentication token for making Vertex REST API requests.
72+
* @param vertex The VertexAI instance.
73+
*/
74+
// TODO: change the `any` type below to be more specific
75+
protected get token(): Promise<any>|string {
76+
if (this.tokenInternal) {
77+
return this.tokenInternal;
78+
}
79+
// Generate a new token if it hasn't been set
80+
// TODO: add error handling here
81+
const token = Promise.resolve(this.googleAuth.getAccessToken());
82+
return token;
83+
}
84+
85+
/**
86+
* Make a generateContent request.
87+
* @param request A GenerateContentRequest object with the request contents.
88+
* @return The GenerateContentResponse object with the response candidates.
89+
*
90+
* NOTE: this method is stubbed in postRequest for now until the service is
91+
* available.
92+
*/
93+
async generateContent(request: GenerateContentParams):
94+
Promise<GenerateContentResult> {
95+
const publisherModelEndpoint = `publishers/google/models/${request.model}`;
96+
97+
const generateContentRequest: GenerateContentRequest = {
98+
model: request.model,
99+
contents: request.contents,
100+
generation_config: request.generation_config,
101+
safety_settings: request.safety_settings,
102+
}
103+
104+
let response;
105+
try {
106+
response = await postRequest({
107+
region: this.location,
108+
project: this.project,
109+
resourcePath: publisherModelEndpoint,
110+
resourceMethod: request.stream ? 'streamGenerateContent' :
111+
'generateContent',
112+
token: await this.token,
113+
data: generateContentRequest,
114+
apiKey: this.apiKey,
115+
apiEndpoint: this.apiEndpoint,
116+
});
117+
if (response === undefined) {
118+
throw new Error('did not get a valid response.')
119+
}
120+
if (!response.ok) {
121+
throw new Error(`${response.status} ${response.statusText}`)
122+
}
123+
} catch (e) {
124+
console.log(e);
125+
}
126+
127+
const streamResult = processStream(response);
128+
129+
if (request.stream === false && streamResult.stream !== undefined) {
130+
const responses = [];
131+
for await (const resp of streamResult.stream) {
132+
responses.push(resp);
133+
}
134+
return {
135+
stream: emptyGenerator(),
136+
responses,
137+
};
138+
} else {
139+
// True or undefined (default true)
140+
return streamResult;
141+
}
142+
143+
// TODO: handle streaming and non-streaming response here
144+
}
145+
146+
startChat(request: StartChatParams): ChatSession {
147+
const startChatRequest = {
148+
model: request.model,
149+
project: this.project,
150+
location: this.location,
151+
history: request.history,
152+
generation_config: request.generation_config,
153+
safety_settings: request.safety_settings,
154+
_vertex_instance: this,
155+
};
156+
157+
return new ChatSession(startChatRequest);
158+
}
159+
}
160+
161+
/**
162+
* Params to initiate a multiturn chat with the model via startChat
163+
*/
164+
export declare interface StartChatParams {
165+
model: string;
166+
history?: Content[];
167+
safety_settings?: SafetySetting[];
168+
generation_config?: GenerationConfig;
169+
stream?: boolean;
170+
}
171+
172+
// StartChatSessionRequest and ChatSession are defined here instead of in
173+
// src/types to avoid a circular dependency issue due the dep on
174+
// VertexAI_Internal
175+
176+
/**
177+
* All params passed to initiate multiturn chat via startChat
178+
*/
179+
export declare interface StartChatSessionRequest extends StartChatParams {
180+
project: string;
181+
location: string;
182+
_vertex_instance: VertexAI_Internal;
183+
}
184+
185+
/**
186+
* Session for a multiturn chat with the model
187+
*/
188+
export class ChatSession {
189+
// Substitute apiKey for these in Labs
190+
private project: string;
191+
private location: string;
192+
193+
private _history: Content[];
194+
private _vertex_instance: VertexAI_Internal;
195+
196+
197+
model: string;
198+
generation_config?: GenerationConfig;
199+
safety_settings?: SafetySetting[];
200+
201+
get history(): Content[] {
202+
return this._history;
203+
}
204+
205+
constructor(request: StartChatSessionRequest) {
206+
this.project = request.project;
207+
this.location = request.location;
208+
this.model = request.model;
209+
this._history = request.history ?? [];
210+
this._vertex_instance = request._vertex_instance;
211+
}
212+
213+
async sendMessage(request: string|
214+
Array<string|Part>): Promise<GenerateContentResult> {
215+
// TODO: this is stubbed until the service is available
216+
let generateContentrequest: GenerateContentParams = {
217+
model: this.model,
218+
contents: [],
219+
safety_settings: this.safety_settings,
220+
generation_config: this.generation_config,
221+
};
222+
223+
let currentContent = [];
224+
225+
if (typeof request === 'string') {
226+
currentContent = [{role: 'user', parts: [{text: request}]}];
227+
} else if (Array.isArray(request)) {
228+
for (const item of request) {
229+
if (typeof item === 'string') {
230+
currentContent.push({role: 'user', parts: [{text: item}]});
231+
} else {
232+
currentContent.push({role: 'user', parts: [item]});
233+
}
234+
}
235+
};
236+
237+
generateContentrequest.contents = currentContent;
238+
const generateContentResponse =
239+
await this._vertex_instance.generateContent(generateContentrequest);
240+
// TODO: add error handling
241+
242+
// First add the messages sent by the user
243+
for (const content of currentContent) {
244+
this._history.push(content);
245+
};
246+
247+
for (const result of generateContentResponse.responses) {
248+
for (const candidate of result.candidates) {
249+
this._history.push(candidate.content);
250+
}
251+
}
252+
return generateContentResponse;
253+
}
254+
}
255+
256+
export {
257+
Content,
258+
GenerateContentParams,
259+
GenerateContentRequest,
260+
GenerateContentResult,
261+
GenerationConfig,
262+
Part,
263+
SafetySetting
264+
};

0 commit comments

Comments
 (0)