Skip to content

Commit a2dfb46

Browse files
zandkojacoblee93
andauthored
community[minor]: feat: BaiduQianfan embeddings (#4926)
* feat: BaiduQianfan embeddings * docs: Update instructions for configuring BAIDU API and Secret keys as env variables * refactor: rename BaiduQianFanEmbeddings to BaiduQianfanEmbeddings for naming consistency * Add entrypoint --------- Co-authored-by: jacoblee93 <[email protected]>
1 parent d8fc97d commit a2dfb46

File tree

7 files changed

+329
-0
lines changed

7 files changed

+329
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
---
2+
sidebar_class_name: node-only
3+
---
4+
5+
# Baidu Qianfan
6+
7+
The `BaiduQianfanEmbeddings` class uses the Baidu Qianfan API to generate embeddings for a given text.
8+
9+
## Setup
10+
11+
Official Website: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/alj562vvu
12+
13+
An API key is required to use this embedding model. You can get one by registering at https://cloud.baidu.com/doc/WENXINWORKSHOP/s/alj562vvu.
14+
15+
Please set the acquired API key as an environment variable named BAIDU_API_KEY, and set your secret key as an environment variable named BAIDU_SECRET_KEY.
16+
17+
Then, you'll need to install the [`@langchain/community`](https://www.npmjs.com/package/@langchain/community) package:
18+
19+
import IntegrationInstallTooltip from "@mdx_components/integration_install_tooltip.mdx";
20+
21+
<IntegrationInstallTooltip></IntegrationInstallTooltip>
22+
23+
```bash npm2yarn
24+
npm install @langchain/community
25+
```
26+
27+
## Usage
28+
29+
import CodeBlock from "@theme/CodeBlock";
30+
import BaiduQianFanExample from "@examples/embeddings/baidu_qianfan.ts";
31+
32+
<CodeBlock language="typescript">{BaiduQianFanExample}</CodeBlock>
+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import { BaiduQianfanEmbeddings } from "@langchain/community/embeddings/baidu_qianfan";
2+
3+
const embeddings = new BaiduQianfanEmbeddings();
4+
const res = await embeddings.embedQuery(
5+
"What would be a good company name a company that makes colorful socks?"
6+
);
7+
console.log({ res });

libs/langchain-community/.gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,10 @@ embeddings/alibaba_tongyi.cjs
118118
embeddings/alibaba_tongyi.js
119119
embeddings/alibaba_tongyi.d.ts
120120
embeddings/alibaba_tongyi.d.cts
121+
embeddings/baidu_qianfan.cjs
122+
embeddings/baidu_qianfan.js
123+
embeddings/baidu_qianfan.d.ts
124+
embeddings/baidu_qianfan.d.cts
121125
embeddings/bedrock.cjs
122126
embeddings/bedrock.js
123127
embeddings/bedrock.d.ts

libs/langchain-community/langchain.config.js

+1
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ export const config = {
5959
"agents/toolkits/connery": "agents/toolkits/connery/index",
6060
// embeddings
6161
"embeddings/alibaba_tongyi": "embeddings/alibaba_tongyi",
62+
"embeddings/baidu_qianfan": "embeddings/baidu_qianfan",
6263
"embeddings/bedrock": "embeddings/bedrock",
6364
"embeddings/cloudflare_workersai": "embeddings/cloudflare_workersai",
6465
"embeddings/cohere": "embeddings/cohere",

libs/langchain-community/package.json

+13
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,15 @@
822822
"import": "./embeddings/alibaba_tongyi.js",
823823
"require": "./embeddings/alibaba_tongyi.cjs"
824824
},
825+
"./embeddings/baidu_qianfan": {
826+
"types": {
827+
"import": "./embeddings/baidu_qianfan.d.ts",
828+
"require": "./embeddings/baidu_qianfan.d.cts",
829+
"default": "./embeddings/baidu_qianfan.d.ts"
830+
},
831+
"import": "./embeddings/baidu_qianfan.js",
832+
"require": "./embeddings/baidu_qianfan.cjs"
833+
},
825834
"./embeddings/bedrock": {
826835
"types": {
827836
"import": "./embeddings/bedrock.d.ts",
@@ -2359,6 +2368,10 @@
23592368
"embeddings/alibaba_tongyi.js",
23602369
"embeddings/alibaba_tongyi.d.ts",
23612370
"embeddings/alibaba_tongyi.d.cts",
2371+
"embeddings/baidu_qianfan.cjs",
2372+
"embeddings/baidu_qianfan.js",
2373+
"embeddings/baidu_qianfan.d.ts",
2374+
"embeddings/baidu_qianfan.d.cts",
23622375
"embeddings/bedrock.cjs",
23632376
"embeddings/bedrock.js",
23642377
"embeddings/bedrock.d.ts",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
import { Embeddings, type EmbeddingsParams } from "@langchain/core/embeddings";
2+
import { chunkArray } from "@langchain/core/utils/chunk_array";
3+
import { getEnvironmentVariable } from "@langchain/core/utils/env";
4+
5+
export interface BaiduQianfanEmbeddingsParams extends EmbeddingsParams {
6+
/** Model name to use */
7+
modelName: "embedding-v1" | "bge_large_zh" | "bge-large-en" | "tao-8k";
8+
9+
/**
10+
* Timeout to use when making requests to BaiduQianfan.
11+
*/
12+
timeout?: number;
13+
14+
/**
15+
* The maximum number of characters allowed for embedding in a single request varies by model:
16+
* - Embedding-V1 model: up to 1000 characters
17+
* - bge-large-zh model: up to 2000 characters
18+
* - bge-large-en model: up to 2000 characters
19+
* - tao-8k model: up to 28000 characters
20+
*
21+
* Note: These limits are model-specific and should be adhered to for optimal performance.
22+
*/
23+
batchSize?: number;
24+
25+
/**
26+
* Whether to strip new lines from the input text.
27+
*/
28+
stripNewLines?: boolean;
29+
}
30+
31+
interface EmbeddingCreateParams {
32+
input: string[];
33+
}
34+
35+
interface EmbeddingResponse {
36+
data: { object: "embedding"; index: number; embedding: number[] }[];
37+
38+
usage: {
39+
prompt_tokens: number;
40+
total_tokens: number;
41+
};
42+
43+
id: string;
44+
}
45+
46+
interface EmbeddingErrorResponse {
47+
error_code: number | string;
48+
error_msg: string;
49+
}
50+
51+
export class BaiduQianfanEmbeddings
52+
extends Embeddings
53+
implements BaiduQianfanEmbeddingsParams
54+
{
55+
modelName: BaiduQianfanEmbeddingsParams["modelName"] = "embedding-v1";
56+
57+
batchSize = 16;
58+
59+
stripNewLines = true;
60+
61+
baiduApiKey: string;
62+
63+
baiduSecretKey: string;
64+
65+
accessToken: string;
66+
67+
constructor(
68+
fields?: Partial<BaiduQianfanEmbeddingsParams> & {
69+
verbose?: boolean;
70+
baiduApiKey?: string;
71+
baiduSecretKey?: string;
72+
}
73+
) {
74+
const fieldsWithDefaults = { maxConcurrency: 2, ...fields };
75+
super(fieldsWithDefaults);
76+
77+
const baiduApiKey =
78+
fieldsWithDefaults?.baiduApiKey ??
79+
getEnvironmentVariable("BAIDU_API_KEY");
80+
81+
const baiduSecretKey =
82+
fieldsWithDefaults?.baiduSecretKey ??
83+
getEnvironmentVariable("BAIDU_SECRET_KEY");
84+
85+
if (!baiduApiKey) {
86+
throw new Error("Baidu API key not found");
87+
}
88+
89+
if (!baiduSecretKey) {
90+
throw new Error("Baidu Secret key not found");
91+
}
92+
93+
this.baiduApiKey = baiduApiKey;
94+
this.baiduSecretKey = baiduSecretKey;
95+
96+
this.modelName = fieldsWithDefaults?.modelName ?? this.modelName;
97+
98+
if (this.modelName === "tao-8k") {
99+
if (fieldsWithDefaults?.batchSize && fieldsWithDefaults.batchSize !== 1) {
100+
throw new Error(
101+
"tao-8k model supports only a batchSize of 1. Please adjust your batchSize accordingly"
102+
);
103+
}
104+
this.batchSize = 1;
105+
} else {
106+
this.batchSize = fieldsWithDefaults?.batchSize ?? this.batchSize;
107+
}
108+
109+
this.stripNewLines =
110+
fieldsWithDefaults?.stripNewLines ?? this.stripNewLines;
111+
}
112+
113+
/**
114+
* Method to generate embeddings for an array of documents. Splits the
115+
* documents into batches and makes requests to the BaiduQianFan API to generate
116+
* embeddings.
117+
* @param texts Array of documents to generate embeddings for.
118+
* @returns Promise that resolves to a 2D array of embeddings for each document.
119+
*/
120+
async embedDocuments(texts: string[]): Promise<number[][]> {
121+
const batches = chunkArray(
122+
this.stripNewLines ? texts.map((t) => t.replace(/\n/g, " ")) : texts,
123+
this.batchSize
124+
);
125+
126+
const batchRequests = batches.map((batch) => {
127+
const params = this.getParams(batch);
128+
129+
return this.embeddingWithRetry(params);
130+
});
131+
132+
const batchResponses = await Promise.all(batchRequests);
133+
134+
const embeddings: number[][] = [];
135+
136+
for (let i = 0; i < batchResponses.length; i += 1) {
137+
const batch = batches[i];
138+
const batchResponse = batchResponses[i] || [];
139+
for (let j = 0; j < batch.length; j += 1) {
140+
embeddings.push(batchResponse[j]);
141+
}
142+
}
143+
144+
return embeddings;
145+
}
146+
147+
/**
148+
* Method to generate an embedding for a single document. Calls the
149+
* embeddingWithRetry method with the document as the input.
150+
* @param text Document to generate an embedding for.
151+
* @returns Promise that resolves to an embedding for the document.
152+
*/
153+
async embedQuery(text: string): Promise<number[]> {
154+
const params = this.getParams([
155+
this.stripNewLines ? text.replace(/\n/g, " ") : text,
156+
]);
157+
158+
const embeddings = (await this.embeddingWithRetry(params)) || [[]];
159+
return embeddings[0];
160+
}
161+
162+
/**
163+
* Method to generate an embedding params.
164+
* @param texts Array of documents to generate embeddings for.
165+
* @returns an embedding params.
166+
*/
167+
private getParams(
168+
texts: EmbeddingCreateParams["input"]
169+
): EmbeddingCreateParams {
170+
return {
171+
input: texts,
172+
};
173+
}
174+
175+
/**
176+
* Private method to make a request to the BaiduAI API to generate
177+
* embeddings. Handles the retry logic and returns the response from the
178+
* API.
179+
* @param request Request to send to the BaiduAI API.
180+
* @returns Promise that resolves to the response from the API.
181+
*/
182+
private async embeddingWithRetry(body: EmbeddingCreateParams) {
183+
if (!this.accessToken) {
184+
this.accessToken = await this.getAccessToken();
185+
}
186+
187+
return fetch(
188+
`https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/${this.modelName}?access_token=${this.accessToken}`,
189+
{
190+
method: "POST",
191+
headers: {
192+
"Content-Type": "application/json",
193+
},
194+
body: JSON.stringify(body),
195+
}
196+
).then(async (response) => {
197+
const embeddingData: EmbeddingResponse | EmbeddingErrorResponse =
198+
await response.json();
199+
200+
if ("error_code" in embeddingData && embeddingData.error_code) {
201+
throw new Error(
202+
`${embeddingData.error_code}: ${embeddingData.error_msg}`
203+
);
204+
}
205+
206+
return (embeddingData as EmbeddingResponse).data.map(
207+
({ embedding }) => embedding
208+
);
209+
});
210+
}
211+
212+
/**
213+
* Method that retrieves the access token for making requests to the Baidu
214+
* API.
215+
* @returns The access token for making requests to the Baidu API.
216+
*/
217+
private async getAccessToken() {
218+
const url = `https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=${this.baiduApiKey}&client_secret=${this.baiduSecretKey}`;
219+
const response = await fetch(url, {
220+
method: "POST",
221+
headers: {
222+
"Content-Type": "application/json",
223+
Accept: "application/json",
224+
},
225+
});
226+
if (!response.ok) {
227+
const text = await response.text();
228+
const error = new Error(
229+
`Baidu get access token failed with status code ${response.status}, response: ${text}`
230+
);
231+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
232+
(error as any).response = response;
233+
throw error;
234+
}
235+
const json = await response.json();
236+
return json.access_token;
237+
}
238+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import { test, expect } from "@jest/globals";
2+
import { BaiduQianfanEmbeddings } from "../baidu_qianfan.js";
3+
4+
test.skip("Test BaiduQianfanEmbeddings.embedQuery", async () => {
5+
const embeddings = new BaiduQianfanEmbeddings();
6+
const res = await embeddings.embedQuery("Hello world");
7+
expect(typeof res[0]).toBe("number");
8+
});
9+
10+
test.skip("Test BaiduQianfanEmbeddings.embedDocuments", async () => {
11+
const embeddings = new BaiduQianfanEmbeddings();
12+
const res = await embeddings.embedDocuments(["Hello world", "Bye bye"]);
13+
expect(res).toHaveLength(2);
14+
expect(typeof res[0][0]).toBe("number");
15+
expect(typeof res[1][0]).toBe("number");
16+
});
17+
18+
test.skip("Test BaiduQianfanEmbeddings concurrency", async () => {
19+
const embeddings = new BaiduQianfanEmbeddings({
20+
batchSize: 1,
21+
});
22+
const res = await embeddings.embedDocuments([
23+
"Hello world",
24+
"Bye bye",
25+
"Hello world",
26+
"Bye bye",
27+
"Hello world",
28+
"Bye bye",
29+
]);
30+
expect(res).toHaveLength(6);
31+
expect(res.find((embedding) => typeof embedding[0] !== "number")).toBe(
32+
undefined
33+
);
34+
});

0 commit comments

Comments
 (0)