From e428bd4de3a79657917518d8e9ffdfcf679adce6 Mon Sep 17 00:00:00 2001 From: Chien-Yu Lin Date: Fri, 4 Oct 2019 15:16:40 -0700 Subject: [PATCH 1/9] update build_gcn.py tutorial updates * support bias in GCN layer * download pretrained gcn model * verify model accuracy * use time_evaluator to measure runtime --- tutorials/frontend/build_gcn.py | 184 ++++++++++++++++---------------- 1 file changed, 93 insertions(+), 91 deletions(-) diff --git a/tutorials/frontend/build_gcn.py b/tutorials/frontend/build_gcn.py index acbd27f9c849..0ec6c9154b14 100644 --- a/tutorials/frontend/build_gcn.py +++ b/tutorials/frontend/build_gcn.py @@ -17,7 +17,8 @@ """ Building a Graph Convolutional Network ===================== -**Author**: `Yulun Yao `_ +**Author**: `Yulun Yao `_, \ + `Chien-Yu Lin `_ This article is an introductory tutorial to build a Graph Convolutional Network (GCN) with Relay. @@ -49,17 +50,25 @@ # = ((H * W)^t * A^t)^t # = ((W^t * H^t) * A^t)^t from tvm import relay +from tvm.contrib import graph_runtime +import tvm, dgl, scipy +import numpy as np +import networkx as nx +from collections import namedtuple +from dgl.data import load_data -def GraphConv( - layer_name, - input_dim, - output_dim, - adj, - input, - activation=None, - norm=None, - ): - r""" +from tvm.contrib.download import download_testdata +import pickle + +def GraphConv(layer_name, + input_dim, + output_dim, + adj, + input, + norm=None, + bias=True, + activation=None): + """ Parameters ---------- layer_name: str @@ -90,38 +99,39 @@ def GraphConv( output: tvm.relay.Expr The Output Tensor for this layer [num_nodes, output_dim] """ + if bias is True: + _bias = relay.var(layer_name + ".bias", shape=(output_dim, 1)) + if norm is not None: input = relay.multiply(input, norm) - weight = relay.var(layer_name + "_weight", shape=(input_dim, output_dim)) - weight_transposed = relay.transpose(weight) - dense = relay.nn.dense(weight_transposed, input) + + weight = relay.var(layer_name + ".weight", shape=(input_dim, output_dim)) + weight_t = relay.transpose(weight) + dense = relay.nn.dense(weight_t, input) output = relay.nn.sparse_dense(dense, adj) - output_transposed = relay.transpose(output) + output_t = relay.transpose(output) if norm is not None: - output_transposed = relay.multiply(output_transposed, norm) + output_t = relay.multiply(output_t, norm) + if bias is True: + output_t = relay.nn.bias_add(output_t, _bias, axis=-1) if activation is not None: - output_transposed = activation(output_transposed) - return output_transposed + output_t = activation(output_t) + return output_t ###################################################################### # Load the dataset # ------------------ # You may substitute this part with your own dataset, here we load data from DGL to benchmark -import tvm, dgl, scipy -import numpy as np -import networkx as nx -from collections import namedtuple -from dgl.data import load_data def load_dataset(dataset="cora"): args = namedtuple("args", ["dataset"]) - dataset = load_data(args(dataset)) + data = load_data(args(dataset)) params = {} - params['infeats'] = dataset.features.astype('float32') # Only support float32 as feature for now + params['infeats'] = data.features.astype('float32') # Only support float32 as feature for now # Remove self-loops to avoid duplicate passing of a node's feature to itself - g = dataset.graph + g = data.graph g.remove_edges_from(g.selfloop_edges()) g.add_edges_from(zip(g.nodes, g.nodes)) @@ -136,13 +146,13 @@ def load_dataset(dataset="cora"): params['norm'] = np.power(degs, -0.5).astype('float32') params['norm'] = params['norm'].reshape((params['norm'].shape[0], 1)) - return params + return data, params ###################################################################### # Set up model Parameters # ------------------ -r""" +""" Parameters ---------- num_hidden: int @@ -164,15 +174,18 @@ def load_dataset(dataset="cora"): Name of dataset. You can pick from ['cora', 'citeseer', 'pubmed'] or you can use your own. """ +dataset = "cora" +data, params = load_dataset(dataset) + num_hidden = 1 -hidden_dim = 16 -num_classes = 7 +hidden_dim = [16] +num_classes = data.num_labels +bias = True +test_mask = data.test_mask +labels = data.labels target = 'llvm' activation = relay.nn.relu -dataset = "cora" -params = load_dataset(dataset) - # Check shape of features assert len(params['infeats'].shape) == 2 nnodes, input_dim = params['infeats'].shape @@ -185,8 +198,6 @@ def load_dataset(dataset="cora"): # Put layers together # ------------------ -layers = [] - # Define input features, norms, adjacency matrix infeats = relay.var("infeats", shape=(nnodes, input_dim)) @@ -199,39 +210,31 @@ def load_dataset(dataset="cora"): Adjacency = namedtuple('Adjacency', ['data', 'indices', 'indptr']) adj = Adjacency(data, indices, indptr) -# Generate Input Layer +# Construct a 2-layer GCN +layers = [] + layers.append(GraphConv( - layer_name= 'in', - input_dim= input_dim, - output_dim= hidden_dim, - adj = adj, - input= infeats, - activation= activation, - norm= norm, + layer_name="layers.0", + input_dim=input_dim, + output_dim=hidden_dim[0], + adj=adj, + input=infeats, + norm=norm, + bias=bias, + activation=activation )) -# Generate Hidden Layers -for i in range(num_hidden): - layers.append(GraphConv( - layer_name= str(i), - input_dim= hidden_dim, - output_dim= hidden_dim, - adj = adj, - input= layers[-1], - activation= activation, - norm= norm, - )) - -# Generate Output Layer layers.append(GraphConv( - layer_name= 'out', - input_dim= hidden_dim, - output_dim= num_classes, - adj = adj, - input= layers[-1], - activation= activation, - norm= norm, + layer_name="layers.1", + input_dim=hidden_dim[0], + output_dim=num_classes, + adj=adj, + input=layers[-1], + norm=norm, + bias=bias, + activation=activation )) + output = layers[-1] # Analyze free variables and generate function @@ -240,43 +243,42 @@ def load_dataset(dataset="cora"): ###################################################################### # Compile and run # ------------------ -# We achieved 6.5x speedup for this dataset against dgl given the same model parameters. -# Output numerical difference < 10e-4 %. # # DGL version: https://github.com/dmlc/dgl/blob/master/examples/mxnet/gcn/gcn.py -from tvm.contrib import graph_runtime -import time -# Set up weights. You can modify this part and use your own trained weights. -params['in_weight'] = np.ones((input_dim, hidden_dim), dtype='float32') -params['out_weight'] = np.ones((hidden_dim, num_classes), dtype='float32') -for i in range(num_hidden): - params["%s_weight"%(str(i))] = np.ones((hidden_dim, hidden_dim), dtype='float32') +# Download pretrained GCN model +model_url = "https://homes.cs.washington.edu/~cyulin/media/gcn_%s.pickle"%(dataset) +model_path = download_testdata(model_url, 'gcn.pickle', module='gcn_model') + +with open(model_path, 'rb') as fp: + model_params = pickle.load(fp) + +for i in range(num_hidden+1): + params["layers.%d.weight"%(i)] = model_params["layers.%d.weight"%(i)] + params["layers.%d.bias"%(i)] = model_params["layers.%d.bias"%(i)] -# Generate graph and library +# Build with relay with relay.build_config(opt_level=0): # Currently only support opt_level=0 graph, lib, params = relay.build(func, target, params=params) lib.save("lib.o") -# Generate module for llvm +# Generate graph runtime ctx = tvm.context(target, 0) m = graph_runtime.create(graph, lib, ctx) m.set_input(**params) -print("finished compiling, testing inference time cost") -totaltime = 0 -for i in range(30): - st = time.time() - # One forward pass on the entire network - m.run() - end = time.time() - # Retrieve output Tensor as numpy array - outval = m.get_output(0).asnumpy() - - totaltime += (end-st) - - if i == 0: - print("features of first five nodes \n %s" % outval[:5]) - if i == 4: - print("5 Cycle Average Forward Pass Time ", totaltime/5) -print("30 Cycle Average Forward Pass Time ", totaltime/30) +# Run the model for one time and test for accuracy +m.run() +outval = m.get_output(0).asnumpy() +pred = outval.argmax(axis=1) +accuracy = ((pred == labels) * test_mask).sum() / test_mask.sum() +print("Test accuracy {:.2%}".format(accuracy)) + +# Evaluate the runtime +print("Evaluate inference time cost...") +timer = m.module.time_evaluator("run", ctx, number=1, repeat=10) +tcost = timer() +prof_res = tcost.results +prof_res = np.array(tcost.results) * 1000 # convert to millisecond +print("Mean inference time (std dev): %.6f ms (%.6f ms)" % + (np.mean(prof_res), np.std(prof_res))) From e5f9bad9a8ede53e8a69861ccf69720f1fd548e7 Mon Sep 17 00:00:00 2001 From: Chien-Yu Lin Date: Fri, 4 Oct 2019 15:41:41 -0700 Subject: [PATCH 2/9] fix adding bias in gcn layer --- tutorials/frontend/build_gcn.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tutorials/frontend/build_gcn.py b/tutorials/frontend/build_gcn.py index 0ec6c9154b14..d8796949c553 100644 --- a/tutorials/frontend/build_gcn.py +++ b/tutorials/frontend/build_gcn.py @@ -90,18 +90,17 @@ def GraphConv(layer_name, norm: relay.Expr, Norm passed to this layer to normalize features before and after Convolution. + bias: bool + Set bias to True to add bias when doing gcn layer + activation: , Activation function applies to the output. e.g. relay.nn.{relu, sigmoid, log_softmax, softmax, leaky_relu} - Returns ---------- output: tvm.relay.Expr The Output Tensor for this layer [num_nodes, output_dim] """ - if bias is True: - _bias = relay.var(layer_name + ".bias", shape=(output_dim, 1)) - if norm is not None: input = relay.multiply(input, norm) @@ -113,6 +112,7 @@ def GraphConv(layer_name, if norm is not None: output_t = relay.multiply(output_t, norm) if bias is True: + _bias = relay.var(layer_name + ".bias", shape=(output_dim, 1)) output_t = relay.nn.bias_add(output_t, _bias, axis=-1) if activation is not None: output_t = activation(output_t) @@ -273,6 +273,7 @@ def load_dataset(dataset="cora"): pred = outval.argmax(axis=1) accuracy = ((pred == labels) * test_mask).sum() / test_mask.sum() print("Test accuracy {:.2%}".format(accuracy)) +print(outval[:5]) # Evaluate the runtime print("Evaluate inference time cost...") From 6a78998d9310025c85129be7c460da2124f924a3 Mon Sep 17 00:00:00 2001 From: Chien-Yu Lin Date: Fri, 4 Oct 2019 15:47:24 -0700 Subject: [PATCH 3/9] remove printing output --- tutorials/frontend/build_gcn.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tutorials/frontend/build_gcn.py b/tutorials/frontend/build_gcn.py index d8796949c553..9d6c18e1e1c3 100644 --- a/tutorials/frontend/build_gcn.py +++ b/tutorials/frontend/build_gcn.py @@ -158,14 +158,14 @@ def load_dataset(dataset="cora"): num_hidden: int number of hidden layers -hidden_dim: int +hidden_dim: list of int input dimension of hidden layers num_classes: int dimension of model output (Number of classes) target: str - currently only support llvm, GPU support will be added in next few weeks + currently only support llvm activation: , Activation function applied to the output. e.g. relay.nn.{relu, sigmoid, log_softmax, softmax, leaky_relu} @@ -180,7 +180,6 @@ def load_dataset(dataset="cora"): num_hidden = 1 hidden_dim = [16] num_classes = data.num_labels -bias = True test_mask = data.test_mask labels = data.labels target = 'llvm' @@ -273,7 +272,6 @@ def load_dataset(dataset="cora"): pred = outval.argmax(axis=1) accuracy = ((pred == labels) * test_mask).sum() / test_mask.sum() print("Test accuracy {:.2%}".format(accuracy)) -print(outval[:5]) # Evaluate the runtime print("Evaluate inference time cost...") From c39bb0936afed7ac3fe8a01b30f546da66fbed70 Mon Sep 17 00:00:00 2001 From: Chien-Yu Lin Date: Sat, 5 Oct 2019 11:48:04 -0700 Subject: [PATCH 4/9] fix small bug --- tutorials/frontend/build_gcn.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tutorials/frontend/build_gcn.py b/tutorials/frontend/build_gcn.py index 9d6c18e1e1c3..d3c35dec78f9 100644 --- a/tutorials/frontend/build_gcn.py +++ b/tutorials/frontend/build_gcn.py @@ -219,7 +219,6 @@ def load_dataset(dataset="cora"): adj=adj, input=infeats, norm=norm, - bias=bias, activation=activation )) @@ -230,7 +229,6 @@ def load_dataset(dataset="cora"): adj=adj, input=layers[-1], norm=norm, - bias=bias, activation=activation )) From d7e450e2216782eed0707bb117e544e67bcf7372 Mon Sep 17 00:00:00 2001 From: Chien-Yu Lin Date: Tue, 8 Oct 2019 13:20:38 -0700 Subject: [PATCH 5/9] add DGL-PyTorch comparison into the build_gcn tutorial --- tutorials/frontend/build_gcn.py | 269 +++++++++++++++++++------------- 1 file changed, 161 insertions(+), 108 deletions(-) diff --git a/tutorials/frontend/build_gcn.py b/tutorials/frontend/build_gcn.py index d3c35dec78f9..1817d64f9274 100644 --- a/tutorials/frontend/build_gcn.py +++ b/tutorials/frontend/build_gcn.py @@ -28,14 +28,132 @@ We directly load the dataset from DGL library to do the apples to apples comparison against DGL. -Please refer to DGL tutorial on installation at +Please refer to DGL doc for DGL installation at https://docs.dgl.ai/install/index.html -GPU support and more sparse operators will soon follow. +and refer to PyTorch guide for PyTorch installation at +https://pytorch.org/get-started/locally/ """ + +###################################################################### +# Define GCN in DGL with PyTorch backend +# ------------------ +# +# DGL example: https://github.com/dmlc/dgl/tree/master/examples/pytorch/gcn +# This part reuses the code from the above example + +import torch +import torch.nn as nn +import torch.nn.functional as F +from dgl.nn.pytorch import GraphConv + +class GCN(nn.Module): + def __init__(self, + g, + n_infeat, + n_hidden, + n_classes, + n_layers, + activation): + super(GCN, self).__init__() + self.g = g + self.layers = nn.ModuleList() + self.layers.append(GraphConv(n_infeat, n_hidden, activation=activation)) + for i in range(n_layers - 1): + self.layers.append(GraphConv(n_hidden, n_hidden, activation=activation)) + self.layers.append(GraphConv(n_hidden, n_classes)) + + def forward(self, features): + h = features + for i, layer in enumerate(self.layers): + h = layer(self.g, h) + return h + + +###################################################################### +# Load the dataset with DGL utilities +# ------------------ +# You may substitute this part with your own dataset, here we load data from DGL + +from dgl import DGLGraph +from dgl.data import load_data +from collections import namedtuple + +def load_dataset(dataset="cora"): + args = namedtuple("args", ["dataset"]) + data = load_data(args(dataset)) + + # Remove self-loops to avoid duplicate passing of a node's feature to itself + g = data.graph + g.remove_edges_from(g.selfloop_edges()) + g.add_edges_from(zip(g.nodes, g.nodes)) + + return g, data + +###################################################################### +# Set up model Parameters +# ------------------ +""" +Parameters +---------- +dataset: str + Name of dataset. You can choose from ['cora', 'citeseer', 'pubmed']. + +num_layer: int + number of hidden layers + +num_hidden: int + number of the hidden units in the hidden layer + +infeat_dim: int + dimension of the input features + +num_classes: int + dimension of model output (Number of classes) +""" +dataset = "cora" + +g, data = load_dataset(dataset) + +num_layers = 1 +num_hidden = 16 +infeat_dim = data.features.shape[1] +num_classes = data.num_labels + +###################################################################### +# Set up the DGL-PyTorch model and get the golden results +# ------------------ +# +# The weights are trained with https://github.com/dmlc/dgl/blob/master/examples/pytorch/gcn/train.py +from tvm.contrib.download import download_testdata + + +features = torch.FloatTensor(data.features) +dgl_g = DGLGraph(g) + +torch_model = GCN(dgl_g, + infeat_dim, + num_hidden, + num_classes, + num_layers, + F.relu) + +# Download the pretrained weights +model_url = "https://homes.cs.washington.edu/~cyulin/media/gcn_%s.torch"%(dataset) +model_path = download_testdata(model_url, "gcn_%s.pickle"%(dataset), module='gcn_model') + +# Load the weights into the model +torch_model.load_state_dict(torch.load(model_path)) + +# Run the DGL model +torch_model.eval() +with torch.no_grad(): + logits_torch = torch_model(features) +print("Print the first five outputs from DGL-PyTorch execution\n", logits_torch[:5]) + ###################################################################### -# Define Graph Convolution Layer +# Define Graph Convolution Layer in Relay # ---------------------------- # To run GCN on TVM, we first need to implement Graph Convolution Layer. # @@ -51,14 +169,7 @@ # = ((W^t * H^t) * A^t)^t from tvm import relay from tvm.contrib import graph_runtime -import tvm, dgl, scipy -import numpy as np -import networkx as nx -from collections import namedtuple -from dgl.data import load_data - -from tvm.contrib.download import download_testdata -import pickle +import tvm def GraphConv(layer_name, input_dim, @@ -91,7 +202,7 @@ def GraphConv(layer_name, Norm passed to this layer to normalize features before and after Convolution. bias: bool - Set bias to True to add bias when doing gcn layer + Set bias to True to add bias when doing GCN layer activation: , Activation function applies to the output. e.g. relay.nn.{relu, sigmoid, log_softmax, softmax, leaky_relu} @@ -119,25 +230,19 @@ def GraphConv(layer_name, return output_t ###################################################################### -# Load the dataset +# Prepare the parameters needed in the GraphConv layers # ------------------ -# You may substitute this part with your own dataset, here we load data from DGL to benchmark - -def load_dataset(dataset="cora"): - args = namedtuple("args", ["dataset"]) - data = load_data(args(dataset)) +# +import numpy as np +import networkx as nx +def prepare_params(g, data): params = {} params['infeats'] = data.features.astype('float32') # Only support float32 as feature for now - # Remove self-loops to avoid duplicate passing of a node's feature to itself - g = data.graph - g.remove_edges_from(g.selfloop_edges()) - g.add_edges_from(zip(g.nodes, g.nodes)) - # Generate adjacency matrix adjacency = nx.to_scipy_sparse_matrix(g) - params['data'] = adjacency.data.astype('float32') + params['g_data'] = adjacency.data.astype('float32') params['indices'] = adjacency.indices.astype('int32') params['indptr'] = adjacency.indptr.astype('int32') @@ -146,136 +251,84 @@ def load_dataset(dataset="cora"): params['norm'] = np.power(degs, -0.5).astype('float32') params['norm'] = params['norm'].reshape((params['norm'].shape[0], 1)) - return data, params + return params -###################################################################### -# Set up model Parameters -# ------------------ +params = prepare_params(g, data) -""" -Parameters ----------- -num_hidden: int - number of hidden layers - -hidden_dim: list of int - input dimension of hidden layers - -num_classes: int - dimension of model output (Number of classes) - -target: str - currently only support llvm - -activation: , - Activation function applied to the output. e.g. relay.nn.{relu, sigmoid, log_softmax, softmax, leaky_relu} - -dataset: str - Name of dataset. You can pick from ['cora', 'citeseer', 'pubmed'] or you can use your own. -""" - -dataset = "cora" -data, params = load_dataset(dataset) - -num_hidden = 1 -hidden_dim = [16] -num_classes = data.num_labels -test_mask = data.test_mask -labels = data.labels -target = 'llvm' -activation = relay.nn.relu - -# Check shape of features +# Check shape of features and the validity of adjacency matrix assert len(params['infeats'].shape) == 2 -nnodes, input_dim = params['infeats'].shape - -# Check validity of adjacency matrix -assert params['data'] is not None and params['indices'] is not None and params['indptr'] is not None -assert nnodes == params['indptr'].shape[0] - 1 +assert params['g_data'] is not None and params['indices'] is not None and params['indptr'] is not None +assert params['infeats'].shape[0] == params['indptr'].shape[0] - 1 ###################################################################### # Put layers together # ------------------ -# Define input features, norms, adjacency matrix -infeats = relay.var("infeats", shape=(nnodes, input_dim)) - +# Define input features, norms, adjacency matrix in Relay +infeats = relay.var("infeats", shape=data.features.shape) norm = relay.Constant(tvm.nd.array(params['norm'])) - -data = relay.Constant(tvm.nd.array(params['data'])) +g_data = relay.Constant(tvm.nd.array(params['g_data'])) indices = relay.Constant(tvm.nd.array(params['indices'])) indptr = relay.Constant(tvm.nd.array(params['indptr'])) Adjacency = namedtuple('Adjacency', ['data', 'indices', 'indptr']) -adj = Adjacency(data, indices, indptr) +adj = Adjacency(g_data, indices, indptr) -# Construct a 2-layer GCN +# Construct the 2-layer GCN layers = [] - layers.append(GraphConv( layer_name="layers.0", - input_dim=input_dim, - output_dim=hidden_dim[0], + input_dim=infeat_dim, + output_dim=num_hidden, adj=adj, input=infeats, norm=norm, - activation=activation + activation=relay.nn.relu )) - layers.append(GraphConv( layer_name="layers.1", - input_dim=hidden_dim[0], + input_dim=num_hidden, output_dim=num_classes, adj=adj, input=layers[-1], norm=norm, - activation=activation + activation=None )) +# Analyze free variables and generate Relay function output = layers[-1] - -# Analyze free variables and generate function func = relay.Function(relay.analysis.free_vars(output), output) ###################################################################### -# Compile and run +# Compile and run with TVM # ------------------ # -# DGL version: https://github.com/dmlc/dgl/blob/master/examples/mxnet/gcn/gcn.py -# Download pretrained GCN model -model_url = "https://homes.cs.washington.edu/~cyulin/media/gcn_%s.pickle"%(dataset) -model_path = download_testdata(model_url, 'gcn.pickle', module='gcn_model') +# Export the weigths from PyTorch model to Python Dict +model_params = {} +for param_tensor in torch_model.state_dict(): + model_params[param_tensor] = torch_model.state_dict()[param_tensor].numpy() -with open(model_path, 'rb') as fp: - model_params = pickle.load(fp) - -for i in range(num_hidden+1): +for i in range(num_layers+1): params["layers.%d.weight"%(i)] = model_params["layers.%d.weight"%(i)] params["layers.%d.bias"%(i)] = model_params["layers.%d.bias"%(i)] -# Build with relay +# Set the TVM build target +target = 'llvm' # Currently only support `llvm` as target + +# Build with Relay with relay.build_config(opt_level=0): # Currently only support opt_level=0 graph, lib, params = relay.build(func, target, params=params) - lib.save("lib.o") # Generate graph runtime ctx = tvm.context(target, 0) m = graph_runtime.create(graph, lib, ctx) m.set_input(**params) -# Run the model for one time and test for accuracy +# Run the model m.run() -outval = m.get_output(0).asnumpy() -pred = outval.argmax(axis=1) -accuracy = ((pred == labels) * test_mask).sum() / test_mask.sum() -print("Test accuracy {:.2%}".format(accuracy)) - -# Evaluate the runtime -print("Evaluate inference time cost...") -timer = m.module.time_evaluator("run", ctx, number=1, repeat=10) -tcost = timer() -prof_res = tcost.results -prof_res = np.array(tcost.results) * 1000 # convert to millisecond -print("Mean inference time (std dev): %.6f ms (%.6f ms)" % - (np.mean(prof_res), np.std(prof_res))) +logits_tvm = m.get_output(0).asnumpy() +print("Print the first five outputs from TVM execution\n", logits_tvm[:5]) + +# Verify the results with DGL-PyTorch +tvm.testing.assert_allclose(logits_torch, logits_tvm, atol=1e-3) From e4a8b60c8528906fb183e7941edd9fff604a17e2 Mon Sep 17 00:00:00 2001 From: Chien-Yu Lin Date: Tue, 8 Oct 2019 16:02:09 -0700 Subject: [PATCH 6/9] add accuracy testing --- tutorials/frontend/build_gcn.py | 31 ++++++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/tutorials/frontend/build_gcn.py b/tutorials/frontend/build_gcn.py index 1817d64f9274..8ce83d456687 100644 --- a/tutorials/frontend/build_gcn.py +++ b/tutorials/frontend/build_gcn.py @@ -72,7 +72,7 @@ def forward(self, features): ###################################################################### -# Load the dataset with DGL utilities +# Define the functions to load dataset and evaluate accuracy # ------------------ # You may substitute this part with your own dataset, here we load data from DGL @@ -91,8 +91,16 @@ def load_dataset(dataset="cora"): return g, data +def evaluate(data, logits): + test_mask = data.test_mask # the test set which isn't included in the training phase + + pred = logits.argmax(axis=1) + acc = ((pred == data.labels) * test_mask).sum() / test_mask.sum() + + return acc + ###################################################################### -# Set up model Parameters +# Load the data and set up model parameters # ------------------ """ Parameters @@ -146,12 +154,17 @@ def load_dataset(dataset="cora"): # Load the weights into the model torch_model.load_state_dict(torch.load(model_path)) -# Run the DGL model +###################################################################### +# Run the DGL model and test for accuracy +# ------------------ torch_model.eval() with torch.no_grad(): logits_torch = torch_model(features) print("Print the first five outputs from DGL-PyTorch execution\n", logits_torch[:5]) +acc = evaluate(data, logits_torch.numpy()) +print("Test accuracy of DGL results: {:.2%}".format(acc)) + ###################################################################### # Define Graph Convolution Layer in Relay # ---------------------------- @@ -325,10 +338,18 @@ def prepare_params(g, data): m = graph_runtime.create(graph, lib, ctx) m.set_input(**params) -# Run the model +###################################################################### +# Run the TVM model, test for accuracy and verify with DGL +# ------------------ m.run() logits_tvm = m.get_output(0).asnumpy() print("Print the first five outputs from TVM execution\n", logits_tvm[:5]) -# Verify the results with DGL-PyTorch +labels = data.labels +test_mask = data.test_mask + +acc = evaluate(data, logits_tvm) +print("Test accuracy of TVM results: {:.2%}".format(acc)) + +# Verify the results with the DGL model tvm.testing.assert_allclose(logits_torch, logits_tvm, atol=1e-3) From 6523d2d4d81778e1c5f4c74d57f86dd33b2c9228 Mon Sep 17 00:00:00 2001 From: Chien-Yu Lin Date: Wed, 9 Oct 2019 15:28:36 -0700 Subject: [PATCH 7/9] adjust import order --- tutorials/frontend/build_gcn.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tutorials/frontend/build_gcn.py b/tutorials/frontend/build_gcn.py index 8ce83d456687..9095224bbd27 100644 --- a/tutorials/frontend/build_gcn.py +++ b/tutorials/frontend/build_gcn.py @@ -76,7 +76,6 @@ def forward(self, features): # ------------------ # You may substitute this part with your own dataset, here we load data from DGL -from dgl import DGLGraph from dgl.data import load_data from collections import namedtuple @@ -135,7 +134,7 @@ def evaluate(data, logits): # # The weights are trained with https://github.com/dmlc/dgl/blob/master/examples/pytorch/gcn/train.py from tvm.contrib.download import download_testdata - +from dgl import DGLGraph features = torch.FloatTensor(data.features) dgl_g = DGLGraph(g) From d54d6fb0d3df585737c21766d32fc8110f0d1f3e Mon Sep 17 00:00:00 2001 From: Chien-Yu Lin Date: Thu, 10 Oct 2019 10:08:35 -0700 Subject: [PATCH 8/9] handle different dgl versions --- tutorials/frontend/build_gcn.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/tutorials/frontend/build_gcn.py b/tutorials/frontend/build_gcn.py index 9095224bbd27..42906b14cc81 100644 --- a/tutorials/frontend/build_gcn.py +++ b/tutorials/frontend/build_gcn.py @@ -42,10 +42,10 @@ # # DGL example: https://github.com/dmlc/dgl/tree/master/examples/pytorch/gcn # This part reuses the code from the above example - import torch import torch.nn as nn import torch.nn.functional as F +import dgl from dgl.nn.pytorch import GraphConv class GCN(nn.Module): @@ -67,7 +67,11 @@ def __init__(self, def forward(self, features): h = features for i, layer in enumerate(self.layers): - h = layer(self.g, h) + # handle api changes for differnt DGL version + if dgl.__version__ <= '0.2': + h = layer(h, self.g) + else: + h = layer(self.g, h) return h @@ -75,7 +79,6 @@ def forward(self, features): # Define the functions to load dataset and evaluate accuracy # ------------------ # You may substitute this part with your own dataset, here we load data from DGL - from dgl.data import load_data from collections import namedtuple @@ -90,6 +93,7 @@ def load_dataset(dataset="cora"): return g, data + def evaluate(data, logits): test_mask = data.test_mask # the test set which isn't included in the training phase @@ -98,6 +102,7 @@ def evaluate(data, logits): return acc + ###################################################################### # Load the data and set up model parameters # ------------------ @@ -147,12 +152,13 @@ def evaluate(data, logits): F.relu) # Download the pretrained weights -model_url = "https://homes.cs.washington.edu/~cyulin/media/gcn_%s.torch"%(dataset) +model_url = "https://homes.cs.washington.edu/~cyulin/media/gnn_model/gcn_%s.torch"%(dataset) model_path = download_testdata(model_url, "gcn_%s.pickle"%(dataset), module='gcn_model') # Load the weights into the model torch_model.load_state_dict(torch.load(model_path)) + ###################################################################### # Run the DGL model and test for accuracy # ------------------ @@ -164,6 +170,7 @@ def evaluate(data, logits): acc = evaluate(data, logits_torch.numpy()) print("Test accuracy of DGL results: {:.2%}".format(acc)) + ###################################################################### # Define Graph Convolution Layer in Relay # ---------------------------- @@ -314,8 +321,6 @@ def prepare_params(g, data): ###################################################################### # Compile and run with TVM # ------------------ -# - # Export the weigths from PyTorch model to Python Dict model_params = {} for param_tensor in torch_model.state_dict(): From 5f4175911ee20861c2d471bb68a8639a3cdbafbb Mon Sep 17 00:00:00 2001 From: Chien-Yu Lin Date: Thu, 10 Oct 2019 15:33:22 -0700 Subject: [PATCH 9/9] update number for dgl version checking --- tutorials/frontend/build_gcn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tutorials/frontend/build_gcn.py b/tutorials/frontend/build_gcn.py index 42906b14cc81..e97c83c6c44e 100644 --- a/tutorials/frontend/build_gcn.py +++ b/tutorials/frontend/build_gcn.py @@ -68,10 +68,10 @@ def forward(self, features): h = features for i, layer in enumerate(self.layers): # handle api changes for differnt DGL version - if dgl.__version__ <= '0.2': - h = layer(h, self.g) - else: + if dgl.__version__ > '0.3': h = layer(self.g, h) + else: + h = layer(h, self.g) return h