Skip to content

Commit

Permalink
Enable bf16 support in the benchmark scripts (#5293)
Browse files Browse the repository at this point in the history
* Enable bf16 support for benchmark

* Enable bf16 support for kernel, points and inference

* Support autocast both for cpu and cuda

* Update benchmark/citation/train_eval.py

* Update benchmark/points/train_eval.py

* Update benchmark/kernel/main_performance.py

* Update CHANGELOG.md

Co-authored-by: Matthias Fey <[email protected]>
  • Loading branch information
yanbing-j and rusty1s authored Aug 30, 2022
1 parent e2b602d commit 96fbf43
Show file tree
Hide file tree
Showing 21 changed files with 132 additions and 74 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.2.0] - 2022-MM-DD
### Added
- Enabled `bf16` support in benchmark scripts ([#5293](https://github.com/pyg-team/pytorch_geometric/pull/5293))
- Added `Aggregation.set_validate_args` option to skip validation of `dim_size` ([#5290](https://github.com/pyg-team/pytorch_geometric/pull/5290))
- Added `SparseTensor` support to inference benchmark suite ([#5242](https://github.com/pyg-team/pytorch_geometric/pull/5242), [#5258](https://github.com/pyg-team/pytorch_geometric/pull/5258))
- Added experimental mode in inference benchmarks ([#5254](https://github.com/pyg-team/pytorch_geometric/pull/5254))
Expand Down
4 changes: 3 additions & 1 deletion benchmark/citation/appnp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
parser.add_argument('--alpha', type=float, default=0.1)
parser.add_argument('--inference', action='store_true')
parser.add_argument('--profile', action='store_true')
parser.add_argument('--bf16', action='store_true')
args = parser.parse_args()


Expand Down Expand Up @@ -50,7 +51,8 @@ def forward(self, data):
dataset = get_planetoid_dataset(args.dataset, not args.no_normalize_features)
permute_masks = random_planetoid_splits if args.random_splits else None
run(dataset, Net(dataset), args.runs, args.epochs, args.lr, args.weight_decay,
args.early_stopping, args.inference, args.profile, permute_masks)
args.early_stopping, args.inference, args.profile, args.bf16,
permute_masks)

if args.profile:
rename_profile_file('citation', APPNP.__name__, args.dataset,
Expand Down
4 changes: 3 additions & 1 deletion benchmark/citation/arma.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
parser.add_argument('--skip_dropout', type=float, default=0.75)
parser.add_argument('--inference', action='store_true')
parser.add_argument('--profile', action='store_true')
parser.add_argument('--bf16', action='store_true')
args = parser.parse_args()


Expand Down Expand Up @@ -52,7 +53,8 @@ def forward(self, data):
dataset = get_planetoid_dataset(args.dataset, not args.no_normalize_features)
permute_masks = random_planetoid_splits if args.random_splits else None
run(dataset, Net(dataset), args.runs, args.epochs, args.lr, args.weight_decay,
args.early_stopping, args.inference, args.profile, permute_masks)
args.early_stopping, args.inference, args.profile, args.bf16,
permute_masks)

if args.profile:
rename_profile_file('citation', ARMAConv.__name__, args.dataset,
Expand Down
4 changes: 3 additions & 1 deletion benchmark/citation/cheb.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
parser.add_argument('--num_hops', type=int, default=3)
parser.add_argument('--inference', action='store_true')
parser.add_argument('--profile', action='store_true')
parser.add_argument('--bf16', action='store_true')
args = parser.parse_args()


Expand All @@ -45,7 +46,8 @@ def forward(self, data):
dataset = get_planetoid_dataset(args.dataset, not args.no_normalize_features)
permute_masks = random_planetoid_splits if args.random_splits else None
run(dataset, Net(dataset), args.runs, args.epochs, args.lr, args.weight_decay,
args.early_stopping, args.inference, args.profile, permute_masks)
args.early_stopping, args.inference, args.profile, args.bf16,
permute_masks)

if args.profile:
rename_profile_file('citation', ChebConv.__name__, args.dataset,
Expand Down
4 changes: 3 additions & 1 deletion benchmark/citation/gat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
parser.add_argument('--output_heads', type=int, default=1)
parser.add_argument('--inference', action='store_true')
parser.add_argument('--profile', action='store_true')
parser.add_argument('--bf16', action='store_true')
args = parser.parse_args()


Expand Down Expand Up @@ -50,7 +51,8 @@ def forward(self, data):
dataset = get_planetoid_dataset(args.dataset, not args.no_normalize_features)
permute_masks = random_planetoid_splits if args.random_splits else None
run(dataset, Net(dataset), args.runs, args.epochs, args.lr, args.weight_decay,
args.early_stopping, args.inference, args.profile, permute_masks)
args.early_stopping, args.inference, args.profile, args.bf16,
permute_masks)

if args.profile:
rename_profile_file('citation', GATConv.__name__, args.dataset,
Expand Down
4 changes: 3 additions & 1 deletion benchmark/citation/gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
parser.add_argument('--no_normalize_features', action='store_true')
parser.add_argument('--inference', action='store_true')
parser.add_argument('--profile', action='store_true')
parser.add_argument('--bf16', action='store_true')
args = parser.parse_args()


Expand All @@ -44,7 +45,8 @@ def forward(self, data):
dataset = get_planetoid_dataset(args.dataset, not args.no_normalize_features)
permute_masks = random_planetoid_splits if args.random_splits else None
run(dataset, Net(dataset), args.runs, args.epochs, args.lr, args.weight_decay,
args.early_stopping, args.inference, args.profile, permute_masks)
args.early_stopping, args.inference, args.profile, args.bf16,
permute_masks)

if args.profile:
rename_profile_file('citation', GCNConv.__name__, args.dataset,
Expand Down
4 changes: 3 additions & 1 deletion benchmark/citation/sgc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
parser.add_argument('--K', type=int, default=2)
parser.add_argument('--inference', action='store_true')
parser.add_argument('--profile', action='store_true')
parser.add_argument('--bf16', action='store_true')
args = parser.parse_args()


Expand All @@ -40,7 +41,8 @@ def forward(self, data):
dataset = get_planetoid_dataset(args.dataset, not args.no_normalize_features)
permute_masks = random_planetoid_splits if args.random_splits else None
run(dataset, Net(dataset), args.runs, args.epochs, args.lr, args.weight_decay,
args.early_stopping, args.inference, args.profile, permute_masks)
args.early_stopping, args.inference, args.profile, args.bf16,
permute_masks)

if args.profile:
rename_profile_file('citation', SGConv.__name__, args.dataset,
Expand Down
31 changes: 20 additions & 11 deletions benchmark/citation/train_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def run_train(dataset, model, runs, epochs, lr, weight_decay, early_stopping,


@torch.no_grad()
def run_inference(dataset, model, epochs, profiling, permute_masks=None,
def run_inference(dataset, model, epochs, profiling, bf16, permute_masks=None,
logger=None):
data = dataset[0]
if permute_masks is not None:
Expand All @@ -99,25 +99,34 @@ def run_inference(dataset, model, epochs, profiling, permute_masks=None,

model.to(device).reset_parameters()

for epoch in range(1, epochs + 1):
if epoch == epochs:
with timeit():
if torch.cuda.is_available():
amp = torch.cuda.amp.autocast(enabled=False)
else:
amp = torch.cpu.amp.autocast(enabled=bf16)
if bf16:
data.x = data.x.to(torch.bfloat16)

with amp:
for epoch in range(1, epochs + 1):
if epoch == epochs:
with timeit():
inference(model, data)
else:
inference(model, data)
else:
inference(model, data)

if profiling:
with torch_profile():
inference(model, data)
if profiling:
with torch_profile():
inference(model, data)


def run(dataset, model, runs, epochs, lr, weight_decay, early_stopping,
inference, profiling, permute_masks=None, logger=None):
inference, profiling, bf16, permute_masks=None, logger=None):
if not inference:
run_train(dataset, model, runs, epochs, lr, weight_decay,
early_stopping, permute_masks, logger)
else:
run_inference(dataset, model, epochs, profiling, permute_masks, logger)
run_inference(dataset, model, epochs, profiling, bf16, permute_masks,
logger)


def train(model, optimizer, data):
Expand Down
48 changes: 30 additions & 18 deletions benchmark/inference/inference_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ def run(args: argparse.ArgumentParser) -> None:
hetero = True if dataset_name == 'ogbn-mag' else False
mask = ('paper', None) if dataset_name == 'ogbn-mag' else None
degree = None
if torch.cuda.is_available():
amp = torch.cuda.amp.autocast(enabled=False)
else:
amp = torch.cpu.amp.autocast(enabled=args.bf16)
dtype = torch.float
if args.bf16:
dtype = torch.bfloat16

inputs_channels = data[
'paper'].num_features if dataset_name == 'ogbn-mag' \
Expand Down Expand Up @@ -93,27 +100,31 @@ def run(args: argparse.ArgumentParser) -> None:
model = model.to(device)
model.eval()

for _ in range(args.warmup):
model.inference(subgraph_loader, device,
progress_bar=True)
if args.experimental_mode:
with torch_geometric.experimental_mode():
with amp:
for _ in range(args.warmup):
model.inference(subgraph_loader, device,
progress_bar=True, dtype=dtype)
if args.experimental_mode:
with torch_geometric.experimental_mode():
with timeit():
model.inference(
subgraph_loader, device,
progress_bar=True, dtype=dtype)
else:
with timeit():
model.inference(subgraph_loader, device,
progress_bar=True)
else:
with timeit():
model.inference(subgraph_loader, device,
progress_bar=True)
progress_bar=True,
dtype=dtype)

if args.profile:
with torch_profile():
model.inference(subgraph_loader, device,
progress_bar=True)
rename_profile_file(
model_name, dataset_name, str(batch_size),
str(layers), str(hidden_channels),
str(subgraph_loader.num_neighbors))
if args.profile:
with torch_profile():
model.inference(subgraph_loader, device,
progress_bar=True,
dtype=dtype)
rename_profile_file(
model_name, dataset_name, str(batch_size),
str(layers), str(hidden_channels),
str(subgraph_loader.num_neighbors))


if __name__ == '__main__':
Expand Down Expand Up @@ -145,6 +156,7 @@ def run(args: argparse.ArgumentParser) -> None:
help='use experimental mode')
argparser.add_argument('--warmup', default=1, type=int)
argparser.add_argument('--profile', action='store_true')
argparser.add_argument('--bf16', action='store_true')

args = argparser.parse_args()

Expand Down
30 changes: 18 additions & 12 deletions benchmark/kernel/main_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
help='The goal test accuracy')
parser.add_argument('--inference', action='store_true')
parser.add_argument('--profile', action='store_true')
parser.add_argument('--bf16', action='store_true')
args = parser.parse_args()

layers = [1, 2, 3]
Expand All @@ -44,6 +45,10 @@
]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
amp = torch.cuda.amp.autocast(enabled=False)
else:
amp = torch.cpu.amp.autocast(enabled=args.bf16)


def prepare_dataloader(dataset_name):
Expand Down Expand Up @@ -108,18 +113,19 @@ def run_inference():

model = Net(dataset, num_layers, hidden).to(device)

for epoch in range(1, args.epochs + 1):
if epoch == args.epochs:
with timeit():
inference_run(model, test_loader)
else:
inference_run(model, test_loader)

if args.profile:
with torch_profile():
inference_run(model, test_loader)
rename_profile_file(Net.__name__, dataset_name,
str(num_layers), str(hidden))
with amp:
for epoch in range(1, args.epochs + 1):
if epoch == args.epochs:
with timeit():
inference_run(model, test_loader, args.bf16)
else:
inference_run(model, test_loader, args.bf16)

if args.profile:
with torch_profile():
inference_run(model, test_loader, args.bf16)
rename_profile_file(Net.__name__, dataset_name,
str(num_layers), str(hidden))


if not args.inference:
Expand Down
4 changes: 3 additions & 1 deletion benchmark/kernel/train_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,10 @@ def eval_loss(model, loader):


@torch.no_grad()
def inference_run(model, loader):
def inference_run(model, loader, bf16):
model.eval()
for data in loader:
data = data.to(device)
if bf16:
data.x = data.x.to(torch.bfloat16)
model(data)
3 changes: 2 additions & 1 deletion benchmark/points/edge_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
parser.add_argument('--weight_decay', type=float, default=0)
parser.add_argument('--inference', action='store_true')
parser.add_argument('--profile', action='store_true')
parser.add_argument('--bf16', action='store_true')
args = parser.parse_args()


Expand Down Expand Up @@ -59,7 +60,7 @@ def forward(self, pos, batch):
model = Net(train_dataset.num_classes)
run(train_dataset, test_dataset, model, args.epochs, args.batch_size, args.lr,
args.lr_decay_factor, args.lr_decay_step_size, args.weight_decay,
args.inference, args.profile)
args.inference, args.profile, args.bf16)

if args.profile:
rename_profile_file('points', DynamicEdgeConv.__name__)
3 changes: 2 additions & 1 deletion benchmark/points/mpnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
parser.add_argument('--weight_decay', type=float, default=0)
parser.add_argument('--inference', action='store_true')
parser.add_argument('--profile', action='store_true')
parser.add_argument('--bf16', action='store_true')
args = parser.parse_args()


Expand Down Expand Up @@ -76,7 +77,7 @@ def forward(self, pos, batch):
model = Net(train_dataset.num_classes)
run(train_dataset, test_dataset, model, args.epochs, args.batch_size, args.lr,
args.lr_decay_factor, args.lr_decay_step_size, args.weight_decay,
args.inference, args.profile)
args.inference, args.profile, args.bf16)

if args.profile:
rename_profile_file('points', NNConv.__name__)
3 changes: 2 additions & 1 deletion benchmark/points/point_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
parser.add_argument('--weight_decay', type=float, default=0)
parser.add_argument('--inference', action='store_true')
parser.add_argument('--profile', action='store_true')
parser.add_argument('--bf16', action='store_true')
args = parser.parse_args()


Expand Down Expand Up @@ -64,7 +65,7 @@ def forward(self, pos, batch):
model = Net(train_dataset.num_classes)
run(train_dataset, test_dataset, model, args.epochs, args.batch_size, args.lr,
args.lr_decay_factor, args.lr_decay_step_size, args.weight_decay,
args.inference, args.profile)
args.inference, args.profile, args.bf16)

if args.profile:
rename_profile_file('points', XConv.__name__)
3 changes: 2 additions & 1 deletion benchmark/points/point_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
parser.add_argument('--weight_decay', type=float, default=0)
parser.add_argument('--inference', action='store_true')
parser.add_argument('--profile', action='store_true')
parser.add_argument('--bf16', action='store_true')
args = parser.parse_args()


Expand Down Expand Up @@ -72,7 +73,7 @@ def forward(self, pos, batch):
model = Net(train_dataset.num_classes)
run(train_dataset, test_dataset, model, args.epochs, args.batch_size, args.lr,
args.lr_decay_factor, args.lr_decay_step_size, args.weight_decay,
args.inference, args.profile)
args.inference, args.profile, args.bf16)

if args.profile:
rename_profile_file('points', PointConv.__name__)
3 changes: 2 additions & 1 deletion benchmark/points/spline_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
parser.add_argument('--weight_decay', type=float, default=0)
parser.add_argument('--inference', action='store_true')
parser.add_argument('--profile', action='store_true')
parser.add_argument('--bf16', action='store_true')
args = parser.parse_args()


Expand Down Expand Up @@ -73,7 +74,7 @@ def forward(self, pos, batch):
model = Net(train_dataset.num_classes)
run(train_dataset, test_dataset, model, args.epochs, args.batch_size, args.lr,
args.lr_decay_factor, args.lr_decay_step_size, args.weight_decay,
args.inference, args.profile)
args.inference, args.profile, args.bf16)

if args.profile:
rename_profile_file('points', SplineConv.__name__)
Loading

0 comments on commit 96fbf43

Please sign in to comment.