diff --git a/examples/runner_sample.py b/examples/runner_sample.py index a0c9a3e13..fa34299ad 100755 --- a/examples/runner_sample.py +++ b/examples/runner_sample.py @@ -20,7 +20,6 @@ def __init__(self, N, sources, targets, get_chunk_size, process_chunk_size=1): super(MyDumbBuilder, self).__init__(sources, targets, get_chunk_size, process_chunk_size) self.N = N - def get_items(self): for i in range(self.N): @@ -28,13 +27,13 @@ def get_items(self): def process_item(self, item): print("processing item: {}".format(item)) - #time.sleep(random.randint(0,3)) + # time.sleep(random.randint(0,3)) return {item: "processed"} def update_targets(self, items): print("Updating targets ...") - print("Received {} processed items".format(len(items))) - print("Processed items: {}".format(items)) + # print("Received {} processed items".format(len(items))) + print("Updated items: {}".format(list(items))) def finalize(self): print("Finalizing ...") @@ -42,14 +41,16 @@ def finalize(self): if __name__ == '__main__': - N=10 - get_chunk_size=3 + N = 10 + get_chunk_size = 3 + process_chunk_size = 2 stores = [MemoryStore(str(i)) for i in range(7)] sources = [stores[0], stores[1], stores[3]] targets = [stores[3], stores[6]] - mdb = MyDumbBuilder(N, sources, targets, get_chunk_size=get_chunk_size) + mdb = MyDumbBuilder(N, sources, targets, get_chunk_size=get_chunk_size, + process_chunk_size=process_chunk_size) builders = [mdb] diff --git a/maggma/runner.py b/maggma/runner.py index 7cddf532a..2322dad0c 100644 --- a/maggma/runner.py +++ b/maggma/runner.py @@ -9,6 +9,7 @@ from monty.json import MSONable from maggma.helpers import get_mpi +from maggma.utils import grouper logger = logging.getLogger(__name__) sh = logging.StreamHandler(stream=sys.stdout) @@ -33,10 +34,14 @@ def __init__(self, builders, num_workers=0): self.status = [] @abc.abstractmethod - def process(self): + def process(self, builder_id): """ Does the processing. e.g. send work to workers(in MPI) or start the processes in multiprocessing. + + Args: + builder_id (int): process the builder_id th builder i.e + process_item --> update_targets --> finalize """ pass @@ -47,6 +52,21 @@ def worker(self): """ pass + def update_targets_in_chunks(self, builder_id, processed_items): + """ + Run the builder's update_targets method on the list of processed items in chunks of size + 'process_chunk_size'. + + Args: + builder_id (int): + processed_items (list): list of items to be used to update the targets + """ + chunk_size = self.builders[builder_id].process_chunk_size + if chunk_size > 0: + print("updating targets in batches of {}".format(chunk_size)) + for pitems in grouper(processed_items, chunk_size): + self.builders[builder_id].update_targets(filter(None, pitems)) + class MPIProcessor(BaseProcessor): @@ -72,7 +92,6 @@ def process(self, builder_id): def master(self, builder_id): print("Building with MPI. {} workers in the pool.".format(self.size - 1)) - processed_items = [] builder = self.builders[builder_id] chunk_size = builder.get_chunk_size @@ -89,7 +108,7 @@ def master(self, builder_id): if n % chunk_size == 0: print("processing chunks of size {}".format(chunk_size)) processed_chunk = self._process_chunk(chunk_size, workers) - processed_items.extend(processed_chunk) + self.update_targets_in_chunks(builder_id, processed_chunk) packet = (builder_id, item) wid = next(worker_id) workers.append(wid) @@ -99,15 +118,12 @@ def master(self, builder_id): # in case the total number of items is not divisible by chunk_size, process the leftovers. if workers: processed_chunk = self._process_chunk(chunk_size, workers) - processed_items.extend(processed_chunk) + self.update_targets_in_chunks(builder_id, processed_chunk) # kill workers for _ in range(self.size - 1): self.comm.send(None, dest=next(worker_id)) - # update the targets - builder.update_targets(processed_items) - # finalize if all(self.status): builder.finalize() @@ -190,7 +206,7 @@ def process(self, builder_id): builder_id (int): the index of the builder in the builders list """ builder = self.builders[builder_id] - chunk_size = builder.get_chunk_size + get_chunk_size = builder.get_chunk_size # establish connection to the sources and targets builder.connect() @@ -198,18 +214,18 @@ def process(self, builder_id): n = 0 # send items to process for item in builder.get_items(): - if n % chunk_size == 0: - print("processing chunks of size {}".format(chunk_size)) + if n > 0 and n % get_chunk_size == 0: + print("processing batch of {} items".format(get_chunk_size)) self._process_chunk() + self.update_targets_in_chunks(builder_id, self.processed_items) + del self.processed_items[:] packet = (builder_id, item) self._queue.put(packet) n += 1 # handle the leftovers self._process_chunk() - - # update the targets - builder.update_targets(self.processed_items) + self.update_targets_in_chunks(builder_id, self.processed_items) # finalize if all(self.status): diff --git a/maggma/utils.py b/maggma/utils.py index ebad13837..f71a5f9d5 100644 --- a/maggma/utils.py +++ b/maggma/utils.py @@ -1,4 +1,5 @@ # coding: utf-8 +import itertools import six @@ -58,3 +59,10 @@ def recursive_update(d, u): d[k] = v else: d[k] = v + +def grouper(iterable, n, fillvalue=None): + """Collect data into fixed-length chunks or blocks.""" + # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx + args = [iter(iterable)] * n + return itertools.zip_longest(*args, fillvalue=fillvalue) +