Skip to content

Commit

Permalink
sequential scheduler: add variability on test_comm_creation, `mpi_c…
Browse files Browse the repository at this point in the history
…omm_creation_function`, `barrier_at_test_start`, `barrier_at_test_end`
  • Loading branch information
BerengerBerthoul committed Aug 29, 2024
1 parent 17f4913 commit f2d3eef
Showing 1 changed file with 91 additions and 40 deletions.
131 changes: 91 additions & 40 deletions pytest_parallel/mpi_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,30 @@ def sub_comm_from_ranks(global_comm, sub_ranks):
sub_comm = global_comm.Create_group(sub_group)
return sub_comm

def create_sub_comm_of_size(global_comm, n_proc, comm_creation_strategy):
if comm_creation_strategy == 'MPI_Comm_create':
def create_sub_comm_of_size(global_comm, n_proc, mpi_comm_creation_function):
if mpi_comm_creation_function == 'MPI_Comm_create':
return sub_comm_from_ranks(global_comm, range(0,n_proc))
elif comm_creation_strategy == 'MPI_Comm_split':
elif mpi_comm_creation_function == 'MPI_Comm_split':
if i_rank < n_proc_test:
color = 1
else:
color = MPI.UNDEFINED
return global_comm.Split(color, key=i_rank)
else:
assert 0, 'unknown MPI communicator creation strategy'
assert 0, 'unknown MPI communicator creation function'

def create_sub_comms_for_each_size(global_comm, comm_creation_strategy='MPI_Comm_create'):
def create_sub_comms_for_each_size(global_comm, mpi_comm_creation_function):
i_rank = global_comm.Get_rank()
n_rank = global_comm.Get_size()
sub_comms = [None] * n_rank
for i in range(0,n_rank):
n_proc = i+1
sub_comms[i] = create_sub_comm_of_size(global_comm, n_proc, comm_creation_strategy)
sub_comms[i] = create_sub_comm_of_size(global_comm, n_proc, mpi_comm_creation_function)
return sub_comms

def filter_and_add_sub_comm(items, global_comm):


def filter_and_add_sub_comm__old(items, global_comm):
i_rank = global_comm.Get_rank()
n_workers = global_comm.Get_size()

Expand All @@ -60,55 +62,102 @@ def filter_and_add_sub_comm(items, global_comm):
else:
item.sub_comm = MPI.COMM_NULL # TODO this should not be needed
else:
#if i_rank < n_proc_test:
# sub_comm = sub_comm_from_ranks(global_comm, range(0,n_proc_test))
#else:
# sub_comm = MPI.COMM_NULL
#print(f"\nTOTO {item=} {sub_comm=}")
#if i_rank < n_proc_test:
# color = 1
#else:
# color = MPI.UNDEFINED

#print(f"{i_rank=} {item=} {item.name=} {color=}")
#sub_comm = global_comm.Split(color, key=i_rank)
if i_rank < n_proc_test:
color = 1
else:
color = MPI.UNDEFINED

##if sub_comm != MPI.COMM_NULL:
## item.sub_comm = sub_comm
#item.sub_comm = sub_comm
sub_comm = global_comm.Split(color)

filtered_items += [item]
if sub_comm != MPI.COMM_NULL:
item.sub_comm = sub_comm
filtered_items += [item]

return filtered_items

def add_sub_comm(items, global_comm, test_comm_creation, mpi_comm_creation_function):
i_rank = global_comm.Get_rank()
n_rank = global_comm.Get_size()

# Strategy 'by_rank': create one sub-communicator by size, from sequential (size=1) to n_rank
if test_comm_creation == 'by_rank':
sub_comms = create_sub_comms_for_each_size(global_comm, mpi_comm_creation_function)
# Strategy 'by_test': create one sub-communicator per test (do not reuse communicators between tests
## Note: the sub-comms are created below (inside the item loop)

for item in items:
n_proc_test = get_n_proc_for_test(item)

