Skip to content

Commit 477e145

Browse files
authored
feat: brain queryMemories tests (#23)
* feat: moved embeddings to it's own file * merge * feat: better testing * feat: even more tests * chore: check * feat: more tests * chore: check
1 parent ae38037 commit 477e145

File tree

13 files changed

+2248
-537
lines changed

13 files changed

+2248
-537
lines changed

packages/brain/package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
"@biomejs/biome": "catalog:",
5858
"typescript": "catalog:",
5959
"@testcontainers/postgresql": "catalog:",
60+
"nanoid": "catalog:",
6061
"testcontainers": "catalog:"
6162
}
6263
}

packages/brain/src/memories/query.ts

Lines changed: 84 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { filterNull, groupBy } from "@kokoro/common/poldash";
1+
import { filterNull, groupBy, lookup } from "@kokoro/common/poldash";
22
import type { PgColumn, SQLWrapper, SubqueryWithSelection } from "@kokoro/db";
33
import {
44
and,
@@ -46,6 +46,7 @@ import {
4646
type TaskState,
4747
} from "@kokoro/validators/db";
4848

49+
import { add, addDays, addMilliseconds, isSameDay, min } from "date-fns";
4950
import { getEmbedding } from "../embeddings";
5051

5152
function createMemoryEventsSubquery(
@@ -129,8 +130,8 @@ function createMemoryEventsSubquery(
129130
lastUpdate: memoryTable.lastUpdate,
130131
})
131132
.from(memoryTable)
132-
.leftJoin(memoryEventTable, eq(memoryTable.id, memoryEventTable.memoryId))
133-
.leftJoin(calendarTable, eq(memoryEventTable.calendarId, calendarTable.id))
133+
.innerJoin(memoryEventTable, eq(memoryTable.id, memoryEventTable.memoryId))
134+
.innerJoin(calendarTable, eq(memoryEventTable.calendarId, calendarTable.id))
134135
.where(
135136
and(
136137
...baseFilters,
@@ -211,8 +212,11 @@ function createMemoryTasksSubquery(
211212
lastUpdate: memoryTable.lastUpdate,
212213
})
213214
.from(memoryTable)
214-
.leftJoin(memoryTaskTable, eq(memoryTable.id, memoryTaskTable.memoryId))
215-
.leftJoin(tasklistsTable, eq(memoryTaskTable.tasklistId, tasklistsTable.id))
215+
.innerJoin(memoryTaskTable, eq(memoryTable.id, memoryTaskTable.memoryId))
216+
.innerJoin(
217+
tasklistsTable,
218+
eq(memoryTaskTable.tasklistId, tasklistsTable.id),
219+
)
216220
.leftJoinLateral(latestStateSubquery, sql`true`)
217221
.where(
218222
and(
@@ -275,13 +279,28 @@ function processMemoryEvents(
275279
startDate: Date,
276280
endDate?: Date,
277281
) {
282+
const processedMemories: QueriedMemory[] = [];
283+
278284
for (const memory of memories) {
285+
processedMemories.push(memory);
286+
279287
if (!memory.event?.rrule) {
280288
continue;
281289
}
282290

283291
const memoryEvent = memory.event;
284292

293+
const instancesLookup = lookup(
294+
memories.filter(
295+
(m) =>
296+
m.event?.startOriginal &&
297+
m.event.recurringEventPlatformId === memoryEvent.platformId &&
298+
m.event.calendarId === memoryEvent.calendarId,
299+
),
300+
// biome-ignore lint/style/noNonNullAssertion: Only events with a startOriginal are processed
301+
(m) => m.event?.startOriginal!,
302+
);
303+
285304
const diff =
286305
memoryEvent.endDate.getTime() - memoryEvent.startDate.getTime();
287306

@@ -292,20 +311,24 @@ function processMemoryEvents(
292311
endDate ?? new Date(startDate.getTime() + 30 * 24 * 60 * 60 * 1000),
293312
);
294313

295-
memories.push(
296-
...otherDates.map((date) => ({
314+
for (const virtualMemoryDate of otherDates) {
315+
if (instancesLookup(virtualMemoryDate)) {
316+
continue;
317+
}
318+
319+
processedMemories.push({
297320
...memory,
298321
event: {
299322
...memoryEvent,
300-
startDate: date,
301-
endDate: new Date(date.getTime() + diff),
323+
startDate: virtualMemoryDate,
324+
endDate: addMilliseconds(virtualMemoryDate, diff),
302325
},
303326
isVirtual: true,
304-
})),
305-
);
327+
});
328+
}
306329
}
307330

308-
return memories;
331+
return processedMemories;
309332
}
310333

311334
export async function getMemories(
@@ -345,7 +368,7 @@ export async function getMemories(
345368
Object.values(groupedMemories).map(processMemoryTasks),
346369
);
347370

348-
return processMemoryEvents(processedMemories, new Date(), undefined);
371+
return processedMemories;
349372
}
350373

351374
export async function queryMemories(
@@ -425,19 +448,18 @@ export async function queryMemories(
425448
eq(memoryTable.userId, userId),
426449
sourceCondition,
427450
textEmbedding
428-
? sql<boolean>`${memoryTable.content} <> '' or ${memoryTable.description} <> ''`
451+
? or(
452+
sql<boolean>`${memoryTable.content} <> ''`,
453+
sql<boolean>`${memoryTable.description} <> ''`,
454+
)
429455
: undefined,
430456
];
431-
432457
const shouldIncludeMemoryType = (type: MemoryType) =>
433-
options.memoryTypes && options.memoryTypes.size > 0
434-
? options.memoryTypes.has(type)
435-
: true;
458+
options.memoryTypes?.has(type) ?? true;
436459

437460
const memoryEventsSubquery = shouldIncludeMemoryType(EVENT_MEMORY_TYPE)
438461
? createMemoryEventsSubquery(
439462
[
440-
isNotNull(memoryEventTable.id),
441463
...baseFilters,
442464
calendarSources
443465
? inArray(memoryEventTable.source, Array.from(calendarSources))
@@ -463,7 +485,6 @@ export async function queryMemories(
463485
const memoryTasksSubquery = shouldIncludeMemoryType(TASK_MEMORY_TYPE)
464486
? createMemoryTasksSubquery(
465487
[
466-
isNotNull(memoryTaskTable.id),
467488
...baseFilters,
468489
taskSources
469490
? inArray(memoryTaskTable.source, Array.from(taskSources))
@@ -494,7 +515,7 @@ export async function queryMemories(
494515
const memoriesSubquery =
495516
// biome-ignore lint/style/noNonNullAssertion: this is wrong
496517
(
497-
memoryEventsSubquery && memoryTasksSubquery
518+
memoryEventsSubquery !== undefined && memoryTasksSubquery !== undefined
498519
? union(
499520
db.select().from(memoryEventsSubquery),
500521
db.select().from(memoryTasksSubquery),
@@ -675,15 +696,49 @@ export async function queryMemories(
675696

676697
const groupedMemories = groupBy(rows, "id");
677698

678-
// Process tasks with attributes
679-
const processedMemories = filterNull(
680-
Object.values(groupedMemories).map(processMemoryTasks),
681-
);
682-
683699
// Process recurrent events
684-
return processMemoryEvents(
685-
processedMemories,
700+
const processedMemories = processMemoryEvents(
701+
filterNull(
702+
// Process tasks with attributes
703+
Object.values(groupedMemories).map(processMemoryTasks),
704+
),
686705
startDate ?? new Date(),
687-
endDate,
706+
endDate ? min([endDate, addDays(startDate ?? new Date(), 90)]) : undefined,
688707
);
708+
709+
return processedMemories.sort((a, b) => {
710+
switch (sortBy) {
711+
case "createdAt": {
712+
return orderBy === "asc"
713+
? a.createdAt.getTime() - b.createdAt.getTime()
714+
: b.createdAt.getTime() - a.createdAt.getTime();
715+
}
716+
case "updatedAt": {
717+
return orderBy === "asc"
718+
? a.lastUpdate.getTime() - b.lastUpdate.getTime()
719+
: b.lastUpdate.getTime() - a.lastUpdate.getTime();
720+
}
721+
// We need to figure out how to handle this for events
722+
// case "priority": {
723+
// return orderBy === "asc"
724+
// ? a.taskAttributes.priority - b.taskAttributes.priority
725+
// : b.taskAttributes.priority - a.taskAttributes.priority;
726+
// }
727+
case "relevantDate": {
728+
const aRelevantDate = a.event?.startDate ?? a.task?.dueDate;
729+
const bRelevantDate = b.event?.startDate ?? b.task?.dueDate;
730+
731+
if (aRelevantDate && bRelevantDate) {
732+
return orderBy === "asc"
733+
? aRelevantDate.getTime() - bRelevantDate.getTime()
734+
: bRelevantDate.getTime() - aRelevantDate.getTime();
735+
}
736+
737+
return 0;
738+
}
739+
default: {
740+
return 0;
741+
}
742+
}
743+
});
689744
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import { migrateDatabase } from "@kokoro/db/migration";
2+
import {
3+
PostgreSqlContainer,
4+
type StartedPostgreSqlContainer,
5+
} from "@testcontainers/postgresql";
6+
import {
7+
GenericContainer,
8+
type StartedTestContainer,
9+
Wait,
10+
} from "testcontainers";
11+
import { afterAll, beforeAll } from "vitest";
12+
13+
export function useDatabaseContainer() {
14+
let postgresContainer: StartedPostgreSqlContainer | undefined;
15+
16+
beforeAll(async () => {
17+
postgresContainer = await new PostgreSqlContainer(
18+
"timescale/timescaledb-ha:pg16-all",
19+
)
20+
.withDatabase("postgres")
21+
.withUsername("postgres")
22+
.withPassword("password")
23+
.withExposedPorts({ container: 5432, host: 5432 })
24+
.start();
25+
26+
await migrateDatabase(postgresContainer.getConnectionUri());
27+
}, 120000);
28+
29+
afterAll(async () => {
30+
await postgresContainer?.stop();
31+
});
32+
33+
return () => postgresContainer;
34+
}
35+
36+
export function useEmbeddingServiceContainer() {
37+
let embeddingServiceContainer: StartedTestContainer | undefined;
38+
39+
beforeAll(async () => {
40+
embeddingServiceContainer = await new GenericContainer(
41+
"ghcr.io/wosherco/all-minilm-l6-v2-restapi-service",
42+
)
43+
.withExposedPorts({ container: 3000, host: 3000 })
44+
.withWaitStrategy(Wait.forHttp("/health", 3000))
45+
.start();
46+
}, 30000);
47+
48+
afterAll(async () => {
49+
await embeddingServiceContainer?.stop();
50+
});
51+
52+
return () => embeddingServiceContainer;
53+
}

0 commit comments

Comments
 (0)