Skip to content

Commit bcc9966

Browse files
Chore: add migration files and update adaptors for feedback loop feature (#506)
* chore: Add migration for creating thread_response_explain table * chore: Add analysisSql method to WrenEngineAdaptor * adding explain api to ai service adapter: * fix migration errer * chore(wren-ui): Add regenerations API to wrenAIAdaptor --------- Co-authored-by: andreashimin <[email protected]>
1 parent 8406a41 commit bcc9966

File tree

4 files changed

+204
-24
lines changed

4 files changed

+204
-24
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/**
2+
* @param { import("knex").Knex } knex
3+
* @returns { Promise<void> }
4+
*/
5+
exports.up = function (knex) {
6+
return knex.schema.createTable('thread_response_explain', (table) => {
7+
table.increments('id').comment('ID');
8+
table
9+
.integer('thread_response_id')
10+
.comment('Reference to thread_response.id');
11+
table
12+
.foreign('thread_response_id')
13+
.references('thread_response.id')
14+
.onDelete('CASCADE');
15+
16+
table.string('query_id').notNullable();
17+
table.string('status').notNullable();
18+
table.jsonb('detail').notNullable();
19+
table.jsonb('error').notNullable();
20+
21+
// timestamps
22+
table.timestamps(true, true);
23+
});
24+
};
25+
26+
/**
27+
* @param { import("knex").Knex } knex
28+
* @returns { Promise<void> }
29+
*/
30+
exports.down = function (knex) {
31+
return knex.schema.dropTable('thread_response_explain');
32+
};
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/**
2+
* @param { import("knex").Knex } knex
3+
* @returns { Promise<void> }
4+
*/
5+
exports.up = function (knex) {
6+
return knex.schema.alterTable('thread_response', (table) => {
7+
table
8+
.jsonb('corrections')
9+
.nullable()
10+
.comment('the corrections of the previous thread response'); // [{type, id, correct}, ...]
11+
});
12+
};
13+
14+
/**
15+
* @param { import("knex").Knex } knex
16+
* @returns { Promise<void> }
17+
*/
18+
exports.down = function (knex) {
19+
return knex.schema.alterTable('thread_response', (table) => {
20+
table.dropColumn('corrections');
21+
});
22+
};

wren-ui/src/apollo/server/adaptors/wrenAIAdaptor.ts

+135-17
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,15 @@ export interface AsyncQueryResponse {
5555
queryId: string;
5656
}
5757

58+
export type ExplainResult = AIServiceResponse<any, ExplainPipelineStatus>;
59+
60+
export enum ExplainPipelineStatus {
61+
UNDERSTANDING = 'UNDERSTANDING',
62+
GENERATING = 'GENERATING',
63+
FINISHED = 'FINISHED',
64+
FAILED = 'FAILED',
65+
}
66+
5867
export enum AskResultStatus {
5968
UNDERSTANDING = 'UNDERSTANDING',
6069
SEARCHING = 'SEARCHING',
@@ -71,7 +80,21 @@ export enum AskCandidateType {
7180
LLM = 'LLM',
7281
}
7382

74-
export interface AskResponse<R, S> {
83+
export enum ExplainType {
84+
FILTER = 'filter',
85+
SELECT_ITEMS = 'selectItems',
86+
RELATION = 'relation',
87+
GROUP_BY_KEYS = 'groupByKeys',
88+
SORTINGS = 'sortings',
89+
}
90+
91+
// UI currently only support nl_expression
92+
export enum ExpressionType {
93+
SQL_EXPRESSION = 'sql_expression',
94+
NL_EXPRESSION = 'nl_expression',
95+
}
96+
97+
export interface AIServiceResponse<R, S> {
7598
status: S;
7699
response: R | null;
77100
error: WrenAIError | null;
@@ -83,15 +106,15 @@ export interface AskDetailInput {
83106
summary: string;
84107
}
85108

86-
export type AskDetailResult = AskResponse<
109+
export type AskDetailResult = AIServiceResponse<
87110
{
88111
description: string;
89112
steps: AskStep[];
90113
},
91114
AskResultStatus
92115
>;
93116

94-
export type AskResult = AskResponse<
117+
export type AskResult = AIServiceResponse<
95118
Array<{
96119
type: AskCandidateType;
97120
sql: string;
@@ -101,7 +124,26 @@ export type AskResult = AskResponse<
101124
AskResultStatus
102125
>;
103126

104-
const getAISerciceError = (error: any) => {
127+
export interface CorrectionObject<T> {
128+
type: T;
129+
value: string;
130+
}
131+
132+
export interface AskCorrection {
133+
before: CorrectionObject<ExplainType>;
134+
after: CorrectionObject<ExpressionType>;
135+
}
136+
137+
export interface AskStepWithCorrections extends AskStep {
138+
corrections: AskCorrection[];
139+
}
140+
141+
export interface RegenerateAskDetailInput {
142+
description: string;
143+
steps: AskStepWithCorrections[];
144+
}
145+
146+
const getAIServiceError = (error: any) => {
105147
const { data } = error.response || {};
106148
return data?.detail
107149
? `${error.message}, detail: ${data.detail}`
@@ -129,6 +171,12 @@ export interface IWrenAIAdaptor {
129171
*/
130172
generateAskDetail(input: AskDetailInput): Promise<AsyncQueryResponse>;
131173
getAskDetailResult(queryId: string): Promise<AskDetailResult>;
174+
explain(question: string, analysisResults: any): Promise<AsyncQueryResponse>;
175+
getExplainResult(queryId: string): Promise<ExplainResult>;
176+
regenerateAskDetail(
177+
input: RegenerateAskDetailInput,
178+
): Promise<AsyncQueryResponse>;
179+
getRegeneratedAskDetailResult(queryId: string): Promise<AskDetailResult>;
132180
}
133181

134182
export class WrenAIAdaptor implements IWrenAIAdaptor {
@@ -148,11 +196,11 @@ export class WrenAIAdaptor implements IWrenAIAdaptor {
148196
const res = await axios.post(`${this.wrenAIBaseEndpoint}/v1/asks`, {
149197
query: input.query,
150198
id: input.deployId,
151-
history: this.transfromHistoryInput(input.history),
199+
history: this.transformHistoryInput(input.history),
152200
});
153201
return { queryId: res.data.query_id };
154202
} catch (err: any) {
155-
logger.debug(`Got error when asking wren AI: ${getAISerciceError(err)}`);
203+
logger.debug(`Got error when asking wren AI: ${getAIServiceError(err)}`);
156204
throw err;
157205
}
158206
}
@@ -164,7 +212,7 @@ export class WrenAIAdaptor implements IWrenAIAdaptor {
164212
status: 'stopped',
165213
});
166214
} catch (err: any) {
167-
logger.debug(`Got error when canceling ask: ${getAISerciceError(err)}`);
215+
logger.debug(`Got error when canceling ask: ${getAIServiceError(err)}`);
168216
throw err;
169217
}
170218
}
@@ -178,7 +226,7 @@ export class WrenAIAdaptor implements IWrenAIAdaptor {
178226
return this.transformAskResult(res.data);
179227
} catch (err: any) {
180228
logger.debug(
181-
`Got error when getting ask result: ${getAISerciceError(err)}`,
229+
`Got error when getting ask result: ${getAIServiceError(err)}`,
182230
);
183231
// throw err;
184232
throw Errors.create(Errors.GeneralErrorCodes.INTERNAL_SERVER_ERROR, {
@@ -187,6 +235,44 @@ export class WrenAIAdaptor implements IWrenAIAdaptor {
187235
}
188236
}
189237

238+
public async explain(
239+
question: string,
240+
analysisResults: any,
241+
): Promise<AsyncQueryResponse> {
242+
try {
243+
const res = await axios.post(
244+
`${this.wrenAIBaseEndpoint}/v1/sql-explanations`,
245+
{
246+
question,
247+
steps_with_analysis_results: analysisResults,
248+
},
249+
);
250+
return { queryId: res.data.query_id };
251+
} catch (err: any) {
252+
logger.debug(`Got error when explaining: ${getAIServiceError(err)}`);
253+
throw err;
254+
}
255+
}
256+
257+
public async getExplainResult(queryId: string): Promise<ExplainResult> {
258+
// make GET request /v1/sql-explanations/:query_id/result to get the result
259+
try {
260+
const res = await axios.get(
261+
`${this.wrenAIBaseEndpoint}/v1/sql-explanations/${queryId}/result`,
262+
);
263+
return {
264+
status: res.data.status as ExplainPipelineStatus,
265+
response: res.data.response,
266+
error: this.transformStatusAndError(res.data).error,
267+
};
268+
} catch (err: any) {
269+
logger.debug(
270+
`Got error when getting explain result: ${getAIServiceError(err)}`,
271+
);
272+
throw err;
273+
}
274+
}
275+
190276
/**
191277
* After you choose a candidate, you can request AI service to generate the detail.
192278
*/
@@ -202,7 +288,7 @@ export class WrenAIAdaptor implements IWrenAIAdaptor {
202288
return { queryId: res.data.query_id };
203289
} catch (err: any) {
204290
logger.debug(
205-
`Got error when generating ask detail: ${getAISerciceError(err)}`,
291+
`Got error when generating ask detail: ${getAIServiceError(err)}`,
206292
);
207293
throw err;
208294
}
@@ -217,7 +303,37 @@ export class WrenAIAdaptor implements IWrenAIAdaptor {
217303
return this.transformAskDetailResult(res.data);
218304
} catch (err: any) {
219305
logger.debug(
220-
`Got error when getting ask detail result: ${getAISerciceError(err)}`,
306+
`Got error when getting ask detail result: ${getAIServiceError(err)}`,
307+
);
308+
throw err;
309+
}
310+
}
311+
312+
public async regenerateAskDetail(input: RegenerateAskDetailInput) {
313+
try {
314+
const res = await axios.post(
315+
`${this.wrenAIBaseEndpoint}/v1/sql-regenerations`,
316+
input,
317+
);
318+
return { queryId: res.data.query_id };
319+
} catch (err: any) {
320+
logger.debug(
321+
`Got error when regenerating ask detail: ${getAIServiceError(err)}`,
322+
);
323+
throw err;
324+
}
325+
}
326+
327+
public async getRegeneratedAskDetailResult(queryId: string) {
328+
// make GET request /v1/sql-regenerations/:query_id/result to get the result
329+
try {
330+
const res = await axios.get(
331+
`${this.wrenAIBaseEndpoint}/v1/sql-regenerations/${queryId}/result`,
332+
);
333+
return this.transformAskDetailResult(res.data);
334+
} catch (err: any) {
335+
logger.debug(
336+
`Got error when getting regenerated ask detail result: ${getAIServiceError(err)}`,
221337
);
222338
throw err;
223339
}
@@ -309,7 +425,7 @@ export class WrenAIAdaptor implements IWrenAIAdaptor {
309425
}));
310426

311427
return {
312-
status,
428+
status: status as AskResultStatus,
313429
error,
314430
response: candidates,
315431
};
@@ -326,7 +442,7 @@ export class WrenAIAdaptor implements IWrenAIAdaptor {
326442
}));
327443

328444
return {
329-
status,
445+
status: status as AskResultStatus,
330446
error,
331447
response: {
332448
description: body?.response?.description,
@@ -336,17 +452,19 @@ export class WrenAIAdaptor implements IWrenAIAdaptor {
336452
}
337453

338454
private transformStatusAndError(body: any): {
339-
status: AskResultStatus;
455+
status: AskResultStatus | ExplainPipelineStatus;
340456
error?: {
341457
code: Errors.GeneralErrorCodes;
342458
message: string;
343459
shortMessage: string;
344460
} | null;
345461
} {
346462
// transform status to enum
347-
const status = AskResultStatus[
348-
body?.status?.toUpperCase()
349-
] as AskResultStatus;
463+
const status =
464+
(AskResultStatus[body?.status?.toUpperCase()] as AskResultStatus) ||
465+
(ExplainPipelineStatus[
466+
body.status
467+
]?.toUpperCase() as ExplainPipelineStatus);
350468

351469
if (!status) {
352470
throw new Error(`Unknown ask status: ${body?.status}`);
@@ -380,7 +498,7 @@ export class WrenAIAdaptor implements IWrenAIAdaptor {
380498
};
381499
}
382500

383-
private transfromHistoryInput(history: AskHistory) {
501+
private transformHistoryInput(history: AskHistory) {
384502
if (!history) {
385503
return null;
386504
}

wren-ui/src/apollo/server/adaptors/wrenEngineAdaptor.ts

+15-7
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ export interface IWrenEngineAdaptor {
9494
sql: string,
9595
options: WrenEngineDryRunOption,
9696
): Promise<DryRunResponse[]>;
97+
98+
// analysis
99+
analysisSql(sql: string, mdl: Manifest): Promise<any>;
97100
}
98101

99102
export class WrenEngineAdaptor implements IWrenEngineAdaptor {
@@ -105,6 +108,7 @@ export class WrenEngineAdaptor implements IWrenEngineAdaptor {
105108
private dryPlanUrlPath = '/v1/mdl/dry-plan';
106109
private dryRunUrlPath = '/v1/mdl/dry-run';
107110
private validateUrlPath = '/v1/mdl/validate';
111+
private analysisUrlPath = '/v1/analysis/sql';
108112

109113
constructor({ wrenEngineEndpoint }: { wrenEngineEndpoint: string }) {
110114
this.wrenEngineBaseEndpoint = wrenEngineEndpoint;
@@ -315,16 +319,20 @@ export class WrenEngineAdaptor implements IWrenEngineAdaptor {
315319
}
316320
}
317321

318-
private async getDeployStatus(): Promise<WrenEngineDeployStatusResponse> {
322+
public async analysisSql(sql: string, mdl: Manifest) {
319323
try {
320-
const res = await axios.get(
321-
`${this.wrenEngineBaseEndpoint}/v1/mdl/status`,
324+
const url = new URL(this.analysisUrlPath, this.wrenEngineBaseEndpoint);
325+
const headers = {
326+
'Content-Type': 'application/json',
327+
};
328+
const res = await axios.post(
329+
url.href,
330+
{ sql, manifest: mdl },
331+
{ headers },
322332
);
323-
return res.data as WrenEngineDeployStatusResponse;
333+
return res.data;
324334
} catch (err: any) {
325-
logger.debug(
326-
`WrenEngine: Got error when getting deploy status: ${err.message}`,
327-
);
335+
logger.debug(`Got error when analyzing sql: ${err.message}`);
328336
throw err;
329337
}
330338
}

0 commit comments

Comments
 (0)