Skip to content

Commit

Permalink
Merge pull request #4 from VibhuJawa/cugraph-dgl-dlfw-patch
Browse files Browse the repository at this point in the history
Fix dlpack boolean errors
  • Loading branch information
VibhuJawa authored Jan 9, 2023
2 parents c4afa53 + 680e751 commit 723fb5b
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions python/cugraph-dgl/examples/graphsage/node-classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

# Timing Imports
import time
# Ignore Warning
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
Expand Down Expand Up @@ -216,6 +219,16 @@ def train(args, device, g, dataset, model):
pool_allocator=True, initial_pool_size=5e9, maximum_pool_size=25e9
)

# Work around for DLFW container issues
# where dlpack conversion of boolean fails
# on the first run
# Similar to issue https://github.com/dmlc/dgl/issues/3591
if 'train_mask' in g.ndata:
g.ndata['train_mask'] = g.ndata['train_mask'].int()
if 'train_mask' in g.ndata:
g.ndata['test_mask'] = g.ndata['test_mask'].int()
if 'val_mask' in g.ndata:
g.ndata['val_mask'] = g.ndata['val_mask'].int()
g = cugraph_dgl.cugraph_storage_from_heterograph(g)
del dataset.g

Expand Down

0 comments on commit 723fb5b

Please sign in to comment.