From 9aa1dcbac5ceaac769bed25d29256997967f8472 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Sat, 3 Nov 2018 01:06:54 +0000 Subject: [PATCH] accelerate subgraph construction. --- src/operator/dgl_graph.cc | 118 ++++++++++++++++++++++++++++++++------ 1 file changed, 101 insertions(+), 17 deletions(-) diff --git a/src/operator/dgl_graph.cc b/src/operator/dgl_graph.cc index 11a782a49229..ce7c884ac391 100644 --- a/src/operator/dgl_graph.cc +++ b/src/operator/dgl_graph.cc @@ -103,16 +103,103 @@ static bool DGLSubgraphType(const nnvm::NodeAttrs& attrs, typedef int64_t dgl_id_t; +class Bitmap { + const size_t size = 1024 * 1024 * 4; + const size_t mask = size - 1; + std::vector map; + + size_t hash(dgl_id_t id) const { + return id & mask; + } + public: + Bitmap(const dgl_id_t *vid_data, int64_t len): map(size) { + for (int64_t i = 0; i < len; ++i) { + map[hash(vid_data[i])] = 1; + } + } + + bool test(dgl_id_t id) const { + return map[hash(id)]; + } +}; + +/* + * This uses a hashtable to check if a node is in the given node list. + */ +class HashTableChecker { + std::unordered_map oldv2newv; + Bitmap map; + public: + HashTableChecker(const dgl_id_t *vid_data, int64_t len): map(vid_data, len) { + oldv2newv.reserve(len); + for (int64_t i = 0; i < len; ++i) { + oldv2newv[vid_data[i]] = i; + } + } + + void CollectOnRow(const dgl_id_t col_idx[], const dgl_id_t eids[], size_t row_len, + std::vector *new_col_idx, + std::vector *orig_eids) { + // TODO(zhengda) I need to make sure the column index in each row is sorted. + for (size_t j = 0; j < row_len; ++j) { + const dgl_id_t oldsucc = col_idx[j]; + const dgl_id_t eid = eids[j]; + Collect(oldsucc, eid, new_col_idx, orig_eids); + } + } + + void Collect(const dgl_id_t old_id, const dgl_id_t old_eid, + std::vector *col_idx, + std::vector *orig_eids) { + if (!map.test(old_id)) + return; + + auto it = oldv2newv.find(old_id); + if (it != oldv2newv.end()) { + const dgl_id_t new_id = it->second; + col_idx->push_back(new_id); + if (orig_eids) + orig_eids->push_back(old_eid); + } + } +}; + +class ScanChecker { + const dgl_id_t *vid_data; + size_t len; + public: + ScanChecker(const dgl_id_t *vid_data, size_t len) { + this->vid_data = vid_data; + this->len = len; + } + + void CollectOnRow(const dgl_id_t col_idx[], const dgl_id_t eids[], size_t row_len, + std::vector *new_col_idx, + std::vector *orig_eids) { + for (size_t v_idx = 0, r_idx = 0; v_idx < len && r_idx < row_len; ) { + if (col_idx[r_idx] == vid_data[v_idx]) { + new_col_idx->push_back(vid_data[v_idx]); + if (orig_eids) + orig_eids->push_back(eids[r_idx]); + r_idx++; + v_idx++; + } else if (col_idx[r_idx] < vid_data[v_idx]) { + r_idx++; + } else { + v_idx++; + } + } + } +}; + static void GetSubgraph(const NDArray &csr_arr, const NDArray &varr, const NDArray &sub_csr, const NDArray *old_eids) { TBlob data = varr.data(); int64_t num_vertices = csr_arr.shape()[0]; - const auto len = varr.shape()[0]; + const size_t len = varr.shape()[0]; const dgl_id_t *vid_data = data.dptr(); - std::unordered_map oldv2newv; - for (int64_t i = 0; i < len; ++i) { - oldv2newv[vid_data[i]] = i; - } + HashTableChecker def_check(vid_data, len); + ScanChecker scan_check(vid_data, len); // Collect the non-zero entries in from the original graph. std::vector row_idx(len + 1); @@ -123,24 +210,21 @@ static void GetSubgraph(const NDArray &csr_arr, const NDArray &varr, const dgl_id_t *eids = csr_arr.data().dptr(); const dgl_id_t *indptr = csr_arr.aux_data(csr::kIndPtr).dptr(); const dgl_id_t *indices = csr_arr.aux_data(csr::kIdx).dptr(); - for (int64_t i = 0; i < len; ++i) { + for (size_t i = 0; i < len; ++i) { const dgl_id_t oldvid = vid_data[i]; CHECK_LT(oldvid, num_vertices) << "Vertex Id " << oldvid << " isn't in a graph of " << num_vertices << " vertices"; size_t row_start = indptr[oldvid]; size_t row_len = indptr[oldvid + 1] - indptr[oldvid]; - // TODO(zhengda) I need to make sure the column index in each row is sorted. - for (size_t j = 0; j < row_len; ++j) { - const dgl_id_t oldsucc = indices[row_start + j]; - const dgl_id_t eid = eids[row_start + j]; - auto it = oldv2newv.find(oldsucc); - if (it != oldv2newv.end()) { - const dgl_id_t newsucc = it->second; - col_idx.push_back(newsucc); - if (old_eids) - orig_eids.push_back(eid); - } + // If there aren't so many elements in a row + if (row_len < len / 2) { + def_check.CollectOnRow(indices + row_start, eids + row_start, row_len, + &col_idx, old_eids == nullptr ? nullptr : &orig_eids); + } else { + scan_check.CollectOnRow(indices + row_start, eids + row_start, row_len, + &col_idx, old_eids == nullptr ? nullptr : &orig_eids); } + row_idx[i + 1] = col_idx.size(); }