Skip to content

Commit

Permalink
first weighted edges support commit
Browse files Browse the repository at this point in the history
  • Loading branch information
thunderock committed Jul 8, 2024
1 parent 89b74f0 commit ce19d0f
Show file tree
Hide file tree
Showing 7 changed files with 419 additions and 5 deletions.
172 changes: 172 additions & 0 deletions csrc/cpu/rw_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,178 @@ void rejection_sampling(const int64_t *rowptr, const int64_t *col,
});
}


void compute_cdf(const int64_t *rowptr, const float_t *edge_weight,
float_t *edge_weight_cdf, int64_t numel) {
/* Convert edge weights to CDF as given in [1]
[1] https://github.com/louisabraham/fastnode2vec/blob/master/fastnode2vec/graph.py#L148
*/
at::parallel_for(0, numel - 1, at::internal::GRAIN_SIZE, [&](int64_t begin, int64_t end) {
for(int64_t i = begin; i < end; i++) {
int64_t row_start = rowptr[i], row_end = rowptr[i + 1];

// Compute sum to normalize weights
float_t sum = 0.0;

for(int64_t j = row_start; j < row_end; j++) {
sum += edge_weight[j];
}

float_t acc = 0.0;

for(int64_t j = row_start; j < row_end; j++) {
acc += edge_weight[j] / sum;
edge_weight_cdf[j] = acc;
}
}
});
}


int64_t get_offset(const float_t *edge_weight, int64_t start, int64_t end) {
/*
The implementation given in [1] utilizes the `searchsorted` function in Numpy.
It is also available in PyTorch and its C++ API (via `at::searchsorted()`).
However, the implementation is adopted to the general case where the searched
values can be a multidimensional tensor. In our case, we have a 1D tensor of
edge weights (in form of a Cumulative Distribution Function) and a single
value, whose position we want to compute. To eliminate the overhead introduced
in the PyTorch implementation, one can examine the source code of
`searchsorted` [2] and find that for our case the whole function call can be
reduced to calling the `cus_lower_bound()` function. Unfortunately, we cannot
access it directly (the namespace is not exposed to the public API), but the
implementation is just a simple binary search. The code was copied here and
reduced to the bare minimum.
[1] https://github.com/louisabraham/fastnode2vec/blob/master/fastnode2vec/graph.py#L69
[2] https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Bucketization.cpp
*/
float_t value = ((float_t)rand() / RAND_MAX); // [0, 1)
int64_t original_start = start;

while (start < end) {
const int64_t mid = start + ((end - start) >> 1);
const float_t mid_val = edge_weight[mid];
if (!(mid_val >= value)) {
start = mid + 1;
}
else {
end = mid;
}
}

return start - original_start;
}

// See: https://louisabraham.github.io/articles/node2vec-sampling.html
// See also: https://github.com/louisabraham/fastnode2vec/blob/master/fastnode2vec/graph.py#L69
void rejection_sampling_weighted(const int64_t *rowptr, const int64_t *col,
const float_t *edge_weight_cdf, int64_t *start,
int64_t *n_out, int64_t *e_out,
const int64_t numel, const int64_t walk_length,
const double p, const double q) {

double max_prob = fmax(fmax(1. / p, 1.), 1. / q);
double prob_0 = 1. / p / max_prob;
double prob_1 = 1. / max_prob;
double prob_2 = 1. / q / max_prob;

int64_t grain_size = at::internal::GRAIN_SIZE / walk_length;
at::parallel_for(0, numel, grain_size, [&](int64_t begin, int64_t end) {
for (auto n = begin; n < end; n++) {
int64_t t = start[n], v, x, e_cur, row_start, row_end;

n_out[n * (walk_length + 1)] = t;

row_start = rowptr[t], row_end = rowptr[t + 1];

if (row_end - row_start == 0) {
e_cur = -1;
v = t;
} else {
e_cur = row_start + get_offset(edge_weight_cdf, row_start, row_end);
v = col[e_cur];
}
n_out[n * (walk_length + 1) + 1] = v;
e_out[n * walk_length] = e_cur;

for (auto l = 1; l < walk_length; l++) {
row_start = rowptr[v], row_end = rowptr[v + 1];

if (row_end - row_start == 0) {
e_cur = -1;
x = v;
} else if (row_end - row_start == 1) {
e_cur = row_start;
x = col[e_cur];
} else {
if (p == 1. && q == 1.) {
e_cur = row_start + get_offset(edge_weight_cdf, row_start, row_end);
x = col[e_cur];
}
else {
while (true) {
e_cur = row_start + get_offset(edge_weight_cdf, row_start, row_end);
x = col[e_cur];

auto r = ((double)rand() / (RAND_MAX)); // [0, 1)

if (x == t && r < prob_0)
break;
else if (is_neighbor(rowptr, col, x, t) && r < prob_1)
break;
else if (r < prob_2)
break;
}
}
}

n_out[n * (walk_length + 1) + (l + 1)] = x;
e_out[n * walk_length + l] = e_cur;
t = v;
v = x;
}
}
});
}


