Skip to content

Commit d3d5f9a

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: Retry for etag errors on context update.
PiperOrigin-RevId: 537445290
1 parent 635ae9c commit d3d5f9a

File tree

2 files changed

+158
-0
lines changed

2 files changed

+158
-0
lines changed

google/cloud/aiplatform/metadata/context.py

+47
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import Optional, Dict, List, Sequence
1919

2020
import proto
21+
import re
2122
import threading
2223

2324
from google.auth import credentials as auth_credentials
@@ -37,6 +38,12 @@
3738
from google.cloud.aiplatform.metadata import execution
3839
from google.cloud.aiplatform.metadata import metadata_store
3940
from google.cloud.aiplatform.metadata import resource
41+
from google.api_core.exceptions import Aborted
42+
43+
_ETAG_ERROR_MAX_RETRY_COUNT = 5
44+
_ETAG_ERROR_REGEX = re.compile(
45+
r"Specified Context \`etag\`: \`(\d+)\` does not match server \`etag\`: \`(\d+)\`"
46+
)
4047

4148

4249
class Context(resource._Resource):
@@ -278,6 +285,46 @@ def _create_resource(
278285
context_id=resource_id,
279286
)
280287

288+
def update(
289+
self,
290+
metadata: Optional[Dict] = None,
291+
description: Optional[str] = None,
292+
credentials: Optional[auth_credentials.Credentials] = None,
293+
):
294+
"""Updates an existing Metadata Context with new metadata.
295+
296+
This is implemented with retry on etag errors, up to
297+
_ETAG_ERROR_MAX_RETRY_COUNT times.
298+
Args:
299+
metadata (Dict):
300+
Optional. metadata contains the updated metadata information.
301+
description (str):
302+
Optional. Description describes the resource to be updated.
303+
credentials (auth_credentials.Credentials):
304+
Custom credentials to use to update this resource. Overrides
305+
credentials set in aiplatform.init.
306+
"""
307+
for _ in range(_ETAG_ERROR_MAX_RETRY_COUNT - 1):
308+
try:
309+
super().update(
310+
metadata=metadata, description=description, credentials=credentials
311+
)
312+
return
313+
except Aborted as aborted_exception:
314+
regex_match = _ETAG_ERROR_REGEX.match(aborted_exception.message)
315+
if regex_match:
316+
local_etag = regex_match.group(1)
317+
server_etag = regex_match.group(2)
318+
if local_etag < server_etag:
319+
self.sync_resource()
320+
continue
321+
raise aborted_exception
322+
323+
# Expose result/exception directly in the last retry.
324+
super().update(
325+
metadata=metadata, description=description, credentials=credentials
326+
)
327+
281328
@classmethod
282329
def _update_resource(
283330
cls,

tests/unit/aiplatform/test_metadata_resources.py

+111
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,53 @@ def update_context_mock():
153153
yield update_context_mock
154154

155155

156+
@pytest.fixture
157+
def update_context_with_errors_mock():
158+
with patch.object(
159+
MetadataServiceClient, "update_context"
160+
) as update_context_with_errors_mock:
161+
update_context_with_errors_mock.side_effect = [
162+
exceptions.Aborted(
163+
"Specified Context `etag`: `1` does not match server `etag`: `2`"
164+
),
165+
GapicContext(
166+
name=_TEST_CONTEXT_NAME,
167+
display_name=_TEST_DISPLAY_NAME,
168+
schema_title=_TEST_SCHEMA_TITLE,
169+
schema_version=_TEST_SCHEMA_VERSION,
170+
description=_TEST_DESCRIPTION,
171+
metadata=_TEST_UPDATED_METADATA,
172+
),
173+
]
174+
yield update_context_with_errors_mock
175+
176+
177+
@pytest.fixture
178+
def update_context_with_errors_mock_2():
179+
with patch.object(
180+
MetadataServiceClient, "update_context"
181+
) as update_context_with_errors_mock_2:
182+
update_context_with_errors_mock_2.side_effect = [
183+
exceptions.Aborted(
184+
"Specified Context `etag`: `2` does not match server `etag`: `1`"
185+
)
186+
]
187+
yield update_context_with_errors_mock_2
188+
189+
190+
@pytest.fixture
191+
def update_context_with_errors_mock_3():
192+
with patch.object(
193+
MetadataServiceClient, "update_context"
194+
) as update_context_with_errors_mock_3:
195+
update_context_with_errors_mock_3.side_effect = [
196+
exceptions.Aborted(
197+
"Specified Context `etag`: `1` does not match server `etag`: `2`"
198+
)
199+
] * 6
200+
yield update_context_with_errors_mock_2
201+
202+
156203
@pytest.fixture
157204
def add_context_artifacts_and_executions_mock():
158205
with patch.object(
@@ -482,6 +529,70 @@ def test_update_context(self, update_context_mock):
482529
update_context_mock.assert_called_once_with(context=updated_context)
483530
assert my_context._gca_resource == updated_context
484531

532+
@pytest.mark.usefixtures("get_context_mock")
533+
@pytest.mark.usefixtures("create_context_mock")
534+
def test_update_context_with_retry_success(self, update_context_with_errors_mock):
535+
aiplatform.init(project=_TEST_PROJECT)
536+
537+
my_context = context.Context._create(
538+
resource_id=_TEST_CONTEXT_ID,
539+
schema_title=_TEST_SCHEMA_TITLE,
540+
display_name=_TEST_DISPLAY_NAME,
541+
schema_version=_TEST_SCHEMA_VERSION,
542+
description=_TEST_DESCRIPTION,
543+
metadata=_TEST_METADATA,
544+
metadata_store_id=_TEST_METADATA_STORE,
545+
)
546+
my_context.update(_TEST_UPDATED_METADATA)
547+
548+
updated_context = GapicContext(
549+
name=_TEST_CONTEXT_NAME,
550+
schema_title=_TEST_SCHEMA_TITLE,
551+
schema_version=_TEST_SCHEMA_VERSION,
552+
display_name=_TEST_DISPLAY_NAME,
553+
description=_TEST_DESCRIPTION,
554+
metadata=_TEST_UPDATED_METADATA,
555+
)
556+
557+
update_context_with_errors_mock.assert_called_with(context=updated_context)
558+
assert my_context._gca_resource == updated_context
559+
560+
@pytest.mark.usefixtures("get_context_mock")
561+
@pytest.mark.usefixtures("create_context_mock")
562+
@pytest.mark.usefixtures("update_context_with_errors_mock_2")
563+
def test_update_context_with_retry_etag_order_failure(self):
564+
aiplatform.init(project=_TEST_PROJECT)
565+
566+
my_context = context.Context._create(
567+
resource_id=_TEST_CONTEXT_ID,
568+
schema_title=_TEST_SCHEMA_TITLE,
569+
display_name=_TEST_DISPLAY_NAME,
570+
schema_version=_TEST_SCHEMA_VERSION,
571+
description=_TEST_DESCRIPTION,
572+
metadata=_TEST_METADATA,
573+
metadata_store_id=_TEST_METADATA_STORE,
574+
)
575+
with pytest.raises(exceptions.Aborted):
576+
my_context.update(_TEST_UPDATED_METADATA)
577+
578+
@pytest.mark.usefixtures("get_context_mock")
579+
@pytest.mark.usefixtures("create_context_mock")
580+
@pytest.mark.usefixtures("update_context_with_errors_mock_3")
581+
def test_update_context_with_retry_too_many_error_failure(self):
582+
aiplatform.init(project=_TEST_PROJECT)
583+
584+
my_context = context.Context._create(
585+
resource_id=_TEST_CONTEXT_ID,
586+
schema_title=_TEST_SCHEMA_TITLE,
587+
display_name=_TEST_DISPLAY_NAME,
588+
schema_version=_TEST_SCHEMA_VERSION,
589+
description=_TEST_DESCRIPTION,
590+
metadata=_TEST_METADATA,
591+
metadata_store_id=_TEST_METADATA_STORE,
592+
)
593+
with pytest.raises(exceptions.Aborted):
594+
my_context.update(_TEST_UPDATED_METADATA)
595+
485596
@pytest.mark.usefixtures("get_context_mock")
486597
def test_list_contexts(self, list_contexts_mock):
487598
aiplatform.init(project=_TEST_PROJECT)

0 commit comments

Comments
 (0)