Skip to content

Commit 7589a8e

Browse files
committed
batchSize option for aiRowByRow and refactor
1 parent cd7dde4 commit 7589a8e

File tree

4 files changed

+147
-120
lines changed

4 files changed

+147
-120
lines changed

src/class/SimpleTable.ts

Lines changed: 5 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -63,17 +63,14 @@ import shouldFlipBeforeExport from "../helpers/shouldFlipBeforeExport.ts";
6363
import getProjection from "../helpers/getProjection.ts";
6464
import cache from "../methods/cache.ts";
6565
import {
66-
askAI,
6766
camelCase,
6867
createDirectory,
6968
formatNumber,
7069
logBarChart,
7170
logDotChart,
7271
logLineChart,
73-
prettyDuration,
7472
rewind,
7573
saveChart,
76-
sleep,
7774
} from "jsr:@nshiab/journalism@1";
7875
import writeDataAsArrays from "../helpers/writeDataAsArrays.ts";
7976
import logHistogram from "../methods/logHistogram.ts";
@@ -96,6 +93,8 @@ import unifyColumns from "../helpers/unifyColumns.ts";
9693
import accumulateQuery from "../helpers/accumulateQuery.ts";
9794
import stringifyDates from "../helpers/stringifyDates.ts";
9895
import stringifyDatesInvert from "../helpers/stringifyDatesInvert.ts";
96+
import aiRowByRow from "../methods/aiRowByRow.ts";
97+
import aiQuery from "../methods/aiQuery.ts";
9998

