Skip to content

Commit 4713d02

Browse files
amirai21asafgardin
authored andcommitted
feat: rag-engine impl. - with examples, no tests yet
1 parent e5399f7 commit 4713d02

File tree

13 files changed

+188
-3
lines changed

13 files changed

+188
-3
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import { AI21 } from 'ai21';
2+
import { FileResponse, UploadFileResponse } from '../../../src/types/rag';
3+
4+
async function waitForFileProcessing(client: AI21, fileId: string, interval: number = 3000): Promise<FileResponse> {
5+
while (true) {
6+
const file: FileResponse = await client.ragEngine.get(fileId);
7+
8+
if (file.status === 'PROCESSED') {
9+
return file;
10+
}
11+
12+
console.log(`File status is '${file.status}'. Waiting for it to be 'PROCESSED'...`);
13+
await new Promise(resolve => setTimeout(resolve, interval));
14+
}
15+
}
16+
17+
async function uploadQueryUpdateDelete() {
18+
const client = new AI21({ apiKey: process.env.AI21_API_KEY });
19+
try {
20+
const uploadFileResponse: UploadFileResponse = await client.ragEngine.create(
21+
'/Users/amirkoblyansky/Documents/ukraine.txt', {path: "test10"});
22+
23+
const fileId = uploadFileResponse.fileId
24+
let file: FileResponse = await waitForFileProcessing(client, fileId);
25+
console.log(file);
26+
27+
console.log("Now updating the file labels");
28+
await client.ragEngine.update(uploadFileResponse.fileId, {labels: ["test99"], publicUrl: "https://www.miri.com"});
29+
file = await client.ragEngine.get(fileId);
30+
console.log(file);
31+
32+
console.log("Now deleting the file");
33+
await client.ragEngine.delete(uploadFileResponse.fileId);
34+
} catch (error) {
35+
console.error('Error:', error);
36+
}
37+
}
38+
39+
async function listFiles() {
40+
const client = new AI21({ apiKey: process.env.AI21_API_KEY });
41+
const files = await client.ragEngine.list({limit: 10});
42+
console.log(files);
43+
}
44+
45+
uploadQueryUpdateDelete().catch(console.error);
46+
47+
listFiles().catch(console.error);
48+

src/AI21.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import { APIClient } from './APIClient';
66
import { Headers } from './types';
77
import * as Runtime from './runtime';
88
import { ConversationalRag } from './resources/rag/conversationalRag';
9+
import { RAGEngine } from 'resources';
910

