Skip to content

Commit bb27619

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: Use Client.list_blobs instead of Bucket.list_blobs in CPR artifact downloader, to make sure that CPR works with custom service accounts on Vertex Prediction.
PiperOrigin-RevId: 504956857
1 parent 2e35263 commit bb27619

File tree

2 files changed

+6
-11
lines changed

2 files changed

+6
-11
lines changed

google/cloud/aiplatform/utils/prediction_utils.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,7 @@ def download_model_artifacts(artifact_uri: str) -> None:
135135
bucket_name, prefix = matches.groups()
136136

137137
gcs_client = storage.Client()
138-
bucket = gcs_client.get_bucket(bucket_name)
139-
blobs = bucket.list_blobs(prefix=prefix)
138+
blobs = gcs_client.list_blobs(bucket_name, prefix=prefix)
140139
for blob in blobs:
141140
name_without_prefix = blob.name[len(prefix) :]
142141
name_without_prefix = (

tests/unit/aiplatform/test_utils.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,11 @@ def __init__(self, name):
7070
blob2 = mock.MagicMock()
7171
type(blob2).name = mock.PropertyMock(return_value=f"{GCS_PREFIX}/")
7272

73-
def get_blobs(prefix):
73+
def get_blobs(bucket_name, prefix=""):
7474
return [blob1, blob2]
7575

7676
with patch.object(storage, "Client") as mock_storage_client:
77-
get_bucket_mock = mock.Mock()
78-
get_bucket_mock.return_value.list_blobs.side_effect = get_blobs
79-
mock_storage_client.return_value.get_bucket.return_value = get_bucket_mock()
77+
mock_storage_client.return_value.list_blobs.side_effect = get_blobs
8078
yield mock_storage_client
8179

8280

@@ -806,16 +804,14 @@ def test_download_model_artifacts(self, mock_storage_client):
806804
prediction_utils.download_model_artifacts(f"gs://{GCS_BUCKET}/{GCS_PREFIX}")
807805

808806
assert mock_storage_client.called
809-
mock_storage_client().get_bucket.assert_called_once_with(GCS_BUCKET)
810-
mock_storage_client().get_bucket().list_blobs.assert_called_once_with(
811-
prefix=GCS_PREFIX
807+
mock_storage_client().list_blobs.assert_called_once_with(
808+
GCS_BUCKET, prefix=GCS_PREFIX
812809
)
813-
mock_storage_client().get_bucket().list_blobs.side_effect("")[
810+
mock_storage_client().list_blobs.side_effect("")[
814811
0
815812
].download_to_filename.assert_called_once_with(FAKE_FILENAME)
816813
assert (
817814
not mock_storage_client()
818-
.get_bucket()
819815
.list_blobs.side_effect("")[1]
820816
.download_to_filename.called
821817
)

0 commit comments

Comments
 (0)