|
12 | 12 | import sys
|
13 | 13 | import os
|
14 | 14 | import gc
|
| 15 | +import importlib |
15 | 16 | import errno
|
16 | 17 | import functools
|
17 | 18 | import signal
|
|
20 | 21 | import socket
|
21 | 22 | import random
|
22 | 23 | import logging
|
| 24 | +import shutil |
23 | 25 | import subprocess
|
24 | 26 | import struct
|
| 27 | +import tempfile |
25 | 28 | import operator
|
26 | 29 | import pickle
|
27 | 30 | import weakref
|
@@ -6397,6 +6400,81 @@ def test_atexit(self):
|
6397 | 6400 | self.assertEqual(f.read(), 'deadbeef')
|
6398 | 6401 |
|
6399 | 6402 |
|
| 6403 | +class _TestSpawnedSysPath(BaseTestCase): |
| 6404 | + """Test that sys.path is setup in forkserver and spawn processes.""" |
| 6405 | + |
| 6406 | + ALLOWED_TYPES = ('processes',) |
| 6407 | + |
| 6408 | + def setUp(self): |
| 6409 | + self._orig_sys_path = list(sys.path) |
| 6410 | + self._temp_dir = tempfile.mkdtemp(prefix="test_sys_path-") |
| 6411 | + self._mod_name = "unique_test_mod" |
| 6412 | + module_path = os.path.join(self._temp_dir, f"{self._mod_name}.py") |
| 6413 | + with open(module_path, "w", encoding="utf-8") as mod: |
| 6414 | + mod.write("# A simple test module\n") |
| 6415 | + sys.path[:] = [p for p in sys.path if p] # remove any existing ""s |
| 6416 | + sys.path.insert(0, self._temp_dir) |
| 6417 | + sys.path.insert(0, "") # Replaced with an abspath in child. |
| 6418 | + try: |
| 6419 | + self._ctx_forkserver = multiprocessing.get_context("forkserver") |
| 6420 | + except ValueError: |
| 6421 | + self._ctx_forkserver = None |
| 6422 | + self._ctx_spawn = multiprocessing.get_context("spawn") |
| 6423 | + |
| 6424 | + def tearDown(self): |
| 6425 | + sys.path[:] = self._orig_sys_path |
| 6426 | + shutil.rmtree(self._temp_dir, ignore_errors=True) |
| 6427 | + |
| 6428 | + @staticmethod |
| 6429 | + def enq_imported_module_names(queue): |
| 6430 | + queue.put(tuple(sys.modules)) |
| 6431 | + |
| 6432 | + def test_forkserver_preload_imports_sys_path(self): |
| 6433 | + ctx = self._ctx_forkserver |
| 6434 | + if not ctx: |
| 6435 | + self.skipTest("requires forkserver start method.") |
| 6436 | + self.assertNotIn(self._mod_name, sys.modules) |
| 6437 | + multiprocessing.forkserver._forkserver._stop() # Must be fresh. |
| 6438 | + ctx.set_forkserver_preload( |
| 6439 | + ["test.test_multiprocessing_forkserver", self._mod_name]) |
| 6440 | + q = ctx.Queue() |
| 6441 | + proc = ctx.Process(target=self.enq_imported_module_names, args=(q,)) |
| 6442 | + proc.start() |
| 6443 | + proc.join() |
| 6444 | + child_imported_modules = q.get() |
| 6445 | + q.close() |
| 6446 | + self.assertIn(self._mod_name, child_imported_modules) |
| 6447 | + |
| 6448 | + @staticmethod |
| 6449 | + def enq_sys_path_and_import(queue, mod_name): |
| 6450 | + queue.put(sys.path) |
| 6451 | + try: |
| 6452 | + importlib.import_module(mod_name) |
| 6453 | + except ImportError as exc: |
| 6454 | + queue.put(exc) |
| 6455 | + else: |
| 6456 | + queue.put(None) |
| 6457 | + |
| 6458 | + def test_child_sys_path(self): |
| 6459 | + for ctx in (self._ctx_spawn, self._ctx_forkserver): |
| 6460 | + if not ctx: |
| 6461 | + continue |
| 6462 | + with self.subTest(f"{ctx.get_start_method()} start method"): |
| 6463 | + q = ctx.Queue() |
| 6464 | + proc = ctx.Process(target=self.enq_sys_path_and_import, |
| 6465 | + args=(q, self._mod_name)) |
| 6466 | + proc.start() |
| 6467 | + proc.join() |
| 6468 | + child_sys_path = q.get() |
| 6469 | + import_error = q.get() |
| 6470 | + q.close() |
| 6471 | + self.assertNotIn("", child_sys_path) # replaced by an abspath |
| 6472 | + self.assertIn(self._temp_dir, child_sys_path) # our addition |
| 6473 | + # ignore the first element, it is the absolute "" replacement |
| 6474 | + self.assertEqual(child_sys_path[1:], sys.path[1:]) |
| 6475 | + self.assertIsNone(import_error, msg=f"child could not import {self._mod_name}") |
| 6476 | + |
| 6477 | + |
6400 | 6478 | class MiscTestCase(unittest.TestCase):
|
6401 | 6479 | def test__all__(self):
|
6402 | 6480 | # Just make sure names in not_exported are excluded
|
|
0 commit comments