Skip to content

Commit 3a22663

Browse files
authored
Merge pull request #8320 from McSinyx/pools
2 parents 69a811c + 0a3b20f commit 3a22663

File tree

3 files changed

+173
-0
lines changed

3 files changed

+173
-0
lines changed

news/f91d42b8-8277-4918-94eb-031bc7be1c3f.trivial

Whitespace-only changes.

src/pip/_internal/utils/parallel.py

+107
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
"""Convenient parallelization of higher order functions.
2+
3+
This module provides two helper functions, with appropriate fallbacks on
4+
Python 2 and on systems lacking support for synchronization mechanisms:
5+
6+
- map_multiprocess
7+
- map_multithread
8+
9+
These helpers work like Python 3's map, with two differences:
10+
11+
- They don't guarantee the order of processing of
12+
the elements of the iterable.
13+
- The underlying process/thread pools chop the iterable into
14+
a number of chunks, so that for very long iterables using
15+
a large value for chunksize can make the job complete much faster
16+
than using the default value of 1.
17+
"""
18+
19+
__all__ = ['map_multiprocess', 'map_multithread']
20+
21+
from contextlib import contextmanager
22+
from multiprocessing import Pool as ProcessPool
23+
from multiprocessing.dummy import Pool as ThreadPool
24+
25+
from pip._vendor.requests.adapters import DEFAULT_POOLSIZE
26+
from pip._vendor.six import PY2
27+
from pip._vendor.six.moves import map
28+
29+
from pip._internal.utils.typing import MYPY_CHECK_RUNNING
30+
31+
if MYPY_CHECK_RUNNING:
32+
from typing import Callable, Iterable, Iterator, Union, TypeVar
33+
from multiprocessing import pool
34+
35+
Pool = Union[pool.Pool, pool.ThreadPool]
36+
S = TypeVar('S')
37+
T = TypeVar('T')
38+
39+
# On platforms without sem_open, multiprocessing[.dummy] Pool
40+
# cannot be created.
41+
try:
42+
import multiprocessing.synchronize # noqa
43+
except ImportError:
44+
LACK_SEM_OPEN = True
45+
else:
46+
LACK_SEM_OPEN = False
47+
48+
# Incredibly large timeout to work around bpo-8296 on Python 2.
49+
TIMEOUT = 2000000
50+
51+
52+
@contextmanager
53+
def closing(pool):
54+
# type: (Pool) -> Iterator[Pool]
55+
"""Return a context manager making sure the pool closes properly."""
56+
try:
57+
yield pool
58+
finally:
59+
# For Pool.imap*, close and join are needed
60+
# for the returned iterator to begin yielding.
61+
pool.close()
62+
pool.join()
63+
pool.terminate()
64+
65+
66+
def _map_fallback(func, iterable, chunksize=1):
67+
# type: (Callable[[S], T], Iterable[S], int) -> Iterator[T]
68+
"""Make an iterator applying func to each element in iterable.
69+
70+
This function is the sequential fallback either on Python 2
71+
where Pool.imap* doesn't react to KeyboardInterrupt
72+
or when sem_open is unavailable.
73+
"""
74+
return map(func, iterable)
75+
76+
77+
def _map_multiprocess(func, iterable, chunksize=1):
78+
# type: (Callable[[S], T], Iterable[S], int) -> Iterator[T]
79+
"""Chop iterable into chunks and submit them to a process pool.
80+
81+
For very long iterables using a large value for chunksize can make
82+
the job complete much faster than using the default value of 1.
83+
84+
Return an unordered iterator of the results.
85+
"""
86+
with closing(ProcessPool()) as pool:
87+
return pool.imap_unordered(func, iterable, chunksize)
88+
89+
90+
def _map_multithread(func, iterable, chunksize=1):
91+
# type: (Callable[[S], T], Iterable[S], int) -> Iterator[T]
92+
"""Chop iterable into chunks and submit them to a thread pool.
93+
94+
For very long iterables using a large value for chunksize can make
95+
the job complete much faster than using the default value of 1.
96+
97+
Return an unordered iterator of the results.
98+
"""
99+
with closing(ThreadPool(DEFAULT_POOLSIZE)) as pool:
100+
return pool.imap_unordered(func, iterable, chunksize)
101+
102+
103+
if LACK_SEM_OPEN or PY2:
104+
map_multiprocess = map_multithread = _map_fallback
105+
else:
106+
map_multiprocess = _map_multiprocess
107+
map_multithread = _map_multithread

tests/unit/test_utils_parallel.py

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
"""Test multiprocessing/multithreading higher-order functions."""
2+
3+
from importlib import import_module
4+
from math import factorial
5+
from sys import modules
6+
7+
from pip._vendor.six import PY2
8+
from pip._vendor.six.moves import map
9+
from pytest import mark
10+
11+
DUNDER_IMPORT = '__builtin__.__import__' if PY2 else 'builtins.__import__'
12+
FUNC, ITERABLE = factorial, range(42)
13+
MAPS = 'map_multiprocess', 'map_multithread'
14+
_import = __import__
15+
16+
17+
def reload_parallel():
18+
try:
19+
del modules['pip._internal.utils.parallel']
20+
except KeyError:
21+
pass
22+
return import_module('pip._internal.utils.parallel')
23+
24+
25+
def lack_sem_open(name, *args, **kwargs):
26+
"""Raise ImportError on import of multiprocessing.synchronize."""
27+
if name.endswith('synchronize'):
28+
raise ImportError
29+
return _import(name, *args, **kwargs)
30+
31+
32+
def have_sem_open(name, *args, **kwargs):
33+
"""Make sure multiprocessing.synchronize import is successful."""
34+
# We don't care about the return value
35+
# since we don't use the pool with this import.
36+
if name.endswith('synchronize'):
37+
return
38+
return _import(name, *args, **kwargs)
39+
40+
41+
@mark.parametrize('name', MAPS)
42+
def test_lack_sem_open(name, monkeypatch):
43+
"""Test fallback when sem_open is not available.
44+
45+
If so, multiprocessing[.dummy].Pool will fail to be created and
46+
map_async should fallback to map.
47+
"""
48+
monkeypatch.setattr(DUNDER_IMPORT, lack_sem_open)
49+
parallel = reload_parallel()
50+
assert getattr(parallel, name) is parallel._map_fallback
51+
52+
53+
@mark.parametrize('name', MAPS)
54+
def test_have_sem_open(name, monkeypatch):
55+
"""Test fallback when sem_open is available."""
56+
monkeypatch.setattr(DUNDER_IMPORT, have_sem_open)
57+
parallel = reload_parallel()
58+
impl = '_map_fallback' if PY2 else '_{}'.format(name)
59+
assert getattr(parallel, name) is getattr(parallel, impl)
60+
61+
62+
@mark.parametrize('name', MAPS)
63+
def test_map(name):
64+
"""Test correctness of result of asynchronous maps."""
65+
map_async = getattr(reload_parallel(), name)
66+
assert set(map_async(FUNC, ITERABLE)) == set(map(FUNC, ITERABLE))

0 commit comments

Comments
 (0)