|
51 | 51 | import importlib
|
52 | 52 | import itertools
|
53 | 53 | import multiprocessing
|
| 54 | +import os |
| 55 | +import signal |
54 | 56 | import sys
|
55 |
| -from typing import Any, Callable, List |
| 57 | +from types import FrameType |
| 58 | +from typing import Any, Callable, List, Optional |
56 | 59 |
|
57 | 60 | from twisted.internet.main import installReactor
|
58 | 61 |
|
| 62 | +# a list of the original signal handlers, before we installed our custom ones. |
| 63 | +# We restore these in our child processes. |
| 64 | +_original_signal_handlers: dict[int, Any] = {} |
| 65 | + |
59 | 66 |
|
60 | 67 | class ProxiedReactor:
|
61 | 68 | """
|
@@ -105,6 +112,11 @@ def _worker_entrypoint(
|
105 | 112 |
|
106 | 113 | sys.argv = args
|
107 | 114 |
|
| 115 | + # reset the custom signal handlers that we installed, so that the children start |
| 116 | + # from a clean slate. |
| 117 | + for sig, handler in _original_signal_handlers.items(): |
| 118 | + signal.signal(sig, handler) |
| 119 | + |
108 | 120 | from twisted.internet.epollreactor import EPollReactor
|
109 | 121 |
|
110 | 122 | proxy_reactor._install_real_reactor(EPollReactor())
|
@@ -167,13 +179,29 @@ def main() -> None:
|
167 | 179 | update_proc.join()
|
168 | 180 | print("===== PREPARED DATABASE =====", file=sys.stderr)
|
169 | 181 |
|
| 182 | + processes: List[multiprocessing.Process] = [] |
| 183 | + |
| 184 | + # Install signal handlers to propagate signals to all our children, so that they |
| 185 | + # shut down cleanly. This also inhibits our own exit, but that's good: we want to |
| 186 | + # wait until the children have exited. |
| 187 | + def handle_signal(signum: int, frame: Optional[FrameType]) -> None: |
| 188 | + print( |
| 189 | + f"complement_fork_starter: Caught signal {signum}. Stopping children.", |
| 190 | + file=sys.stderr, |
| 191 | + ) |
| 192 | + for p in processes: |
| 193 | + if p.pid: |
| 194 | + os.kill(p.pid, signum) |
| 195 | + |
| 196 | + for sig in (signal.SIGINT, signal.SIGTERM): |
| 197 | + _original_signal_handlers[sig] = signal.signal(sig, handle_signal) |
| 198 | + |
170 | 199 | # At this point, we've imported all the main entrypoints for all the workers.
|
171 | 200 | # Now we basically just fork() out to create the workers we need.
|
172 | 201 | # Because we're using fork(), all the workers get a clone of this launcher's
|
173 | 202 | # memory space and don't need to repeat the work of loading the code!
|
174 | 203 | # Instead of using fork() directly, we use the multiprocessing library,
|
175 | 204 | # which uses fork() on Unix platforms.
|
176 |
| - processes = [] |
177 | 205 | for (func, worker_args) in zip(worker_functions, args_by_worker):
|
178 | 206 | process = multiprocessing.Process(
|
179 | 207 | target=_worker_entrypoint, args=(func, proxy_reactor, worker_args)
|
|
0 commit comments