gh-141860: Add on_error= keyword arg to multiprocessing.set_forkserver_preload (GH-141859)

Add a keyword-only `on_error` parameter to `multiprocessing.set_forkserver_preload()`. This allows the user to have exceptions during optional `forkserver` start method module preloading cause the forkserver subprocess to warn (generally to stderr) or exit with an error (preventing use of the forkserver) instead of being silently ignored.

This _also_ fixes an oversight, errors when preloading a `__main__` module are now treated the similarly. Those would always raise unlike other modules in preload, but that had gone unnoticed as up until bug fix PR GH-135295 in 3.14.1 and 3.13.8, the `__main__` module was never actually preloaded.

Based on original work by Nick Neumann @aggieNick02 in GH-99515.
This commit is contained in:
Gregory P. Smith 2026-01-18 14:04:18 -08:00 committed by GitHub
parent 54bedcf714
commit c879b2a7a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 346 additions and 30 deletions

View File

@ -1234,22 +1234,32 @@ Miscellaneous
.. versionchanged:: 3.11
Accepts a :term:`path-like object`.
.. function:: set_forkserver_preload(module_names)
.. function:: set_forkserver_preload(module_names, *, on_error='ignore')
Set a list of module names for the forkserver main process to attempt to
import so that their already imported state is inherited by forked
processes. Any :exc:`ImportError` when doing so is silently ignored.
This can be used as a performance enhancement to avoid repeated work
in every process.
processes. This can be used as a performance enhancement to avoid repeated
work in every process.
For this to work, it must be called before the forkserver process has been
launched (before creating a :class:`Pool` or starting a :class:`Process`).
The *on_error* parameter controls how :exc:`ImportError` exceptions during
module preloading are handled: ``"ignore"`` (default) silently ignores
failures, ``"warn"`` causes the forkserver subprocess to emit an
:exc:`ImportWarning` to stderr, and ``"fail"`` causes the forkserver
subprocess to exit with the exception traceback on stderr, making
subsequent process creation fail with :exc:`EOFError` or
:exc:`ConnectionError`.
Only meaningful when using the ``'forkserver'`` start method.
See :ref:`multiprocessing-start-methods`.
.. versionadded:: 3.4
.. versionchanged:: next
Added the *on_error* parameter.
.. function:: set_start_method(method, force=False)
Set the method which should be used to start child processes.

View File

@ -177,12 +177,15 @@ class BaseContext(object):
from .spawn import set_executable
set_executable(executable)
def set_forkserver_preload(self, module_names):
def set_forkserver_preload(self, module_names, *, on_error='ignore'):
'''Set list of module names to try to load in forkserver process.
This is really just a hint.
The on_error parameter controls how import failures are handled:
"ignore" (default) silently ignores failures, "warn" emits warnings,
and "fail" raises exceptions breaking the forkserver context.
'''
from .forkserver import set_forkserver_preload
set_forkserver_preload(module_names)
set_forkserver_preload(module_names, on_error=on_error)
def get_context(self, method=None):
if method is None:

View File

