1
- import { filterNull , groupBy } from "@kokoro/common/poldash" ;
1
+ import { filterNull , groupBy , lookup } from "@kokoro/common/poldash" ;
2
2
import type { PgColumn , SQLWrapper , SubqueryWithSelection } from "@kokoro/db" ;
3
3
import {
4
4
and ,
@@ -46,6 +46,7 @@ import {
46
46
type TaskState ,
47
47
} from "@kokoro/validators/db" ;
48
48
49
+ import { add , addDays , addMilliseconds , isSameDay , min } from "date-fns" ;
49
50
import { getEmbedding } from "../embeddings" ;
50
51
51
52
function createMemoryEventsSubquery (
@@ -129,8 +130,8 @@ function createMemoryEventsSubquery(
129
130
lastUpdate : memoryTable . lastUpdate ,
130
131
} )
131
132
. from ( memoryTable )
132
- . leftJoin ( memoryEventTable , eq ( memoryTable . id , memoryEventTable . memoryId ) )
133
- . leftJoin ( calendarTable , eq ( memoryEventTable . calendarId , calendarTable . id ) )
133
+ . innerJoin ( memoryEventTable , eq ( memoryTable . id , memoryEventTable . memoryId ) )
134
+ . innerJoin ( calendarTable , eq ( memoryEventTable . calendarId , calendarTable . id ) )
134
135
. where (
135
136
and (
136
137
...baseFilters ,
@@ -211,8 +212,11 @@ function createMemoryTasksSubquery(
211
212
lastUpdate : memoryTable . lastUpdate ,
212
213
} )
213
214
. from ( memoryTable )
214
- . leftJoin ( memoryTaskTable , eq ( memoryTable . id , memoryTaskTable . memoryId ) )
215
- . leftJoin ( tasklistsTable , eq ( memoryTaskTable . tasklistId , tasklistsTable . id ) )
215
+ . innerJoin ( memoryTaskTable , eq ( memoryTable . id , memoryTaskTable . memoryId ) )
216
+ . innerJoin (
217
+ tasklistsTable ,
218
+ eq ( memoryTaskTable . tasklistId , tasklistsTable . id ) ,
219
+ )
216
220
. leftJoinLateral ( latestStateSubquery , sql `true` )
217
221
. where (
218
222
and (
@@ -275,13 +279,28 @@ function processMemoryEvents(
275
279
startDate : Date ,
276
280
endDate ?: Date ,
277
281
) {
282
+ const processedMemories : QueriedMemory [ ] = [ ] ;
283
+
278
284
for ( const memory of memories ) {
285
+ processedMemories . push ( memory ) ;
286
+
279
287
if ( ! memory . event ?. rrule ) {
280
288
continue ;
281
289
}
282
290
283
291
const memoryEvent = memory . event ;
284
292
293
+ const instancesLookup = lookup (
294
+ memories . filter (
295
+ ( m ) =>
296
+ m . event ?. startOriginal &&
297
+ m . event . recurringEventPlatformId === memoryEvent . platformId &&
298
+ m . event . calendarId === memoryEvent . calendarId ,
299
+ ) ,
300
+ // biome-ignore lint/style/noNonNullAssertion: Only events with a startOriginal are processed
301
+ ( m ) => m . event ?. startOriginal ! ,
302
+ ) ;
303
+
285
304
const diff =
286
305
memoryEvent . endDate . getTime ( ) - memoryEvent . startDate . getTime ( ) ;
287
306
@@ -292,20 +311,24 @@ function processMemoryEvents(
292
311
endDate ?? new Date ( startDate . getTime ( ) + 30 * 24 * 60 * 60 * 1000 ) ,
293
312
) ;
294
313
295
- memories . push (
296
- ...otherDates . map ( ( date ) => ( {
314
+ for ( const virtualMemoryDate of otherDates ) {
315
+ if ( instancesLookup ( virtualMemoryDate ) ) {
316
+ continue ;
317
+ }
318
+
319
+ processedMemories . push ( {
297
320
...memory ,
298
321
event : {
299
322
...memoryEvent ,
300
- startDate : date ,
301
- endDate : new Date ( date . getTime ( ) + diff ) ,
323
+ startDate : virtualMemoryDate ,
324
+ endDate : addMilliseconds ( virtualMemoryDate , diff ) ,
302
325
} ,
303
326
isVirtual : true ,
304
- } ) ) ,
305
- ) ;
327
+ } ) ;
328
+ }
306
329
}
307
330
308
- return memories ;
331
+ return processedMemories ;
309
332
}
310
333
311
334
export async function getMemories (
@@ -345,7 +368,7 @@ export async function getMemories(
345
368
Object . values ( groupedMemories ) . map ( processMemoryTasks ) ,
346
369
) ;
347
370
348
- return processMemoryEvents ( processedMemories , new Date ( ) , undefined ) ;
371
+ return processedMemories ;
349
372
}
350
373
351
374
export async function queryMemories (
@@ -425,19 +448,18 @@ export async function queryMemories(
425
448
eq ( memoryTable . userId , userId ) ,
426
449
sourceCondition ,
427
450
textEmbedding
428
- ? sql < boolean > `${ memoryTable . content } <> '' or ${ memoryTable . description } <> ''`
451
+ ? or (
452
+ sql < boolean > `${ memoryTable . content } <> ''` ,
453
+ sql < boolean > `${ memoryTable . description } <> ''` ,
454
+ )
429
455
: undefined ,
430
456
] ;
431
-
432
457
const shouldIncludeMemoryType = ( type : MemoryType ) =>
433
- options . memoryTypes && options . memoryTypes . size > 0
434
- ? options . memoryTypes . has ( type )
435
- : true ;
458
+ options . memoryTypes ?. has ( type ) ?? true ;
436
459
437
460
const memoryEventsSubquery = shouldIncludeMemoryType ( EVENT_MEMORY_TYPE )
438
461
? createMemoryEventsSubquery (
439
462
[
440
- isNotNull ( memoryEventTable . id ) ,
441
463
...baseFilters ,
442
464
calendarSources
443
465
? inArray ( memoryEventTable . source , Array . from ( calendarSources ) )
@@ -463,7 +485,6 @@ export async function queryMemories(
463
485
const memoryTasksSubquery = shouldIncludeMemoryType ( TASK_MEMORY_TYPE )
464
486
? createMemoryTasksSubquery (
465
487
[
466
- isNotNull ( memoryTaskTable . id ) ,
467
488
...baseFilters ,
468
489
taskSources
469
490
? inArray ( memoryTaskTable . source , Array . from ( taskSources ) )
@@ -494,7 +515,7 @@ export async function queryMemories(
494
515
const memoriesSubquery =
495
516
// biome-ignore lint/style/noNonNullAssertion: this is wrong
496
517
(
497
- memoryEventsSubquery && memoryTasksSubquery
518
+ memoryEventsSubquery !== undefined && memoryTasksSubquery !== undefined
498
519
? union (
499
520
db . select ( ) . from ( memoryEventsSubquery ) ,
500
521
db . select ( ) . from ( memoryTasksSubquery ) ,
@@ -675,15 +696,49 @@ export async function queryMemories(
675
696
676
697
const groupedMemories = groupBy ( rows , "id" ) ;
677
698
678
- // Process tasks with attributes
679
- const processedMemories = filterNull (
680
- Object . values ( groupedMemories ) . map ( processMemoryTasks ) ,
681
- ) ;
682
-
683
699
// Process recurrent events
684
- return processMemoryEvents (
685
- processedMemories ,
700
+ const processedMemories = processMemoryEvents (
701
+ filterNull (
702
+ // Process tasks with attributes
703
+ Object . values ( groupedMemories ) . map ( processMemoryTasks ) ,
704
+ ) ,
686
705
startDate ?? new Date ( ) ,
687
- endDate ,
706
+ endDate ? min ( [ endDate , addDays ( startDate ?? new Date ( ) , 90 ) ] ) : undefined ,
688
707
) ;
708
+
709
+ return processedMemories . sort ( ( a , b ) => {
710
+ switch ( sortBy ) {
711
+ case "createdAt" : {
712
+ return orderBy === "asc"
713
+ ? a . createdAt . getTime ( ) - b . createdAt . getTime ( )
714
+ : b . createdAt . getTime ( ) - a . createdAt . getTime ( ) ;
715
+ }
716
+ case "updatedAt" : {
717
+ return orderBy === "asc"
718
+ ? a . lastUpdate . getTime ( ) - b . lastUpdate . getTime ( )
719
+ : b . lastUpdate . getTime ( ) - a . lastUpdate . getTime ( ) ;
720
+ }
721
+ // We need to figure out how to handle this for events
722
+ // case "priority": {
723
+ // return orderBy === "asc"
724
+ // ? a.taskAttributes.priority - b.taskAttributes.priority
725
+ // : b.taskAttributes.priority - a.taskAttributes.priority;
726
+ // }
727
+ case "relevantDate" : {
728
+ const aRelevantDate = a . event ?. startDate ?? a . task ?. dueDate ;
729
+ const bRelevantDate = b . event ?. startDate ?? b . task ?. dueDate ;
730
+
731
+ if ( aRelevantDate && bRelevantDate ) {
732
+ return orderBy === "asc"
733
+ ? aRelevantDate . getTime ( ) - bRelevantDate . getTime ( )
734
+ : bRelevantDate . getTime ( ) - aRelevantDate . getTime ( ) ;
735
+ }
736
+
737
+ return 0 ;
738
+ }
739
+ default : {
740
+ return 0 ;
741
+ }
742
+ }
743
+ } ) ;
689
744
}
0 commit comments