1011
export interface ClientOptions {
1112
baseURL?: string | undefined;
@@ -67,6 +68,7 @@ export class AI21 extends APIClient {
6768
// Resources
6869
chat: Chat = new Chat(this);
6970
conversationalRag: ConversationalRag = new ConversationalRag(this);
71+
ragEngine: RAGEngine = new RAGEngine(this);
7072

7173
// eslint-disable-next-line @typescript-eslint/no-unused-vars
7274
protected override authHeaders(_: Types.FinalRequestOptions): Types.Headers {

src/APIClient.ts

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ import {
1212
import { AI21EnvConfig } from './EnvConfig';
1313
import { createFetchInstance } from './runtime';
1414
import { Fetch } from 'fetch';
15+
import { createReadStream } from 'fs';
16+
import { basename as getBasename } from 'path';
17+
import FormData from 'form-data';
1518

1619
const validatePositiveInteger = (name: string, n: unknown): number => {
1720
if (typeof n !== 'number' || !Number.isInteger(n)) {
@@ -61,6 +64,49 @@ export abstract class APIClient {
6164
return this.makeRequest('delete', path, opts);
6265
}
6366

67+
upload<Req, Rsp>(path: string, filePath: string, opts?: RequestOptions<Req>): Promise<Rsp> {
68+
const formDataRequest = this.makeFormDataRequest(path, filePath, opts);
69+
return this.performRequest(formDataRequest).then(
70+
(response) => this.fetch.handleResponse<Rsp>(response) as Rsp,
71+
);
72+
}
73+
74+
protected makeFormDataRequest<Req>(path: string, filePath: string, opts?: RequestOptions<Req>): FinalRequestOptions {
75+
const formData = new FormData();
76+
const fileStream = createReadStream(filePath);
77+
const fileName = getBasename(filePath);
78+
79+
formData.append('file', fileStream, fileName);
80+
81+
if (opts?.body) {
82+
const body = opts.body as Record<string, string>;
83+
for (const [key, value] of Object.entries(body)) {
84+
if (Array.isArray(value)) {
85+
value.forEach(item => formData.append(key, item));
86+
} else {
87+
formData.append(key, value);
88+
}
89+
}
90+
}
91+
92+
const headers = {
93+
...opts?.headers,
94+
'Content-Type': `multipart/form-data; boundary=${formData.getBoundary()}`
95+
};
96+
console.log(headers);
97+
console.log("-------------------------");
98+
console.log(formData.getHeaders());
99+
console.log("-------------------------");
100+
101+
const options: FinalRequestOptions = {
102+
method: 'post',
103+
path: path,
104+
body: formData,
105+
headers,
106+
};
107+
return options;
108+
}
109+
64110
protected getUserAgent(): string {
65111
const platform =
66112
this.isRunningInBrowser() ?
@@ -96,12 +142,18 @@ export abstract class APIClient {
96142
}
97143

98144
private async performRequest(options: FinalRequestOptions): Promise<APIResponseProps> {
99-
const url = `${this.baseURL}${options.path}`;
145+
let url = `${this.baseURL}${options.path}`;
146+
147+
if (options.query) {
148+
const queryString = new URLSearchParams(options.query as Record<string, string>).toString();
149+
url += `?${queryString}`;
150+
}
100151

101152
const headers = {
102153
...this.defaultHeaders(options),
103154
...options.headers,
104155
};
156+
105157
const response = await this.fetch.call(url, { ...options, headers });
106158

107159
if (!response.ok) {

src/fetch/NodeFetch.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
import { FinalRequestOptions, CrossPlatformResponse } from 'types';
22
import { BaseFetch } from './BaseFetch';
33
import { Stream, NodeSSEDecoder } from '../streaming';
4+
import FormData from 'form-data';
45

56
export class NodeFetch extends BaseFetch {
67
async call(url: string, options: FinalRequestOptions): Promise<CrossPlatformResponse> {
78
const nodeFetchModule = await import('node-fetch');
89
const nodeFetch = nodeFetchModule.default;
910

11+
const body = options.body instanceof FormData ? options.body : JSON.stringify(options.body);
12+
1013
return nodeFetch(url, {
1114
method: options.method,
1215
headers: options?.headers ? (options.headers as Record<string, string>) : undefined,
13-
body: options?.body ? JSON.stringify(options.body) : undefined,
16+
body,
1417
});
1518
}
1619

src/resources/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
export { Chat, Completions } from './chat';
22
export { ConversationalRag } from './rag';
3+
export { RAGEngine } from './rag';

src/resources/rag/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
export { ConversationalRag } from './conversationalRag';
2+
export { RAGEngine } from './ragEngine';

src/resources/rag/ragEngine.ts

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import * as Models from '../../types';
2+
import { APIResource } from '../../APIResource';
3+
import { UploadFileResponse, UploadFileRequest, ListFilesFilters, UpdateFileRequest } from '../../types/rag';
4+
import { FileResponse } from 'types/rag/FileResponse';
5+
6+
7+
const RAG_ENGINE_PATH = '/library/files';
8+
9+
10+
export class RAGEngine extends APIResource {
11+
create(filePath: string, body: UploadFileRequest, options?: Models.RequestOptions) {
12+
return this.client.upload<UploadFileRequest, UploadFileResponse>(RAG_ENGINE_PATH, filePath, {
13+
body: body,
14+
...options,
15+
} as Models.RequestOptions<UploadFileRequest>) as Promise<UploadFileResponse>;
16+
}
17+
18+
get(fileId: string, options?: Models.RequestOptions) {
19+
return this.client.get<string, FileResponse>(
20+
`${RAG_ENGINE_PATH}/${fileId}`, options as Models.RequestOptions<string>) as Promise<FileResponse>;
21+
}
22+
23+
delete(fileId: string, options?: Models.RequestOptions) {
24+
return this.client.delete<string, null>(
25+
`${RAG_ENGINE_PATH}/${fileId}`, options as Models.RequestOptions<string>) as Promise<null>;
26+
}
27+
28+
list(body: ListFilesFilters | null, options?: Models.RequestOptions) {
29+
return this.client.get<ListFilesFilters | null, FileResponse[]>(
30+
RAG_ENGINE_PATH, {query: body, ...options} as Models.RequestOptions<ListFilesFilters | null>) as Promise<FileResponse[]>;
31+
}
32+
33+
update(fileId: string, body: UpdateFileRequest, options?: Models.RequestOptions) {
34+
return this.client.put<UpdateFileRequest, null>(
35+
`${RAG_ENGINE_PATH}/${fileId}`, {body, ...options} as Models.RequestOptions<UpdateFileRequest>) as Promise<null>;
36+
}
37+
}

src/types/API.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ export type RequestOptions<Req = unknown | Record<string, unknown> | ArrayBuffer
1010
method?: HTTPMethod;
1111
path?: string;
1212
query?: Req | undefined;
13-
body?: Req | null | undefined;
13+
body?: Req | FormData | null | undefined;
1414
headers?: Headers | undefined;
1515

1616
maxRetries?: number;

src/types/rag/FileResponse.ts

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
export interface FileResponse {
2+
fileId: string;
3+
name: string;
4+
fileType: string;
5+
sizeBytes: number;
6+
createdBy: string;
7+
creationDate: Date;
8+
lastUpdated: Date;
9+
status: string;
10+
path?: string | null;
11+
labels?: string[] | null;
12+
publicUrl?: string | null;
13+
errorCode?: number | null;
14+
errorMessage?: string | null;
15+
}

src/types/rag/ListFilesFilters.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
export interface ListFilesFilters {
2+
offset?: number | null;
3+
limit?: number | null;
4+
}

0 commit comments

Comments
 (0)