Skip to content

Commit 88ccc93

Browse files
committed
New method aiEmbeddings
1 parent 00a465a commit 88ccc93

File tree

8 files changed

+311
-16
lines changed

8 files changed

+311
-16
lines changed

deno.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
"nodeModulesDir": "auto",
2222
"imports": {
2323
"@duckdb/node-api": "npm:@duckdb/[email protected]",
24-
"@nshiab/journalism": "jsr:@nshiab/journalism@^1.28.2",
24+
"@nshiab/journalism": "jsr:@nshiab/[email protected].5",
2525
"@observablehq/plot": "npm:@observablehq/[email protected]",
2626
"@std/assert": "jsr:@std/[email protected]"
2727
},

deno.lock

Lines changed: 16 additions & 15 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/class/SimpleTable.ts

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ import stringifyDates from "../helpers/stringifyDates.ts";
9595
import stringifyDatesInvert from "../helpers/stringifyDatesInvert.ts";
9696
import aiRowByRow from "../methods/aiRowByRow.ts";
9797
import aiQuery from "../methods/aiQuery.ts";
98+
import aiEmbeddings from "../methods/aiEmbeddings.ts";
9899

99100
/**
100101
* SimpleTable is a class representing a table in a SimpleDB. It can handle tabular and geospatial data. To create one, it's best to instantiate a SimpleDB first.
@@ -627,6 +628,74 @@ export default class SimpleTable extends Simple {
627628
await aiRowByRow(this, column, newColumn, prompt, options);
628629
}
629630

631+
/**
632+
* Generates embeddings for a specified column and stores the results in a new column.
633+
*
634+
* This method currently supports Google Gemini, Vertex AI, and local models running with Ollama. It retrieves credentials and the model from environment variables (`AI_KEY`, `AI_PROJECT`, `AI_LOCATION`, `AI_EMBEDDINGS_MODEL`) or accepts them as options. Options take precedence over environment variables.
635+
*
636+
* To run local models with Ollama, set the `OLLAMA` environment variable to `true` and start Ollama on your machine. Make sure to install the model you want and set the `AI_EMBEDDINGS_MODEL` environment variable to the model name.
637+
*
638+
* To avoid exceeding rate limits, you can use the `rateLimitPerMinute` option to automatically add a delay between requests to comply with the rate limit.
639+
*
640+
* If you have a business or professional account with high rate limits, you can set the `concurrent` option to process multiple requests concurrently and speed up the process.
641+
*
642+
* The `cache` option allows you to cache the results of each request locally, saving resources and time. The data is cached in the local hidden folder `.journalism-cache` (because this method uses the `getEmbedding` function from the [journalism library](https://github.com/nshiab/journalism)). Don't forget to add `.journalism-cache` to your `.gitignore` file!
643+
*
644+
* This method won't work if your table contains geometries.
645+
*
646+
* @example
647+
* Basic usage with cache, rate limit, and verbose logging
648+
* ```ts
649+
* // New table with column "food".
650+
* await table.loadArray([
651+
* { food: "pizza" },
652+
* { food: "sushi" },
653+
* { food: "burger" },
654+
* { food: "pasta" },
655+
* { food: "salad" },
656+
* { food: "tacos" }
657+
* ]);
658+
*
659+
* // Ask the AI to generate embeddings in a new column "embeddings".
660+
* await table.aiEmbeddings("food", "embeddings", {
661+
* // Cache the results locally
662+
* cache: true,
663+
* // Avoid exceeding a rate limit by waiting between requests
664+
* rateLimitPerMinute: 15,
665+
* // Log details
666+
* verbose: true,
667+
* });
668+
* ```
669+
*
670+
* @param column - The column to be used as input for the embeddings.
671+
* @param newColumn - The name of the new column where the embeddings will be stored.
672+
* @param options - Configuration options for the AI request.
673+
* @param options.concurrent - The number of concurrent requests to send. Defaults to 1.
674+
* @param options.cache - If true, the results will be cached locally. Defaults to false.
675+
* @param options.rateLimitPerMinute - The rate limit for the AI requests in requests per minute. If necessary, the method will wait between requests. Defaults to no limit.
676+
* @param options.model - The model to use. Defaults to the `AI_MODEL` environment variable.
677+
* @param options.apiKey - The API key. Defaults to the `AI_KEY` environment variable.
678+
* @param options.vertex - Whether to use Vertex AI. Defaults to `false`. If `AI_PROJECT` and `AI_LOCATION` are set in the environment, it will automatically switch to true.
679+
* @param options.project - The Google Cloud project ID. Defaults to the `AI_PROJECT` environment variable.
680+
* @param options.location - The Google Cloud location. Defaults to the `AI_LOCATION` environment variable.
681+
* @param options.ollama - Whether to use Ollama. Defaults to the `OLLAMA` environment variable.
682+
* @param options.verbose - Whether to log additional information. Defaults to `false`.
683+
*/
684+
async aiEmbeddings(column: string, newColumn: string, options: {
685+
concurrent?: number;
686+
cache?: boolean;
687+
model?: string;
688+
apiKey?: string;
689+
vertex?: boolean;
690+
project?: string;
691+
location?: string;
692+
ollama?: boolean;
693+
verbose?: boolean;
694+
rateLimitPerMinute?: number;
695+
} = {}) {
696+
await aiEmbeddings(this, column, newColumn, options);
697+
}
698+
630699
/**
631700
* Generates and executes a SQL query based on a prompt. Additional instructions are automatically added before and after your prompt, such as the column types. To see the full prompt, set the `verbose` option to true.
632701
*

src/helpers/convertForJS.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ export default function convertForJS(rows: {
2929
for (const row of rows) {
3030
row[key] = row[key] === null ? null : "<Geometry>";
3131
}
32+
} else if (types[key].includes("FLOAT[")) {
33+
for (const row of rows) {
34+
row[key] = row[key] === null ? null : `<${types[key]}>`;
35+
}
3236
}
3337
}
3438
}

src/helpers/parseDuckDBType.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import {
2+
ARRAY,
23
BIGINT,
34
BOOLEAN,
45
DATE,
56
DOUBLE,
7+
FLOAT,
68
INTEGER,
79
TIME,
810
TIMESTAMP,
@@ -29,6 +31,10 @@ export default function parseDuckDBType(type: string) {
2931
return TIME;
3032
} else if (type === "BOOLEAN") {
3133
return BOOLEAN;
34+
} else if (type.includes("FLOAT[")) {
35+
// For embeddings
36+
const size = type.replace("FLOAT[", "").replace("]", "");
37+
return ARRAY(FLOAT, parseInt(size));
3238
} else {
3339
throw new Error(`Type ${type} not supported.`);
3440
}

src/methods/aiEmbeddings.ts

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import { formatNumber, getEmbedding, sleep } from "@nshiab/journalism";
2+
import type { SimpleTable } from "../index.ts";
3+
4+
export default async function aiEmbeddings(
5+
simpleTable: SimpleTable,
6+
column: string,
7+
newColumn: string,
8+
options: {
9+
concurrent?: number;
10+
cache?: boolean;
11+
model?: string;
12+
apiKey?: string;
13+
vertex?: boolean;
14+
project?: string;
15+
location?: string;
16+
ollama?: boolean;
17+
verbose?: boolean;
18+
rateLimitPerMinute?: number;
19+
} = {},
20+
) {
21+
await simpleTable.updateWithJS(async (rows) => {
22+
if (options.verbose) {
23+
console.log("\naiEmbeddings()");
24+
}
25+
26+
const concurrent = options.concurrent ?? 1;
27+
28+
let requests = [];
29+
for (let i = 0; i < rows.length; i++) {
30+
if (options.verbose) {
31+
console.log(
32+
`\nProcessing row ${i + 1} of ${rows.length}... (${
33+
formatNumber(
34+
(i + 1) / rows.length * 100,
35+
{
36+
significantDigits: 3,
37+
suffix: "%",
38+
},
39+
)
40+
})`,
41+
);
42+
}
43+
44+
if (requests.length < concurrent) {
45+
const text = rows[i][column];
46+
if (typeof text !== "string") {
47+
throw new Error(
48+
`The column "${column}" must be a string. Found ${text} instead.`,
49+
);
50+
}
51+
requests.push(
52+
getEmbedding(text, options),
53+
);
54+
}
55+
56+
if (requests.length === concurrent || i + 1 >= rows.length) {
57+
const start = new Date();
58+
const newValues = await Promise.all(requests);
59+
for (let j = 0; j < newValues.length; j++) {
60+
// Should be improved...
61+
rows[i + j][newColumn] = newValues[j] as unknown as number;
62+
}
63+
const end = new Date();
64+
65+
const duration = end.getTime() - start.getTime();
66+
// If duration is less than 10ms per request, it should means data comes from cache and we don't need to wait
67+
if (
68+
typeof options.rateLimitPerMinute === "number" &&
69+
duration > 10 * requests.length && i + 1 < rows.length
70+
) {
71+
const delay = Math.round(
72+
(60 / (options.rateLimitPerMinute / concurrent)) * 1000,
73+
);
74+
await sleep(delay, { start, log: options.verbose });
75+
}
76+
77+
requests = [];
78+
}
79+
}
80+
81+
return rows;
82+
});
83+
}

src/methods/loadArray.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import {
2+
arrayValue,
23
type DuckDBConnection,
34
DuckDBDataChunk,
45
DuckDBTimestampValue,
@@ -43,6 +44,13 @@ export default async function loadArray(
4344
);
4445
}
4546
}
47+
} else if (Array.isArray(arrayOfObjects[0][key])) {
48+
types[i] = `FLOAT[${arrayOfObjects[0][key].length}]`;
49+
50+
for (let j = 0; j < arrayOfObjects.length; j++) {
51+
const d = arrayOfObjects[j][key];
52+
dataForChunk[j][i] = arrayValue(d as number[]);
53+
}
4654
} else {
4755
throw new Error(`Type object not supported.`);
4856
}

0 commit comments

Comments
 (0)