@ -42,6 +42,7 @@ class ForkServer(object):
self._inherited_fds = None
self._lock = threading.Lock()
self._preload_modules = ['__main__']
self._preload_on_error = 'ignore'
def _stop(self):
# Method used by unit tests to stop the server
@ -64,11 +65,22 @@ class ForkServer(object):
self._forkserver_address = None
self._forkserver_authkey = None
def set_forkserver_preload(self, modules_names):
'''Set list of module names to try to load in forkserver process.'''
def set_forkserver_preload(self, modules_names, *, on_error='ignore'):
'''Set list of module names to try to load in forkserver process.
The on_error parameter controls how import failures are handled:
"ignore" (default) silently ignores failures, "warn" emits warnings,
and "fail" raises exceptions breaking the forkserver context.
'''
if not all(type(mod) is str for mod in modules_names):
raise TypeError('module_names must be a list of strings')
if on_error not in ('ignore', 'warn', 'fail'):
raise ValueError(
f"on_error must be 'ignore', 'warn', or 'fail', "
f"not {on_error!r}"
)
self._preload_modules = modules_names
self._preload_on_error = on_error
def get_inherited_fds(self):
'''Return list of fds inherited from parent process.
@ -107,6 +119,14 @@ class ForkServer(object):
wrapped_client, self._forkserver_authkey)
connection.deliver_challenge(
wrapped_client, self._forkserver_authkey)
except (EOFError, ConnectionError, BrokenPipeError) as exc:
if (self._preload_modules and
self._preload_on_error == 'fail'):
exc.add_note(
"Forkserver process may have crashed during module "
"preloading. Check stderr."
)
raise
finally:
wrapped_client._detach()
del wrapped_client
@ -154,6 +174,8 @@ class ForkServer(object):
main_kws['main_path'] = data['init_main_from_path']
if 'sys_argv' in data:
main_kws['sys_argv'] = data['sys_argv']
if self._preload_on_error != 'ignore':
main_kws['on_error'] = self._preload_on_error
with socket.socket(socket.AF_UNIX) as listener:
address = connection.arbitrary_address('AF_UNIX')
@ -198,8 +220,69 @@ class ForkServer(object):
#
#
def _handle_import_error(on_error, modinfo, exc, *, warn_stacklevel):
"""Handle an import error according to the on_error policy."""
match on_error:
case 'fail':
raise
case 'warn':
warnings.warn(
f"Failed to preload {modinfo}: {exc}",
ImportWarning,
stacklevel=warn_stacklevel + 1
)
case 'ignore':
pass
def _handle_preload(preload, main_path=None, sys_path=None, sys_argv=None,
on_error='ignore'):
"""Handle module preloading with configurable error handling.
Args:
preload: List of module names to preload.
main_path: Path to __main__ module if '__main__' is in preload.
sys_path: sys.path to use for imports (None means use current).
sys_argv: sys.argv to use (None means use current).
on_error: How to handle import errors ("ignore", "warn", or "fail").
"""
if not preload:
return
if sys_argv is not None:
sys.argv[:] = sys_argv
if sys_path is not None:
sys.path[:] = sys_path
if '__main__' in preload and main_path is not None:
process.current_process()._inheriting = True
try:
spawn.import_main_path(main_path)
except Exception as e:
# Catch broad Exception because import_main_path() uses
# runpy.run_path() which executes the script and can raise
# any exception, not just ImportError
_handle_import_error(
on_error, f"__main__ from {main_path!r}", e, warn_stacklevel=2
)
finally:
del process.current_process()._inheriting
for modname in preload:
try:
__import__(modname)
except ImportError as e:
_handle_import_error(
on_error, f"module {modname!r}", e, warn_stacklevel=2
)
# gh-135335: flush stdout/stderr in case any of the preloaded modules
# wrote to them, otherwise children might inherit buffered data
util._flush_std_streams()
def main(listener_fd, alive_r, preload, main_path=None, sys_path=None,
*, sys_argv=None, authkey_r=None):
*, sys_argv=None, authkey_r=None, on_error='ignore'):
"""Run forkserver."""
if authkey_r is not None:
try:
@ -210,26 +293,7 @@ def main(listener_fd, alive_r, preload, main_path=None, sys_path=None,
else:
authkey = b''
if preload:
if sys_argv is not None:
sys.argv[:] = sys_argv
if sys_path is not None:
sys.path[:] = sys_path
if '__main__' in preload and main_path is not None:
process.current_process()._inheriting = True
try:
spawn.import_main_path(main_path)
finally:
del process.current_process()._inheriting
for modname in preload:
try:
__import__(modname)
except ImportError:
pass
# gh-135335: flush stdout/stderr in case any of the preloaded modules
# wrote to them, otherwise children might inherit buffered data
util._flush_std_streams()
_handle_preload(preload, main_path, sys_path, sys_argv, on_error)
util._close_stdin()

View File

@ -9,5 +9,8 @@ if support.PGO:
if sys.platform == "win32":
raise unittest.SkipTest("forkserver is not available on Windows")
if not support.has_fork_support:
raise unittest.SkipTest("requires working os.fork()")
def load_tests(*args):
return support.load_package_tests(os.path.dirname(__file__), *args)

View File

@ -0,0 +1,230 @@
"""Tests for forkserver preload functionality."""
import contextlib
import multiprocessing
import os
import shutil
import sys
import tempfile
import unittest
from multiprocessing import forkserver, spawn
class TestForkserverPreload(unittest.TestCase):
"""Tests for forkserver preload functionality."""
def setUp(self):
self._saved_warnoptions = sys.warnoptions.copy()
# Remove warning options that would convert ImportWarning to errors:
# - 'error' converts all warnings to errors
# - 'error::ImportWarning' specifically converts ImportWarning
# Keep other specific options like 'error::BytesWarning' that
# subprocess's _args_from_interpreter_flags() expects to remove
sys.warnoptions[:] = [
opt for opt in sys.warnoptions
if opt not in ('error', 'error::ImportWarning')
]
self.ctx = multiprocessing.get_context('forkserver')
forkserver._forkserver._stop()
def tearDown(self):
sys.warnoptions[:] = self._saved_warnoptions
forkserver._forkserver._stop()
@staticmethod
def _send_value(conn, value):
"""Send value through connection. Static method to be picklable as Process target."""
conn.send(value)
@contextlib.contextmanager
def capture_forkserver_stderr(self):
"""Capture stderr from forkserver by preloading a module that redirects it.
Yields (module_name, capture_file_path). The capture file can be read
after the forkserver has processed preloads. This works because
forkserver.main() calls util._flush_std_streams() after preloading,
ensuring captured output is written before we read it.
"""
tmpdir = tempfile.mkdtemp()
capture_module = os.path.join(tmpdir, '_capture_stderr.py')
capture_file = os.path.join(tmpdir, 'stderr.txt')
try:
with open(capture_module, 'w') as f:
# Use line buffering (buffering=1) to ensure warnings are written.
# Enable ImportWarning since it's ignored by default.
f.write(
f'import sys, warnings; '
f'sys.stderr = open({capture_file!r}, "w", buffering=1); '
f'warnings.filterwarnings("always", category=ImportWarning)\n'
)
sys.path.insert(0, tmpdir)
yield '_capture_stderr', capture_file
finally:
sys.path.remove(tmpdir)
shutil.rmtree(tmpdir, ignore_errors=True)
def test_preload_on_error_ignore_default(self):
"""Test that invalid modules are silently ignored by default."""
self.ctx.set_forkserver_preload(['nonexistent_module_xyz'])
r, w = self.ctx.Pipe(duplex=False)
p = self.ctx.Process(target=self._send_value, args=(w, 42))
p.start()
w.close()
result = r.recv()
r.close()
p.join()
self.assertEqual(result, 42)
self.assertEqual(p.exitcode, 0)
def test_preload_on_error_ignore_explicit(self):
"""Test that invalid modules are silently ignored with on_error='ignore'."""
self.ctx.set_forkserver_preload(['nonexistent_module_xyz'], on_error='ignore')
r, w = self.ctx.Pipe(duplex=False)
p = self.ctx.Process(target=self._send_value, args=(w, 99))
p.start()
w.close()
result = r.recv()
r.close()
p.join()
self.assertEqual(result, 99)
self.assertEqual(p.exitcode, 0)
def test_preload_on_error_warn(self):
"""Test that invalid modules emit warnings with on_error='warn'."""
with self.capture_forkserver_stderr() as (capture_mod, stderr_file):
self.ctx.set_forkserver_preload(
[capture_mod, 'nonexistent_module_xyz'], on_error='warn')
r, w = self.ctx.Pipe(duplex=False)
p = self.ctx.Process(target=self._send_value, args=(w, 123))
p.start()
w.close()
result = r.recv()
r.close()
p.join()
self.assertEqual(result, 123)
self.assertEqual(p.exitcode, 0)
with open(stderr_file) as f:
stderr_output = f.read()
self.assertIn('nonexistent_module_xyz', stderr_output)
self.assertIn('ImportWarning', stderr_output)
def test_preload_on_error_fail_breaks_context(self):
"""Test that invalid modules with on_error='fail' breaks the forkserver."""
with self.capture_forkserver_stderr() as (capture_mod, stderr_file):
self.ctx.set_forkserver_preload(
[capture_mod, 'nonexistent_module_xyz'], on_error='fail')
r, w = self.ctx.Pipe(duplex=False)
try:
p = self.ctx.Process(target=self._send_value, args=(w, 42))
with self.assertRaises((EOFError, ConnectionError, BrokenPipeError)) as cm:
p.start()
notes = getattr(cm.exception, '__notes__', [])
self.assertTrue(notes, "Expected exception to have __notes__")
self.assertIn('Forkserver process may have crashed', notes[0])
with open(stderr_file) as f:
stderr_output = f.read()
self.assertIn('nonexistent_module_xyz', stderr_output)
self.assertIn('ModuleNotFoundError', stderr_output)
finally:
w.close()
r.close()
def test_preload_valid_modules_with_on_error_fail(self):
"""Test that valid modules work fine with on_error='fail'."""
self.ctx.set_forkserver_preload(['os', 'sys'], on_error='fail')
r, w = self.ctx.Pipe(duplex=False)
p = self.ctx.Process(target=self._send_value, args=(w, 'success'))
p.start()
w.close()
result = r.recv()
r.close()
p.join()
self.assertEqual(result, 'success')
self.assertEqual(p.exitcode, 0)
def test_preload_invalid_on_error_value(self):
"""Test that invalid on_error values raise ValueError."""
with self.assertRaises(ValueError) as cm:
self.ctx.set_forkserver_preload(['os'], on_error='invalid')
self.assertIn("on_error must be 'ignore', 'warn', or 'fail'", str(cm.exception))
class TestHandlePreload(unittest.TestCase):
"""Unit tests for _handle_preload() function."""
def setUp(self):
self._saved_main = sys.modules['__main__']
def tearDown(self):
spawn.old_main_modules.clear()
sys.modules['__main__'] = self._saved_main
def test_handle_preload_main_on_error_fail(self):
"""Test that __main__ import failures raise with on_error='fail'."""
with tempfile.NamedTemporaryFile(mode='w', suffix='.py') as f:
f.write('raise RuntimeError("test error in __main__")\n')
f.flush()
with self.assertRaises(RuntimeError) as cm:
forkserver._handle_preload(['__main__'], main_path=f.name, on_error='fail')
self.assertIn("test error in __main__", str(cm.exception))
def test_handle_preload_main_on_error_warn(self):
"""Test that __main__ import failures warn with on_error='warn'."""
with tempfile.NamedTemporaryFile(mode='w', suffix='.py') as f:
f.write('raise ImportError("test import error")\n')
f.flush()
with self.assertWarns(ImportWarning) as cm:
forkserver._handle_preload(['__main__'], main_path=f.name, on_error='warn')
self.assertIn("Failed to preload __main__", str(cm.warning))
self.assertIn("test import error", str(cm.warning))
def test_handle_preload_main_on_error_ignore(self):
"""Test that __main__ import failures are ignored with on_error='ignore'."""
with tempfile.NamedTemporaryFile(mode='w', suffix='.py') as f:
f.write('raise ImportError("test import error")\n')
f.flush()
forkserver._handle_preload(['__main__'], main_path=f.name, on_error='ignore')
def test_handle_preload_main_valid(self):
"""Test that valid __main__ preload works."""
with tempfile.NamedTemporaryFile(mode='w', suffix='.py') as f:
f.write('test_var = 42\n')
f.flush()
forkserver._handle_preload(['__main__'], main_path=f.name, on_error='fail')
def test_handle_preload_module_on_error_fail(self):
"""Test that module import failures raise with on_error='fail'."""
with self.assertRaises(ModuleNotFoundError):
forkserver._handle_preload(['nonexistent_test_module_xyz'], on_error='fail')
def test_handle_preload_module_on_error_warn(self):
"""Test that module import failures warn with on_error='warn'."""
with self.assertWarns(ImportWarning) as cm:
forkserver._handle_preload(['nonexistent_test_module_xyz'], on_error='warn')
self.assertIn("Failed to preload module", str(cm.warning))
def test_handle_preload_module_on_error_ignore(self):
"""Test that module import failures are ignored with on_error='ignore'."""
forkserver._handle_preload(['nonexistent_test_module_xyz'], on_error='ignore')
def test_handle_preload_combined(self):
"""Test preloading both __main__ and modules."""
with tempfile.NamedTemporaryFile(mode='w', suffix='.py') as f:
f.write('import sys\n')
f.flush()
forkserver._handle_preload(['__main__', 'os', 'sys'], main_path=f.name, on_error='fail')
if __name__ == '__main__':
unittest.main()

View File

@ -1340,6 +1340,7 @@ Trent Nelson
Andrew Nester
Osvaldo Santana Neto
Chad Netzer
Nick Neumann
Max Neunhöffer
Anthon van der Neut
George Neville-Neil

View File

@ -0,0 +1,5 @@
Add an ``on_error`` keyword-only parameter to
:func:`multiprocessing.set_forkserver_preload` to control how import failures
during module preloading are handled. Accepts ``'ignore'`` (default, silent),
``'warn'`` (emit :exc:`ImportWarning`), or ``'fail'`` (raise exception).
Contributed by Nick Neumann and Gregory P. Smith.