Skip to content

Commit 3aec6a7

Browse files
authored
feat: _TrainingScriptPythonPackager to support folders (#812)
* _TrainingScriptPythonPackager to support folders Allow _TrainingScriptPythonPackager to support a folder in addition to a single script. * Update source_utils.py * Fixed tests * More fixes to test * Added missing imports * Removed unused import
1 parent 153578f commit 3aec6a7

File tree

3 files changed

+35
-22
lines changed

3 files changed

+35
-22
lines changed

google/cloud/aiplatform/utils/source_utils.py

+22-11
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717

1818
import functools
19+
import os
1920
import pathlib
2021
import shutil
2122
import subprocess
@@ -62,7 +63,7 @@ class _TrainingScriptPythonPackager:
6263
Constant command to generate the source distribution package.
6364
6465
Attributes:
65-
script_path: local path of script to package
66+
script_path: local path of script or folder to package
6667
requirements: list of Python dependencies to add to package
6768
6869
Usage:
@@ -79,7 +80,6 @@ class _TrainingScriptPythonPackager:
7980

8081
_TRAINER_FOLDER = "trainer"
8182
_ROOT_MODULE = "aiplatform_custom_trainer_script"
82-
_TASK_MODULE_NAME = "task"
8383
_SETUP_PY_VERSION = "0.1"
8484

8585
_SETUP_PY_TEMPLATE = """from setuptools import find_packages
@@ -96,10 +96,12 @@ class _TrainingScriptPythonPackager:
9696

9797
_SETUP_PY_SOURCE_DISTRIBUTION_CMD = "setup.py sdist --formats=gztar"
9898

99-
# Module name that can be executed during training. ie. python -m
100-
module_name = f"{_ROOT_MODULE}.{_TASK_MODULE_NAME}"
101-
102-
def __init__(self, script_path: str, requirements: Optional[Sequence[str]] = None):
99+
def __init__(
100+
self,
101+
script_path: str,
102+
task_module_name: str = "task",
103+
requirements: Optional[Sequence[str]] = None,
104+
):
103105
"""Initializes packager.
104106
105107
Args:
@@ -109,8 +111,14 @@ def __init__(self, script_path: str, requirements: Optional[Sequence[str]] = Non
109111
"""
110112

111113
self.script_path = script_path
114+
self.task_module_name = task_module_name
112115
self.requirements = requirements or []
113116

117+
@property
118+
def module_name(self) -> str:
119+
# Module name that can be executed during training. ie. python -m
120+
return f"{self._ROOT_MODULE}.{self.task_module_name}"
121+
114122
def make_package(self, package_directory: str) -> str:
115123
"""Converts script into a Python package suitable for python module
116124
execution.
@@ -134,9 +142,6 @@ def make_package(self, package_directory: str) -> str:
134142
# __init__.py path in root module
135143
init_path = trainer_path / "__init__.py"
136144

137-
# The module that will contain the script
138-
script_out_path = trainer_path / f"{self._TASK_MODULE_NAME}.py"
139-
140145
# The path to setup.py in the package.
141146
setup_py_path = trainer_root_path / "setup.py"
142147

@@ -165,8 +170,14 @@ def make_package(self, package_directory: str) -> str:
165170
with setup_py_path.open("w") as fp:
166171
fp.write(setup_py_output)
167172

168-
# Copy script as module of python package.
169-
shutil.copy(self.script_path, script_out_path)
173+
if os.path.isdir(self.script_path):
174+
shutil.copytree(self.script_path, trainer_path, dirs_exist_ok=True)
175+
else:
176+
# The module that will contain the script
177+
script_out_path = trainer_path / f"{self.task_module_name}.py"
178+
179+
# Copy script as module of python package.
180+
shutil.copy(self.script_path, script_out_path)
170181

171182
# Run setup.py to create the source distribution.
172183
setup_cmd = [

tests/unit/aiplatform/test_end_to_end.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from google.cloud.aiplatform import initializer
2525
from google.cloud.aiplatform import models
2626
from google.cloud.aiplatform import schema
27-
from google.cloud.aiplatform.utils import source_utils
2827

2928
from google.cloud.aiplatform_v1.types import (
3029
dataset as gca_dataset,
@@ -224,7 +223,7 @@ def test_dataset_create_to_model_predict(
224223
},
225224
"python_package_spec": {
226225
"executor_image_uri": test_training_jobs._TEST_TRAINING_CONTAINER_IMAGE,
227-
"python_module": source_utils._TrainingScriptPythonPackager.module_name,
226+
"python_module": test_training_jobs._TEST_MODULE_NAME,
228227
"package_uris": [test_training_jobs._TEST_OUTPUT_PYTHON_PACKAGE_PATH],
229228
"args": true_args,
230229
},
@@ -411,7 +410,7 @@ def test_dataset_create_to_model_predict_with_pipeline_fail(
411410
},
412411
"python_package_spec": {
413412
"executor_image_uri": test_training_jobs._TEST_TRAINING_CONTAINER_IMAGE,
414-
"python_module": source_utils._TrainingScriptPythonPackager.module_name,
413+
"python_module": test_training_jobs._TEST_MODULE_NAME,
415414
"package_uris": [test_training_jobs._TEST_OUTPUT_PYTHON_PACKAGE_PATH],
416415
"args": true_args,
417416
},

tests/unit/aiplatform/test_training_jobs.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
from unittest import mock
3232
from unittest.mock import patch
3333

34+
import test_training_jobs
35+
3436
from google.auth import credentials as auth_credentials
3537

3638
from google.cloud import aiplatform
@@ -89,6 +91,7 @@
8991
_TEST_SERVING_CONTAINER_IMAGE = "gcr.io/test-serving/container:image"
9092
_TEST_SERVING_CONTAINER_PREDICTION_ROUTE = "predict"
9193
_TEST_SERVING_CONTAINER_HEALTH_ROUTE = "metadata"
94+
_TEST_MODULE_NAME = f"{source_utils._TrainingScriptPythonPackager._ROOT_MODULE}.task"
9295

9396
_TEST_METADATA_SCHEMA_URI_NONTABULAR = schema.dataset.metadata.image
9497
_TEST_ANNOTATION_SCHEMA_URI = schema.dataset.annotation.image.classification
@@ -827,7 +830,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
827830
},
828831
"python_package_spec": {
829832
"executor_image_uri": _TEST_TRAINING_CONTAINER_IMAGE,
830-
"python_module": source_utils._TrainingScriptPythonPackager.module_name,
833+
"python_module": _TEST_MODULE_NAME,
831834
"package_uris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH],
832835
"args": true_args,
833836
"env": true_env,
@@ -995,7 +998,7 @@ def test_run_call_pipeline_service_create_with_bigquery_destination(
995998
},
996999
"python_package_spec": {
9971000
"executor_image_uri": _TEST_TRAINING_CONTAINER_IMAGE,
998-
"python_module": source_utils._TrainingScriptPythonPackager.module_name,
1001+
"python_module": test_training_jobs._TEST_MODULE_NAME,
9991002
"package_uris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH],
10001003
"args": true_args,
10011004
"env": true_env,
@@ -1303,7 +1306,7 @@ def test_run_call_pipeline_service_create_with_no_dataset(
13031306
},
13041307
"python_package_spec": {
13051308
"executor_image_uri": _TEST_TRAINING_CONTAINER_IMAGE,
1306-
"python_module": source_utils._TrainingScriptPythonPackager.module_name,
1309+
"python_module": test_training_jobs._TEST_MODULE_NAME,
13071310
"package_uris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH],
13081311
"args": true_args,
13091312
"env": true_env,
@@ -1606,7 +1609,7 @@ def test_run_call_pipeline_service_create_distributed_training(
16061609
},
16071610
"python_package_spec": {
16081611
"executor_image_uri": _TEST_TRAINING_CONTAINER_IMAGE,
1609-
"python_module": source_utils._TrainingScriptPythonPackager.module_name,
1612+
"python_module": test_training_jobs._TEST_MODULE_NAME,
16101613
"package_uris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH],
16111614
"args": true_args,
16121615
"env": true_env,
@@ -1625,7 +1628,7 @@ def test_run_call_pipeline_service_create_distributed_training(
16251628
},
16261629
"python_package_spec": {
16271630
"executor_image_uri": _TEST_TRAINING_CONTAINER_IMAGE,
1628-
"python_module": source_utils._TrainingScriptPythonPackager.module_name,
1631+
"python_module": _TEST_MODULE_NAME,
16291632
"package_uris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH],
16301633
"args": true_args,
16311634
"env": true_env,
@@ -1756,7 +1759,7 @@ def test_run_call_pipeline_service_create_distributed_training_with_reduction_se
17561759
},
17571760
"python_package_spec": {
17581761
"executor_image_uri": _TEST_TRAINING_CONTAINER_IMAGE,
1759-
"python_module": source_utils._TrainingScriptPythonPackager.module_name,
1762+
"python_module": _TEST_MODULE_NAME,
17601763
"package_uris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH],
17611764
"args": true_args,
17621765
"env": true_env,
@@ -1775,7 +1778,7 @@ def test_run_call_pipeline_service_create_distributed_training_with_reduction_se
17751778
},
17761779
"python_package_spec": {
17771780
"executor_image_uri": _TEST_TRAINING_CONTAINER_IMAGE,
1778-
"python_module": source_utils._TrainingScriptPythonPackager.module_name,
1781+
"python_module": test_training_jobs._TEST_MODULE_NAME,
17791782
"package_uris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH],
17801783
"args": true_args,
17811784
"env": true_env,
@@ -2013,7 +2016,7 @@ def test_run_call_pipeline_service_create_with_nontabular_dataset_without_model_
20132016
},
20142017
"python_package_spec": {
20152018
"executor_image_uri": _TEST_TRAINING_CONTAINER_IMAGE,
2016-
"python_module": source_utils._TrainingScriptPythonPackager.module_name,
2019+
"python_module": test_training_jobs._TEST_MODULE_NAME,
20172020
"package_uris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH],
20182021
"args": true_args,
20192022
},

0 commit comments

Comments
 (0)