Skip to content

Commit 1fbf049

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Fix dataset export system test
PiperOrigin-RevId: 597603710
1 parent bbdd9e2 commit 1fbf049

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

tests/system/aiplatform/test_dataset.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151

5252
_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}"
5353
_TEST_API_ENDPOINT = f"{_TEST_LOCATION}-aiplatform.googleapis.com"
54-
_TEST_IMAGE_DATASET_ID = "1084241610289446912" # permanent_50_flowers_dataset
54+
_TEST_IMAGE_DATASET_ID = "1997950066622464000" # permanent_50_flowers_dataset
5555
_TEST_TEXT_DATASET_ID = (
5656
"6203215905493614592" # permanent_text_entity_extraction_dataset
5757
)
@@ -390,24 +390,24 @@ def test_export_data_for_custom_training(self, staging_bucket):
390390
# Custom training data export should be generic, hence using the base
391391
# _Dataset class here in test. In practice, users shuold be able to
392392
# use this function in any inhericted classes of _Dataset.
393-
dataset = aiplatform.datasets._Dataset(dataset_name=_TEST_TEXT_DATASET_ID)
393+
dataset = aiplatform.datasets._Dataset(dataset_name=_TEST_IMAGE_DATASET_ID)
394394

395395
split = {
396-
"training_fraction": 0.6,
397-
"validation_fraction": 0.2,
398-
"test_fraction": 0.2,
396+
"training_filter": "labels.aiplatform.googleapis.com/ml_use=training",
397+
"validation_filter": "labels.aiplatform.googleapis.com/ml_use=validation",
398+
"test_filter": "labels.aiplatform.googleapis.com/ml_use=test",
399399
}
400400

401401
export_data_response = dataset.export_data_for_custom_training(
402402
output_dir=f"gs://{staging_bucket.name}",
403-
annotation_schema_uri="gs://google-cloud-aiplatform/schema/dataset/annotation/text_classification_1.0.0.yaml",
403+
annotation_schema_uri="gs://google-cloud-aiplatform/schema/dataset/annotation/image_classification_1.0.0.yaml",
404404
split=split,
405405
)
406406

407407
# Ensure three output paths (training, validation and test) are provided
408408
assert len(export_data_response["exported_files"]) == 3
409-
# Ensure data stats are calculated and present
410-
assert export_data_response["data_stats"]["training_data_items_count"] > 0
409+
# Ensure data stats are calculated and correct
410+
assert export_data_response["data_stats"]["training_data_items_count"] == 40
411411

412412
def test_update_dataset(self):
413413
"""Create a new dataset and use update() method to change its display_name, labels, and description.

0 commit comments

Comments
 (0)