Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tutorial: update Building a Graph Convolutional Network tutorial #4060

Merged
merged 9 commits into from
Oct 11, 2019
187 changes: 93 additions & 94 deletions tutorials/frontend/build_gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
"""
Building a Graph Convolutional Network
=====================
**Author**: `Yulun Yao <https://yulunyao.io/>`_
**Author**: `Yulun Yao <https://yulunyao.io/>`_, \
`Chien-Yu Lin <https://homes.cs.washington.edu/~cyulin/>`_

This article is an introductory tutorial to build a Graph Convolutional Network (GCN) with Relay.

Expand Down Expand Up @@ -49,17 +50,25 @@
# = ((H * W)^t * A^t)^t
# = ((W^t * H^t) * A^t)^t
from tvm import relay
from tvm.contrib import graph_runtime
import tvm, dgl, scipy
import numpy as np
import networkx as nx
from collections import namedtuple
from dgl.data import load_data

def GraphConv(
layer_name,
input_dim,
output_dim,
adj,
input,
activation=None,
norm=None,
):
r"""
from tvm.contrib.download import download_testdata
import pickle

def GraphConv(layer_name,
input_dim,
output_dim,
adj,
input,
norm=None,
bias=True,
activation=None):
"""
Parameters
----------
layer_name: str
Expand All @@ -81,47 +90,48 @@ def GraphConv(
norm: relay.Expr,
Norm passed to this layer to normalize features before and after Convolution.

bias: bool
Set bias to True to add bias when doing gcn layer
tmoreau89 marked this conversation as resolved.
Show resolved Hide resolved

activation: <function relay.op.nn>,
Activation function applies to the output. e.g. relay.nn.{relu, sigmoid, log_softmax, softmax, leaky_relu}


Returns
----------
output: tvm.relay.Expr
The Output Tensor for this layer [num_nodes, output_dim]
"""
if norm is not None:
input = relay.multiply(input, norm)
weight = relay.var(layer_name + "_weight", shape=(input_dim, output_dim))
weight_transposed = relay.transpose(weight)
dense = relay.nn.dense(weight_transposed, input)

weight = relay.var(layer_name + ".weight", shape=(input_dim, output_dim))
weight_t = relay.transpose(weight)
dense = relay.nn.dense(weight_t, input)
output = relay.nn.sparse_dense(dense, adj)
output_transposed = relay.transpose(output)
output_t = relay.transpose(output)
if norm is not None:
output_transposed = relay.multiply(output_transposed, norm)
output_t = relay.multiply(output_t, norm)
if bias is True:
_bias = relay.var(layer_name + ".bias", shape=(output_dim, 1))
output_t = relay.nn.bias_add(output_t, _bias, axis=-1)
if activation is not None:
output_transposed = activation(output_transposed)
return output_transposed
output_t = activation(output_t)
return output_t

######################################################################
# Load the dataset
# ------------------
# You may substitute this part with your own dataset, here we load data from DGL to benchmark
import tvm, dgl, scipy
import numpy as np
import networkx as nx
from collections import namedtuple
from dgl.data import load_data

def load_dataset(dataset="cora"):
args = namedtuple("args", ["dataset"])
dataset = load_data(args(dataset))
data = load_data(args(dataset))

params = {}
params['infeats'] = dataset.features.astype('float32') # Only support float32 as feature for now
params['infeats'] = data.features.astype('float32') # Only support float32 as feature for now

# Remove self-loops to avoid duplicate passing of a node's feature to itself
g = dataset.graph
g = data.graph
g.remove_edges_from(g.selfloop_edges())
g.add_edges_from(zip(g.nodes, g.nodes))

Expand All @@ -136,26 +146,26 @@ def load_dataset(dataset="cora"):
params['norm'] = np.power(degs, -0.5).astype('float32')
params['norm'] = params['norm'].reshape((params['norm'].shape[0], 1))

return params
return data, params

######################################################################
# Set up model Parameters
# ------------------

r"""
"""
Parameters
----------
num_hidden: int
number of hidden layers

hidden_dim: int
hidden_dim: list of int
input dimension of hidden layers

num_classes: int
dimension of model output (Number of classes)

target: str
currently only support llvm, GPU support will be added in next few weeks
currently only support llvm