if n_proc_test > n_rank: # not enough procs: mark as to be skipped
mark_skip(item)
item.sub_comm = MPI.COMM_NULL
#if n_proc_test > n_workers: # not enough procs: will be skipped
# if global_comm.Get_rank() == 0:
# item.sub_comm = MPI.COMM_SELF
# mark_skip(item)
# else:
# item.sub_comm = MPI.COMM_NULL # TODO this should not be needed
else:
if test_comm_creation == 'by_rank':
item.sub_comm = sub_comms[n_proc_test-1]
elif test_comm_creation == 'by_test':
item.sub_comm = create_sub_comm_of_size(global_comm, n_proc_test, mpi_comm_creation_function)
else:
assert 0, 'unknown test MPI communicator creation strategy'


class SequentialScheduler:
def __init__(self, global_comm):
def __init__(self, global_comm, test_comm_creation='by_rank', mpi_comm_creation_function='MPI_Comm_create', barrier_at_test_start=True, barrier_at_test_end=False):
self.global_comm = global_comm.Dup() # ensure that all communications within the framework are private to the framework
self.sub_comms = create_sub_comms_for_each_size(self.global_comm)
self.test_comm_creation = test_comm_creation
self.mpi_comm_creation_function = mpi_comm_creation_function

self.barrier_at_test_start = barrier_at_test_start
self.barrier_at_test_end = barrier_at_test_end

@pytest.hookimpl(trylast=True)
def pytest_collection_modifyitems(self, config, items):
items[:] = filter_and_add_sub_comm(items, self.global_comm)
add_sub_comm(items, self.global_comm, self.test_comm_creation, self.mpi_comm_creation_function)

#@pytest.hookimpl(tryfirst=True)
#def pytest_runtest_protocol(self, item, nextitem):
# #i_rank = self.global_comm.Get_rank()
# #n_proc_test = get_n_proc_for_test(item)
# #if i_rank < n_proc_test:
# # sub_comm = sub_comm_from_ranks(self.global_comm, range(0,n_proc_test))
# #else:
# # sub_comm = MPI.COMM_NULL
# #item.sub_comm = sub_comm
# n_proc_test = get_n_proc_for_test(item)
# #if n_proc_test <= self.global_comm.Get_size():
# #if n_proc_test < self.global_comm.rank:
# item.sub_comm = self.sub_comms[n_proc_test-1]
# #else:
# # item.sub_comm = MPI.COMM_NULL

@pytest.hookimpl(tryfirst=True)
@pytest.hookimpl(hookwrapper=True, tryfirst=True)
def pytest_runtest_protocol(self, item, nextitem):
#i_rank = self.global_comm.Get_rank()
#n_proc_test = get_n_proc_for_test(item)
#if i_rank < n_proc_test:
# sub_comm = sub_comm_from_ranks(self.global_comm, range(0,n_proc_test))
#else:
# sub_comm = MPI.COMM_NULL
#item.sub_comm = sub_comm
n_proc_test = get_n_proc_for_test(item)
item.sub_comm = self.sub_comms[n_proc_test-1]

if self.barrier_at_test_start:
self.global_comm.barrier()
_ = yield
if self.barrier_at_test_end:
self.global_comm.barrier()

#@pytest.hookimpl(tryfirst=True)
#def pytest_runtest_protocol(self, item, nextitem):
# pass
# #return True
# #if item.sub_comm != MPI.COMM_NULL:
# # _ = yield
# #else:
# # return True

@pytest.hookimpl(hookwrapper=True, tryfirst=True)
def pytest_pyfunc_call(self, pyfuncitem):
#self.global_comm.barrier()
if pyfuncitem.sub_comm != MPI.COMM_NULL:
_ = yield
# else: the rank does not participate in the test, so do nothing
#self.global_comm.barrier()
else: # the rank does not participate in the test, so do nothing
return True

@pytest.hookimpl(hookwrapper=True, tryfirst=True)
def pytest_runtestloop(self, session) -> bool:
Expand All @@ -133,6 +182,8 @@ def pytest_runtest_makereport(self, item):
def pytest_runtest_logreport(self, report):
if report.sub_comm != MPI.COMM_NULL:
gather_report_on_local_rank_0(report)
else:
return True # ranks that don't participate in the tests don't have to report anything


def group_items_by_parallel_steps(items, n_workers):
Expand Down

0 comments on commit f2d3eef

Please sign in to comment.