diff --git a/benchmark/loader/neighbor_loader.py b/benchmark/loader/neighbor_loader.py index cd8f7ed0186b..17e58d985ebb 100644 --- a/benchmark/loader/neighbor_loader.py +++ b/benchmark/loader/neighbor_loader.py @@ -3,6 +3,7 @@ from timeit import default_timer import tqdm +import ast from ogb.nodeproppred import PygNodePropPredDataset import torch_geometric.transforms as T @@ -83,10 +84,14 @@ def run(args: argparse.ArgumentParser) -> None: add('--device', default='cpu') add('--datasets', nargs="+", default=['arxiv', 'products', 'mag']) add('--root', default='../../data') - add('--batch-sizes', default=[8192, 4096, 2048, 1024, 512]) - add('--eval-batch-sizes', default=[16384, 8192, 4096, 2048, 1024, 512]) - add('--homo-neighbor_sizes', default=[[10, 5], [15, 10, 5], [20, 15, 10]]) - add('--hetero-neighbor_sizes', default=[[5], [10], [10, 5]], type=int) + add('--batch-sizes', default=[8192, 4096, 2048, 1024, 512], type=int, + nargs='+') + add('--eval-batch-sizes', default=[16384, 8192, 4096, 2048, 1024, 512], + type=int, nargs='+') + add('--homo-neighbor_sizes', default=[[10, 5], [15, 10, 5], [20, 15, 10]], + type=ast.literal_eval) + add('--hetero-neighbor_sizes', default=[[5], [10], [10, 5]], + type=ast.literal_eval) add('--num-workers', default=0) add('--runs', default=3)