Skip to content

Commit 50d4129

Browse files
ivanmkcrosiezou
andauthored
fix: Removed dirs_exist_ok parameter as it's not backwards compatible (#1170)
* fix: Removed dirs_exist_ok parameter as it's not backwards compatible * Added unit tests and fixed bug * Removed unneeded import Co-authored-by: Rosie Zou <[email protected]>
1 parent 9ef057a commit 50d4129

File tree

2 files changed

+82
-1
lines changed

2 files changed

+82
-1
lines changed

google/cloud/aiplatform/utils/source_utils.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,11 @@ def make_package(self, package_directory: str) -> str:
171171
fp.write(setup_py_output)
172172

173173
if os.path.isdir(self.script_path):
174-
shutil.copytree(self.script_path, trainer_path, dirs_exist_ok=True)
174+
# Remove destination path if it already exists
175+
shutil.rmtree(trainer_path)
176+
177+
# Copy folder recursively
178+
shutil.copytree(src=self.script_path, dst=trainer_path)
175179
else:
176180
# The module that will contain the script
177181
script_out_path = trainer_path / f"{self.task_module_name}.py"

tests/unit/aiplatform/test_training_utils.py

+77
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,14 @@
1616
#
1717

1818
from importlib import reload
19+
import filecmp
1920
import json
2021
import os
2122
import pytest
23+
import tempfile
2224

2325
from google.cloud.aiplatform.training_utils import environment_variables
26+
from google.cloud.aiplatform.utils import source_utils
2427
from unittest import mock
2528

2629
_TEST_TRAINING_DATA_URI = "gs://training-data-uri"
@@ -203,3 +206,77 @@ def test_http_handler_port(self):
203206
def test_http_handler_port_none(self):
204207
reload(environment_variables)
205208
assert environment_variables.http_handler_port is None
209+
210+
@pytest.fixture()
211+
def mock_temp_file_name(self):
212+
# Create random files
213+
# tmpdirname = tempfile.TemporaryDirectory()
214+
file = tempfile.NamedTemporaryFile()
215+
216+
with open(file.name, "w") as handle:
217+
handle.write("test")
218+
219+
yield file.name
220+
221+
file.close()
222+
223+
def test_package_file(self, mock_temp_file_name):
224+
# Test that the packager properly copies the source file to the destination file
225+
226+
packager = source_utils._TrainingScriptPythonPackager(
227+
script_path=mock_temp_file_name
228+
)
229+
230+
with tempfile.TemporaryDirectory() as destination_directory_name:
231+
_ = packager.make_package(package_directory=destination_directory_name)
232+
233+
# Check that contents of source_distribution_path is the same as destination_directory_name
234+
destination_inner_path = f"{destination_directory_name}/{packager._TRAINER_FOLDER}/{packager._ROOT_MODULE}/{packager.task_module_name}.py"
235+
236+
assert filecmp.cmp(
237+
mock_temp_file_name, destination_inner_path, shallow=False
238+
)
239+
240+
@pytest.fixture()
241+
def mock_temp_folder_name(self):
242+
# Create random folder
243+
folder = tempfile.TemporaryDirectory()
244+
245+
file = tempfile.NamedTemporaryFile(dir=folder.name)
246+
247+
# Create random file in the folder
248+
with open(file.name, "w") as handle:
249+
handle.write("test")
250+
251+
yield folder.name
252+
253+
file.close()
254+
255+
folder.cleanup()
256+
257+
def test_package_folder(self, mock_temp_folder_name):
258+
# Test that the packager properly copies the source folder to the destination folder
259+
260+
packager = source_utils._TrainingScriptPythonPackager(
261+
script_path=mock_temp_folder_name
262+
)
263+
264+
with tempfile.TemporaryDirectory() as destination_directory_name:
265+
# Add an existing file into the destination directory to check if it gets deleted
266+
existing_file = tempfile.NamedTemporaryFile(dir=destination_directory_name)
267+
268+
with open(existing_file.name, "w") as handle:
269+
handle.write("existing")
270+
271+
_ = packager.make_package(package_directory=destination_directory_name)
272+
273+
# Check that contents of source_distribution_path is the same as destination_directory_name
274+
destination_inner_path = f"{destination_directory_name}/{packager._TRAINER_FOLDER}/{packager._ROOT_MODULE}"
275+
276+
dcmp = filecmp.dircmp(mock_temp_folder_name, destination_inner_path)
277+
278+
assert len(dcmp.diff_files) == 0
279+
assert len(dcmp.left_only) == 0
280+
assert len(dcmp.right_only) == 0
281+
282+
existing_file.close()

0 commit comments

Comments
 (0)