diff --git a/dask_cuda/benchmarks/local_cudf_merge.py b/dask_cuda/benchmarks/local_cudf_merge.py index 17b5edfef..0c5e5ef4b 100644 --- a/dask_cuda/benchmarks/local_cudf_merge.py +++ b/dask_cuda/benchmarks/local_cudf_merge.py @@ -135,8 +135,15 @@ def get_random_ddf(chunk_size, num_chunks, frac_match, chunk_type, args): def merge(args, ddf1, ddf2, write_profile): + + # Allow default broadcast behavior, unless + # "--shuffle-join" or "--broadcast-join" was + # specified (with "--shuffle-join" taking + # precedence) + broadcast = False if args.shuffle_join else (True if args.broadcast_join else None) + # Lazy merge/join operation - ddf_join = ddf1.merge(ddf2, on=["key"], how="inner") + ddf_join = ddf1.merge(ddf2, on=["key"], how="inner", broadcast=broadcast,) if args.set_index: ddf_join = ddf_join.set_index("key") @@ -163,10 +170,10 @@ def merge_explicit_comms(args, ddf1, ddf2): def run(client, args, n_workers, write_profile=None): # Generate random Dask dataframes ddf_base = get_random_ddf( - args.chunk_size, n_workers, args.frac_match, "build", args + args.chunk_size, args.base_chunks, args.frac_match, "build", args ).persist() ddf_other = get_random_ddf( - args.chunk_size, n_workers, args.frac_match, "other", args + args.chunk_size, args.other_chunks, args.frac_match, "other", args ).persist() wait(ddf_base) wait(ddf_other) @@ -228,6 +235,11 @@ def main(args): n_workers = len(scheduler_workers) client.wait_for_workers(n_workers) + # Allow the number of chunks to vary between + # the "base" and "other" DataFrames + args.base_chunks = args.base_chunks or n_workers + args.other_chunks = args.other_chunks or n_workers + if args.all_to_all: all_to_all(client) @@ -258,6 +270,10 @@ def main(args): for (w1, w2), nb in total_nbytes.items() } + broadcast = ( + "false" if args.shuffle_join else ("true" if args.broadcast_join else "default") + ) + t_runs = numpy.empty(len(took_list)) if args.markdown: print("```") @@ -266,6 +282,9 @@ def main(args): print(f"backend | {args.backend}") print(f"merge type | {args.type}") print(f"rows-per-chunk | {args.chunk_size}") + print(f"base-chunks | {args.base_chunks}") + print(f"other-chunks | {args.other_chunks}") + print(f"broadcast | {broadcast}") print(f"protocol | {args.protocol}") print(f"device(s) | {args.devs}") print(f"rmm-pool | {(not args.disable_rmm_pool)}") @@ -334,6 +353,28 @@ def parse_args(): "type": int, "help": "Chunk size (default 1_000_000)", }, + { + "name": "--base-chunks", + "default": None, + "type": int, + "help": "Number of base-DataFrame partitions (default: n_workers)", + }, + { + "name": "--other-chunks", + "default": None, + "type": int, + "help": "Number of other-DataFrame partitions (default: n_workers)", + }, + { + "name": "--broadcast-join", + "action": "store_true", + "help": "Use broadcast join when possible.", + }, + { + "name": "--shuffle-join", + "action": "store_true", + "help": "Use shuffle join (takes precedence over '--broadcast-join').", + }, { "name": "--ignore-size", "default": "1 MiB",