Skip to content

Commit

Permalink
accelerate subgraph construction.
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-da committed Nov 3, 2018
1 parent 248b921 commit 9aa1dcb
Showing 1 changed file with 101 additions and 17 deletions.
118 changes: 101 additions & 17 deletions src/operator/dgl_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,103 @@ static bool DGLSubgraphType(const nnvm::NodeAttrs& attrs,

typedef int64_t dgl_id_t;

class Bitmap {
const size_t size = 1024 * 1024 * 4;
const size_t mask = size - 1;
std::vector<bool> map;

size_t hash(dgl_id_t id) const {
return id & mask;
}
public:
Bitmap(const dgl_id_t *vid_data, int64_t len): map(size) {
for (int64_t i = 0; i < len; ++i) {
map[hash(vid_data[i])] = 1;
}
}

bool test(dgl_id_t id) const {
return map[hash(id)];
}
};

/*
* This uses a hashtable to check if a node is in the given node list.
*/
class HashTableChecker {
std::unordered_map<dgl_id_t, dgl_id_t> oldv2newv;
Bitmap map;
public:
HashTableChecker(const dgl_id_t *vid_data, int64_t len): map(vid_data, len) {
oldv2newv.reserve(len);
for (int64_t i = 0; i < len; ++i) {
oldv2newv[vid_data[i]] = i;
}
}

void CollectOnRow(const dgl_id_t col_idx[], const dgl_id_t eids[], size_t row_len,
std::vector<dgl_id_t> *new_col_idx,
std::vector<dgl_id_t> *orig_eids) {
// TODO(zhengda) I need to make sure the column index in each row is sorted.
for (size_t j = 0; j < row_len; ++j) {
const dgl_id_t oldsucc = col_idx[j];
const dgl_id_t eid = eids[j];
Collect(oldsucc, eid, new_col_idx, orig_eids);
}
}

void Collect(const dgl_id_t old_id, const dgl_id_t old_eid,
std::vector<dgl_id_t> *col_idx,
std::vector<dgl_id_t> *orig_eids) {
if (!map.test(old_id))
return;

auto it = oldv2newv.find(old_id);
if (it != oldv2newv.end()) {
const dgl_id_t new_id = it->second;
col_idx->push_back(new_id);
if (orig_eids)
orig_eids->push_back(old_eid);
}
}
};

class ScanChecker {
const dgl_id_t *vid_data;
size_t len;
public:
ScanChecker(const dgl_id_t *vid_data, size_t len) {
this->vid_data = vid_data;
this->len = len;
}

void CollectOnRow(const dgl_id_t col_idx[], const dgl_id_t eids[], size_t row_len,
std::vector<dgl_id_t> *new_col_idx,
std::vector<dgl_id_t> *orig_eids) {
for (size_t v_idx = 0, r_idx = 0; v_idx < len && r_idx < row_len; ) {
if (col_idx[r_idx] == vid_data[v_idx]) {
new_col_idx->push_back(vid_data[v_idx]);
if (orig_eids)
orig_eids->push_back(eids[r_idx]);
r_idx++;
v_idx++;
} else if (col_idx[r_idx] < vid_data[v_idx]) {
r_idx++;
} else {
v_idx++;
}
}
}
};

static void GetSubgraph(const NDArray &csr_arr, const NDArray &varr,
const NDArray &sub_csr, const NDArray *old_eids) {
TBlob data = varr.data();
int64_t num_vertices = csr_arr.shape()[0];
const auto len = varr.shape()[0];
const size_t len = varr.shape()[0];
const dgl_id_t *vid_data = data.dptr<dgl_id_t>();
std::unordered_map<dgl_id_t, dgl_id_t> oldv2newv;
for (int64_t i = 0; i < len; ++i) {
oldv2newv[vid_data[i]] = i;
}
HashTableChecker def_check(vid_data, len);
ScanChecker scan_check(vid_data, len);

// Collect the non-zero entries in from the original graph.
std::vector<dgl_id_t> row_idx(len + 1);
Expand All @@ -123,24 +210,21 @@ static void GetSubgraph(const NDArray &csr_arr, const NDArray &varr,
const dgl_id_t *eids = csr_arr.data().dptr<dgl_id_t>();
const dgl_id_t *indptr = csr_arr.aux_data(csr::kIndPtr).dptr<dgl_id_t>();
const dgl_id_t *indices = csr_arr.aux_data(csr::kIdx).dptr<dgl_id_t>();
for (int64_t i = 0; i < len; ++i) {
for (size_t i = 0; i < len; ++i) {
const dgl_id_t oldvid = vid_data[i];
CHECK_LT(oldvid, num_vertices) << "Vertex Id " << oldvid << " isn't in a graph of "
<< num_vertices << " vertices";
size_t row_start = indptr[oldvid];
size_t row_len = indptr[oldvid + 1] - indptr[oldvid];
// TODO(zhengda) I need to make sure the column index in each row is sorted.
for (size_t j = 0; j < row_len; ++j) {
const dgl_id_t oldsucc = indices[row_start + j];
const dgl_id_t eid = eids[row_start + j];
auto it = oldv2newv.find(oldsucc);
if (it != oldv2newv.end()) {
const dgl_id_t newsucc = it->second;
col_idx.push_back(newsucc);
if (old_eids)
orig_eids.push_back(eid);
}
// If there aren't so many elements in a row
if (row_len < len / 2) {
def_check.CollectOnRow(indices + row_start, eids + row_start, row_len,
&col_idx, old_eids == nullptr ? nullptr : &orig_eids);
} else {
scan_check.CollectOnRow(indices + row_start, eids + row_start, row_len,
&col_idx, old_eids == nullptr ? nullptr : &orig_eids);
}

row_idx[i + 1] = col_idx.size();
}

Expand Down

0 comments on commit 9aa1dcb

Please sign in to comment.