Skip to content

Commit

Permalink
Accelerate bernoulli number generation on CPU (pytorch#7171)
Browse files Browse the repository at this point in the history
* opt bernoulli rng with vsl and openmp

* detect cpu vendor for bernnoulli

* retrigger test platform

*  check the vendor more severely

* use cpuinfo to check vendor
  • Loading branch information
MlWoo authored and soumith committed Jun 5, 2018
1 parent ee0b75a commit 227a764
Show file tree
Hide file tree
Showing 11 changed files with 228 additions and 9 deletions.
13 changes: 13 additions & 0 deletions aten/src/ATen/Declarations.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -3568,6 +3568,19 @@
kwarg_only: True
- double p
]]
[[
name: _cpu_bernoulli_
backends:
- CPU
cname: bernoulli
return: self
arguments:
- THTensor* self
- arg: THGenerator* generator
default: nullptr
kwarg_only: True
- double p
]]
[[
name: _th_bernoulli
types:
Expand Down
8 changes: 1 addition & 7 deletions aten/src/ATen/native/Distributions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,7 @@ Tensor& bernoulli_(Tensor& self, const Tensor& p_, Generator* gen) {

Tensor& bernoulli_(Tensor& self, double p, Generator* gen) {
if (!self.is_cuda()) {
AT_DISPATCH_ALL_TYPES(self.type(), "bernoulli_", [&] {
THGenerator* generator = get_generator(gen);
std::lock_guard<std::mutex> lock(generator->mutex);
CPU_tensor_apply1<scalar_t>(self, [generator, p](scalar_t& ret_val) {
ret_val = (scalar_t)THRandom_bernoulli(generator, p);
});
});
self._cpu_bernoulli_(p, gen);
return self;
}
Tensor probs = self.type().toScalarType(kDouble).tensor({}).fill_(p);
Expand Down
4 changes: 4 additions & 0 deletions aten/src/TH/THGeneral.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
#include <stddef.h>
#include <inttypes.h>

#ifdef TH_BLAS_MKL
#include <mkl_vsl.h>
#endif

#cmakedefine USE_BLAS
#cmakedefine USE_LAPACK
#cmakedefine BLAS_F2C
Expand Down
4 changes: 2 additions & 2 deletions aten/src/TH/THTensorApply.h
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@
TYPE1 *rp = TENSOR1->storage->data<TYPE1>()+TENSOR1->storageOffset; \
TYPE2 *tp = TENSOR2->storage->data<TYPE2>()+TENSOR2->storageOffset; \
ptrdiff_t iter = 0; \
if(tp != rp) { \
if(tp != (TYPE2*)rp) { \
PRAGMA(ivdep) \
PRAGMA( omp parallel for if (SIZE > OMP_THRESHOLD * 10) firstprivate(rp, tp)) \
for (iter = 0; iter < SIZE; iter++) { \
Expand Down Expand Up @@ -449,7 +449,7 @@
TYPE2 *tp = TENSOR2->storage->data<TYPE2>()+TENSOR2->storageOffset; \
TYPE3 *srcp = TENSOR3->storage->data<TYPE3>()+TENSOR3->storageOffset; \
ptrdiff_t iter = 0;\
if (rp != tp) { \
if(tp != (TYPE2*)rp) { \
PRAGMA(ivdep) \
PRAGMA( omp parallel for if (SIZE > OMP_THRESHOLD * 10) ) \
for (iter = 0; iter < SIZE; iter++) {\
Expand Down
89 changes: 89 additions & 0 deletions aten/src/TH/generic/THTensorRandom.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
#define TH_GENERIC_FILE "generic/THTensorRandom.cpp"
#else

#ifdef _OPENMP
#include <omp.h>
#endif

#include <cpuinfo.h>

#include "THGenerator.hpp"

void THTensor_(random)(THTensor *self, THGenerator *_generator)
Expand Down Expand Up @@ -51,10 +57,93 @@ void THTensor_(geometric)(THTensor *self, THGenerator *_generator, double p)
TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_geometric(_generator, p););
}

#ifdef TH_BLAS_MKL
#define BERNOULLI_OMP 800
#define TH_OMP_OVERHEAD_THRESHOLD_COPY 20000

void iBernoulli_generate_copy(THTensor *self, THGenerator *_generator, const double p)
{
int64_t seed = THRandom_random(_generator);
int64_t n = THTensor_(nElement)(self);
int contig = THTensor_(isContiguous)(self);
int *tmp = NULL;
THIntTensor* intTensor = NULL;

if (contig) {
#ifdef TH_REAL_IS_INT
tmp = THIntTensor_data(self);
#else
tmp = (int*)THAlloc(n*sizeof(int));
#endif
} else {
intTensor = THIntTensor_new();
THIntTensor_resizeNd(intTensor, self->nDimension, self->size, NULL);
tmp = THIntTensor_data(intTensor);
}

#ifdef _OPENMP
size_t nthr = !omp_in_parallel() && n >= BERNOULLI_OMP ? omp_get_num_threads() : 1;
#pragma omp parallel num_threads(nthr) firstprivate(nthr)
{
size_t tid = omp_get_thread_num();
int64_t seg_len_tmp = n / nthr;
int64_t line_index_offset = tid * seg_len_tmp;
int64_t line_seg_len = (tid == nthr - 1)? (n-line_index_offset) : seg_len_tmp;
#else
{
int64_t line_index_offset = 0;
int64_t line_seg_len = n;
#endif

if (line_seg_len > 0) {
VSLStreamStatePtr stream;
vslNewStream(&stream, VSL_BRNG_MCG31, seed);
vslSkipAheadStream(stream, line_index_offset);
viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, stream, line_seg_len,
tmp + line_index_offset, p);
vslDeleteStream(&stream);

#ifndef TH_REAL_IS_INT
if (contig) {
real* self_seg = THTensor_(data)(self) + line_index_offset;
int* tmp_seg = tmp + line_index_offset;
THVector_(cvtFromInt)(self_seg, tmp_seg, line_seg_len);
}
#endif
}
}

if(contig) {
#ifndef TH_REAL_IS_INT
THFree(tmp);
#endif
} else {
#ifdef _OPENMP
TH_TENSOR_APPLY2_OMP(n, 1, 0, int, intTensor, real, self, *self_data = *intTensor_data;, TH_OMP_OVERHEAD_THRESHOLD_COPY)
#else
TH_TENSOR_APPLY2(int, intTensor, real, self, *self_data = *intTensor_data;)
#endif
THIntTensor_free(intTensor);
}

}

#endif

void THTensor_(bernoulli)(THTensor *self, THGenerator *_generator, double p)
{
#ifdef TH_BLAS_MKL
if(cpuinfo_initialize() && cpuinfo_vendor_intel == cpuinfo_get_processor(0)->core->vendor) {
std::lock_guard<std::mutex> lock(_generator->mutex);
iBernoulli_generate_copy(self, _generator, p);
} else {
std::lock_guard<std::mutex> lock(_generator->mutex);
TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_bernoulli(_generator, p););
}
#else
std::lock_guard<std::mutex> lock(_generator->mutex);
TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_bernoulli(_generator, p););
#endif
}

void THTensor_(bernoulli_FloatTensor)(THTensor *self, THGenerator *_generator, THFloatTensor *p)
Expand Down
3 changes: 3 additions & 0 deletions aten/src/TH/generic/THVector.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ TH_API void THVector_(normal_fill)(real *data,
struct THGenerator *generator,
const real mean,
const real stddev);
#ifndef TH_REAL_IS_INT
TH_API void THVector_(cvtFromInt)(real *y, const int *x, const ptrdiff_t n);
#endif

#if defined(TH_REAL_IS_SHORT) || defined(TH_REAL_IS_INT) || defined(TH_REAL_IS_LONG)
TH_API void THVector_(abs)(real *y, const real *x, const ptrdiff_t n);
Expand Down
18 changes: 18 additions & 0 deletions aten/src/TH/generic/THVectorDefault.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,24 @@ void THVector_(divs_DEFAULT)(real *y, const real *x, const real c, const ptrdiff
y[i] = x[i] / c;
}

#ifndef TH_REAL_IS_INT
void THVector_(cvtFromInt_DEFAULT)(real *y, const int *x, const ptrdiff_t n)
{
ptrdiff_t i = 0;

for(; i<n-4; i+=4)
{
y[i] = (real)x[i];
y[i+1] = (real)x[i+1];
y[i+2] = (real)x[i+2];
y[i+3] = (real)x[i+3];
}

for(; i < n; i++)
y[i] = (real)x[i];
}
#endif

// Fills 16 normally distributed samples into data, interleaved with a
// stride of 8, i.e. in order of ([0], [8]), ([1], [9]), ...
static void THVector_(interleaved_normal_fill_16)(real *data,
Expand Down
27 changes: 27 additions & 0 deletions aten/src/TH/generic/THVectorDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,29 @@ void THVector_(copy)(real *y, const real *x, const ptrdiff_t n) {
THVector_(copy_DISPATCHPTR)(y, x, n);
}

#ifndef TH_REAL_IS_INT
static void (*THVector_(cvtFromInt_DISPATCHPTR))(real *, const int *, const ptrdiff_t) = &THVector_(cvtFromInt_DEFAULT);
static FunctionDescription THVector_(cvtFromInt_DISPATCHTABLE)[] = {
#if defined(USE_AVX)
#if defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)
FUNCTION_IMPL(THVector_(cvtFromInt_AVX), SIMDExtension_AVX),
#endif
#endif
#if defined(USE_SSE2) || defined(USE_SSE3) || defined(USE_SSSE3) \
|| defined(USE_SSE4_1) || defined(USE_SSE4_2)
#if defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)
FUNCTION_IMPL(THVector_(cvtFromInt_SSE), SIMDExtension_SSE),
#endif
#endif


FUNCTION_IMPL(THVector_(cvtFromInt_DEFAULT), SIMDExtension_DEFAULT)
};
void THVector_(cvtFromInt)(real *y, const int *x, const ptrdiff_t n) {
THVector_(cvtFromInt_DISPATCHPTR)(y, x, n);
}
#endif

static void (*THVector_(normal_fill_DISPATCHPTR))(real *, const int64_t, THGenerator *, const real, const real) = &THVector_(normal_fill_DEFAULT);
static FunctionDescription THVector_(normal_fill_DISPATCHTABLE)[] = {
#if defined(TH_REAL_IS_FLOAT) && defined(USE_AVX2)
Expand Down Expand Up @@ -290,6 +313,10 @@ struct THVector_(startup) {
INIT_DISPATCH_PTR(copy);
INIT_DISPATCH_PTR(normal_fill);

#ifndef TH_REAL_IS_INT
INIT_DISPATCH_PTR(cvtFromInt);
#endif

#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
INIT_DISPATCH_PTR(sigmoid);
#endif
Expand Down
34 changes: 34 additions & 0 deletions aten/src/TH/vector/AVX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,4 +271,38 @@ void THFloatVector_adds_AVX(float *y, const float *x, const float c, const ptrdi
}
}

void THFloatVector_cvtFromInt_AVX(float *y, const int *x, const ptrdiff_t n) {
ptrdiff_t i;
__m256i YMM0, YMM1;
__m256 YMM2, YMM3;
for (i=0; i<=((n)-16); i+=16) {
YMM0 = _mm256_loadu_si256((__m256i const*)(x+i));
YMM1 = _mm256_loadu_si256((__m256i const*)(x+i+8));
YMM2 = _mm256_cvtepi32_ps(YMM0);
YMM3 = _mm256_cvtepi32_ps(YMM1);
_mm256_storeu_ps(y+i, YMM2);
_mm256_storeu_ps(y+i+8, YMM3);
}
for (; i<(n); i++) {
y[i] = (float)x[i];
}
}

void THDoubleVector_cvtFromInt_AVX(double *y, const int *x, const ptrdiff_t n) {
ptrdiff_t i;
__m128i YMM0, YMM1;
__m256d YMM2, YMM3;
for (i=0; i<=((n)- 8); i+=8) {
YMM0 = _mm_loadu_si128((__m128i const*)(x+i));
YMM1 = _mm_loadu_si128((__m128i const*)(x+i+4));
YMM2 = _mm256_cvtepi32_pd(YMM0);
YMM3 = _mm256_cvtepi32_pd(YMM1);
_mm256_storeu_pd(y+i, YMM2);
_mm256_storeu_pd(y+i+4, YMM3);
}
for (; i<(n); i++) {
y[i] = (double)x[i];
}
}

#endif // defined(__AVX__)
2 changes: 2 additions & 0 deletions aten/src/TH/vector/AVX.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ TH_API void THDoubleVector_cmul_AVX(double *z, const double *x, const double *y,
TH_API void THDoubleVector_muls_AVX(double *y, const double *x, const double c, const ptrdiff_t n);
TH_API void THDoubleVector_cadd_AVX(double *z, const double *x, const double *y, const double c, const ptrdiff_t n);
TH_API void THDoubleVector_adds_AVX(double *y, const double *x, const double c, const ptrdiff_t n);
TH_API void THDoubleVector_cvtFromInt_AVX(double *y, const int *x, const ptrdiff_t n);
TH_API void THFloatVector_copy_AVX(float *y, const float *x, const ptrdiff_t n);
TH_API void THFloatVector_fill_AVX(float *x, const float c, const ptrdiff_t n);
TH_API void THFloatVector_cdiv_AVX(float *z, const float *x, const float *y, const ptrdiff_t n);
Expand All @@ -20,4 +21,5 @@ TH_API void THFloatVector_cmul_AVX(float *z, const float *x, const float *y, con
TH_API void THFloatVector_muls_AVX(float *y, const float *x, const float c, const ptrdiff_t n);
TH_API void THFloatVector_cadd_AVX(float *z, const float *x, const float *y, const float c, const ptrdiff_t n);
TH_API void THFloatVector_adds_AVX(float *y, const float *x, const float c, const ptrdiff_t n);
TH_API void THFloatVector_cvtFromInt_AVX(float *y, const int *x, const ptrdiff_t n);
#endif
35 changes: 35 additions & 0 deletions aten/src/TH/vector/SSE.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,38 @@ static void THFloatVector_divs_SSE(float *y, const float *x, const float c, cons
y[i] = x[i] / c;
}
}

static void THFloatVector_cvtFromInt_SSE(float *y, const int *x, const ptrdiff_t n) {
ptrdiff_t i;
__m128i YMM0, YMM1;
__m128 YMM2, YMM3;
for (i=0; i<=((n)-8); i+=8) {
YMM0 = _mm_loadu_si128((__m128i const*)(x+i));
YMM1 = _mm_loadu_si128((__m128i const*)(x+i+4));
YMM2 = _mm_cvtepi32_ps(YMM0);
YMM3 = _mm_cvtepi32_ps(YMM1);
_mm_storeu_ps(y+i, YMM2);
_mm_storeu_ps(y+i+4, YMM3);
}
for (; i<(n); i++) {
y[i] = (float)x[i];
}
}

static void THDoubleVector_cvtFromInt_SSE(double *y, const int *x, const ptrdiff_t n) {
ptrdiff_t i;
__m128i YMM0, YMM1;
__m128d YMM2, YMM3;
for (i=0; i<=((n)- 4); i+=4) {
YMM0 = _mm_loadu_si128((__m128i const*)(x+i));
YMM2 = _mm_cvtepi32_pd(YMM0);
YMM1 = _mm_srli_si128(YMM0, 8);
YMM3 = _mm_cvtepi32_pd(YMM1);
_mm_storeu_pd(y+i, YMM2);
_mm_storeu_pd(y+i+2, YMM3);
}
for (; i<(n); i++) {
y[i] = (double)x[i];
}
}

0 comments on commit 227a764

Please sign in to comment.