10099
/**
101100
* SimpleTable is a class representing a table in a SimpleDB. It can handle tabular and geospatial data. To create one, it's best to instantiate a SimpleDB first.
@@ -563,15 +562,13 @@ export default class SimpleTable extends Simple {
563562
* @param options - Configuration options for the AI request.
564563
* @param options.batchSize - The number of rows to process in each batch. By default, it is 1.
565564
* @param options.cache - If true, the results will be cached locally. By default, it is false.
566-
* @param options.cacheVerbose - If true, more information about the cache will be logged. By default, it is false.
567565
* @param options.rateLimitPerMinute - The rate limit for the AI requests in requests per minute. If necessary, the method will wait between requests. By default, there is no limit.
568566
* @param options.model - The model to use. Defaults to the `AI_MODEL` environment variable.
569567
* @param options.apiKey - The API key. Defaults to the `AI_KEY` environment variable.
570568
* @param options.vertex - Whether to use Vertex AI. Defaults to `false`. If `AI_PROJECT` and `AI_LOCATION` are set in the environment, it will automatically switch to true.
571569
* @param options.project - The Google Cloud project ID. Defaults to the `AI_PROJECT` environment variable.
572570
* @param options.location - The Google Cloud location. Defaults to the `AI_LOCATION` environment variable.
573571
* @param options.verbose - Whether to log additional information. Defaults to `false`.
574-
* @param options.costEstimate - Whether to estimate the cost of the request. Defaults to `false`.
575572
*/
576573
async aiRowByRow(
577574
column: string,
@@ -580,97 +577,16 @@ export default class SimpleTable extends Simple {
580577
options: {
581578
batchSize?: number;
582579
cache?: boolean;
583-
cacheVerbose?: boolean;
584580
model?: string;
585581
apiKey?: string;
586582
vertex?: boolean;
587583
project?: string;
588584
location?: string;
589585
verbose?: boolean;
590-
costEstimate?: boolean;
591586
rateLimitPerMinute?: number;
592587
} = {},
593588
) {
594-
await this.updateWithJS(async (rows) => {
595-
if (options.verbose) {
596-
console.log("\naiRowByRow()");
597-
}
598-
599-
const batchSize = options.batchSize ?? 1;
600-
601-
for (let i = 0; i < rows.length; i += batchSize) {
602-
options.verbose &&
603-
console.log(
604-
`\n${Math.min(i + batchSize, rows.length)}/${rows.length} | ${
605-
formatNumber(
606-
(Math.min(i + batchSize, rows.length)) / rows.length * 100,
607-
{
608-
significantDigits: 3,
609-
suffix: "%",
610-
},
611-
)
612-
}`,
613-
);
614-
const batch = rows.slice(i, i + batchSize);
615-
const fullPrompt =
616-
`${prompt}\nHere are the ${column} values as a list: ${
617-
JSON.stringify(batch.map((d) => d[column]))
618-
}\nReturn the results in a list as well, in the same order.`;
619-
620-
if (options.verbose) {
621-
console.log("\nPrompt:");
622-
console.log(fullPrompt);
623-
}
624-
625-
const start = new Date();
626-
627-
// Types could be improved
628-
const newValues = await askAI(
629-
fullPrompt,
630-
{
631-
...options,
632-
verbose: options.costEstimate || options.cacheVerbose,
633-
returnJson: true,
634-
},
635-
) as (string | number | boolean | Date | null)[];
636-
637-
if (newValues.length !== batch.length) {
638-
throw new Error(
639-
`The AI returned ${newValues.length} values, but the batch size is ${batchSize}.`,
640-
);
641-
}
642-
643-
const end = new Date();
644-
645-
if (options.verbose) {
646-
console.log("\nResponse:", newValues);
647-
if (!options.costEstimate) {
648-
console.log("Execution time:", prettyDuration(start, { end }));
649-
}
650-
}
651-
652-
for (let j = 0; j < newValues.length; j++) {
653-
rows[i + j][newColumn] = newValues[j];
654-
}
655-
656-
if (typeof options.rateLimitPerMinute === "number") {
657-
const delay = Math.round((60 / options.rateLimitPerMinute) * 1000) -
658-
(end.getTime() - start.getTime());
659-
if (delay > 0) {
660-
if (options.verbose) {
661-
console.log(
662-
`Waiting ${
663-
prettyDuration(0, { end: delay })
664-
} to respect rate limit...`,
665-
);
666-
}
667-
await sleep(delay);
668-
}
669-
}
670-
}
671-
672-
return rows;
673-
});
589+
await aiRowByRow(this, column, newColumn, prompt, options);
674590
}
675591

676592
/**
@@ -691,57 +607,30 @@ export default class SimpleTable extends Simple {
691607
* // Don't forget to add .journalism to your .gitignore file!
692608
* await table.aiQuery(
693609
* "Give me the average salary by department",
694-
* { cache: true }
610+
* { cache: true, verbose: true }
695611
* )
696612
* ```
697613
*
698614
* @param prompt - The input string to guide the AI in generating the SQL query.
699615
* @param options - Configuration options for the AI request.
700616
* @param options.cache - If true, the query will be cached locally. By default, it is false.
701-
* @param options.cacheVerbose - If true, more information about the cache will be logged. By default, it is false.
702617
* @param options.model - The model to use. Defaults to the `AI_MODEL` environment variable.
703618
* @param options.apiKey - The API key. Defaults to the `AI_KEY` environment variable.
704619
* @param options.vertex - Whether to use Vertex AI. Defaults to `false`. If `AI_PROJECT` and `AI_LOCATION` are set in the environment, it will automatically switch to true.
705620
* @param options.project - The Google Cloud project ID. Defaults to the `AI_PROJECT` environment variable.
706621
* @param options.location - The Google Cloud location. Defaults to the `AI_LOCATION` environment variable.
707622
* @param options.verbose - Whether to log additional information. Defaults to `false`.
708-
* @param options.costEstimate - Whether to estimate the cost of the request. Defaults to `false`.
709623
*/
710624
async aiQuery(prompt: string, options: {
711625
cache?: boolean;
712-
cacheVerbose?: boolean;
713626
model?: string;
714627
apiKey?: string;
715628
vertex?: boolean;
716629
project?: string;
717630
location?: string;
718631
verbose?: boolean;
719-
costEstimate?: boolean;
720632
} = {}) {
721-
const p =
722-
`I have a SQL table named "${this.name}". The data is already in it with these columns:\n${
723-
JSON.stringify(await this.getTypes(), undefined, 2)
724-
}\nI want you to give me a SQL query to do this:\n- ${prompt}\nThe query must replace the existing "${this.name}" table with 'CREATE OR REPLACE TABLE "${this.name}"'. Return just the query, nothing else.`;
725-
726-
if (options.verbose) {
727-
console.log("\naiQuery()");
728-
console.log("\nPrompt:");
729-
console.log(p);
730-
}
731-
732-
// Types could be improved
733-
let query = await askAI(p, {
734-
...options,
735-
verbose: options.costEstimate || options.cacheVerbose,
736-
}) as unknown as string;
737-
query = query.replace("```sql", "").replace("```", "").trim();
738-
739-
if (options.verbose) {
740-
console.log("\nResponse:");
741-
console.log(query);
742-
}
743-
744-
await this.sdb.customQuery(query);
633+
await aiQuery(this, prompt, options);
745634
}
746635

747636
/**

src/methods/aiQuery.ts

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import { askAI } from "@nshiab/journalism";
2+
import type { SimpleTable } from "../index.ts";
3+
4+
export default async function aiQuery(
5+
simpleTable: SimpleTable,
6+
prompt: string,
7+
options: {
8+
cache?: boolean;
9+
model?: string;
10+
apiKey?: string;
11+
vertex?: boolean;
12+
project?: string;
13+
location?: string;
14+
verbose?: boolean;
15+
} = {},
16+
) {
17+
const p =
18+
`I have a SQL table named "${simpleTable.name}". The data is already in it with these columns:\n${
19+
JSON.stringify(await simpleTable.getTypes(), undefined, 2)
20+
}\nI want you to give me a SQL query to do this:\n- ${prompt}\nThe query must replace the existing "${simpleTable.name}" table with 'CREATE OR REPLACE TABLE "${simpleTable.name}"'. Return just the query, nothing else.`;
21+
22+
if (options.verbose) {
23+
console.log("\naiQuery()");
24+
console.log("\nPrompt:");
25+
console.log(p);
26+
}
27+
28+
// Types could be improved
29+
let query = await askAI(p, options) as unknown as string;
30+
query = query.replace("```sql", "").replace("```", "").trim();
31+
32+
if (options.verbose) {
33+
console.log("\nResponse:");
34+
console.log(query);
35+
}
36+
37+
await simpleTable.sdb.customQuery(query);
38+
}

src/methods/aiRowByRow.ts

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import { askAI, formatNumber, prettyDuration, sleep } from "@nshiab/journalism";
2+
import type { SimpleTable } from "../index.ts";
3+
4+
export default async function aiRowByRow(
5+
simpleTable: SimpleTable,
6+
column: string,
7+
newColumn: string,
8+
prompt: string,
9+
options: {
10+
batchSize?: number;
11+
cache?: boolean;
12+
model?: string;
13+
apiKey?: string;
14+
vertex?: boolean;
15+
project?: string;
16+
location?: string;
17+
verbose?: boolean;
18+
rateLimitPerMinute?: number;
19+
} = {},
20+
) {
21+
await simpleTable.updateWithJS(async (rows) => {
22+
if (options.verbose) {
23+
console.log("\naiRowByRow()");
24+
}
25+
26+
const batchSize = options.batchSize ?? 1;
27+
28+
for (let i = 0; i < rows.length; i += batchSize) {
29+
options.verbose &&
30+
console.log(
31+
`\n${Math.min(i + batchSize, rows.length)}/${rows.length} | ${
32+
formatNumber(
33+
(Math.min(i + batchSize, rows.length)) / rows.length * 100,
34+
{
35+
significantDigits: 3,
36+
suffix: "%",
37+
},
38+
)
39+
}`,
40+
);
41+
const batch = rows.slice(i, i + batchSize);
42+
const fullPrompt = `${prompt}\nHere are the ${column} values as a list: ${
43+
JSON.stringify(batch.map((d) => d[column]))
44+
}\nReturn the results in a list as well, in the same order.`;
45+
46+
if (options.verbose) {
47+
console.log("\nPrompt:");
48+
console.log(fullPrompt);
49+
}
50+
51+
const start = new Date();
52+
53+
// Types could be improved
54+
const newValues = await askAI(
55+
fullPrompt,
56+
{
57+
...options,
58+
returnJson: true,
59+
},
60+
) as (string | number | boolean | Date | null)[];
61+
62+
if (!Array.isArray(newValues)) {
63+
throw new Error(
64+
`The AI returned a non-array value: ${JSON.stringify(newValues)}`,
65+
);
66+
}
67+
if (newValues.length !== batch.length) {
68+
throw new Error(
69+
`The AI returned ${newValues.length} values, but the batch size is ${batchSize}.`,
70+
);
71+
}
72+
73+
const end = new Date();
74+
75+
if (options.verbose) {
76+
console.log("\nResponse:", newValues);
77+
}
78+
79+
for (let j = 0; j < newValues.length; j++) {
80+
rows[i + j][newColumn] = newValues[j];
81+
}
82+
83+
if (typeof options.rateLimitPerMinute === "number") {
84+
const delay = Math.round((60 / options.rateLimitPerMinute) * 1000) -
85+
(end.getTime() - start.getTime());
86+
if (delay > 0) {
87+
if (options.verbose) {
88+
console.log(
89+
`Waiting ${
90+
prettyDuration(0, { end: delay })
91+
} to respect rate limit...`,
92+
);
93+
}
94+
await sleep(delay);
95+
}
96+
}
97+
}
98+
99+
return rows;
100+
});
101+
}

test/unit/methods/aiRowByRow.test.ts

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ if (typeof aiKey === "string" && aiKey !== "") {
6868
"city",
6969
"country",
7070
`Give me the country of the city.`,
71-
{ batchSize: 10, cache: true, cacheVerbose: true },
71+
{ batchSize: 10, cache: true, verbose: true },
7272
);
7373
const data = await table.getData();
7474

@@ -107,7 +107,7 @@ if (typeof aiKey === "string" && aiKey !== "") {
107107
"city",
108108
"country",
109109
`Give me the country of the city.`,
110-
{ batchSize: 10, cache: true, cacheVerbose: true },
110+
{ batchSize: 10, cache: true, verbose: true },
111111
);
112112
const data = await table.getData();
113113

@@ -146,7 +146,7 @@ if (typeof aiKey === "string" && aiKey !== "") {
146146
"city",
147147
"country",
148148
`Give me the country of the city.`,
149-
{ batchSize: 10, cacheVerbose: true },
149+
{ batchSize: 10, verbose: true },
150150
);
151151
const data = await table.getData();
152152

@@ -228,7 +228,6 @@ if (typeof aiKey === "string" && aiKey !== "") {
228228
batchSize: 10,
229229
verbose: true,
230230
rateLimitPerMinute: 15,
231-
costEstimate: true,
232231
},
233232
);
234233
const data = await table.getData();

0 commit comments

Comments
 (0)