mirror of
https://github.com/python/cpython.git
synced 2026-01-26 12:55:08 +00:00
gh-141860: Add on_error= keyword arg to multiprocessing.set_forkserver_preload (GH-141859)
Add a keyword-only `on_error` parameter to `multiprocessing.set_forkserver_preload()`. This allows the user to have exceptions during optional `forkserver` start method module preloading cause the forkserver subprocess to warn (generally to stderr) or exit with an error (preventing use of the forkserver) instead of being silently ignored. This _also_ fixes an oversight, errors when preloading a `__main__` module are now treated the similarly. Those would always raise unlike other modules in preload, but that had gone unnoticed as up until bug fix PR GH-135295 in 3.14.1 and 3.13.8, the `__main__` module was never actually preloaded. Based on original work by Nick Neumann @aggieNick02 in GH-99515.
This commit is contained in:
parent
54bedcf714
commit
c879b2a7a5
@ -1234,22 +1234,32 @@ Miscellaneous
|
||||
.. versionchanged:: 3.11
|
||||
Accepts a :term:`path-like object`.
|
||||
|
||||
.. function:: set_forkserver_preload(module_names)
|
||||
.. function:: set_forkserver_preload(module_names, *, on_error='ignore')
|
||||
|
||||
Set a list of module names for the forkserver main process to attempt to
|
||||
import so that their already imported state is inherited by forked
|
||||
processes. Any :exc:`ImportError` when doing so is silently ignored.
|
||||
This can be used as a performance enhancement to avoid repeated work
|
||||
in every process.
|
||||
processes. This can be used as a performance enhancement to avoid repeated
|
||||
work in every process.
|
||||
|
||||
For this to work, it must be called before the forkserver process has been
|
||||
launched (before creating a :class:`Pool` or starting a :class:`Process`).
|
||||
|
||||
The *on_error* parameter controls how :exc:`ImportError` exceptions during
|
||||
module preloading are handled: ``"ignore"`` (default) silently ignores
|
||||
failures, ``"warn"`` causes the forkserver subprocess to emit an
|
||||
:exc:`ImportWarning` to stderr, and ``"fail"`` causes the forkserver
|
||||
subprocess to exit with the exception traceback on stderr, making
|
||||
subsequent process creation fail with :exc:`EOFError` or
|
||||
:exc:`ConnectionError`.
|
||||
|
||||
Only meaningful when using the ``'forkserver'`` start method.
|
||||
See :ref:`multiprocessing-start-methods`.
|
||||
|
||||
.. versionadded:: 3.4
|
||||
|
||||
.. versionchanged:: next
|
||||
Added the *on_error* parameter.
|
||||
|
||||
.. function:: set_start_method(method, force=False)
|
||||
|
||||
Set the method which should be used to start child processes.
|
||||
|
||||
@ -177,12 +177,15 @@ class BaseContext(object):
|
||||
from .spawn import set_executable
|
||||
set_executable(executable)
|
||||
|
||||
def set_forkserver_preload(self, module_names):
|
||||
def set_forkserver_preload(self, module_names, *, on_error='ignore'):
|
||||
'''Set list of module names to try to load in forkserver process.
|
||||
This is really just a hint.
|
||||
|
||||
The on_error parameter controls how import failures are handled:
|
||||
"ignore" (default) silently ignores failures, "warn" emits warnings,
|
||||
and "fail" raises exceptions breaking the forkserver context.
|
||||
'''
|
||||
from .forkserver import set_forkserver_preload
|
||||
set_forkserver_preload(module_names)
|
||||
set_forkserver_preload(module_names, on_error=on_error)
|
||||
|
||||
def get_context(self, method=None):
|
||||
if method is None:
|
||||
|
||||
@ -42,6 +42,7 @@ class ForkServer(object):
|
||||
self._inherited_fds = None
|
||||
self._lock = threading.Lock()
|
||||
self._preload_modules = ['__main__']
|
||||
self._preload_on_error = 'ignore'
|
||||
|
||||
def _stop(self):
|
||||
# Method used by unit tests to stop the server
|
||||
@ -64,11 +65,22 @@ class ForkServer(object):
|
||||
self._forkserver_address = None
|
||||
self._forkserver_authkey = None
|
||||
|
||||
def set_forkserver_preload(self, modules_names):
|
||||
'''Set list of module names to try to load in forkserver process.'''
|
||||
def set_forkserver_preload(self, modules_names, *, on_error='ignore'):
|
||||
'''Set list of module names to try to load in forkserver process.
|
||||
|
||||
The on_error parameter controls how import failures are handled:
|
||||
"ignore" (default) silently ignores failures, "warn" emits warnings,
|
||||
and "fail" raises exceptions breaking the forkserver context.
|
||||
'''
|
||||
if not all(type(mod) is str for mod in modules_names):
|
||||
raise TypeError('module_names must be a list of strings')
|
||||
if on_error not in ('ignore', 'warn', 'fail'):
|
||||
raise ValueError(
|
||||
f"on_error must be 'ignore', 'warn', or 'fail', "
|
||||
f"not {on_error!r}"
|
||||
)
|
||||
self._preload_modules = modules_names
|
||||
self._preload_on_error = on_error
|
||||
|
||||
def get_inherited_fds(self):
|
||||
'''Return list of fds inherited from parent process.
|
||||
@ -107,6 +119,14 @@ class ForkServer(object):
|
||||
wrapped_client, self._forkserver_authkey)
|
||||
connection.deliver_challenge(
|
||||
wrapped_client, self._forkserver_authkey)
|
||||
except (EOFError, ConnectionError, BrokenPipeError) as exc:
|
||||
if (self._preload_modules and
|
||||
self._preload_on_error == 'fail'):
|
||||
exc.add_note(
|
||||
"Forkserver process may have crashed during module "
|
||||
"preloading. Check stderr."
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
wrapped_client._detach()
|
||||
del wrapped_client
|
||||
@ -154,6 +174,8 @@ class ForkServer(object):
|
||||
main_kws['main_path'] = data['init_main_from_path']
|
||||
if 'sys_argv' in data:
|
||||
main_kws['sys_argv'] = data['sys_argv']
|
||||
if self._preload_on_error != 'ignore':
|
||||
main_kws['on_error'] = self._preload_on_error
|
||||
|
||||
with socket.socket(socket.AF_UNIX) as listener:
|
||||
address = connection.arbitrary_address('AF_UNIX')
|
||||
@ -198,8 +220,69 @@ class ForkServer(object):
|
||||
#
|
||||
#
|
||||
|
||||
def _handle_import_error(on_error, modinfo, exc, *, warn_stacklevel):
|
||||
"""Handle an import error according to the on_error policy."""
|
||||
match on_error:
|
||||
case 'fail':
|
||||
raise
|
||||
case 'warn':
|
||||
warnings.warn(
|
||||
f"Failed to preload {modinfo}: {exc}",
|
||||
ImportWarning,
|
||||
stacklevel=warn_stacklevel + 1
|
||||
)
|
||||
case 'ignore':
|
||||
pass
|
||||
|
||||
|
||||
def _handle_preload(preload, main_path=None, sys_path=None, sys_argv=None,
|
||||
on_error='ignore'):
|
||||
"""Handle module preloading with configurable error handling.
|
||||
|
||||
Args:
|
||||
preload: List of module names to preload.
|
||||
main_path: Path to __main__ module if '__main__' is in preload.
|
||||
sys_path: sys.path to use for imports (None means use current).
|
||||
sys_argv: sys.argv to use (None means use current).
|
||||
on_error: How to handle import errors ("ignore", "warn", or "fail").
|
||||
"""
|
||||
if not preload:
|
||||
return
|
||||
|
||||
if sys_argv is not None:
|
||||
sys.argv[:] = sys_argv
|
||||
if sys_path is not None:
|
||||
sys.path[:] = sys_path
|
||||
|
||||
if '__main__' in preload and main_path is not None:
|
||||
process.current_process()._inheriting = True
|
||||
try:
|
||||
spawn.import_main_path(main_path)
|
||||
except Exception as e:
|
||||
# Catch broad Exception because import_main_path() uses
|
||||
# runpy.run_path() which executes the script and can raise
|
||||
# any exception, not just ImportError
|
||||
_handle_import_error(
|
||||
on_error, f"__main__ from {main_path!r}", e, warn_stacklevel=2
|
||||
)
|
||||
finally:
|
||||
del process.current_process()._inheriting
|
||||
|
||||
for modname in preload:
|
||||
try:
|
||||
__import__(modname)
|
||||
except ImportError as e:
|
||||
_handle_import_error(
|
||||
on_error, f"module {modname!r}", e, warn_stacklevel=2
|
||||
)
|
||||
|
||||
# gh-135335: flush stdout/stderr in case any of the preloaded modules
|
||||
# wrote to them, otherwise children might inherit buffered data
|
||||
util._flush_std_streams()
|
||||
|
||||
|
||||
def main(listener_fd, alive_r, preload, main_path=None, sys_path=None,
|
||||
*, sys_argv=None, authkey_r=None):
|
||||
*, sys_argv=None, authkey_r=None, on_error='ignore'):
|
||||
"""Run forkserver."""
|
||||
if authkey_r is not None:
|
||||
try:
|
||||
@ -210,26 +293,7 @@ def main(listener_fd, alive_r, preload, main_path=None, sys_path=None,
|
||||
else:
|
||||
authkey = b''
|
||||
|
||||
if preload:
|
||||
if sys_argv is not None:
|
||||
sys.argv[:] = sys_argv
|
||||
if sys_path is not None:
|
||||
sys.path[:] = sys_path
|
||||
if '__main__' in preload and main_path is not None:
|
||||
process.current_process()._inheriting = True
|
||||
try:
|
||||
spawn.import_main_path(main_path)
|
||||
finally:
|
||||
del process.current_process()._inheriting
|
||||
for modname in preload:
|
||||
try:
|
||||
__import__(modname)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# gh-135335: flush stdout/stderr in case any of the preloaded modules
|
||||
# wrote to them, otherwise children might inherit buffered data
|
||||
util._flush_std_streams()
|
||||
_handle_preload(preload, main_path, sys_path, sys_argv, on_error)
|
||||
|
||||
util._close_stdin()
|
||||
|
||||
|
||||
@ -9,5 +9,8 @@ if support.PGO:
|
||||
if sys.platform == "win32":
|
||||
raise unittest.SkipTest("forkserver is not available on Windows")
|
||||
|
||||
if not support.has_fork_support:
|
||||
raise unittest.SkipTest("requires working os.fork()")
|
||||
|
||||
def load_tests(*args):
|
||||
return support.load_package_tests(os.path.dirname(__file__), *args)
|
||||
|
||||
230
Lib/test/test_multiprocessing_forkserver/test_preload.py
Normal file
230
Lib/test/test_multiprocessing_forkserver/test_preload.py
Normal file
@ -0,0 +1,230 @@
|
||||
"""Tests for forkserver preload functionality."""
|
||||
|
||||
import contextlib
|
||||
import multiprocessing
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from multiprocessing import forkserver, spawn
|
||||
|
||||
|
||||
class TestForkserverPreload(unittest.TestCase):
|
||||
"""Tests for forkserver preload functionality."""
|
||||
|
||||
def setUp(self):
|
||||
self._saved_warnoptions = sys.warnoptions.copy()
|
||||
# Remove warning options that would convert ImportWarning to errors:
|
||||
# - 'error' converts all warnings to errors
|
||||
# - 'error::ImportWarning' specifically converts ImportWarning
|
||||
# Keep other specific options like 'error::BytesWarning' that
|
||||
# subprocess's _args_from_interpreter_flags() expects to remove
|
||||
sys.warnoptions[:] = [
|
||||
opt for opt in sys.warnoptions
|
||||
if opt not in ('error', 'error::ImportWarning')
|
||||
]
|
||||
self.ctx = multiprocessing.get_context('forkserver')
|
||||
forkserver._forkserver._stop()
|
||||
|
||||
def tearDown(self):
|
||||
sys.warnoptions[:] = self._saved_warnoptions
|
||||
forkserver._forkserver._stop()
|
||||
|
||||
@staticmethod
|
||||
def _send_value(conn, value):
|
||||
"""Send value through connection. Static method to be picklable as Process target."""
|
||||
conn.send(value)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def capture_forkserver_stderr(self):
|
||||
"""Capture stderr from forkserver by preloading a module that redirects it.
|
||||
|
||||
Yields (module_name, capture_file_path). The capture file can be read
|
||||
after the forkserver has processed preloads. This works because
|
||||
forkserver.main() calls util._flush_std_streams() after preloading,
|
||||
ensuring captured output is written before we read it.
|
||||
"""
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
capture_module = os.path.join(tmpdir, '_capture_stderr.py')
|
||||
capture_file = os.path.join(tmpdir, 'stderr.txt')
|
||||
try:
|
||||
with open(capture_module, 'w') as f:
|
||||
# Use line buffering (buffering=1) to ensure warnings are written.
|
||||
# Enable ImportWarning since it's ignored by default.
|
||||
f.write(
|
||||
f'import sys, warnings; '
|
||||
f'sys.stderr = open({capture_file!r}, "w", buffering=1); '
|
||||
f'warnings.filterwarnings("always", category=ImportWarning)\n'
|
||||
)
|
||||
sys.path.insert(0, tmpdir)
|
||||
yield '_capture_stderr', capture_file
|
||||
finally:
|
||||
sys.path.remove(tmpdir)
|
||||
shutil.rmtree(tmpdir, ignore_errors=True)
|
||||
|
||||
def test_preload_on_error_ignore_default(self):
|
||||
"""Test that invalid modules are silently ignored by default."""
|
||||
self.ctx.set_forkserver_preload(['nonexistent_module_xyz'])
|
||||
|
||||
r, w = self.ctx.Pipe(duplex=False)
|
||||
p = self.ctx.Process(target=self._send_value, args=(w, 42))
|
||||
p.start()
|
||||
w.close()
|
||||
result = r.recv()
|
||||
r.close()
|
||||
p.join()
|
||||
|
||||
self.assertEqual(result, 42)
|
||||
self.assertEqual(p.exitcode, 0)
|
||||
|
||||
def test_preload_on_error_ignore_explicit(self):
|
||||
"""Test that invalid modules are silently ignored with on_error='ignore'."""
|
||||
self.ctx.set_forkserver_preload(['nonexistent_module_xyz'], on_error='ignore')
|
||||
|
||||
r, w = self.ctx.Pipe(duplex=False)
|
||||
p = self.ctx.Process(target=self._send_value, args=(w, 99))
|
||||
p.start()
|
||||
w.close()
|
||||
result = r.recv()
|
||||
r.close()
|
||||
p.join()
|
||||
|
||||
self.assertEqual(result, 99)
|
||||
self.assertEqual(p.exitcode, 0)
|
||||
|
||||
def test_preload_on_error_warn(self):
|
||||
"""Test that invalid modules emit warnings with on_error='warn'."""
|
||||
with self.capture_forkserver_stderr() as (capture_mod, stderr_file):
|
||||
self.ctx.set_forkserver_preload(
|
||||
[capture_mod, 'nonexistent_module_xyz'], on_error='warn')
|
||||
|
||||
r, w = self.ctx.Pipe(duplex=False)
|
||||
p = self.ctx.Process(target=self._send_value, args=(w, 123))
|
||||
p.start()
|
||||
w.close()
|
||||
result = r.recv()
|
||||
r.close()
|
||||
p.join()
|
||||
|
||||
self.assertEqual(result, 123)
|
||||
self.assertEqual(p.exitcode, 0)
|
||||
|
||||
with open(stderr_file) as f:
|
||||
stderr_output = f.read()
|
||||
self.assertIn('nonexistent_module_xyz', stderr_output)
|
||||
self.assertIn('ImportWarning', stderr_output)
|
||||
|
||||
def test_preload_on_error_fail_breaks_context(self):
|
||||
"""Test that invalid modules with on_error='fail' breaks the forkserver."""
|
||||
with self.capture_forkserver_stderr() as (capture_mod, stderr_file):
|
||||
self.ctx.set_forkserver_preload(
|
||||
[capture_mod, 'nonexistent_module_xyz'], on_error='fail')
|
||||
|
||||
r, w = self.ctx.Pipe(duplex=False)
|
||||
try:
|
||||
p = self.ctx.Process(target=self._send_value, args=(w, 42))
|
||||
with self.assertRaises((EOFError, ConnectionError, BrokenPipeError)) as cm:
|
||||
p.start()
|
||||
notes = getattr(cm.exception, '__notes__', [])
|
||||
self.assertTrue(notes, "Expected exception to have __notes__")
|
||||
self.assertIn('Forkserver process may have crashed', notes[0])
|
||||
|
||||
with open(stderr_file) as f:
|
||||
stderr_output = f.read()
|
||||
self.assertIn('nonexistent_module_xyz', stderr_output)
|
||||
self.assertIn('ModuleNotFoundError', stderr_output)
|
||||
finally:
|
||||
w.close()
|
||||
r.close()
|
||||
|
||||
def test_preload_valid_modules_with_on_error_fail(self):
|
||||
"""Test that valid modules work fine with on_error='fail'."""
|
||||
self.ctx.set_forkserver_preload(['os', 'sys'], on_error='fail')
|
||||
|
||||
r, w = self.ctx.Pipe(duplex=False)
|
||||
p = self.ctx.Process(target=self._send_value, args=(w, 'success'))
|
||||
p.start()
|
||||
w.close()
|
||||
result = r.recv()
|
||||
r.close()
|
||||
p.join()
|
||||
|
||||
self.assertEqual(result, 'success')
|
||||
self.assertEqual(p.exitcode, 0)
|
||||
|
||||
def test_preload_invalid_on_error_value(self):
|
||||
"""Test that invalid on_error values raise ValueError."""
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
self.ctx.set_forkserver_preload(['os'], on_error='invalid')
|
||||
self.assertIn("on_error must be 'ignore', 'warn', or 'fail'", str(cm.exception))
|
||||
|
||||
|
||||
class TestHandlePreload(unittest.TestCase):
|
||||
"""Unit tests for _handle_preload() function."""
|
||||
|
||||
def setUp(self):
|
||||
self._saved_main = sys.modules['__main__']
|
||||
|
||||
def tearDown(self):
|
||||
spawn.old_main_modules.clear()
|
||||
sys.modules['__main__'] = self._saved_main
|
||||
|
||||
def test_handle_preload_main_on_error_fail(self):
|
||||
"""Test that __main__ import failures raise with on_error='fail'."""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.py') as f:
|
||||
f.write('raise RuntimeError("test error in __main__")\n')
|
||||
f.flush()
|
||||
with self.assertRaises(RuntimeError) as cm:
|
||||
forkserver._handle_preload(['__main__'], main_path=f.name, on_error='fail')
|
||||
self.assertIn("test error in __main__", str(cm.exception))
|
||||
|
||||
def test_handle_preload_main_on_error_warn(self):
|
||||
"""Test that __main__ import failures warn with on_error='warn'."""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.py') as f:
|
||||
f.write('raise ImportError("test import error")\n')
|
||||
f.flush()
|
||||
with self.assertWarns(ImportWarning) as cm:
|
||||
forkserver._handle_preload(['__main__'], main_path=f.name, on_error='warn')
|
||||
self.assertIn("Failed to preload __main__", str(cm.warning))
|
||||
self.assertIn("test import error", str(cm.warning))
|
||||
|
||||
def test_handle_preload_main_on_error_ignore(self):
|
||||
"""Test that __main__ import failures are ignored with on_error='ignore'."""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.py') as f:
|
||||
f.write('raise ImportError("test import error")\n')
|
||||
f.flush()
|
||||
forkserver._handle_preload(['__main__'], main_path=f.name, on_error='ignore')
|
||||
|
||||
def test_handle_preload_main_valid(self):
|
||||
"""Test that valid __main__ preload works."""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.py') as f:
|
||||
f.write('test_var = 42\n')
|
||||
f.flush()
|
||||
forkserver._handle_preload(['__main__'], main_path=f.name, on_error='fail')
|
||||
|
||||
def test_handle_preload_module_on_error_fail(self):
|
||||
"""Test that module import failures raise with on_error='fail'."""
|
||||
with self.assertRaises(ModuleNotFoundError):
|
||||
forkserver._handle_preload(['nonexistent_test_module_xyz'], on_error='fail')
|
||||
|
||||
def test_handle_preload_module_on_error_warn(self):
|
||||
"""Test that module import failures warn with on_error='warn'."""
|
||||
with self.assertWarns(ImportWarning) as cm:
|
||||
forkserver._handle_preload(['nonexistent_test_module_xyz'], on_error='warn')
|
||||
self.assertIn("Failed to preload module", str(cm.warning))
|
||||
|
||||
def test_handle_preload_module_on_error_ignore(self):
|
||||
"""Test that module import failures are ignored with on_error='ignore'."""
|
||||
forkserver._handle_preload(['nonexistent_test_module_xyz'], on_error='ignore')
|
||||
|
||||
def test_handle_preload_combined(self):
|
||||
"""Test preloading both __main__ and modules."""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.py') as f:
|
||||
f.write('import sys\n')
|
||||
f.flush()
|
||||
forkserver._handle_preload(['__main__', 'os', 'sys'], main_path=f.name, on_error='fail')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@ -1340,6 +1340,7 @@ Trent Nelson
|
||||
Andrew Nester
|
||||
Osvaldo Santana Neto
|
||||
Chad Netzer
|
||||
Nick Neumann
|
||||
Max Neunhöffer
|
||||
Anthon van der Neut
|
||||
George Neville-Neil
|
||||
|
||||
@ -0,0 +1,5 @@
|
||||
Add an ``on_error`` keyword-only parameter to
|
||||
:func:`multiprocessing.set_forkserver_preload` to control how import failures
|
||||
during module preloading are handled. Accepts ``'ignore'`` (default, silent),
|
||||
``'warn'`` (emit :exc:`ImportWarning`), or ``'fail'`` (raise exception).
|
||||
Contributed by Nick Neumann and Gregory P. Smith.
|
||||
Loading…
x
Reference in New Issue
Block a user