Skip to content

Commit 5ffff81

Browse files
committed
Add Vertex embeddings to integration package
1 parent 96bba4b commit 5ffff81

File tree

15 files changed

+377
-7
lines changed

15 files changed

+377
-7
lines changed

libs/langchain-community/src/embeddings/googlevertexai.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import {
1010
import { GoogleVertexAILLMConnection } from "../utils/googlevertexai-connection.js";
1111

1212
/**
13+
* @deprecated Import and use from @langchain/google-vertexai or @langchain/google-vertexai-web
1314
* Defines the parameters required to initialize a
1415
* GoogleVertexAIEmbeddings instance. It extends EmbeddingsParams and
1516
* GoogleVertexAIConnectionParams.
@@ -19,12 +20,14 @@ export interface GoogleVertexAIEmbeddingsParams
1920
GoogleVertexAIBaseLLMInput<GoogleAuthOptions> {}
2021

2122
/**
23+
* @deprecated Import and use from @langchain/google-vertexai or @langchain/google-vertexai-web
2224
* Defines additional options specific to the
2325
* GoogleVertexAILLMEmbeddingsInstance. It extends AsyncCallerCallOptions.
2426
*/
2527
interface GoogleVertexAILLMEmbeddingsOptions extends AsyncCallerCallOptions {}
2628

2729
/**
30+
* @deprecated Import and use from @langchain/google-vertexai or @langchain/google-vertexai-web
2831
* Represents an instance for generating embeddings using the Google
2932
* Vertex AI API. It contains the content to be embedded.
3033
*/
@@ -33,6 +36,7 @@ interface GoogleVertexAILLMEmbeddingsInstance {
3336
}
3437

3538
/**
39+
* @deprecated Import and use from @langchain/google-vertexai or @langchain/google-vertexai-web
3640
* Defines the structure of the embeddings results returned by the Google
3741
* Vertex AI API. It extends GoogleVertexAIBasePrediction and contains the
3842
* embeddings and their statistics.
@@ -48,6 +52,7 @@ interface GoogleVertexEmbeddingsResults extends GoogleVertexAIBasePrediction {
4852
}
4953

5054
/**
55+
* @deprecated Import and use from @langchain/google-vertexai or @langchain/google-vertexai-web
5156
* Enables calls to the Google Cloud's Vertex AI API to access
5257
* the embeddings generated by Large Language Models.
5358
*

libs/langchain-community/src/utils/googlevertexai-connection.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,9 @@ export class GoogleVertexAILLMConnection<
212212
}
213213

214214
const projectId = await this.client.getProjectId();
215-
215+
console.log(
216+
`https://${this.endpoint}/v1/projects/${projectId}/locations/${this.location}/publishers/google/models/${this.model}:${method}`
217+
);
216218
return `https://${this.endpoint}/v1/projects/${projectId}/locations/${this.location}/publishers/google/models/${this.model}:${method}`;
217219
}
218220

libs/langchain-google-common/src/connection.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,8 @@ export abstract class GoogleHostConnection<
166166
}
167167

168168
export abstract class GoogleAIConnection<
169-
CallOptions extends BaseLanguageModelCallOptions,
170-
MessageType,
169+
CallOptions extends AsyncCallerCallOptions,
170+
InputType,
171171
AuthOptions
172172
>
173173
extends GoogleHostConnection<CallOptions, GoogleLLMResponse, AuthOptions>
@@ -232,12 +232,12 @@ export abstract class GoogleAIConnection<
232232
}
233233

234234
abstract formatData(
235-
input: MessageType,
235+
input: InputType,
236236
parameters: GoogleAIModelRequestParams
237237
): unknown;
238238

239239
async request(
240-
input: MessageType,
240+
input: InputType,
241241
parameters: GoogleAIModelRequestParams,
242242
options: CallOptions
243243
): Promise<GoogleLLMResponse> {
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
import { Embeddings, EmbeddingsParams } from "@langchain/core/embeddings";
2+
import {
3+
AsyncCaller,
4+
AsyncCallerCallOptions,
5+
} from "@langchain/core/utils/async_caller";
6+
import { chunkArray } from "@langchain/core/utils/chunk_array";
7+
import { GoogleAIConnection } from "./connection.js";
8+
import { ApiKeyGoogleAuth, GoogleAbstractedClient } from "./auth.js";
9+
import { GoogleAIModelRequestParams, GoogleConnectionParams } from "./types.js";
10+
import { getEnvironmentVariable } from "@langchain/core/utils/env";
11+
12+
class EmbeddingsConnection<
13+
CallOptions extends AsyncCallerCallOptions,
14+
AuthOptions
15+
> extends GoogleAIConnection<
16+
CallOptions,
17+
GoogleEmbeddingsInstance[],
18+
AuthOptions
19+
> {
20+
convertSystemMessageToHumanContent: boolean | undefined;
21+
22+
constructor(
23+
fields: GoogleConnectionParams<AuthOptions> | undefined,
24+
caller: AsyncCaller,
25+
client: GoogleAbstractedClient,
26+
streaming: boolean
27+
) {
28+
super(fields, caller, client, streaming);
29+
}
30+
31+
async buildUrlMethod(): Promise<string> {
32+
return "predict";
33+
}
34+
35+
formatData(
36+
input: GoogleEmbeddingsInstance[],
37+
parameters: GoogleAIModelRequestParams
38+
): unknown {
39+
return {
40+
instances: input,
41+
parameters,
42+
};
43+
}
44+
}
45+
46+
/**
47+
* Defines the parameters required to initialize a
48+
* GoogleEmbeddings instance. It extends EmbeddingsParams and
49+
* GoogleConnectionParams.
50+
*/
51+
export interface BaseGoogleEmbeddingsParams<AuthOptions>
52+
extends EmbeddingsParams,
53+
GoogleConnectionParams<AuthOptions> {
54+
model: string;
55+
}
56+
57+
/**
58+
* Defines additional options specific to the
59+
* GoogleEmbeddingsInstance. It extends AsyncCallerCallOptions.
60+
*/
61+
export interface BaseGoogleEmbeddingsOptions extends AsyncCallerCallOptions {}
62+
63+
/**
64+
* Represents an instance for generating embeddings using the Google
65+
* Vertex AI API. It contains the content to be embedded.
66+
*/
67+
export interface GoogleEmbeddingsInstance {
68+
content: string;
69+
}
70+
71+
/**
72+
* Defines the structure of the embeddings results returned by the Google
73+
* Vertex AI API. It extends GoogleBasePrediction and contains the
74+
* embeddings and their statistics.
75+
*/
76+
export interface BaseGoogleEmbeddingsResults {
77+
embeddings: {
78+
statistics: {
79+
token_count: number;
80+
truncated: boolean;
81+
};
82+
values: number[];
83+
};
84+
}
85+
86+
/**
87+
* Enables calls to the Google Cloud's Vertex AI API to access
88+
* the embeddings generated by Large Language Models.
89+
*
90+
* To use, you will need to have one of the following authentication
91+
* methods in place:
92+
* - You are logged into an account permitted to the Google Cloud project
93+
* using Vertex AI.
94+
* - You are running this on a machine using a service account permitted to
95+
* the Google Cloud project using Vertex AI.
96+
* - The `GOOGLE_APPLICATION_CREDENTIALS` environment variable is set to the
97+
* path of a credentials file for a service account permitted to the
98+
* Google Cloud project using Vertex AI.
99+
* @example
100+
* ```typescript
101+
* const model = new GoogleEmbeddings();
102+
* const res = await model.embedQuery(
103+
* "What would be a good company name for a company that makes colorful socks?"
104+
* );
105+
* console.log({ res });
106+
* ```
107+
*/
108+
export abstract class BaseGoogleEmbeddings<AuthOptions>
109+
extends Embeddings
110+
implements BaseGoogleEmbeddingsParams<AuthOptions>
111+
{
112+
model: string;
113+
114+
private connection: GoogleAIConnection<
115+
BaseGoogleEmbeddingsOptions,
116+
GoogleEmbeddingsInstance[],
117+
GoogleConnectionParams<AuthOptions>
118+
>;
119+
120+
constructor(fields: BaseGoogleEmbeddingsParams<AuthOptions>) {
121+
super(fields);
122+
123+
this.model = fields.model;
124+
this.connection = new EmbeddingsConnection(
125+
{ ...fields, ...this },
126+
this.caller,
127+
this.buildClient(fields),
128+
false
129+
);
130+
}
131+
132+
abstract buildAbstractedClient(
133+
fields?: GoogleConnectionParams<AuthOptions>
134+
): GoogleAbstractedClient;
135+
136+
buildApiKeyClient(apiKey: string): GoogleAbstractedClient {
137+
return new ApiKeyGoogleAuth(apiKey);
138+
}
139+
140+
buildApiKey(
141+
fields?: GoogleConnectionParams<AuthOptions>
142+
): string | undefined {
143+
return fields?.apiKey ?? getEnvironmentVariable("GOOGLE_API_KEY");
144+
}
145+
146+
buildClient(
147+
fields?: GoogleConnectionParams<AuthOptions>
148+
): GoogleAbstractedClient {
149+
const apiKey = this.buildApiKey(fields);
150+
if (apiKey) {
151+
return this.buildApiKeyClient(apiKey);
152+
} else {
153+
return this.buildAbstractedClient(fields);
154+
}
155+
}
156+
157+
/**
158+
* Takes an array of documents as input and returns a promise that
159+
* resolves to a 2D array of embeddings for each document. It splits the
160+
* documents into chunks and makes requests to the Google Vertex AI API to
161+
* generate embeddings.
162+
* @param documents An array of documents to be embedded.
163+
* @returns A promise that resolves to a 2D array of embeddings for each document.
164+
*/
165+
async embedDocuments(documents: string[]): Promise<number[][]> {
166+
const instanceChunks: GoogleEmbeddingsInstance[][] = chunkArray(
167+
documents.map((document) => ({
168+
content: document,
169+
})),
170+
5
171+
); // Vertex AI accepts max 5 instances per prediction
172+
const parameters = {};
173+
const options = {};
174+
const responses = await Promise.all(
175+
instanceChunks.map((instances) =>
176+
this.connection.request(instances, parameters, options)
177+
)
178+
);
179+
const result: number[][] =
180+
responses
181+
?.map(
182+
(response) =>
183+
(response?.data as any)?.predictions?.map(
184+
(result: any) => result.embeddings.values
185+
) ?? []
186+
)
187+
.flat() ?? [];
188+
return result;
189+
}
190+
191+
/**
192+
* Takes a document as input and returns a promise that resolves to an
193+
* embedding for the document. It calls the embedDocuments method with the
194+
* document as the input.
195+
* @param document A document to be embedded.
196+
* @returns A promise that resolves to an embedding for the document.
197+
*/
198+
async embedQuery(document: string): Promise<number[]> {
199+
const data = await this.embedDocuments([document]);
200+
return data[0];
201+
}
202+
}

libs/langchain-google-common/src/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
export * from "./chat_models.js";
22
export * from "./llms.js";
3+
export * from "./embeddings.js";
34

45
export * from "./auth.js";
56
export * from "./connection.js";

libs/langchain-google-gauth/src/auth.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import {
33
ensureAuthOptionScopes,
44
GoogleAbstractedClient,
55
GoogleAbstractedClientOps,
6-
GoogleBaseLLMInput,
6+
GoogleConnectionParams,
77
JsonStream,
88
} from "@langchain/google-common";
99
import { GoogleAuth, GoogleAuthOptions } from "google-auth-library";
@@ -27,7 +27,7 @@ export class NodeJsonStream extends JsonStream {
2727
export class GAuthClient implements GoogleAbstractedClient {
2828
gauth: GoogleAuth;
2929

30-
constructor(fields?: GoogleBaseLLMInput<GoogleAuthOptions>) {
30+
constructor(fields?: GoogleConnectionParams<GoogleAuthOptions>) {
3131
const options = ensureAuthOptionScopes<GoogleAuthOptions>(
3232
fields?.authOptions,
3333
"scopes",
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import {
2+
GoogleAbstractedClient,
3+
GoogleConnectionParams,
4+
BaseGoogleEmbeddings,
5+
BaseGoogleEmbeddingsParams,
6+
} from "@langchain/google-common";
7+
import { GoogleAuthOptions } from "google-auth-library";
8+
import { GAuthClient } from "./auth.js";
9+
10+
/**
11+
* Input to LLM class.
12+
*/
13+
export interface GoogleEmbeddingsInput
14+
extends BaseGoogleEmbeddingsParams<GoogleAuthOptions> {}
15+
16+
/**
17+
* Integration with an LLM.
18+
*/
19+
export class GoogleEmbeddings
20+
extends BaseGoogleEmbeddings<GoogleAuthOptions>
21+
implements GoogleEmbeddingsInput
22+
{
23+
// Used for tracing, replace with the same name as your class
24+
static lc_name() {
25+
return "GoogleEmbeddings";
26+
}
27+
28+
lc_serializable = true;
29+
30+
constructor(fields: GoogleEmbeddingsInput) {
31+
super(fields);
32+
}
33+
34+
buildAbstractedClient(
35+
fields?: GoogleConnectionParams<GoogleAuthOptions>
36+
): GoogleAbstractedClient {
37+
return new GAuthClient(fields);
38+
}
39+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
export * from "./chat_models.js";
22
export * from "./llms.js";
3+
export * from "./embeddings.js";
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import {
2+
type GoogleEmbeddingsInput,
3+
GoogleEmbeddings,
4+
} from "@langchain/google-webauth";
5+
6+
/**
7+
* Input to chat model class.
8+
*/
9+
export interface GoogleVertexAIEmbeddingsInput extends GoogleEmbeddingsInput {}
10+
11+
/**
12+
* Integration with a chat model.
13+
*/
14+
export class GoogleVertexAIEmbeddings extends GoogleEmbeddings {
15+
static lc_name() {
16+
return "GoogleVertexAIEmbeddings";
17+
}
18+
19+
constructor(fields: GoogleVertexAIEmbeddingsInput) {
20+
super({
21+
...fields,
22+
platformType: "gcp",
23+
});
24+
}
25+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
export * from "./chat_models.js";
22
export * from "./llms.js";
3+
export * from "./embeddings.js";

0 commit comments

Comments
 (0)