Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Deterministic UMAP with floating point rounding. #3848

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions cpp/include/cuml/manifold/umapparams.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -150,9 +150,12 @@ class UMAPParams {

uint64_t random_state = 0;

bool multicore_implem = true;

int optim_batch_size = 0;
/**
* Whether should we use deterministic algorithm. This should be set to true if
random_state is provided, otherwise it's false. When it's true, cuml will have
higher memory usage but produce stable numeric output.
*/
bool deterministic = true;

Internals::GraphBasedDimRedCallback* callback = nullptr;
};
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/umap/runner.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ void _transform(const raft::handle_t &handle, const umap_inputs &inputs,

SimplSetEmbedImpl::optimize_layout<TPB_X, value_t>(
transformed, inputs.n, embedding, embedding_n, comp_coo.rows(),
comp_coo.cols(), comp_coo.nnz, epochs_per_sample.data(), inputs.n,
comp_coo.cols(), comp_coo.nnz, epochs_per_sample.data(),
params->repulsion_strength, params, n_epochs, d_alloc, stream);
ML::POP_RANGE();

Expand Down
207 changes: 136 additions & 71 deletions cpp/src/umap/simpl_set_embed/algo.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,34 @@
#pragma once

#include <cuml/manifold/umapparams.h>
#include <cuml/common/device_buffer.hpp>
#include <cuml/common/logger.hpp>

#include <rmm/device_uvector.hpp>
#include <rmm/exec_policy.hpp>

#include <curand.h>
#include <math.h>
#include <raft/cudart_utils.h>
#include <thrust/device_ptr.h>
#include <thrust/extrema.h>
#include <thrust/for_each.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/system/cuda/execution_policy.h>

#include <common/fast_int_div.cuh>
#include <cstdlib>
#include <cuml/common/logger.hpp>

#include <raft/cudart_utils.h>
#include <raft/linalg/unary_op.cuh>
#include <raft/mr/device/allocator.hpp>
#include <raft/random/rng_impl.cuh>
#include <raft/sparse/coo.cuh>

#include <string>
#include "optimize_batch_kernel.cuh"

#include <thrust/iterator/discard_iterator.h>
#include <raft/sparse/op/filter.cuh>

#pragma once

namespace UMAPAlgo {
Expand Down Expand Up @@ -79,28 +90,94 @@ void make_epochs_per_sample(T *weights, int weights_n, int n_epochs, T *result,
stream);
}

template <typename T>
void optimization_iteration_finalization(UMAPParams *params, T *head_embedding,
T &alpha, int n, int n_epochs,
uint64_t &seed) {
if (params->callback) params->callback->on_epoch_end(head_embedding);
alpha = params->initial_alpha * (1.0 - (T(n) / T(n_epochs)));
seed += 1;
}

/**
* Kernel applying updates to embedding
* TODO: Replace this kernel with modified version of Linalg::Add
* as described at https://github.com/rapidsai/cuml/issues/1781
* Update the embeddings and clear the buffers when using deterministic algorithm.
*/
template <typename T, int TPB_X>
__global__ void apply_optimization_kernel(T *embedding,
double *embedding_updates, int n) {
int idx = (blockIdx.x * TPB_X) + threadIdx.x;
if (idx < n) {
embedding[idx] += embedding_updates[idx];
template <typename T>
void apply_embedding_updates(T *head_embedding, T *head_buffer, int head_n,
T *tail_embedding, T *tail_buffer, int tail_n,
UMAPParams *params, bool move_other,
rmm::cuda_stream_view stream) {
ASSERT(params->deterministic, "Only used when deterministic is set to true.");
if (move_other) {
auto n_components = params->n_components;
thrust::for_each(rmm::exec_policy(stream),
thrust::make_counting_iterator(0u),
thrust::make_counting_iterator(0u) +
std::max(head_n, tail_n) * params->n_components,
[=] __device__(uint32_t i) {
if (i < head_n * n_components) {
head_embedding[i] += head_buffer[i];
head_buffer[i] = 0.0f;
}
if (i < tail_n * n_components) {
tail_embedding[i] += tail_buffer[i];
tail_buffer[i] = 0.0f;
}
});
} else {
// No need to update reference embedding
thrust::for_each(
rmm::exec_policy(stream), thrust::make_counting_iterator(0u),
thrust::make_counting_iterator(0u) + head_n * params->n_components,
[=] __device__(uint32_t i) {
head_embedding[i] += head_buffer[i];
head_buffer[i] = 0.0f;
});
}
}

/**
* \brief Constructs a rounding factor used to truncate elements in a sum such that the
* sum of the truncated elements is the same no matter what the order of the sum is.
*
* Algorithm 5: Reproducible Sequential Sum in 'Fast Reproducible Floating-Point
* Summation' by Demmel and Nguyen
*
* In algorithm 5 the bound is calculated as $max(|v_i|) * n$. We use maximum number of
* edges connected to each vertex as n.
*
* The calculation trick is borrowed from fbcuda, which is BSD-licensed.
*/
template <typename T>
inline void optimization_iteration_finalization(UMAPParams *params,
T *head_embedding, T &alpha,
int n, int n_epochs,
uint64_t &seed) {
if (params->callback) params->callback->on_epoch_end(head_embedding);
alpha = params->initial_alpha * (1.0 - (T(n) / T(n_epochs)));
seed += 1;
T create_rounding_factor(T max_abs, int n) {
T delta =
max_abs / (static_cast<T>(1.0) -
static_cast<T>(2.0) * n * std::numeric_limits<T>::epsilon());

// Calculate ceil(log_2(delta)).
// frexpf() calculates exp and returns `x` such that
// delta = x * 2^exp, where `x` in (-1.0, -0.5] U [0.5, 1).
// Because |x| < 1, exp is exactly ceil(log_2(delta)).
int exp;
std::frexp(delta, &exp);

// return M = 2 ^ ceil(log_2(delta))
return std::ldexp(static_cast<T>(1.0), exp);
}

template <typename T>
T create_gradient_rounding_factor(const int *head, int nnz, int n_samples,
T alpha, rmm::cuda_stream_view stream) {
rmm::device_uvector<T> buffer(n_samples, stream);
// calcuate the maximum number of edges conected to 1 vertex.
thrust::reduce_by_key(rmm::exec_policy(stream), head, head + nnz,
thrust::make_constant_iterator(1u),
thrust::make_discard_iterator(), buffer.data());
auto ptr = thrust::device_pointer_cast(buffer.data());
uint32_t n_edges =
*(thrust::max_element(rmm::exec_policy(stream), ptr, ptr + buffer.size()));
T max_abs = T(n_edges) * T(4.0) * std::abs(alpha);
return create_rounding_factor(max_abs, n_edges);
}

/**
Expand All @@ -115,19 +192,15 @@ inline void optimization_iteration_finalization(UMAPParams *params,
template <int TPB_X, typename T>
void optimize_layout(T *head_embedding, int head_n, T *tail_embedding,
int tail_n, const int *head, const int *tail, int nnz,
T *epochs_per_sample, int n_vertices, float gamma,
UMAPParams *params, int n_epochs,
T *epochs_per_sample, float gamma, UMAPParams *params,
int n_epochs,
std::shared_ptr<raft::mr::device::allocator> d_alloc,
cudaStream_t stream) {
// Are we doing a fit or a transform?
bool move_other = head_embedding == tail_embedding;

if (params->optim_batch_size <= 0) {
params->optim_batch_size = 100000 / params->n_components;
}

T alpha = params->initial_alpha;

auto stream_view = rmm::cuda_stream_view(stream);
MLCommon::device_buffer<T> epoch_of_next_negative_sample(d_alloc, stream,
nnz);
T nsr_inv = T(1.0) / params->negative_sample_rate;
Expand All @@ -138,56 +211,48 @@ void optimize_layout(T *head_embedding, int head_n, T *tail_embedding,
MLCommon::device_buffer<T> epoch_of_next_sample(d_alloc, stream, nnz);
raft::copy(epoch_of_next_sample.data(), epochs_per_sample, nnz, stream);

// Buffers used to store the gradient updates to avoid conflicts
rmm::device_uvector<T> head_buffer(0, stream_view);
rmm::device_uvector<T> tail_buffer(0, stream_view);
// Write to embedding directly if deterministic is not needed.
T *d_head_buffer = head_embedding;
T *d_tail_buffer = tail_embedding;
if (params->deterministic) {
head_buffer.resize(head_n * params->n_components, stream_view);
CUDA_CHECK(cudaMemsetAsync(head_buffer.data(), '\0',
sizeof(T) * head_buffer.size(), stream));
// No need for tail if it's not being written.
if (move_other) {
tail_buffer.resize(tail_n * params->n_components, stream_view);
CUDA_CHECK(cudaMemsetAsync(tail_buffer.data(), '\0',
sizeof(T) * tail_buffer.size(), stream));
}
d_head_buffer = head_buffer.data();
d_tail_buffer = tail_buffer.data();
}

dim3 grid(raft::ceildiv(nnz, TPB_X), 1, 1);
dim3 blk(TPB_X, 1, 1);
uint64_t seed = params->random_state;

MLCommon::FastIntDiv tail_n_fast(tail_n);
T rounding =
create_gradient_rounding_factor<T>(head, nnz, head_n, alpha, stream_view);

if (params->multicore_implem) {
for (int n = 0; n < n_epochs; n++) {
call_optimize_batch_kernel<T, TPB_X>(
head_embedding, head_n, tail_embedding, tail_n_fast, head, tail, nnz,
epochs_per_sample, n_vertices, epoch_of_next_negative_sample.data(),
epoch_of_next_sample.data(), alpha, n, gamma, seed, nullptr, move_other,
params, n, grid, blk, stream);
CUDA_CHECK(cudaGetLastError());
optimization_iteration_finalization(params, head_embedding, alpha, n,
n_epochs, seed);
}
} else {
MLCommon::device_buffer<double> embedding_updates_buf(
d_alloc, stream, n_vertices * params->n_components);
double *embedding_updates = embedding_updates_buf.data();
dim3 grid2(raft::ceildiv(n_vertices * params->n_components, TPB_X));

for (int n = 0; n < n_epochs; n++) {
CUDA_CHECK(cudaMemsetAsync(
embedding_updates, 0,
n_vertices * params->n_components * sizeof(double), stream));

int toDo = nnz;
int offset = 0;
while (toDo > 0) {
int curBatchSize = min(toDo, params->optim_batch_size);
call_optimize_batch_kernel<T, TPB_X>(
head_embedding, head_n, tail_embedding, tail_n_fast, head, tail,
offset + curBatchSize, epochs_per_sample, n_vertices,
epoch_of_next_negative_sample.data(), epoch_of_next_sample.data(),
alpha, n, gamma, seed, embedding_updates, move_other, params, n, grid,
blk, stream, offset);
CUDA_CHECK(cudaGetLastError());

toDo -= curBatchSize;
offset += curBatchSize;
}

apply_optimization_kernel<T, TPB_X><<<grid2, blk, 0, stream>>>(
head_embedding, embedding_updates, n_vertices * params->n_components);
CUDA_CHECK(cudaGetLastError());
optimization_iteration_finalization(params, head_embedding, alpha, n,
n_epochs, seed);
MLCommon::FastIntDiv tail_n_fast(tail_n);
for (int n = 0; n < n_epochs; n++) {
call_optimize_batch_kernel<T, TPB_X>(
head_embedding, d_head_buffer, head_n, tail_embedding, d_tail_buffer,
tail_n_fast, head, tail, nnz, epochs_per_sample,
epoch_of_next_negative_sample.data(), epoch_of_next_sample.data(), alpha,
gamma, seed, move_other, params, n, grid, blk, stream, rounding);
if (params->deterministic) {
apply_embedding_updates(head_embedding, d_head_buffer, head_n,
tail_embedding, d_tail_buffer, tail_n, params,
move_other, stream_view);
}
CUDA_CHECK(cudaGetLastError());
optimization_iteration_finalization(params, head_embedding, alpha, n,
n_epochs, seed);
}
}

Expand Down Expand Up @@ -250,7 +315,7 @@ void launcher(int m, int n, raft::sparse::COO<T> *in, UMAPParams *params,
}

optimize_layout<TPB_X, T>(embedding, m, embedding, m, out.rows(), out.cols(),
out.nnz, epochs_per_sample.data(), m,
out.nnz, epochs_per_sample.data(),
params->repulsion_strength, params, n_epochs,
d_alloc, stream);

Expand Down
Loading