Skip to content

Commit 6be874a

Browse files
Zhenyi Qicopybara-github
Zhenyi Qi
authored andcommitted
feat: GenAI - Context Caching - add get() classmethod and refresh() instance method
PiperOrigin-RevId: 644141561
1 parent 62f7af5 commit 6be874a

File tree

2 files changed

+67
-18
lines changed

2 files changed

+67
-18
lines changed

tests/unit/vertexai/test_caching.py

+45
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,32 @@ def get_cached_content(self, name, retry=None):
8181
yield get_cached_content
8282

8383

84+
@pytest.fixture
85+
def mock_list_cached_contents():
86+
"""Mocks GenAiCacheServiceClient.get_cached_content()."""
87+
88+
def list_cached_contents(self, request):
89+
del self, request
90+
response = [
91+
GapicCachedContent(
92+
name="cached_content1_from_list_request",
93+
model="model-name1",
94+
),
95+
GapicCachedContent(
96+
name="cached_content2_from_list_request",
97+
model="model-name2",
98+
),
99+
]
100+
return response
101+
102+
with mock.patch.object(
103+
gen_ai_cache_service.client.GenAiCacheServiceClient,
104+
"list_cached_contents",
105+
new=list_cached_contents,
106+
) as list_cached_contents:
107+
yield list_cached_contents
108+
109+
84110
@pytest.mark.usefixtures("google_auth_mock")
85111
class TestCaching:
86112
"""Unit tests for caching.CachedContent."""
@@ -118,6 +144,19 @@ def test_constructor_with_only_content_id(self, mock_get_cached_content):
118144
)
119145
assert cache.model_name == "model-name"
120146

147+
def test_get_with_content_id(self, mock_get_cached_content):
148+
partial_resource_name = "contents-id"
149+
150+
cache = caching.CachedContent.get(
151+
cached_content_name=partial_resource_name,
152+
)
153+
154+
assert cache.name == "contents-id"
155+
assert cache.resource_name == (
156+
f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/cachedContents/contents-id"
157+
)
158+
assert cache.model_name == "model-name"
159+
121160
def test_create_with_real_payload(
122161
self, mock_create_cached_content, mock_get_cached_content
123162
):
@@ -162,3 +201,9 @@ def test_create_with_real_payload_and_wrapped_type(
162201
== f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/model-name"
163202
)
164203
assert cache.name == _CREATED_CONTENT_ID
204+
205+
def test_list(self, mock_list_cached_contents):
206+
cached_contents = caching.CachedContent.list()
207+
for i, cached_content in enumerate(cached_contents):
208+
assert cached_content.name == f"cached_content{i + 1}_from_list_request"
209+
assert cached_content.model_name == f"model-name{i + 1}"

vertexai/caching/_caching.py

+22-18
Original file line numberDiff line numberDiff line change
@@ -135,28 +135,15 @@ def __init__(self, cached_content_name: str):
135135
"456".
136136
"""
137137
super().__init__(resource_name=cached_content_name)
138-
139-
resource_name = aiplatform_utils.full_resource_name(
140-
resource_name=cached_content_name,
141-
resource_noun=self._resource_noun,
142-
parse_resource_name_method=self._parse_resource_name,
143-
format_resource_name_method=self._format_resource_name,
144-
project=self.project,
145-
location=self.location,
146-
parent_resource_name_fields=None,
147-
resource_id_validator=self._resource_id_validator,
148-
)
149-
self._gca_resource = gca_cached_content.CachedContent(name=resource_name)
138+
self._gca_resource = self._get_gca_resource(cached_content_name)
150139

151140
@property
152141
def _raw_cached_content(self) -> gca_cached_content.CachedContent:
153142
return self._gca_resource
154143

155144
@property
156145
def model_name(self) -> str:
157-
if not self._raw_cached_content.model:
158-
self._sync_gca_resource()
159-
return self._raw_cached_content.model
146+
return self._gca_resource.model
160147

161148
@classmethod
162149
def create(
@@ -235,6 +222,10 @@ def create(
235222
obj._gca_resource = cached_content_resource
236223
return obj
237224

225+
def refresh(self):
226+
"""Syncs the local cached content with the remote resource."""
227+
self._sync_gca_resource()
228+
238229
def update(
239230
self,
240231
*,
@@ -265,15 +256,28 @@ def update(
265256

266257
@property
267258
def expire_time(self) -> datetime.datetime:
268-
"""Time this resource was last updated."""
269-
self._sync_gca_resource()
259+
"""Time this resource is considered expired.
260+
261+
The returned value may be stale. Use refresh() to get the latest value.
262+
263+
Returns:
264+
The expiration time of the cached content resource.
265+
"""
270266
return self._gca_resource.expire_time
271267

272268
def delete(self):
269+
"""Deletes the current cached content resource."""
273270
self._delete()
274271

275272
@classmethod
276-
def list(cls):
273+
def list(cls) -> List["CachedContent"]:
274+
"""Lists the active cached content resources."""
277275
# TODO(b/345326114): Make list() interface richer after aligning with
278276
# Google AI SDK
279277
return cls._list()
278+
279+
@classmethod
280+
def get(cls, cached_content_name: str) -> "CachedContent":
281+
"""Retrieves an existing cached content resource."""
282+
cache = cls(cached_content_name)
283+
return cache

0 commit comments

Comments
 (0)