std::tuple<torch::Tensor, torch::Tensor>
random_walk_weighted_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor edge_weight, torch::Tensor start,
int64_t walk_length, double p, double q) {
CHECK_CPU(rowptr);
CHECK_CPU(col);
CHECK_CPU(edge_weight);
CHECK_CPU(start);

CHECK_INPUT(rowptr.dim() == 1);
CHECK_INPUT(col.dim() == 1);
CHECK_INPUT(edge_weight.dim() == 1);
CHECK_INPUT(start.dim() == 1);

auto n_out = torch::empty({start.size(0), walk_length + 1}, start.options());
auto e_out = torch::empty({start.size(0), walk_length}, start.options());

auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
auto edge_weight_data = edge_weight.data_ptr<float_t>();
auto start_data = start.data_ptr<int64_t>();
auto n_out_data = n_out.data_ptr<int64_t>();
auto e_out_data = e_out.data_ptr<int64_t>();

auto edge_weight_cdf = torch::empty({edge_weight.size(0)}, edge_weight.options());
auto edge_weight_cdf_data = edge_weight_cdf.data_ptr<float_t>();

compute_cdf(rowptr_data, edge_weight_data, edge_weight_cdf_data, rowptr.numel());

rejection_sampling_weighted(rowptr_data, col_data, edge_weight_cdf_data,
start_data, n_out_data, e_out_data, start.numel(),
walk_length, p, q);

return std::make_tuple(n_out, e_out);
}


std::tuple<torch::Tensor, torch::Tensor>
random_walk_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
int64_t walk_length, double p, double q) {
Expand Down
5 changes: 5 additions & 0 deletions csrc/cpu/rw_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,8 @@
std::tuple<torch::Tensor, torch::Tensor>
random_walk_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
int64_t walk_length, double p, double q);

std::tuple<torch::Tensor, torch::Tensor>
random_walk_weighted_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor edge_weight, torch::Tensor start,
int64_t walk_length, double p, double q);
164 changes: 164 additions & 0 deletions csrc/cuda/rw_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,170 @@ __global__ void uniform_sampling_kernel(const int64_t *rowptr,
}
}


__global__ void cdf_kernel(const int64_t *rowptr, const float_t *edge_weight,
float_t *edge_weight_cdf, int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;

if (thread_idx < numel - 1) {
int64_t row_start = rowptr[thread_idx], row_end = rowptr[thread_idx + 1];

float_t sum = 0.0;

for(int64_t i = row_start; i < row_end; i++) {
sum += edge_weight[i];
}

float_t acc = 0.0;

for(int64_t i = row_start; i < row_end; i++) {
acc += edge_weight[i] / sum;
edge_weight_cdf[i] = acc;
}
}
}

__device__ void get_offset(const float_t *edge_weight, int64_t start, int64_t end,
float_t value, int64_t *position_out) {
int64_t original_start = start;

while (start < end) {
const int64_t mid = start + ((end - start) >> 1);
const float_t mid_val = edge_weight[mid];
if (!(mid_val >= value)) {
start = mid + 1;
}
else {
end = mid;
}
}

*position_out = start - original_start;
}

