1
- import numpy as np
1
+ import sys
2
+
2
3
import pytest
3
4
from mpi4py import MPI
4
5
5
6
from .algo import partition , lower_bound
6
- from .utils import get_n_proc_for_test , add_n_procs , run_item_test , mark_original_index
7
- from .utils_mpi import number_of_working_processes , is_dyn_master_process
7
+ from .utils . items import get_n_proc_for_test , add_n_procs , run_item_test , mark_original_index
8
+ from .utils . mpi import number_of_working_processes , is_dyn_master_process
8
9
from .gather_report import gather_report_on_local_rank_0
9
10
from .static_scheduler_utils import group_items_by_parallel_steps
10
11
11
12
12
13
def mark_skip (item ):
13
14
comm = MPI .COMM_WORLD
14
- n_rank = comm .Get_size ()
15
+ n_rank = comm .size
15
16
n_proc_test = get_n_proc_for_test (item )
16
17
skip_msg = f"Not enough procs to execute: { n_proc_test } required but only { n_rank } available"
17
18
item .add_marker (pytest .mark .skip (reason = skip_msg ), append = False )
@@ -28,7 +29,8 @@ def create_sub_comm_of_size(global_comm, n_proc, mpi_comm_creation_function):
28
29
if mpi_comm_creation_function == 'MPI_Comm_create' :
29
30
return sub_comm_from_ranks (global_comm , range (0 ,n_proc ))
30
31
elif mpi_comm_creation_function == 'MPI_Comm_split' :
31
- if i_rank < n_proc_test :
32
+ i_rank = global_comm .rank
33
+ if i_rank < n_proc :
32
34
color = 1
33
35
else :
34
36
color = MPI .UNDEFINED
@@ -37,8 +39,7 @@ def create_sub_comm_of_size(global_comm, n_proc, mpi_comm_creation_function):
37
39
assert 0 , 'Unknown MPI communicator creation function. Available: `MPI_Comm_create`, `MPI_Comm_split`'
38
40
39
41
def create_sub_comms_for_each_size (global_comm , mpi_comm_creation_function ):
40
- i_rank = global_comm .Get_rank ()
41
- n_rank = global_comm .Get_size ()
42
+ n_rank = global_comm .size
42
43
sub_comms = [None ] * n_rank
43
44
for i in range (0 ,n_rank ):
44
45
n_proc = i + 1
@@ -47,8 +48,7 @@ def create_sub_comms_for_each_size(global_comm, mpi_comm_creation_function):
47
48
48
49
49
50
def add_sub_comm (items , global_comm , test_comm_creation , mpi_comm_creation_function ):
50
- i_rank = global_comm .Get_rank ()
51
- n_rank = global_comm .Get_size ()
51
+ n_rank = global_comm .size
52
52
53
53
# Strategy 'by_rank': create one sub-communicator by size, from sequential (size=1) to n_rank
54
54
if test_comm_creation == 'by_rank' :
@@ -71,12 +71,17 @@ def add_sub_comm(items, global_comm, test_comm_creation, mpi_comm_creation_funct
71
71
assert 0 , 'Unknown test MPI communicator creation strategy. Available: `by_rank`, `by_test`'
72
72
73
73
class SequentialScheduler :
74
- def __init__ (self , global_comm , test_comm_creation = 'by_rank' , mpi_comm_creation_function = 'MPI_Comm_create' , barrier_at_test_start = True , barrier_at_test_end = True ):
74
+ def __init__ (self , global_comm ):
75
75
self .global_comm = global_comm .Dup () # ensure that all communications within the framework are private to the framework
76
- self .test_comm_creation = test_comm_creation
77
- self .mpi_comm_creation_function = mpi_comm_creation_function
78
- self .barrier_at_test_start = barrier_at_test_start
79
- self .barrier_at_test_end = barrier_at_test_end
76
+
77
+ # These parameters are not accessible through the API, but are left here for tweaking and experimenting
78
+ self .test_comm_creation = 'by_rank' # possible values : 'by_rank' | 'by_test'
79
+ self .mpi_comm_creation_function = 'MPI_Comm_create' # possible values : 'MPI_Comm_create' | 'MPI_Comm_split'
80
+ self .barrier_at_test_start = True
81
+ self .barrier_at_test_end = True
82
+ if sys .platform == "win32" :
83
+ self .mpi_comm_creation_function = 'MPI_Comm_split' # because 'MPI_Comm_create' uses `Create_group`,
84
+ # that is not implemented in mpi4py for Windows
80
85
81
86
@pytest .hookimpl (trylast = True )
82
87
def pytest_collection_modifyitems (self , config , items ):
@@ -86,20 +91,10 @@ def pytest_collection_modifyitems(self, config, items):
86
91
def pytest_runtest_protocol (self , item , nextitem ):
87
92
if self .barrier_at_test_start :
88
93
self .global_comm .barrier ()
89
- #print(f'pytest_runtest_protocol beg {MPI.COMM_WORLD.rank=}')
90
94
_ = yield
91
- #print(f'pytest_runtest_protocol end {MPI.COMM_WORLD.rank=}')
92
95
if self .barrier_at_test_end :
93
96
self .global_comm .barrier ()
94
97
95
- #@pytest.hookimpl(tryfirst=True)
96
- #def pytest_runtest_protocol(self, item, nextitem):
97
- # if self.barrier_at_test_start:
98
- # self.global_comm.barrier()
99
- # print(f'pytest_runtest_protocol beg {MPI.COMM_WORLD.rank=}')
100
- # if item.sub_comm == MPI.COMM_NULL:
101
- # return True # for this hook, `firstresult=True` so returning a non-None will stop other hooks to run
102
-
103
98
@pytest .hookimpl (tryfirst = True )
104
99
def pytest_pyfunc_call (self , pyfuncitem ):
105
100
#print(f'pytest_pyfunc_call {MPI.COMM_WORLD.rank=}')
@@ -113,7 +108,7 @@ def pytest_runtestloop(self, session) -> bool:
113
108
_ = yield
114
109
# prevent return value being non-zero (ExitCode.NO_TESTS_COLLECTED)
115
110
# when no test run on non-master
116
- if self .global_comm .Get_rank () != 0 and session .testscollected == 0 :
111
+ if self .global_comm .rank != 0 and session .testscollected == 0 :
117
112
session .testscollected = 1
118
113
return True
119
114
@@ -136,7 +131,7 @@ def pytest_runtest_logreport(self, report):
136
131
137
132
138
133
def prepare_items_to_run (items , comm ):
139
- i_rank = comm .Get_rank ()
134
+ i_rank = comm .rank
140
135
141
136
items_to_run = []
142
137
@@ -168,7 +163,7 @@ def prepare_items_to_run(items, comm):
168
163
169
164
170
165
def items_to_run_on_this_proc (items_by_steps , items_to_skip , comm ):
171
- i_rank = comm .Get_rank ()
166
+ i_rank = comm .rank
172
167
173
168
items = []
174
169
@@ -204,14 +199,13 @@ def pytest_runtestloop(self, session) -> bool:
204
199
and not session .config .option .continue_on_collection_errors
205
200
):
206
201
raise session .Interrupted (
207
- "%d error%s during collection"
208
- % (session .testsfailed , "s" if session .testsfailed != 1 else "" )
202
+ f"{ session .testsfailed } error{ 's' if session .testsfailed != 1 else '' } during collection"
209
203
)
210
204
211
205
if session .config .option .collectonly :
212
206
return True
213
207
214
- n_workers = self .global_comm .Get_size ()
208
+ n_workers = self .global_comm .size
215
209
216
210
add_n_procs (session .items )
217
211
@@ -221,20 +215,12 @@ def pytest_runtestloop(self, session) -> bool:
221
215
items_by_steps , items_to_skip , self .global_comm
222
216
)
223
217
224
- for i , item in enumerate (items ):
225
- # nextitem = items[i + 1] if i + 1 < len(items) else None
226
- # For optimization purposes, it would be nice to have the previous commented line
227
- # (`nextitem` is only used internally by PyTest in _setupstate.teardown_exact)
228
- # Here, it does not work:
229
- # it seems that things are messed up on rank 0
230
- # because the nextitem might not be run (see pytest_runtest_setup/call/teardown hooks just above)
231
- # In practice though, it seems that it is not the main thing that slows things down...
232
-
218
+ for item in items :
233
219
nextitem = None
234
220
run_item_test (item , nextitem , session )
235
221
236
222
# prevent return value being non-zero (ExitCode.NO_TESTS_COLLECTED) when no test run on non-master
237
- if self .global_comm .Get_rank () != 0 and session .testscollected == 0 :
223
+ if self .global_comm .rank != 0 and session .testscollected == 0 :
238
224
session .testscollected = 1
239
225
return True
240
226
@@ -256,8 +242,8 @@ def pytest_runtest_logreport(self, report):
256
242
gather_report_on_local_rank_0 (report )
257
243
258
244
# master ranks of each sub_comm must send their report to rank 0
259
- if sub_comm .Get_rank () == 0 : # only master are concerned
260
- if self .global_comm .Get_rank () != 0 : # if master is not global master, send
245
+ if sub_comm .rank == 0 : # only master are concerned
246
+ if self .global_comm .rank != 0 : # if master is not global master, send
261
247
self .global_comm .send (report , dest = 0 )
262
248
elif report .master_running_proc != 0 : # else, recv if test run remotely
263
249
# In the line below, MPI.ANY_TAG will NOT clash with communications outside the framework because self.global_comm is private
@@ -322,7 +308,7 @@ def schedule_test(item, available_procs, inter_comm):
322
308
323
309
# mark the procs as busy
324
310
for sub_rank in sub_ranks :
325
- available_procs [sub_rank ] = False
311
+ available_procs [sub_rank ] = 0
326
312
327
313
# TODO isend would be slightly better (less waiting)
328
314
for sub_rank in sub_ranks :
@@ -354,19 +340,19 @@ def wait_test_to_complete(items_to_run, session, available_procs, inter_comm):
354
340
for sub_rank in sub_ranks :
355
341
if sub_rank != first_rank_done :
356
342
rank_original_idx = inter_comm .recv (source = sub_rank , tag = WORK_DONE_TAG )
357
- assert ( rank_original_idx == original_idx ) # sub_rank is supposed to have worked on the same test
343
+ assert rank_original_idx == original_idx # sub_rank is supposed to have worked on the same test
358
344
359
345
# the procs are now available
360
346
for sub_rank in sub_ranks :
361
- available_procs [sub_rank ] = True
347
+ available_procs [sub_rank ] = 1
362
348
363
349
# "run" the test (i.e. trigger PyTest pipeline but do not really run the code)
364
350
nextitem = None # not known at this point
365
351
run_item_test (item , nextitem , session )
366
352
367
353
368
354
def wait_last_tests_to_complete (items_to_run , session , available_procs , inter_comm ):
369
- while np . sum (available_procs ) < len (available_procs ):
355
+ while sum (available_procs ) < len (available_procs ):
370
356
wait_test_to_complete (items_to_run , session , available_procs , inter_comm )
371
357
372
358
@@ -418,8 +404,7 @@ def pytest_runtestloop(self, session) -> bool:
418
404
and not session .config .option .continue_on_collection_errors
419
405
):
420
406
raise session .Interrupted (
421
- "%d error%s during collection"
422
- % (session .testsfailed , "s" if session .testsfailed != 1 else "" )
407
+ f"{ session .testsfailed } error{ 's' if session .testsfailed != 1 else '' } during collection"
423
408
)
424
409
425
410
if session .config .option .collectonly :
@@ -451,10 +436,10 @@ def pytest_runtestloop(self, session) -> bool:
451
436
452
437
# schedule tests to run
453
438
items_left_to_run = sorted (items_to_run , key = lambda item : item .n_proc )
454
- available_procs = np . ones ( n_workers , dtype = np . int8 )
439
+ available_procs = [ 1 ] * n_workers
455
440
456
441
while len (items_left_to_run ) > 0 :
457
- n_av_procs = np . sum (available_procs )
442
+ n_av_procs = sum (available_procs )
458
443
459
444
item_idx = item_with_biggest_admissible_n_proc (items_left_to_run , n_av_procs )
460
445
@@ -511,7 +496,7 @@ def pytest_runtest_logreport(self, report):
511
496
sub_comm = report .sub_comm
512
497
gather_report_on_local_rank_0 (report )
513
498
514
- if sub_comm .Get_rank () == 0 : # if local master proc, send
499
+ if sub_comm .rank == 0 : # if local master proc, send
515
500
# The idea of the scheduler is the following:
516
501
# The server schedules test over clients
517
502
# A client executes the test then report to the server it is done
0 commit comments