Skip to content

Commit f37989c

Browse files
Merge branch 'dev_report_crash'
2 parents bbf4978 + f47726d commit f37989c

31 files changed

+810
-445
lines changed

.github/workflows/test.yml

+10-3
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ jobs:
1616
pylint:
1717
runs-on: ubuntu-latest
1818
steps:
19-
- uses: actions/checkout@v3
19+
- uses: actions/checkout@v4
2020
- uses: mpi4py/setup-mpi@v1
2121
- name: Set up Python
22-
uses: actions/setup-python@v4
22+
uses: actions/setup-python@v5
2323
with:
2424
python-version: "3.10"
2525
- name: Install dependencies
@@ -30,7 +30,7 @@ jobs:
3030
pip install -r *.egg-info/requires.txt
3131
- name: Analysing the code with pylint
3232
run: |
33-
pylint --unsafe-load-any-extension=y --disable=fixme $(git ls-files '*.py') || true
33+
pylint --unsafe-load-any-extension=y --disable=fixme $(git ls-files "pytest_parallel/*.py" "test/*.py") || true
3434
3535
build:
3636
needs: [pylint]
@@ -67,12 +67,19 @@ jobs:
6767
mpi: intelmpi
6868
- os: ubuntu-latest
6969
mpi: msmpi
70+
# mpich seems broken on Ubuntu
7071
- os: ubuntu-latest
7172
py-version: 3.8
7273
mpi: mpich
7374
- os: ubuntu-latest
7475
py-version: 3.9
7576
mpi: mpich
77+
- os: ubuntu-latest
78+
py-version: 3.10
79+
mpi: mpich
80+
- os: ubuntu-latest
81+
py-version: 3.11
82+
mpi: mpich
7683
name: ${{ matrix.mpi }} - ${{matrix.py-version}} - ${{matrix.os}}
7784
steps:
7885
- name: Checkout

.slurm_draft/worker.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111
test_idx = int(sys.argv[3])
1212

1313
comm = MPI.COMM_WORLD
14-
print(f'start at {scheduler_ip}@{server_port} test {test_idx} at rank {comm.Get_rank()}/{comm.Get_size()} exec on {socket.gethostname()} - ',datetime.datetime.now())
14+
print(f'start at {scheduler_ip}@{server_port} test {test_idx} at rank {comm.rank}/{comm.size} exec on {socket.gethostname()} - ',datetime.datetime.now())
1515

