Skip to content

Commit 733fddd

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add support for ADK memory service to AdkApp template
PiperOrigin-RevId: 775323095
1 parent beae2e3 commit 733fddd

File tree

1 file changed

+20
-0
lines changed
  • vertexai/preview/reasoning_engines/templates

1 file changed

+20
-0
lines changed

vertexai/preview/reasoning_engines/templates/adk.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,13 @@
5858
except (ImportError, AttributeError):
5959
BaseArtifactService = Any
6060

61+
try:
62+
from google.adk.memory import BaseMemoryService
63+
64+
BaseMemoryService = BaseMemoryService
65+
except (ImportError, AttributeError):
66+
BaseMemoryService = Any
67+
6168
try:
6269
from opentelemetry.sdk import trace
6370

@@ -281,6 +288,7 @@ def __init__(
281288
enable_tracing: bool = False,
282289
session_service_builder: Optional[Callable[..., "BaseSessionService"]] = None,
283290
artifact_service_builder: Optional[Callable[..., "BaseArtifactService"]] = None,
291+
memory_service_builder: Optional[Callable[..., "BaseMemoryService"]] = None,
284292
env_vars: Optional[Dict[str, str]] = None,
285293
):
286294
"""An ADK Application."""
@@ -301,6 +309,7 @@ def __init__(
301309
"enable_tracing": enable_tracing,
302310
"session_service_builder": session_service_builder,
303311
"artifact_service_builder": artifact_service_builder,
312+
"memory_service_builder": memory_service_builder,
304313
"app_name": _DEFAULT_APP_NAME,
305314
"env_vars": env_vars or {},
306315
}
@@ -410,6 +419,7 @@ def clone(self):
410419
enable_tracing=self._tmpl_attrs.get("enable_tracing"),
411420
session_service_builder=self._tmpl_attrs.get("session_service_builder"),
412421
artifact_service_builder=self._tmpl_attrs.get("artifact_service_builder"),
422+
memory_service_builder=self._tmpl_attrs.get("memory_service_builder"),
413423
env_vars=self._tmpl_attrs.get("env_vars"),
414424
)
415425

@@ -421,6 +431,7 @@ def set_up(self):
421431
from google.adk.artifacts.in_memory_artifact_service import (
422432
InMemoryArtifactService,
423433
)
434+
from google.adk.memory.in_memory_memory_service import InMemoryMemoryService
424435

425436
os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "1"
426437
project = self._tmpl_attrs.get("project")
@@ -460,18 +471,27 @@ def set_up(self):
460471
else:
461472
self._tmpl_attrs["session_service"] = InMemorySessionService()
462473

474+
memory_service_builder = self._tmpl_attrs.get("memory_service_builder")
475+
if memory_service_builder:
476+
self._tmpl_attrs["memory_service"] = memory_service_builder()
477+
else:
478+
self._tmpl_attrs["memory_service"] = InMemoryMemoryService()
479+
463480
self._tmpl_attrs["runner"] = Runner(
464481
agent=self._tmpl_attrs.get("agent"),
465482
session_service=self._tmpl_attrs.get("session_service"),
466483
artifact_service=self._tmpl_attrs.get("artifact_service"),
484+
memory_service=self._tmpl_attrs.get("memory_service"),
467485
app_name=self._tmpl_attrs.get("app_name"),
468486
)
469487
self._tmpl_attrs["in_memory_session_service"] = InMemorySessionService()
470488
self._tmpl_attrs["in_memory_artifact_service"] = InMemoryArtifactService()
489+
self._tmpl_attrs["in_memory_memory_service"] = InMemoryMemoryService()
471490
self._tmpl_attrs["in_memory_runner"] = Runner(
472491
agent=self._tmpl_attrs.get("agent"),
473492
session_service=self._tmpl_attrs.get("in_memory_session_service"),
474493
artifact_service=self._tmpl_attrs.get("in_memory_artifact_service"),
494+
memory_service=self._tmpl_attrs.get("in_memory_memory_service"),
475495
app_name=self._tmpl_attrs.get("app_name"),
476496
)
477497

0 commit comments

Comments
 (0)