diff --git a/CRISPResso2/CRISPRessoMultiProcessing.py b/CRISPResso2/CRISPRessoMultiProcessing.py index 0ea1f813..dba4c1ad 100644 --- a/CRISPResso2/CRISPRessoMultiProcessing.py +++ b/CRISPResso2/CRISPRessoMultiProcessing.py @@ -29,17 +29,18 @@ def run_crispresso(crispresso_cmds, descriptor, idx): idx: index of the command to run """ crispresso_cmd=crispresso_cmds[idx] + logger = logging.getLogger(getmodule(stack()[1][0]).__name__) - logging.info('Running CRISPResso on %s #%d/%d: %s' % (descriptor, idx, len(crispresso_cmds), crispresso_cmd)) + logger.info('Running CRISPResso on %s #%d/%d: %s' % (descriptor, idx, len(crispresso_cmds), crispresso_cmd)) return_value = sb.call(crispresso_cmd, shell=True) if return_value == 137: - logging.warn('CRISPResso was killed by your system (return value %d) on %s #%d: "%s"\nPlease reduce the number of processes (-p) and run again.'%(return_value, descriptor, idx, crispresso_cmd)) + logger.warn('CRISPResso was killed by your system (return value %d) on %s #%d: "%s"\nPlease reduce the number of processes (-p) and run again.'%(return_value, descriptor, idx, crispresso_cmd)) elif return_value != 0: - logging.warn('CRISPResso command failed (return value %d) on %s #%d: "%s"'%(return_value, descriptor, idx, crispresso_cmd)) + logger.warn('CRISPResso command failed (return value %d) on %s #%d: "%s"'%(return_value, descriptor, idx, crispresso_cmd)) else: - logging.info('Finished CRISPResso %s #%d' %(descriptor, idx)) + logger.info('Finished CRISPResso %s #%d' %(descriptor, idx)) return return_value @@ -91,11 +92,12 @@ def run_crispresso_cmds(crispresso_cmds, n_processes="1", descriptor = 'region', int_n_processes = int(n_processes) logger.info("Running CRISPResso with %d processes" % int_n_processes) - pool = mp.Pool(processes=int_n_processes) + if int_n_processes > 1: + pool = mp.Pool(processes=int_n_processes) + pFunc = partial(run_crispresso, crispresso_cmds, descriptor) + p_wrapper = partial(wrapper, pFunc) idxs = range(len(crispresso_cmds)) ret_vals = [None] * len(crispresso_cmds) - pFunc = partial(run_crispresso, crispresso_cmds, descriptor) - p_wrapper = partial(wrapper, pFunc) if start_end_percent is not None: percent_complete_increment = start_end_percent[1] - start_end_percent[0] percent_complete_step = percent_complete_increment / len(crispresso_cmds) @@ -109,14 +111,24 @@ def run_crispresso_cmds(crispresso_cmds, n_processes="1", descriptor = 'region', signal.signal(signal.SIGINT, original_sigint_handler) try: completed = 0 - for idx, res in pool.imap_unordered(p_wrapper, enumerate(idxs)): - ret_vals[idx] = res - completed += 1 - percent_complete += percent_complete_step - logger.info( - "Completed {0}/{1} runs".format(completed, len(crispresso_cmds)), - {'percent_complete': percent_complete}, - ) + if int_n_processes == 1: + for idx, cmd in enumerate(crispresso_cmds): + ret_vals[idx] = run_crispresso(crispresso_cmds, descriptor, idx) + completed += 1 + percent_complete += percent_complete_step + logger.info( + "Completed {0}/{1} runs".format(completed, len(crispresso_cmds)), + {'percent_complete': percent_complete}, + ) + else: + for idx, res in pool.imap_unordered(p_wrapper, enumerate(idxs)): + ret_vals[idx] = res + completed += 1 + percent_complete += percent_complete_step + logger.info( + "Completed {0}/{1} runs".format(completed, len(crispresso_cmds)), + {'percent_complete': percent_complete}, + ) for idx, ret in enumerate(ret_vals): if ret == 137: raise Exception('CRISPResso %s #%d was killed by your system. Please decrease the number of processes (-p) and run again.'%(descriptor, idx)) @@ -135,8 +147,10 @@ def run_crispresso_cmds(crispresso_cmds, n_processes="1", descriptor = 'region', if descriptor.endswith("ch") or descriptor.endswith("sh"): plural = descriptor+"es" logger.info("Finished all " + plural) - pool.close() - pool.join() + if int_n_processes > 1: + pool.close() + if int_n_processes > 1: + pool.join() def run_pandas_apply_parallel(input_df, input_function_chunk, n_processes=1): """ @@ -163,7 +177,10 @@ def input_function_chunk(df): #shuffle the dataset to avoid finishing all the ones on top while leaving the ones on the bottom unfinished n_splits = min(n_processes, len(input_df)) df_split = np.array_split(input_df.sample(frac=1), n_splits) - pool = mp.Pool(processes = n_splits) + if n_processes > 1: + pool = mp.Pool(processes = n_splits) + else: + return input_function_chunk(input_df) #handle signals -- bug in python 2.7 (https://stackoverflow.com/questions/11312525/catch-ctrlc-sigint-and-exit-multiprocesses-gracefully-in-python) original_sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN) @@ -193,40 +210,55 @@ def run_function_on_array_chunk_parallel(input_array, input_function, n_processe input_function: function to run on chunks of the array input_function should take in a smaller array of objects """ - pool = mp.Pool(processes = n_processes) - - #handle signals -- bug in python 2.7 (https://stackoverflow.com/questions/11312525/catch-ctrlc-sigint-and-exit-multiprocesses-gracefully-in-python) - original_sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN) - signal.signal(signal.SIGINT, original_sigint_handler) - try: - n = int(max(10, len(input_array)/n_processes)) #don't parallelize unless at least 10 tasks - input_chunks = [input_array[i * n:(i + 1) * n] for i in range((len(input_array) + n - 1) // n )] - r = pool.map_async(input_function, input_chunks) - results = r.get(60*60*60) # Without the timeout this blocking call ignores all signals. - except KeyboardInterrupt: - pool.terminate() - logging.warn('Caught SIGINT. Program Terminated') - raise Exception('CRISPResso2 Terminated') - exit (0) - except Exception as e: - print('CRISPResso2 failed') - raise e + if n_processes == 1: + try: + results = input_function(input_array) + except Exception as e: + print('CRISPResso2 failed') + raise e + return results else: - pool.close() - pool.join() - return [y for x in results for y in x] + pool = mp.Pool(processes = n_processes) + + #handle signals -- bug in python 2.7 (https://stackoverflow.com/questions/11312525/catch-ctrlc-sigint-and-exit-multiprocesses-gracefully-in-python) + original_sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN) + signal.signal(signal.SIGINT, original_sigint_handler) + try: + n = int(max(10, len(input_array)/n_processes)) #don't parallelize unless at least 10 tasks + input_chunks = [input_array[i * n:(i + 1) * n] for i in range((len(input_array) + n - 1) // n )] + r = pool.map_async(input_function, input_chunks) + results = r.get(60*60*60) # Without the timeout this blocking call ignores all signals. + except KeyboardInterrupt: + pool.terminate() + logging.warn('Caught SIGINT. Program Terminated') + raise Exception('CRISPResso2 Terminated') + exit (0) + except Exception as e: + print('CRISPResso2 failed') + raise e + else: + pool.close() + pool.join() + return [y for x in results for y in x] def run_subprocess(cmd): return sb.call(cmd, shell=True) -def run_parallel_commands(commands_arr,n_processes=1,descriptor='CRISPResso2',continue_on_fail=False): +def run_parallel_commands(commands_arr, n_processes=1, descriptor='CRISPResso2', continue_on_fail=False): """ input: commands_arr: list of shell commands to run descriptor: string to print out to user describing run """ - pool = mp.Pool(processes = n_processes) + if n_processes > 1: + pool = mp.Pool(processes = n_processes) + else: + for idx, command in enumerate(commands_arr): + return_value = run_subprocess(command) + if return_value != 0 and not continue_on_fail: + raise Exception(f'{descriptor} #{idx} was failed') + return #handle signals -- bug in python 2.7 (https://stackoverflow.com/questions/11312525/catch-ctrlc-sigint-and-exit-multiprocesses-gracefully-in-python) original_sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN) @@ -281,4 +313,3 @@ def run_plot(plot_func, plot_args, num_processes, process_futures, process_pool) except Exception as e: logger.warn(f"Plot error {e}, skipping plot \n") logger.debug(traceback.format_exc()) -