|
51 | 51 |
|
52 | 52 | _TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}"
|
53 | 53 | _TEST_API_ENDPOINT = f"{_TEST_LOCATION}-aiplatform.googleapis.com"
|
54 |
| -_TEST_IMAGE_DATASET_ID = "1084241610289446912" # permanent_50_flowers_dataset |
| 54 | +_TEST_IMAGE_DATASET_ID = "1997950066622464000" # permanent_50_flowers_dataset |
55 | 55 | _TEST_TEXT_DATASET_ID = (
|
56 | 56 | "6203215905493614592" # permanent_text_entity_extraction_dataset
|
57 | 57 | )
|
@@ -390,24 +390,24 @@ def test_export_data_for_custom_training(self, staging_bucket):
|
390 | 390 | # Custom training data export should be generic, hence using the base
|
391 | 391 | # _Dataset class here in test. In practice, users shuold be able to
|
392 | 392 | # use this function in any inhericted classes of _Dataset.
|
393 |
| - dataset = aiplatform.datasets._Dataset(dataset_name=_TEST_TEXT_DATASET_ID) |
| 393 | + dataset = aiplatform.datasets._Dataset(dataset_name=_TEST_IMAGE_DATASET_ID) |
394 | 394 |
|
395 | 395 | split = {
|
396 |
| - "training_fraction": 0.6, |
397 |
| - "validation_fraction": 0.2, |
398 |
| - "test_fraction": 0.2, |
| 396 | + "training_filter": "labels.aiplatform.googleapis.com/ml_use=training", |
| 397 | + "validation_filter": "labels.aiplatform.googleapis.com/ml_use=validation", |
| 398 | + "test_filter": "labels.aiplatform.googleapis.com/ml_use=test", |
399 | 399 | }
|
400 | 400 |
|
401 | 401 | export_data_response = dataset.export_data_for_custom_training(
|
402 | 402 | output_dir=f"gs://{staging_bucket.name}",
|
403 |
| - annotation_schema_uri="gs://google-cloud-aiplatform/schema/dataset/annotation/text_classification_1.0.0.yaml", |
| 403 | + annotation_schema_uri="gs://google-cloud-aiplatform/schema/dataset/annotation/image_classification_1.0.0.yaml", |
404 | 404 | split=split,
|
405 | 405 | )
|
406 | 406 |
|
407 | 407 | # Ensure three output paths (training, validation and test) are provided
|
408 | 408 | assert len(export_data_response["exported_files"]) == 3
|
409 |
| - # Ensure data stats are calculated and present |
410 |
| - assert export_data_response["data_stats"]["training_data_items_count"] > 0 |
| 409 | + # Ensure data stats are calculated and correct |
| 410 | + assert export_data_response["data_stats"]["training_data_items_count"] == 40 |
411 | 411 |
|
412 | 412 | def test_update_dataset(self):
|
413 | 413 | """Create a new dataset and use update() method to change its display_name, labels, and description.
|
|
0 commit comments