__global__ void
rejection_sampling_weighted_kernel(unsigned int seed, const int64_t *rowptr,
const int64_t *col, const float_t *edge_weight_cdf,
const int64_t *start, int64_t *n_out,
int64_t *e_out, const int64_t walk_length,
const int64_t numel, const double p,
const double q) {

curandState_t state;
curand_init(seed, 0, 0, &state);

double max_prob = fmax(fmax(1. / p, 1.), 1. / q);
double prob_0 = 1. / p / max_prob;
double prob_1 = 1. / max_prob;
double prob_2 = 1. / q / max_prob;

const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;

if (thread_idx < numel) {
int64_t t = start[thread_idx], v, x, e_cur, row_start, row_end, offset;

n_out[thread_idx] = t;

row_start = rowptr[t], row_end = rowptr[t + 1];

if (row_end - row_start == 0) {
e_cur = -1;
v = t;
} else {
get_offset(edge_weight_cdf, row_start, row_end, curand_uniform(&state), &offset);
e_cur = row_start + offset;
v = col[e_cur];
}

n_out[numel + thread_idx] = v;
e_out[thread_idx] = e_cur;

for (int64_t l = 1; l < walk_length; l++) {
row_start = rowptr[v], row_end = rowptr[v + 1];

if (row_end - row_start == 0) {
e_cur = -1;
x = v;
} else if (row_end - row_start == 1) {
e_cur = row_start;
x = col[e_cur];
} else {
if (p == 1. && q == 1.) {
get_offset(edge_weight_cdf, row_start, row_end, curand_uniform(&state), &offset);
e_cur = row_start + offset;
x = col[e_cur];
}
else {
while (true) {
get_offset(edge_weight_cdf, row_start, row_end, curand_uniform(&state), &offset);
e_cur = row_start + offset;
x = col[e_cur];

double r = curand_uniform(&state); // (0, 1]

if (x == t && r < prob_0)
break;

bool is_neighbor = false;
row_start = rowptr[x], row_end = rowptr[x + 1];
for (int64_t i = row_start; i < row_end; i++) {
if (col[i] == t) {
is_neighbor = true;
break;
}
}

if (is_neighbor && r < prob_1)
break;
else if (r < prob_2)
break;
}
}
}

n_out[(l + 1) * numel + thread_idx] = x;
e_out[l * numel + thread_idx] = e_cur;
t = v;
v = x;
}
}
}

std::tuple<torch::Tensor, torch::Tensor>
random_walk_weighted_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor edge_weight, torch::Tensor start,
int64_t walk_length, double p, double q) {
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
CHECK_CUDA(edge_weight);
CHECK_CUDA(start);
cudaSetDevice(rowptr.get_device());

CHECK_INPUT(rowptr.dim() == 1);
CHECK_INPUT(col.dim() == 1);
CHECK_INPUT(edge_weight.dim() == 1);
CHECK_INPUT(start.dim() == 1);

auto n_out = torch::empty({walk_length + 1, start.size(0)}, start.options());
auto e_out = torch::empty({walk_length, start.size(0)}, start.options());

auto stream = at::cuda::getCurrentCUDAStream();

auto edge_weight_cdf = torch::empty({edge_weight.size(0)}, edge_weight.options());

cdf_kernel<<<BLOCKS(rowptr.numel()), THREADS, 0, stream>>>(
rowptr.data_ptr<int64_t>(), edge_weight.data_ptr<float_t>(),
edge_weight_cdf.data_ptr<float_t>(), rowptr.numel());

rejection_sampling_weighted_kernel<<<BLOCKS(start.numel()), THREADS, 0, stream>>>(
time(NULL), rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
edge_weight_cdf.data_ptr<float_t>(), start.data_ptr<int64_t>(),
n_out.data_ptr<int64_t>(), e_out.data_ptr<int64_t>(),
walk_length, start.numel(), p, q);

return std::make_tuple(n_out.t().contiguous(), e_out.t().contiguous());
}

__global__ void
rejection_sampling_kernel(unsigned int seed, const int64_t *rowptr,
const int64_t *col, const int64_t *start,
Expand Down
5 changes: 5 additions & 0 deletions csrc/cuda/rw_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,8 @@
std::tuple<torch::Tensor, torch::Tensor>
random_walk_cuda(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
int64_t walk_length, double p, double q);

std::tuple<torch::Tensor, torch::Tensor>
random_walk_weighted_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor edge_weight, torch::Tensor start,
int64_t walk_length, double p, double q);
20 changes: 19 additions & 1 deletion csrc/rw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,23 @@ random_walk(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
}
}

CLUSTER_API std::tuple<torch::Tensor, torch::Tensor>
random_walk_weighted(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor edge_weight, torch::Tensor start,
int64_t walk_length, double p, double q) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
return random_walk_weighted_cuda(rowptr, col, edge_weight, start, walk_length, p, q);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return random_walk_weighted_cpu(rowptr, col, edge_weight, start, walk_length, p, q);
}
}

static auto registry =
torch::RegisterOperators().op("torch_cluster::random_walk", &random_walk);
torch::RegisterOperators().op("torch_cluster::random_walk", &random_walk)
.op("torch_cluster::random_walk_weighted", &random_walk_weighted);


Loading

0 comments on commit ce19d0f

Please sign in to comment.