Skip to content

feat: support dataset update #1416

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 65 additions & 3 deletions google/cloud/aiplatform/datasets/dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

# Copyright 2020 Google LLC
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -31,6 +31,7 @@
io as gca_io,
)
from google.cloud.aiplatform.datasets import _datasources
from google.protobuf import field_mask_pb2

_LOGGER = base.Logger(__name__)

Expand Down Expand Up @@ -597,8 +598,69 @@ def export_data(self, output_dir: str) -> Sequence[str]:

return export_data_response.exported_files

def update(self):
raise NotImplementedError("Update dataset has not been implemented yet")
def update(
self,
*,
display_name: Optional[str] = None,
labels: Optional[Dict[str, str]] = None,
description: Optional[str] = None,
update_request_timeout: Optional[float] = None,
) -> "_Dataset":
"""Update the dataset.
Updatable fields:
- ``display_name``
- ``description``
- ``labels``

Args:
display_name (str):
Optional. The user-defined name of the Dataset.
The name can be up to 128 characters long and can be consist
of any UTF-8 characters.
labels (Dict[str, str]):
Optional. Labels with user-defined metadata to organize your Tensorboards.
Label keys and values can be no longer than 64 characters
(Unicode codepoints), can only contain lowercase letters, numeric
characters, underscores and dashes. International characters are allowed.
No more than 64 user labels can be associated with one Tensorboard
(System labels are excluded).
See https://goo.gl/xmQnxf for more information and examples of labels.
System reserved label keys are prefixed with "aiplatform.googleapis.com/"
and are immutable.
description (str):
Optional. The description of the Dataset.
update_request_timeout (float):
Optional. The timeout for the update request in seconds.

Returns:
dataset (Dataset):
Updated dataset.
"""

update_mask = field_mask_pb2.FieldMask()
if display_name:
update_mask.paths.append("display_name")

if labels:
update_mask.paths.append("labels")

if description:
update_mask.paths.append("description")

update_dataset = gca_dataset.Dataset(
name=self.resource_name,
display_name=display_name,
description=description,
labels=labels,
)

self._gca_resource = self.api_client.update_dataset(
dataset=update_dataset,
update_mask=update_mask,
timeout=update_request_timeout,
)

return self

@classmethod
def list(
Expand Down
27 changes: 26 additions & 1 deletion tests/system/aiplatform/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

# Copyright 2020 Google LLC
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -50,6 +50,8 @@
"6203215905493614592" # permanent_text_entity_extraction_dataset
)
_TEST_DATASET_DISPLAY_NAME = "permanent_50_flowers_dataset"
_TEST_DATASET_LABELS = {"test": "labels"}
_TEST_DATASET_DESCRIPTION = "test description"
_TEST_TABULAR_CLASSIFICATION_GCS_SOURCE = "gs://ucaip-sample-resources/iris_1000.csv"
_TEST_FORECASTING_BQ_SOURCE = (
"bq://ucaip-sample-tests:ucaip_test_us_central1.2020_sales_train"
Expand Down Expand Up @@ -350,3 +352,26 @@ def test_export_data(self, storage_client, staging_bucket):
blob = bucket.get_blob(prefix)

assert blob # Verify the returned GCS export path exists

def test_update_dataset(self):
"""Create a new dataset and use update() method to change its display_name, labels, and description.
Then confirm these fields of the dataset was successfully modifed."""

try:
dataset = aiplatform.ImageDataset.create()
labels = dataset.labels

dataset = dataset.update(
display_name=_TEST_DATASET_DISPLAY_NAME,
labels=_TEST_DATASET_LABELS,
description=_TEST_DATASET_DESCRIPTION,
update_request_timeout=None,
)
labels.update(_TEST_DATASET_LABELS)

assert dataset.display_name == _TEST_DATASET_DISPLAY_NAME
assert dataset.labels == labels
assert dataset.gca_resource.description == _TEST_DATASET_DESCRIPTION

finally:
dataset.delete()
46 changes: 46 additions & 0 deletions tests/unit/aiplatform/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from google.cloud.aiplatform import schema
from google.cloud import bigquery
from google.cloud import storage
from google.protobuf import field_mask_pb2

from google.cloud.aiplatform.compat.services import dataset_service_client

Expand All @@ -59,6 +60,7 @@
_TEST_ID = "1028944691210842416"
_TEST_DISPLAY_NAME = "my_dataset_1234"
_TEST_DATA_LABEL_ITEMS = None
_TEST_DESCRIPTION = "test description"

_TEST_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/datasets/{_TEST_ID}"
_TEST_ALT_NAME = (
Expand Down Expand Up @@ -425,6 +427,20 @@ def export_data_mock():
yield export_data_mock


@pytest.fixture
def update_dataset_mock():
with patch.object(
dataset_service_client.DatasetServiceClient, "update_dataset"
) as update_dataset_mock:
update_dataset_mock.return_value = gca_dataset.Dataset(
name=_TEST_NAME,
display_name=f"update_{_TEST_DISPLAY_NAME}",
labels=_TEST_LABELS,
description=_TEST_DESCRIPTION,
)
yield update_dataset_mock


@pytest.fixture
def list_datasets_mock():
with patch.object(
Expand Down Expand Up @@ -996,6 +1012,36 @@ def test_delete_dataset(self, delete_dataset_mock, sync):

delete_dataset_mock.assert_called_once_with(name=my_dataset.resource_name)

@pytest.mark.usefixtures("get_dataset_mock")
def test_update_dataset(self, update_dataset_mock):
aiplatform.init(project=_TEST_PROJECT)

my_dataset = datasets._Dataset(dataset_name=_TEST_NAME)

my_dataset = my_dataset.update(
display_name=f"update_{_TEST_DISPLAY_NAME}",
labels=_TEST_LABELS,
description=_TEST_DESCRIPTION,
update_request_timeout=None,
)

expected_dataset = gca_dataset.Dataset(
name=_TEST_NAME,
display_name=f"update_{_TEST_DISPLAY_NAME}",
labels=_TEST_LABELS,
description=_TEST_DESCRIPTION,
)

expected_mask = field_mask_pb2.FieldMask(
paths=["display_name", "labels", "description"]
)

update_dataset_mock.assert_called_once_with(
dataset=expected_dataset,
update_mask=expected_mask,
timeout=None,
)


@pytest.mark.usefixtures("google_auth_mock")
class TestImageDataset:
Expand Down