58
58
except (ImportError , AttributeError ):
59
59
BaseArtifactService = Any
60
60
61
+ try :
62
+ from google .adk .memory import BaseMemoryService
63
+
64
+ BaseMemoryService = BaseMemoryService
65
+ except (ImportError , AttributeError ):
66
+ BaseMemoryService = Any
67
+
61
68
try :
62
69
from opentelemetry .sdk import trace
63
70
@@ -281,6 +288,7 @@ def __init__(
281
288
enable_tracing : bool = False ,
282
289
session_service_builder : Optional [Callable [..., "BaseSessionService" ]] = None ,
283
290
artifact_service_builder : Optional [Callable [..., "BaseArtifactService" ]] = None ,
291
+ memory_service_builder : Optional [Callable [..., "BaseMemoryService" ]] = None ,
284
292
env_vars : Optional [Dict [str , str ]] = None ,
285
293
):
286
294
"""An ADK Application."""
@@ -301,6 +309,7 @@ def __init__(
301
309
"enable_tracing" : enable_tracing ,
302
310
"session_service_builder" : session_service_builder ,
303
311
"artifact_service_builder" : artifact_service_builder ,
312
+ "memory_service_builder" : memory_service_builder ,
304
313
"app_name" : _DEFAULT_APP_NAME ,
305
314
"env_vars" : env_vars or {},
306
315
}
@@ -410,6 +419,7 @@ def clone(self):
410
419
enable_tracing = self ._tmpl_attrs .get ("enable_tracing" ),
411
420
session_service_builder = self ._tmpl_attrs .get ("session_service_builder" ),
412
421
artifact_service_builder = self ._tmpl_attrs .get ("artifact_service_builder" ),
422
+ memory_service_builder = self ._tmpl_attrs .get ("memory_service_builder" ),
413
423
env_vars = self ._tmpl_attrs .get ("env_vars" ),
414
424
)
415
425
@@ -421,6 +431,7 @@ def set_up(self):
421
431
from google .adk .artifacts .in_memory_artifact_service import (
422
432
InMemoryArtifactService ,
423
433
)
434
+ from google .adk .memory .in_memory_memory_service import InMemoryMemoryService
424
435
425
436
os .environ ["GOOGLE_GENAI_USE_VERTEXAI" ] = "1"
426
437
project = self ._tmpl_attrs .get ("project" )
@@ -460,18 +471,27 @@ def set_up(self):
460
471
else :
461
472
self ._tmpl_attrs ["session_service" ] = InMemorySessionService ()
462
473
474
+ memory_service_builder = self ._tmpl_attrs .get ("memory_service_builder" )
475
+ if memory_service_builder :
476
+ self ._tmpl_attrs ["memory_service" ] = memory_service_builder ()
477
+ else :
478
+ self ._tmpl_attrs ["memory_service" ] = InMemoryMemoryService ()
479
+
463
480
self ._tmpl_attrs ["runner" ] = Runner (
464
481
agent = self ._tmpl_attrs .get ("agent" ),
465
482
session_service = self ._tmpl_attrs .get ("session_service" ),
466
483
artifact_service = self ._tmpl_attrs .get ("artifact_service" ),
484
+ memory_service = self ._tmpl_attrs .get ("memory_service" ),
467
485
app_name = self ._tmpl_attrs .get ("app_name" ),
468
486
)
469
487
self ._tmpl_attrs ["in_memory_session_service" ] = InMemorySessionService ()
470
488
self ._tmpl_attrs ["in_memory_artifact_service" ] = InMemoryArtifactService ()
489
+ self ._tmpl_attrs ["in_memory_memory_service" ] = InMemoryMemoryService ()
471
490
self ._tmpl_attrs ["in_memory_runner" ] = Runner (
472
491
agent = self ._tmpl_attrs .get ("agent" ),
473
492
session_service = self ._tmpl_attrs .get ("in_memory_session_service" ),
474
493
artifact_service = self ._tmpl_attrs .get ("in_memory_artifact_service" ),
494
+ memory_service = self ._tmpl_attrs .get ("in_memory_memory_service" ),
475
495
app_name = self ._tmpl_attrs .get ("app_name" ),
476
496
)
477
497
0 commit comments