activation: <function relay.op.nn>,
Activation function applied to the output. e.g. relay.nn.{relu, sigmoid, log_softmax, softmax, leaky_relu}
Expand All @@ -164,15 +174,17 @@ def load_dataset(dataset="cora"):
Name of dataset. You can pick from ['cora', 'citeseer', 'pubmed'] or you can use your own.
"""

dataset = "cora"
data, params = load_dataset(dataset)

num_hidden = 1
hidden_dim = 16
num_classes = 7
hidden_dim = [16]
num_classes = data.num_labels
test_mask = data.test_mask
labels = data.labels
target = 'llvm'
activation = relay.nn.relu

dataset = "cora"
params = load_dataset(dataset)

# Check shape of features
assert len(params['infeats'].shape) == 2
nnodes, input_dim = params['infeats'].shape
Expand All @@ -185,8 +197,6 @@ def load_dataset(dataset="cora"):
# Put layers together
# ------------------

layers = []

# Define input features, norms, adjacency matrix
infeats = relay.var("infeats", shape=(nnodes, input_dim))

Expand All @@ -199,39 +209,29 @@ def load_dataset(dataset="cora"):
Adjacency = namedtuple('Adjacency', ['data', 'indices', 'indptr'])
adj = Adjacency(data, indices, indptr)

# Generate Input Layer
# Construct a 2-layer GCN
layers = []

layers.append(GraphConv(
layer_name= 'in',
input_dim= input_dim,
output_dim= hidden_dim,
adj = adj,
input= infeats,
activation= activation,
norm= norm,
layer_name="layers.0",
input_dim=input_dim,
output_dim=hidden_dim[0],
adj=adj,
input=infeats,
norm=norm,
activation=activation
))

# Generate Hidden Layers
for i in range(num_hidden):
layers.append(GraphConv(
layer_name= str(i),
input_dim= hidden_dim,
output_dim= hidden_dim,
adj = adj,
input= layers[-1],
activation= activation,
norm= norm,
))

# Generate Output Layer
layers.append(GraphConv(
layer_name= 'out',
input_dim= hidden_dim,
output_dim= num_classes,
adj = adj,
input= layers[-1],
activation= activation,
norm= norm,
layer_name="layers.1",
input_dim=hidden_dim[0],
output_dim=num_classes,
adj=adj,
input=layers[-1],
norm=norm,
activation=activation
))

output = layers[-1]

# Analyze free variables and generate function
Expand All @@ -240,43 +240,42 @@ def load_dataset(dataset="cora"):
######################################################################
# Compile and run
# ------------------
# We achieved 6.5x speedup for this dataset against dgl given the same model parameters.
# Output numerical difference < 10e-4 %.
#
# DGL version: https://github.com/dmlc/dgl/blob/master/examples/mxnet/gcn/gcn.py
from tvm.contrib import graph_runtime
import time

# Set up weights. You can modify this part and use your own trained weights.
params['in_weight'] = np.ones((input_dim, hidden_dim), dtype='float32')
params['out_weight'] = np.ones((hidden_dim, num_classes), dtype='float32')
for i in range(num_hidden):
params["%s_weight"%(str(i))] = np.ones((hidden_dim, hidden_dim), dtype='float32')
# Download pretrained GCN model
model_url = "https://homes.cs.washington.edu/~cyulin/media/gcn_%s.pickle"%(dataset)
tmoreau89 marked this conversation as resolved.
Show resolved Hide resolved
model_path = download_testdata(model_url, 'gcn.pickle', module='gcn_model')

with open(model_path, 'rb') as fp:
model_params = pickle.load(fp)

for i in range(num_hidden+1):
params["layers.%d.weight"%(i)] = model_params["layers.%d.weight"%(i)]
params["layers.%d.bias"%(i)] = model_params["layers.%d.bias"%(i)]

# Generate graph and library
# Build with relay
with relay.build_config(opt_level=0): # Currently only support opt_level=0
graph, lib, params = relay.build(func, target, params=params)
lib.save("lib.o")

# Generate module for llvm
# Generate graph runtime
ctx = tvm.context(target, 0)
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**params)

print("finished compiling, testing inference time cost")
totaltime = 0
for i in range(30):
st = time.time()
# One forward pass on the entire network
m.run()
end = time.time()
# Retrieve output Tensor as numpy array
outval = m.get_output(0).asnumpy()

totaltime += (end-st)

if i == 0:
print("features of first five nodes \n %s" % outval[:5])
if i == 4:
print("5 Cycle Average Forward Pass Time ", totaltime/5)
print("30 Cycle Average Forward Pass Time ", totaltime/30)
# Run the model for one time and test for accuracy
m.run()
outval = m.get_output(0).asnumpy()
pred = outval.argmax(axis=1)
accuracy = ((pred == labels) * test_mask).sum() / test_mask.sum()
print("Test accuracy {:.2%}".format(accuracy))

# Evaluate the runtime
print("Evaluate inference time cost...")
timer = m.module.time_evaluator("run", ctx, number=1, repeat=10)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Increase the number

tcost = timer()
prof_res = tcost.results
prof_res = np.array(tcost.results) * 1000 # convert to millisecond
print("Mean inference time (std dev): %.6f ms (%.6f ms)" %
(np.mean(prof_res), np.std(prof_res)))