|
16 | 16 | #
|
17 | 17 |
|
18 | 18 | from importlib import reload
|
| 19 | +import filecmp |
19 | 20 | import json
|
20 | 21 | import os
|
21 | 22 | import pytest
|
| 23 | +import tempfile |
22 | 24 |
|
23 | 25 | from google.cloud.aiplatform.training_utils import environment_variables
|
| 26 | +from google.cloud.aiplatform.utils import source_utils |
24 | 27 | from unittest import mock
|
25 | 28 |
|
26 | 29 | _TEST_TRAINING_DATA_URI = "gs://training-data-uri"
|
@@ -203,3 +206,77 @@ def test_http_handler_port(self):
|
203 | 206 | def test_http_handler_port_none(self):
|
204 | 207 | reload(environment_variables)
|
205 | 208 | assert environment_variables.http_handler_port is None
|
| 209 | + |
| 210 | + @pytest.fixture() |
| 211 | + def mock_temp_file_name(self): |
| 212 | + # Create random files |
| 213 | + # tmpdirname = tempfile.TemporaryDirectory() |
| 214 | + file = tempfile.NamedTemporaryFile() |
| 215 | + |
| 216 | + with open(file.name, "w") as handle: |
| 217 | + handle.write("test") |
| 218 | + |
| 219 | + yield file.name |
| 220 | + |
| 221 | + file.close() |
| 222 | + |
| 223 | + def test_package_file(self, mock_temp_file_name): |
| 224 | + # Test that the packager properly copies the source file to the destination file |
| 225 | + |
| 226 | + packager = source_utils._TrainingScriptPythonPackager( |
| 227 | + script_path=mock_temp_file_name |
| 228 | + ) |
| 229 | + |
| 230 | + with tempfile.TemporaryDirectory() as destination_directory_name: |
| 231 | + _ = packager.make_package(package_directory=destination_directory_name) |
| 232 | + |
| 233 | + # Check that contents of source_distribution_path is the same as destination_directory_name |
| 234 | + destination_inner_path = f"{destination_directory_name}/{packager._TRAINER_FOLDER}/{packager._ROOT_MODULE}/{packager.task_module_name}.py" |
| 235 | + |
| 236 | + assert filecmp.cmp( |
| 237 | + mock_temp_file_name, destination_inner_path, shallow=False |
| 238 | + ) |
| 239 | + |
| 240 | + @pytest.fixture() |
| 241 | + def mock_temp_folder_name(self): |
| 242 | + # Create random folder |
| 243 | + folder = tempfile.TemporaryDirectory() |
| 244 | + |
| 245 | + file = tempfile.NamedTemporaryFile(dir=folder.name) |
| 246 | + |
| 247 | + # Create random file in the folder |
| 248 | + with open(file.name, "w") as handle: |
| 249 | + handle.write("test") |
| 250 | + |
| 251 | + yield folder.name |
| 252 | + |
| 253 | + file.close() |
| 254 | + |
| 255 | + folder.cleanup() |
| 256 | + |
| 257 | + def test_package_folder(self, mock_temp_folder_name): |
| 258 | + # Test that the packager properly copies the source folder to the destination folder |
| 259 | + |
| 260 | + packager = source_utils._TrainingScriptPythonPackager( |
| 261 | + script_path=mock_temp_folder_name |
| 262 | + ) |
| 263 | + |
| 264 | + with tempfile.TemporaryDirectory() as destination_directory_name: |
| 265 | + # Add an existing file into the destination directory to check if it gets deleted |
| 266 | + existing_file = tempfile.NamedTemporaryFile(dir=destination_directory_name) |
| 267 | + |
| 268 | + with open(existing_file.name, "w") as handle: |
| 269 | + handle.write("existing") |
| 270 | + |
| 271 | + _ = packager.make_package(package_directory=destination_directory_name) |
| 272 | + |
| 273 | + # Check that contents of source_distribution_path is the same as destination_directory_name |
| 274 | + destination_inner_path = f"{destination_directory_name}/{packager._TRAINER_FOLDER}/{packager._ROOT_MODULE}" |
| 275 | + |
| 276 | + dcmp = filecmp.dircmp(mock_temp_folder_name, destination_inner_path) |
| 277 | + |
| 278 | + assert len(dcmp.diff_files) == 0 |
| 279 | + assert len(dcmp.left_only) == 0 |
| 280 | + assert len(dcmp.right_only) == 0 |
| 281 | + |
| 282 | + existing_file.close() |
0 commit comments