Skip to content

Commit 2e577a2

Browse files
amirai21asafgardin
authored andcommitted
feat: convrag added to client, fix response mapping
1 parent 40e0a78 commit 2e577a2

File tree

10 files changed

+111
-11
lines changed

10 files changed

+111
-11
lines changed

package.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
"lint": "npx eslint 'src/**/*.{ts,tsx}' --no-ignore",
2121
"format": "prettier --write \"(src|test)/**\" --no-error-on-unmatched-pattern",
2222
"prepare": "npm run build",
23-
"example": "npx tsx src/example.ts",
23+
"chatExample": "npx tsx src/examples/chatExample.ts",
24+
"convRagExample": "npx tsx src/examples/convRagExample.ts",
2425
"circular": "madge --circular --extensions ts src",
2526
"quality": "npm run circular && npm run lint && tsc --noEmit && npm run format && npm run unused-deps",
2627
"quality:fix": "npm run circular && npm run lint -- --fix && tsc --noEmit && npm run format"

src/AI21.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import { MissingAPIKeyError } from './errors';
44
import { Chat } from './resources/chat';
55
import { APIClient } from './APIClient';
66
import { Headers } from './types';
7+
import { ConversationalRag } from './resources/rag/conversationalRag';
78

89
export type ClientOptions = {
910
baseURL?: string;
@@ -54,6 +55,7 @@ export class AI21 extends APIClient {
5455

5556
// Resources
5657
chat: Chat = new Chat(this);
58+
conversationalRag: ConversationalRag = new ConversationalRag(this);
5759

5860
// eslint-disable-next-line @typescript-eslint/no-unused-vars
5961
protected override authHeaders(_: Types.FinalRequestOptions): Types.Headers {

src/ResponseHandler.ts

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ type APIResponse<T> = {
1111
export async function handleAPIResponse<T>({
1212
response,
1313
options,
14-
}: APIResponseProps): Promise<Stream<T> | APIResponse<T>> {
14+
}: APIResponseProps): Promise<Stream<T> | Promise<APIResponse<T>>> {
1515
if (options.stream) {
1616
if (!response.body) {
1717
throw new AI21Error('Response body is null');
@@ -20,10 +20,5 @@ export async function handleAPIResponse<T>({
2020
}
2121

2222
const contentType = response.headers.get('content-type');
23-
const data = contentType?.includes('application/json') ? await response.json() : null;
24-
25-
return {
26-
data,
27-
response,
28-
};
23+
return contentType?.includes('application/json') ? await response.json() : null;
2924
}

src/example.ts renamed to src/examples/chatExample.ts

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,32 @@
1-
import { AI21 } from './AI21';
1+
import { AI21 } from "../AI21";
2+
23
/*
34
This is a temporary example to test the API streaming/non-streaming functionality.
45
*/
56
async function main() {
67
const client = new AI21({ apiKey: process.env.AI21_API_KEY });
78

89
try {
9-
const response = await client.chat.completions.create({
10+
11+
console.log('-------------------------------- streaming flow --------------------------------');
12+
13+
let streamResponse = await client.chat.completions.create({
1014
model: 'jamba-1.5-mini',
1115
messages: [{ role: 'user', content: 'Hello, how are you? tell me a 100 line story about a cat' }],
1216
stream: true,
1317
});
14-
for await (const chunk of response) {
18+
for await (const chunk of streamResponse) {
1519
process.stdout.write(chunk.choices[0]?.delta?.content || '');
1620
}
21+
22+
console.log('-------------------------------- non streaming flow --------------------------------');
23+
24+
const response = await client.chat.completions.create({
25+
model: 'jamba-1.5-mini',
26+
messages: [{ role: 'user', content: 'Hello, how are you? tell me a 100 line story about a cat' }],
27+
stream: false,
28+
});
29+
console.log(response);
1730
} catch (error) {
1831
console.error('Error:', error);
1932
}

src/examples/convRagExample.ts

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import { AI21 } from "../AI21";
2+
3+
/*
4+
This is a temporary example to test the Conversational RAG functionality.
5+
*/
6+
async function main() {
7+
8+
/* TODO - add a file upload example when library support is added and combined with the below flow */
9+
10+
const client = new AI21({ apiKey: process.env.AI21_API_KEY });
11+
try {
12+
13+
/* The following example is for a question that is not in the context of files uploaded to RAG */
14+
15+
const answer_not_in_ctx_response = await client.conversationalRag.create({
16+
messages: [{ role: 'user', content: 'Who is the Russian president?' }],
17+
});
18+
console.log(answer_not_in_ctx_response);
19+
20+
/* The following example is for a question that should be answered based on files uploaded to RAG */
21+
22+
const answer_in_ctx_response = await client.conversationalRag.create({
23+
messages: [{ role: 'user', content: 'What is headace?' }],
24+
});
25+
console.log(answer_in_ctx_response);
26+
} catch (error) {
27+
console.error('Error:', error);
28+
}
29+
}
30+
31+
main().catch(console.error);
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import * as Models from '../../types';
2+
import { APIResource } from '../../APIResource';
3+
import { ConversationalRagRequest } from '../../types/rag/ConversationalRagRequest';
4+
import { ConversationalRagResponse } from '../../types/rag/ConversationalRagResponse';
5+
6+
export class ConversationalRag extends APIResource {
7+
8+
create(body: ConversationalRagRequest, options?: Models.RequestOptions){
9+
return this.client.post<ConversationalRagRequest, ConversationalRagResponse>(
10+
'/conversational-rag',
11+
{
12+
body,
13+
...options,
14+
} as Models.RequestOptions<ConversationalRagRequest>,
15+
) as Promise<ConversationalRagResponse>;
16+
}
17+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import { ChatMessageParam } from "../chat";
2+
import { RetrievalStrategy } from "./RetrievalStrategy";
3+
4+
export interface ConversationalRagRequest {
5+
messages: ChatMessageParam[];
6+
path?: string | null;
7+
labels?: string[] | null;
8+
file_ids?: string[] | null;
9+
max_segments?: number | null;
10+
retrieval_strategy?: RetrievalStrategy | string | null;
11+
retrieval_similarity_threshold?: number | null;
12+
max_neighbors?: number | null;
13+
hybrid_search_alpha?: number | null;
14+
[key: string]: any;
15+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import { ChatMessage } from "../chat";
2+
import { ConversationalRagSource } from "./ConversationalRagSource";
3+
4+
export interface ConversationalRagResponse {
5+
id: string;
6+
choices: ChatMessage[];
7+
search_queries?: string[] | null;
8+
context_retrieved: boolean;
9+
answer_in_context: boolean;
10+
sources: ConversationalRagSource[];
11+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
export interface ConversationalRagSource {
2+
text: string;
3+
file_id: string;
4+
file_name: string;
5+
score: number;
6+
order?: number | null;
7+
public_url?: string | null;
8+
labels?: string[] | null;
9+
}

src/types/rag/RetrievalStrategy.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
export enum RetrievalStrategy {
2+
DEFAULT = "default",
3+
SEGMENTS = "segments",
4+
ADD_NEIGHBORS = "add_neighbors",
5+
FULL_DOC = "full_doc"
6+
}

0 commit comments

Comments
 (0)