From 96fbf43f5f9c3d526af0a04b4ebd3ec5164caf55 Mon Sep 17 00:00:00 2001 From: YanbingJiang Date: Tue, 30 Aug 2022 22:12:15 +0800 Subject: [PATCH] Enable `bf16` support in the benchmark scripts (#5293) * Enable bf16 support for benchmark * Enable bf16 support for kernel, points and inference * Support autocast both for cpu and cuda * Update benchmark/citation/train_eval.py * Update benchmark/points/train_eval.py * Update benchmark/kernel/main_performance.py * Update CHANGELOG.md Co-authored-by: Matthias Fey --- CHANGELOG.md | 1 + benchmark/citation/appnp.py | 4 +- benchmark/citation/arma.py | 4 +- benchmark/citation/cheb.py | 4 +- benchmark/citation/gat.py | 4 +- benchmark/citation/gcn.py | 4 +- benchmark/citation/sgc.py | 4 +- benchmark/citation/train_eval.py | 31 +++++++++----- benchmark/inference/inference_benchmark.py | 48 ++++++++++++++-------- benchmark/kernel/main_performance.py | 30 ++++++++------ benchmark/kernel/train_eval.py | 4 +- benchmark/points/edge_cnn.py | 3 +- benchmark/points/mpnn.py | 3 +- benchmark/points/point_cnn.py | 3 +- benchmark/points/point_net.py | 3 +- benchmark/points/spline_cnn.py | 3 +- benchmark/points/train_eval.py | 37 ++++++++++------- torch_geometric/nn/conv/eg_conv.py | 4 +- torch_geometric/nn/conv/gcn_conv.py | 4 +- torch_geometric/nn/conv/pdn_conv.py | 5 ++- torch_geometric/nn/models/basic_gnn.py | 3 +- 21 files changed, 132 insertions(+), 74 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b3f81e9c99a1..79a055c8fae7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [2.2.0] - 2022-MM-DD ### Added +- Enabled `bf16` support in benchmark scripts ([#5293](https://github.com/pyg-team/pytorch_geometric/pull/5293)) - Added `Aggregation.set_validate_args` option to skip validation of `dim_size` ([#5290](https://github.com/pyg-team/pytorch_geometric/pull/5290)) - Added `SparseTensor` support to inference benchmark suite ([#5242](https://github.com/pyg-team/pytorch_geometric/pull/5242), [#5258](https://github.com/pyg-team/pytorch_geometric/pull/5258)) - Added experimental mode in inference benchmarks ([#5254](https://github.com/pyg-team/pytorch_geometric/pull/5254)) diff --git a/benchmark/citation/appnp.py b/benchmark/citation/appnp.py index 03573cacbccf..2f4a60bfa134 100644 --- a/benchmark/citation/appnp.py +++ b/benchmark/citation/appnp.py @@ -23,6 +23,7 @@ parser.add_argument('--alpha', type=float, default=0.1) parser.add_argument('--inference', action='store_true') parser.add_argument('--profile', action='store_true') +parser.add_argument('--bf16', action='store_true') args = parser.parse_args() @@ -50,7 +51,8 @@ def forward(self, data): dataset = get_planetoid_dataset(args.dataset, not args.no_normalize_features) permute_masks = random_planetoid_splits if args.random_splits else None run(dataset, Net(dataset), args.runs, args.epochs, args.lr, args.weight_decay, - args.early_stopping, args.inference, args.profile, permute_masks) + args.early_stopping, args.inference, args.profile, args.bf16, + permute_masks) if args.profile: rename_profile_file('citation', APPNP.__name__, args.dataset, diff --git a/benchmark/citation/arma.py b/benchmark/citation/arma.py index 65cc9029fe5f..2a756fe31415 100644 --- a/benchmark/citation/arma.py +++ b/benchmark/citation/arma.py @@ -24,6 +24,7 @@ parser.add_argument('--skip_dropout', type=float, default=0.75) parser.add_argument('--inference', action='store_true') parser.add_argument('--profile', action='store_true') +parser.add_argument('--bf16', action='store_true') args = parser.parse_args() @@ -52,7 +53,8 @@ def forward(self, data): dataset = get_planetoid_dataset(args.dataset, not args.no_normalize_features) permute_masks = random_planetoid_splits if args.random_splits else None run(dataset, Net(dataset), args.runs, args.epochs, args.lr, args.weight_decay, - args.early_stopping, args.inference, args.profile, permute_masks) + args.early_stopping, args.inference, args.profile, args.bf16, + permute_masks) if args.profile: rename_profile_file('citation', ARMAConv.__name__, args.dataset, diff --git a/benchmark/citation/cheb.py b/benchmark/citation/cheb.py index 79f0182adc2c..199b9444279b 100644 --- a/benchmark/citation/cheb.py +++ b/benchmark/citation/cheb.py @@ -21,6 +21,7 @@ parser.add_argument('--num_hops', type=int, default=3) parser.add_argument('--inference', action='store_true') parser.add_argument('--profile', action='store_true') +parser.add_argument('--bf16', action='store_true') args = parser.parse_args() @@ -45,7 +46,8 @@ def forward(self, data): dataset = get_planetoid_dataset(args.dataset, not args.no_normalize_features) permute_masks = random_planetoid_splits if args.random_splits else None run(dataset, Net(dataset), args.runs, args.epochs, args.lr, args.weight_decay, - args.early_stopping, args.inference, args.profile, permute_masks) + args.early_stopping, args.inference, args.profile, args.bf16, + permute_masks) if args.profile: rename_profile_file('citation', ChebConv.__name__, args.dataset, diff --git a/benchmark/citation/gat.py b/benchmark/citation/gat.py index 6d171370c269..214e6592b9bd 100644 --- a/benchmark/citation/gat.py +++ b/benchmark/citation/gat.py @@ -22,6 +22,7 @@ parser.add_argument('--output_heads', type=int, default=1) parser.add_argument('--inference', action='store_true') parser.add_argument('--profile', action='store_true') +parser.add_argument('--bf16', action='store_true') args = parser.parse_args() @@ -50,7 +51,8 @@ def forward(self, data): dataset = get_planetoid_dataset(args.dataset, not args.no_normalize_features) permute_masks = random_planetoid_splits if args.random_splits else None run(dataset, Net(dataset), args.runs, args.epochs, args.lr, args.weight_decay, - args.early_stopping, args.inference, args.profile, permute_masks) + args.early_stopping, args.inference, args.profile, args.bf16, + permute_masks) if args.profile: rename_profile_file('citation', GATConv.__name__, args.dataset, diff --git a/benchmark/citation/gcn.py b/benchmark/citation/gcn.py index b42b531a2fdf..a9dc840a7e16 100644 --- a/benchmark/citation/gcn.py +++ b/benchmark/citation/gcn.py @@ -20,6 +20,7 @@ parser.add_argument('--no_normalize_features', action='store_true') parser.add_argument('--inference', action='store_true') parser.add_argument('--profile', action='store_true') +parser.add_argument('--bf16', action='store_true') args = parser.parse_args() @@ -44,7 +45,8 @@ def forward(self, data): dataset = get_planetoid_dataset(args.dataset, not args.no_normalize_features) permute_masks = random_planetoid_splits if args.random_splits else None run(dataset, Net(dataset), args.runs, args.epochs, args.lr, args.weight_decay, - args.early_stopping, args.inference, args.profile, permute_masks) + args.early_stopping, args.inference, args.profile, args.bf16, + permute_masks) if args.profile: rename_profile_file('citation', GCNConv.__name__, args.dataset, diff --git a/benchmark/citation/sgc.py b/benchmark/citation/sgc.py index 633ffb208d25..07ccad52c5ae 100644 --- a/benchmark/citation/sgc.py +++ b/benchmark/citation/sgc.py @@ -19,6 +19,7 @@ parser.add_argument('--K', type=int, default=2) parser.add_argument('--inference', action='store_true') parser.add_argument('--profile', action='store_true') +parser.add_argument('--bf16', action='store_true') args = parser.parse_args() @@ -40,7 +41,8 @@ def forward(self, data): dataset = get_planetoid_dataset(args.dataset, not args.no_normalize_features) permute_masks = random_planetoid_splits if args.random_splits else None run(dataset, Net(dataset), args.runs, args.epochs, args.lr, args.weight_decay, - args.early_stopping, args.inference, args.profile, permute_masks) + args.early_stopping, args.inference, args.profile, args.bf16, + permute_masks) if args.profile: rename_profile_file('citation', SGConv.__name__, args.dataset, diff --git a/benchmark/citation/train_eval.py b/benchmark/citation/train_eval.py index d8fbab1d5c37..cddf6c75fb2d 100644 --- a/benchmark/citation/train_eval.py +++ b/benchmark/citation/train_eval.py @@ -90,7 +90,7 @@ def run_train(dataset, model, runs, epochs, lr, weight_decay, early_stopping, @torch.no_grad() -def run_inference(dataset, model, epochs, profiling, permute_masks=None, +def run_inference(dataset, model, epochs, profiling, bf16, permute_masks=None, logger=None): data = dataset[0] if permute_masks is not None: @@ -99,25 +99,34 @@ def run_inference(dataset, model, epochs, profiling, permute_masks=None, model.to(device).reset_parameters() - for epoch in range(1, epochs + 1): - if epoch == epochs: - with timeit(): + if torch.cuda.is_available(): + amp = torch.cuda.amp.autocast(enabled=False) + else: + amp = torch.cpu.amp.autocast(enabled=bf16) + if bf16: + data.x = data.x.to(torch.bfloat16) + + with amp: + for epoch in range(1, epochs + 1): + if epoch == epochs: + with timeit(): + inference(model, data) + else: inference(model, data) - else: - inference(model, data) - if profiling: - with torch_profile(): - inference(model, data) + if profiling: + with torch_profile(): + inference(model, data) def run(dataset, model, runs, epochs, lr, weight_decay, early_stopping, - inference, profiling, permute_masks=None, logger=None): + inference, profiling, bf16, permute_masks=None, logger=None): if not inference: run_train(dataset, model, runs, epochs, lr, weight_decay, early_stopping, permute_masks, logger) else: - run_inference(dataset, model, epochs, profiling, permute_masks, logger) + run_inference(dataset, model, epochs, profiling, bf16, permute_masks, + logger) def train(model, optimizer, data): diff --git a/benchmark/inference/inference_benchmark.py b/benchmark/inference/inference_benchmark.py index fb1135999bc4..ded1629a2402 100644 --- a/benchmark/inference/inference_benchmark.py +++ b/benchmark/inference/inference_benchmark.py @@ -30,6 +30,13 @@ def run(args: argparse.ArgumentParser) -> None: hetero = True if dataset_name == 'ogbn-mag' else False mask = ('paper', None) if dataset_name == 'ogbn-mag' else None degree = None + if torch.cuda.is_available(): + amp = torch.cuda.amp.autocast(enabled=False) + else: + amp = torch.cpu.amp.autocast(enabled=args.bf16) + dtype = torch.float + if args.bf16: + dtype = torch.bfloat16 inputs_channels = data[ 'paper'].num_features if dataset_name == 'ogbn-mag' \ @@ -93,27 +100,31 @@ def run(args: argparse.ArgumentParser) -> None: model = model.to(device) model.eval() - for _ in range(args.warmup): - model.inference(subgraph_loader, device, - progress_bar=True) - if args.experimental_mode: - with torch_geometric.experimental_mode(): + with amp: + for _ in range(args.warmup): + model.inference(subgraph_loader, device, + progress_bar=True, dtype=dtype) + if args.experimental_mode: + with torch_geometric.experimental_mode(): + with timeit(): + model.inference( + subgraph_loader, device, + progress_bar=True, dtype=dtype) + else: with timeit(): model.inference(subgraph_loader, device, - progress_bar=True) - else: - with timeit(): - model.inference(subgraph_loader, device, - progress_bar=True) + progress_bar=True, + dtype=dtype) - if args.profile: - with torch_profile(): - model.inference(subgraph_loader, device, - progress_bar=True) - rename_profile_file( - model_name, dataset_name, str(batch_size), - str(layers), str(hidden_channels), - str(subgraph_loader.num_neighbors)) + if args.profile: + with torch_profile(): + model.inference(subgraph_loader, device, + progress_bar=True, + dtype=dtype) + rename_profile_file( + model_name, dataset_name, str(batch_size), + str(layers), str(hidden_channels), + str(subgraph_loader.num_neighbors)) if __name__ == '__main__': @@ -145,6 +156,7 @@ def run(args: argparse.ArgumentParser) -> None: help='use experimental mode') argparser.add_argument('--warmup', default=1, type=int) argparser.add_argument('--profile', action='store_true') + argparser.add_argument('--bf16', action='store_true') args = argparser.parse_args() diff --git a/benchmark/kernel/main_performance.py b/benchmark/kernel/main_performance.py index 3eb3f7c15d95..609d5d91eaf6 100644 --- a/benchmark/kernel/main_performance.py +++ b/benchmark/kernel/main_performance.py @@ -30,6 +30,7 @@ help='The goal test accuracy') parser.add_argument('--inference', action='store_true') parser.add_argument('--profile', action='store_true') +parser.add_argument('--bf16', action='store_true') args = parser.parse_args() layers = [1, 2, 3] @@ -44,6 +45,10 @@ ] device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +if torch.cuda.is_available(): + amp = torch.cuda.amp.autocast(enabled=False) +else: + amp = torch.cpu.amp.autocast(enabled=args.bf16) def prepare_dataloader(dataset_name): @@ -108,18 +113,19 @@ def run_inference(): model = Net(dataset, num_layers, hidden).to(device) - for epoch in range(1, args.epochs + 1): - if epoch == args.epochs: - with timeit(): - inference_run(model, test_loader) - else: - inference_run(model, test_loader) - - if args.profile: - with torch_profile(): - inference_run(model, test_loader) - rename_profile_file(Net.__name__, dataset_name, - str(num_layers), str(hidden)) + with amp: + for epoch in range(1, args.epochs + 1): + if epoch == args.epochs: + with timeit(): + inference_run(model, test_loader, args.bf16) + else: + inference_run(model, test_loader, args.bf16) + + if args.profile: + with torch_profile(): + inference_run(model, test_loader, args.bf16) + rename_profile_file(Net.__name__, dataset_name, + str(num_layers), str(hidden)) if not args.inference: diff --git a/benchmark/kernel/train_eval.py b/benchmark/kernel/train_eval.py index 379023e994fa..6ec11954ba24 100644 --- a/benchmark/kernel/train_eval.py +++ b/benchmark/kernel/train_eval.py @@ -146,8 +146,10 @@ def eval_loss(model, loader): @torch.no_grad() -def inference_run(model, loader): +def inference_run(model, loader, bf16): model.eval() for data in loader: data = data.to(device) + if bf16: + data.x = data.x.to(torch.bfloat16) model(data) diff --git a/benchmark/points/edge_cnn.py b/benchmark/points/edge_cnn.py index 6aed1ae71a17..0a0bcbfeac6f 100644 --- a/benchmark/points/edge_cnn.py +++ b/benchmark/points/edge_cnn.py @@ -20,6 +20,7 @@ parser.add_argument('--weight_decay', type=float, default=0) parser.add_argument('--inference', action='store_true') parser.add_argument('--profile', action='store_true') +parser.add_argument('--bf16', action='store_true') args = parser.parse_args() @@ -59,7 +60,7 @@ def forward(self, pos, batch): model = Net(train_dataset.num_classes) run(train_dataset, test_dataset, model, args.epochs, args.batch_size, args.lr, args.lr_decay_factor, args.lr_decay_step_size, args.weight_decay, - args.inference, args.profile) + args.inference, args.profile, args.bf16) if args.profile: rename_profile_file('points', DynamicEdgeConv.__name__) diff --git a/benchmark/points/mpnn.py b/benchmark/points/mpnn.py index 8bf4633d3003..68c66c683e60 100644 --- a/benchmark/points/mpnn.py +++ b/benchmark/points/mpnn.py @@ -20,6 +20,7 @@ parser.add_argument('--weight_decay', type=float, default=0) parser.add_argument('--inference', action='store_true') parser.add_argument('--profile', action='store_true') +parser.add_argument('--bf16', action='store_true') args = parser.parse_args() @@ -76,7 +77,7 @@ def forward(self, pos, batch): model = Net(train_dataset.num_classes) run(train_dataset, test_dataset, model, args.epochs, args.batch_size, args.lr, args.lr_decay_factor, args.lr_decay_step_size, args.weight_decay, - args.inference, args.profile) + args.inference, args.profile, args.bf16) if args.profile: rename_profile_file('points', NNConv.__name__) diff --git a/benchmark/points/point_cnn.py b/benchmark/points/point_cnn.py index 59501cc9a63a..2b804e4effea 100644 --- a/benchmark/points/point_cnn.py +++ b/benchmark/points/point_cnn.py @@ -18,6 +18,7 @@ parser.add_argument('--weight_decay', type=float, default=0) parser.add_argument('--inference', action='store_true') parser.add_argument('--profile', action='store_true') +parser.add_argument('--bf16', action='store_true') args = parser.parse_args() @@ -64,7 +65,7 @@ def forward(self, pos, batch): model = Net(train_dataset.num_classes) run(train_dataset, test_dataset, model, args.epochs, args.batch_size, args.lr, args.lr_decay_factor, args.lr_decay_step_size, args.weight_decay, - args.inference, args.profile) + args.inference, args.profile, args.bf16) if args.profile: rename_profile_file('points', XConv.__name__) diff --git a/benchmark/points/point_net.py b/benchmark/points/point_net.py index 9ee7546e0f29..aee4e58044c2 100644 --- a/benchmark/points/point_net.py +++ b/benchmark/points/point_net.py @@ -20,6 +20,7 @@ parser.add_argument('--weight_decay', type=float, default=0) parser.add_argument('--inference', action='store_true') parser.add_argument('--profile', action='store_true') +parser.add_argument('--bf16', action='store_true') args = parser.parse_args() @@ -72,7 +73,7 @@ def forward(self, pos, batch): model = Net(train_dataset.num_classes) run(train_dataset, test_dataset, model, args.epochs, args.batch_size, args.lr, args.lr_decay_factor, args.lr_decay_step_size, args.weight_decay, - args.inference, args.profile) + args.inference, args.profile, args.bf16) if args.profile: rename_profile_file('points', PointConv.__name__) diff --git a/benchmark/points/spline_cnn.py b/benchmark/points/spline_cnn.py index 383195a2d871..5ab5b010ae53 100644 --- a/benchmark/points/spline_cnn.py +++ b/benchmark/points/spline_cnn.py @@ -18,6 +18,7 @@ parser.add_argument('--weight_decay', type=float, default=0) parser.add_argument('--inference', action='store_true') parser.add_argument('--profile', action='store_true') +parser.add_argument('--bf16', action='store_true') args = parser.parse_args() @@ -73,7 +74,7 @@ def forward(self, pos, batch): model = Net(train_dataset.num_classes) run(train_dataset, test_dataset, model, args.epochs, args.batch_size, args.lr, args.lr_decay_factor, args.lr_decay_step_size, args.weight_decay, - args.inference, args.profile) + args.inference, args.profile, args.bf16) if args.profile: rename_profile_file('points', SplineConv.__name__) diff --git a/benchmark/points/train_eval.py b/benchmark/points/train_eval.py index 31d666306333..9640debd3346 100644 --- a/benchmark/points/train_eval.py +++ b/benchmark/points/train_eval.py @@ -41,31 +41,37 @@ def run_train(train_dataset, test_dataset, model, epochs, batch_size, lr, @torch.no_grad() -def run_inference(test_dataset, model, epochs, batch_size, profiling): +def run_inference(test_dataset, model, epochs, batch_size, profiling, bf16): model = model.to(device) test_loader = DataLoader(test_dataset, batch_size, shuffle=False) - for epoch in range(1, epochs + 1): - print("Epoch: ", epoch) - if epoch == epochs: - with timeit(): - inference(model, test_loader, device) - else: - inference(model, test_loader, device) + if torch.cuda.is_available(): + amp = torch.cuda.amp.autocast(enabled=False) + else: + amp = torch.cpu.amp.autocast(enabled=bf16) + + with amp: + for epoch in range(1, epochs + 1): + print("Epoch: ", epoch) + if epoch == epochs: + with timeit(): + inference(model, test_loader, device, bf16) + else: + inference(model, test_loader, device, bf16) - if profiling: - with torch_profile(): - inference(model, test_loader, device) + if profiling: + with torch_profile(): + inference(model, test_loader, device, bf16) def run(train_dataset, test_dataset, model, epochs, batch_size, lr, lr_decay_factor, lr_decay_step_size, weight_decay, inference, - profiling): + profiling, bf16): if not inference: run_train(train_dataset, test_dataset, model, epochs, batch_size, lr, lr_decay_factor, lr_decay_step_size, weight_decay) else: - run_inference(test_dataset, model, epochs, batch_size, profiling) + run_inference(test_dataset, model, epochs, batch_size, profiling, bf16) def train(model, optimizer, train_loader, device): @@ -95,8 +101,11 @@ def test(model, test_loader, device): @torch.no_grad() -def inference(model, test_loader, device): +def inference(model, test_loader, device, bf16): model.eval() for data in test_loader: data = data.to(device) + if bf16: + data.pos = data.pos.to(torch.bfloat16) + model = model.to(torch.bfloat16) model(data.pos, data.batch) diff --git a/torch_geometric/nn/conv/eg_conv.py b/torch_geometric/nn/conv/eg_conv.py index 21f71c6abc8c..330081f8d4ed 100644 --- a/torch_geometric/nn/conv/eg_conv.py +++ b/torch_geometric/nn/conv/eg_conv.py @@ -131,7 +131,7 @@ def forward(self, x: Tensor, edge_index: Adj) -> Tensor: edge_index, symnorm_weight = gcn_norm( # yapf: disable edge_index, None, num_nodes=x.size(self.node_dim), improved=False, add_self_loops=self.add_self_loops, - flow=self.flow) + flow=self.flow, dtype=x.dtype) if self.cached: self._cached_edge_index = (edge_index, symnorm_weight) else: @@ -143,7 +143,7 @@ def forward(self, x: Tensor, edge_index: Adj) -> Tensor: edge_index = gcn_norm( # yapf: disable edge_index, None, num_nodes=x.size(self.node_dim), improved=False, add_self_loops=self.add_self_loops, - flow=self.flow) + flow=self.flow, dtype=x.dtype) if self.cached: self._cached_adj_t = edge_index else: diff --git a/torch_geometric/nn/conv/gcn_conv.py b/torch_geometric/nn/conv/gcn_conv.py index 08ce3470951e..eab695f4a256 100644 --- a/torch_geometric/nn/conv/gcn_conv.py +++ b/torch_geometric/nn/conv/gcn_conv.py @@ -174,7 +174,7 @@ def forward(self, x: Tensor, edge_index: Adj, if cache is None: edge_index, edge_weight = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), - self.improved, self.add_self_loops, self.flow) + self.improved, self.add_self_loops, self.flow, x.dtype) if self.cached: self._cached_edge_index = (edge_index, edge_weight) else: @@ -185,7 +185,7 @@ def forward(self, x: Tensor, edge_index: Adj, if cache is None: edge_index = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), - self.improved, self.add_self_loops, self.flow) + self.improved, self.add_self_loops, self.flow, x.dtype) if self.cached: self._cached_adj_t = edge_index else: diff --git a/torch_geometric/nn/conv/pdn_conv.py b/torch_geometric/nn/conv/pdn_conv.py index 86a5710248c2..4a64c9f683a9 100644 --- a/torch_geometric/nn/conv/pdn_conv.py +++ b/torch_geometric/nn/conv/pdn_conv.py @@ -101,10 +101,11 @@ def forward(self, x: Tensor, edge_index: Adj, edge_index, edge_attr = gcn_norm(edge_index, edge_attr, x.size(self.node_dim), False, self.add_self_loops, - self.flow) + self.flow, x.dtype) elif isinstance(edge_index, SparseTensor): edge_index = gcn_norm(edge_index, None, x.size(self.node_dim), - False, self.add_self_loops, self.flow) + False, self.add_self_loops, self.flow, + x.dtype) x = self.lin(x) diff --git a/torch_geometric/nn/models/basic_gnn.py b/torch_geometric/nn/models/basic_gnn.py index 9d2d8d717a04..dbb22c254f1f 100644 --- a/torch_geometric/nn/models/basic_gnn.py +++ b/torch_geometric/nn/models/basic_gnn.py @@ -196,7 +196,7 @@ def forward( @torch.no_grad() def inference(self, loader: NeighborLoader, device: Optional[torch.device] = None, - progress_bar: bool = False) -> Tensor: + progress_bar: bool = False, dtype=torch.float) -> Tensor: r"""Performs layer-wise inference on large-graphs using :class:`~torch_geometric.loader.NeighborLoader`. :class:`~torch_geometric.loader.NeighborLoader` should sample the the @@ -216,6 +216,7 @@ def inference(self, loader: NeighborLoader, pbar.set_description('Inference') x_all = loader.data.x.cpu() + x_all = x_all.to(dtype) loader.data.n_id = torch.arange(x_all.size(0)) for i in range(self.num_layers):