16-
if comm.Get_rank() == 0:
16+
if comm.rank == 0:
1717
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
1818
s.connect((scheduler_ip, server_port))
1919
#time.sleep(10+5*test_idx)
20-
#msg = f'Hello from test {test_idx} at rank {comm.Get_rank()}/{comm.Get_size()} exec on {socket.gethostname()}'
20+
#msg = f'Hello from test {test_idx} at rank {comm.rank}/{comm.size} exec on {socket.gethostname()}'
2121
#socket_utils.send(s, msg)
2222
info = {
2323
'test_idx': test_idx,

CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ cmake_policy(SET CMP0074 NEW) # force find_package to take <PackageName>_ROOT va
99
# Project
1010
# ----------------------------------------------------------------------
1111
project(
12-
pytest_parallel VERSION 1.2.0
12+
pytest_parallel VERSION 1.3.0
1313
DESCRIPTION "pytest_parallel extends PyTest to support parallel testing using mpi4py"
1414
)
1515

README.md

+275-41
Large diffs are not rendered by default.

doc/images/test_fail.png

2.71 KB
Loading

doc/images/test_skip.png

15.3 KB
Loading

pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ authors = [
1818
{name = "Berenger Berthoul", email = "[email protected]"},
1919
]
2020
maintainers = [
21-
{name = "Bruno Maugars", email = "[email protected]"},
21+
{name = "Berenger Berthoul", email = "[email protected]"},
2222
]
2323
license = {text = "Mozilla Public License 2.0"}
2424
keywords = [
@@ -52,7 +52,7 @@ dependencies = [
5252
"mpi4py",
5353
"numpy",
5454
]
55-
version = "1.2.0"
55+
version = "1.3.0"
5656

5757
[project.urls]
5858
Homepage = "https://github.com/onera/pytest_parallel"

pytest_parallel/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
__version__ = "1.2"
1+
__version__ = "1.3"
22

33
from . import mark

pytest_parallel/gather_report.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ def gather_report_on_local_rank_0(report):
4545
del report.sub_comm # No need to keep it in the report
4646
# Furthermore we need to serialize the report
4747
# and mpi4py does not know how to serialize report.sub_comm
48-
i_sub_rank = sub_comm.Get_rank()
49-
n_sub_rank = sub_comm.Get_size()
48+
i_sub_rank = sub_comm.rank
49+
n_sub_rank = sub_comm.size
5050

5151
if (
5252
report.outcome != "skipped"

pytest_parallel/mpi_reporter.py

+36-51
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
1-
import numpy as np
1+
import sys
2+
23
import pytest
34
from mpi4py import MPI
45

56
from .algo import partition, lower_bound
6-
from .utils import get_n_proc_for_test, add_n_procs, run_item_test, mark_original_index
7-
from .utils_mpi import number_of_working_processes, is_dyn_master_process
7+
from .utils.items import get_n_proc_for_test, add_n_procs, run_item_test, mark_original_index
8+
from .utils.mpi import number_of_working_processes, is_dyn_master_process
89
from .gather_report import gather_report_on_local_rank_0
910
from .static_scheduler_utils import group_items_by_parallel_steps
1011

1112

1213
def mark_skip(item):
1314
comm = MPI.COMM_WORLD
14-
n_rank = comm.Get_size()
15+
n_rank = comm.size
1516
n_proc_test = get_n_proc_for_test(item)
1617
skip_msg = f"Not enough procs to execute: {n_proc_test} required but only {n_rank} available"
1718
item.add_marker(pytest.mark.skip(reason=skip_msg), append=False)
@@ -28,7 +29,8 @@ def create_sub_comm_of_size(global_comm, n_proc, mpi_comm_creation_function):
2829
if mpi_comm_creation_function == 'MPI_Comm_create':
2930
return sub_comm_from_ranks(global_comm, range(0,n_proc))
3031
elif mpi_comm_creation_function == 'MPI_Comm_split':
31-
if i_rank < n_proc_test:
32+
i_rank = global_comm.rank
33+
if i_rank < n_proc:
3234
color = 1
3335
else:
3436
color = MPI.UNDEFINED
@@ -37,8 +39,7 @@ def create_sub_comm_of_size(global_comm, n_proc, mpi_comm_creation_function):
3739
assert 0, 'Unknown MPI communicator creation function. Available: `MPI_Comm_create`, `MPI_Comm_split`'
3840

3941
def create_sub_comms_for_each_size(global_comm, mpi_comm_creation_function):
40-
i_rank = global_comm.Get_rank()
41-
n_rank = global_comm.Get_size()
42+
n_rank = global_comm.size
4243
sub_comms = [None] * n_rank
4344
for i in range(0,n_rank):
4445
n_proc = i+1
@@ -47,8 +48,7 @@ def create_sub_comms_for_each_size(global_comm, mpi_comm_creation_function):
4748

4849

4950
def add_sub_comm(items, global_comm, test_comm_creation, mpi_comm_creation_function):
50-
i_rank = global_comm.Get_rank()
51-
n_rank = global_comm.Get_size()
51+
n_rank = global_comm.size
5252

5353
# Strategy 'by_rank': create one sub-communicator by size, from sequential (size=1) to n_rank
5454
if test_comm_creation == 'by_rank':
@@ -71,12 +71,17 @@ def add_sub_comm(items, global_comm, test_comm_creation, mpi_comm_creation_funct
7171
assert 0, 'Unknown test MPI communicator creation strategy. Available: `by_rank`, `by_test`'
7272

7373
class SequentialScheduler:
74-
def __init__(self, global_comm, test_comm_creation='by_rank', mpi_comm_creation_function='MPI_Comm_create', barrier_at_test_start=True, barrier_at_test_end=True):
74+
def __init__(self, global_comm):
7575
self.global_comm = global_comm.Dup() # ensure that all communications within the framework are private to the framework
76-
self.test_comm_creation = test_comm_creation
77-
self.mpi_comm_creation_function = mpi_comm_creation_function
78-
self.barrier_at_test_start = barrier_at_test_start
79-
self.barrier_at_test_end = barrier_at_test_end
76+
77+
# These parameters are not accessible through the API, but are left here for tweaking and experimenting
78+
self.test_comm_creation = 'by_rank' # possible values : 'by_rank' | 'by_test'
79+
self.mpi_comm_creation_function = 'MPI_Comm_create' # possible values : 'MPI_Comm_create' | 'MPI_Comm_split'
80+
self.barrier_at_test_start = True
81+
self.barrier_at_test_end = True
82+
if sys.platform == "win32":
83+
self.mpi_comm_creation_function = 'MPI_Comm_split' # because 'MPI_Comm_create' uses `Create_group`,
84+
# that is not implemented in mpi4py for Windows
8085

8186
@pytest.hookimpl(trylast=True)
8287
def pytest_collection_modifyitems(self, config, items):
@@ -86,20 +91,10 @@ def pytest_collection_modifyitems(self, config, items):
8691
def pytest_runtest_protocol(self, item, nextitem):
8792
if self.barrier_at_test_start:
8893
self.global_comm.barrier()
89-
#print(f'pytest_runtest_protocol beg {MPI.COMM_WORLD.rank=}')
9094
_ = yield
91-
#print(f'pytest_runtest_protocol end {MPI.COMM_WORLD.rank=}')
9295
if self.barrier_at_test_end:
9396
self.global_comm.barrier()
9497

95-
#@pytest.hookimpl(tryfirst=True)
96-
#def pytest_runtest_protocol(self, item, nextitem):
97-
# if self.barrier_at_test_start:
98-
# self.global_comm.barrier()
99-
# print(f'pytest_runtest_protocol beg {MPI.COMM_WORLD.rank=}')
100-
# if item.sub_comm == MPI.COMM_NULL:
101-
# return True # for this hook, `firstresult=True` so returning a non-None will stop other hooks to run
102-
10398
@pytest.hookimpl(tryfirst=True)
10499
def pytest_pyfunc_call(self, pyfuncitem):
105100
#print(f'pytest_pyfunc_call {MPI.COMM_WORLD.rank=}')
@@ -113,7 +108,7 @@ def pytest_runtestloop(self, session) -> bool:
113108
_ = yield
114109
# prevent return value being non-zero (ExitCode.NO_TESTS_COLLECTED)
115110
# when no test run on non-master
116-
if self.global_comm.Get_rank() != 0 and session.testscollected == 0:
111+
if self.global_comm.rank != 0 and session.testscollected == 0:
117112
session.testscollected = 1
118113
return True
119114

@@ -136,7 +131,7 @@ def pytest_runtest_logreport(self, report):
136131

137132

138133
def prepare_items_to_run(items, comm):
139-
i_rank = comm.Get_rank()
134+
i_rank = comm.rank
140135

141136
items_to_run = []
142137

@@ -168,7 +163,7 @@ def prepare_items_to_run(items, comm):
168163

169164

170165
def items_to_run_on_this_proc(items_by_steps, items_to_skip, comm):
171-
i_rank = comm.Get_rank()
166+
i_rank = comm.rank
172167

173168
items = []
174169

@@ -204,14 +199,13 @@ def pytest_runtestloop(self, session) -> bool:
204199
and not session.config.option.continue_on_collection_errors
205200
):
206201
raise session.Interrupted(
207-
"%d error%s during collection"
208-
% (session.testsfailed, "s" if session.testsfailed != 1 else "")
202+
f"{session.testsfailed} error{'s' if session.testsfailed != 1 else ''} during collection"
209203
)
210204

211205
if session.config.option.collectonly:
212206
return True
213207

214-
n_workers = self.global_comm.Get_size()
208+
n_workers = self.global_comm.size
215209

216210
add_n_procs(session.items)
217211

@@ -221,20 +215,12 @@ def pytest_runtestloop(self, session) -> bool:
221215
items_by_steps, items_to_skip, self.global_comm
222216
)
223217

224-
for i, item in enumerate(items):
225-
# nextitem = items[i + 1] if i + 1 < len(items) else None
226-
# For optimization purposes, it would be nice to have the previous commented line
227-
# (`nextitem` is only used internally by PyTest in _setupstate.teardown_exact)
228-
# Here, it does not work:
229-
# it seems that things are messed up on rank 0
230-
# because the nextitem might not be run (see pytest_runtest_setup/call/teardown hooks just above)
231-
# In practice though, it seems that it is not the main thing that slows things down...
232-
218+
for item in items:
233219
nextitem = None
234220
run_item_test(item, nextitem, session)
235221

236222
# prevent return value being non-zero (ExitCode.NO_TESTS_COLLECTED) when no test run on non-master
237-
if self.global_comm.Get_rank() != 0 and session.testscollected == 0:
223+
if self.global_comm.rank != 0 and session.testscollected == 0:
238224
session.testscollected = 1
239225
return True
240226

@@ -256,8 +242,8 @@ def pytest_runtest_logreport(self, report):
256242
gather_report_on_local_rank_0(report)
257243

258244
# master ranks of each sub_comm must send their report to rank 0
259-
if sub_comm.Get_rank() == 0: # only master are concerned
260-
if self.global_comm.Get_rank() != 0: # if master is not global master, send
245+
if sub_comm.rank == 0: # only master are concerned
246+
if self.global_comm.rank != 0: # if master is not global master, send
261247
self.global_comm.send(report, dest=0)
262248
elif report.master_running_proc != 0: # else, recv if test run remotely
263249
# In the line below, MPI.ANY_TAG will NOT clash with communications outside the framework because self.global_comm is private
@@ -322,7 +308,7 @@ def schedule_test(item, available_procs, inter_comm):
322308

323309
# mark the procs as busy
324310
for sub_rank in sub_ranks:
325-
available_procs[sub_rank] = False
311+
available_procs[sub_rank] = 0
326312

327313
# TODO isend would be slightly better (less waiting)
328314
for sub_rank in sub_ranks:
@@ -354,19 +340,19 @@ def wait_test_to_complete(items_to_run, session, available_procs, inter_comm):
354340
for sub_rank in sub_ranks:
355341
if sub_rank != first_rank_done:
356342
rank_original_idx = inter_comm.recv(source=sub_rank, tag=WORK_DONE_TAG)
357-
assert (rank_original_idx == original_idx) # sub_rank is supposed to have worked on the same test
343+
assert rank_original_idx == original_idx # sub_rank is supposed to have worked on the same test
358344

359345
# the procs are now available
360346
for sub_rank in sub_ranks:
361-
available_procs[sub_rank] = True
347+
available_procs[sub_rank] = 1
362348

363349
# "run" the test (i.e. trigger PyTest pipeline but do not really run the code)
364350
nextitem = None # not known at this point
365351
run_item_test(item, nextitem, session)
366352

367353

368354
def wait_last_tests_to_complete(items_to_run, session, available_procs, inter_comm):
369-
while np.sum(available_procs) < len(available_procs):
355+
while sum(available_procs) < len(available_procs):
370356
wait_test_to_complete(items_to_run, session, available_procs, inter_comm)
371357

372358

@@ -418,8 +404,7 @@ def pytest_runtestloop(self, session) -> bool:
418404
and not session.config.option.continue_on_collection_errors
419405
):
420406
raise session.Interrupted(
421-
"%d error%s during collection"
422-
% (session.testsfailed, "s" if session.testsfailed != 1 else "")
407+
f"{session.testsfailed} error{'s' if session.testsfailed != 1 else ''} during collection"
423408
)
424409

425410
if session.config.option.collectonly:
@@ -451,10 +436,10 @@ def pytest_runtestloop(self, session) -> bool:
451436

452437
# schedule tests to run
453438
items_left_to_run = sorted(items_to_run, key=lambda item: item.n_proc)
454-
available_procs = np.ones(n_workers, dtype=np.int8)
439+
available_procs = [1] * n_workers
455440

456441
while len(items_left_to_run) > 0:
457-
n_av_procs = np.sum(available_procs)
442+
n_av_procs = sum(available_procs)
458443

459444
item_idx = item_with_biggest_admissible_n_proc(items_left_to_run, n_av_procs)
460445

@@ -511,7 +496,7 @@ def pytest_runtest_logreport(self, report):
511496
sub_comm = report.sub_comm
512497
gather_report_on_local_rank_0(report)
513498

514-
if sub_comm.Get_rank() == 0: # if local master proc, send
499+
if sub_comm.rank == 0: # if local master proc, send
515500
# The idea of the scheduler is the following:
516501
# The server schedules test over clients
517502
# A client executes the test then report to the server it is done

0 commit comments

Comments
 (0)