Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable wider load/store for multi_tensor_apply kernels #763

Merged
merged 2 commits into from
Apr 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 82 additions & 31 deletions apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,17 @@
#define BLOCK_SIZE 512
#define ILP 4

template<typename T>
FDecaYed marked this conversation as resolved.
Show resolved Hide resolved
__device__ __forceinline__ bool is_aligned(T* p){
return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
}

template<typename T>
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}

#include "type_shim.h"

typedef enum{
Expand Down Expand Up @@ -99,49 +110,90 @@ struct AdamFunctor
T incoming_v[ILP];
T incoming_g[ILP];

for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP) {
// to make things simple, we put aligned case in a different code path
if(n % ILP == 0 &&
chunk_size % ILP == 0 &&
is_aligned(p) &&
is_aligned(m) &&
is_aligned(v) &&
is_aligned(g) &&
is_aligned(p_copy))
{
for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
{
// load
GRAD_T tmp_g[ILP];
load_store(incoming_p, p, 0, i_start);
load_store(incoming_m, m, 0, i_start);
load_store(incoming_v, v, 0, i_start);
load_store(tmp_g, g, 0, i_start);
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
incoming_g[ii] = static_cast<T>(tmp_g[ii]);
T scaled_grad = incoming_g[ii]/grad_scale;
incoming_m[ii] = b1*incoming_m[ii] + (1-b1)*scaled_grad;
incoming_v[ii] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(incoming_v[ii] + eps);
else // Mode 1
denom = sqrtf(incoming_v[ii]) + eps;
float update = (incoming_m[ii]/denom) + (decay*incoming_p[ii]);
incoming_p[ii] = incoming_p[ii] - (step_size*update);
if (DEPTH == 5) tmp_g[ii] = static_cast<GRAD_T>(incoming_p[ii]);
}
load_store(p, incoming_p, i_start, 0);
load_store(m, incoming_m, i_start, 0);
load_store(v, incoming_v, i_start, 0);
if (DEPTH == 5) load_store(p_copy, tmp_g, i_start, 0);
}
}
else
{
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP) {

#pragma unroll
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
incoming_p[ii] = 0;
incoming_m[ii] = 0;
incoming_v[ii] = 0;
incoming_g[ii] = 0;
incoming_p[ii] = 0;
incoming_m[ii] = 0;
incoming_v[ii] = 0;
incoming_g[ii] = 0;

int i = i_start + threadIdx.x + ii*blockDim.x;
if (i < n && i < chunk_size) {
incoming_p[ii] = p[i];
incoming_m[ii] = m[i];
incoming_v[ii] = v[i];
incoming_g[ii] = static_cast<T>(g[i]);
}
int i = i_start + threadIdx.x + ii*blockDim.x;
if (i < n && i < chunk_size) {
incoming_p[ii] = p[i];
incoming_m[ii] = m[i];
incoming_v[ii] = v[i];
incoming_g[ii] = static_cast<T>(g[i]);
}
}

// note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling
// the write loop, since writes just fire off once their LDGs arrive.
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though.
#pragma unroll
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int j = i_start + threadIdx.x + ii*blockDim.x;
int j = i_start + threadIdx.x + ii*blockDim.x;

if(j < n && j < chunk_size) {
T scaled_grad = incoming_g[ii]/grad_scale;
m[j] = b1*incoming_m[ii] + (1-b1)*scaled_grad;
v[j] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(v[j] + eps);
else // Mode 1
denom = sqrtf(v[j]) + eps;
float update = (m[j]/denom) + (decay*incoming_p[ii]);
p[j] = incoming_p[ii] - (step_size*update);
if (DEPTH == 5) p_copy[j] = (GRAD_T) p[j];
}
if(j < n && j < chunk_size) {
T scaled_grad = incoming_g[ii]/grad_scale;
m[j] = b1*incoming_m[ii] + (1-b1)*scaled_grad;
v[j] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(v[j] + eps);
else // Mode 1
denom = sqrtf(v[j]) + eps;
float update = (m[j]/denom) + (decay*incoming_p[ii]);
p[j] = incoming_p[ii] - (step_size*update);
if (DEPTH == 5) p_copy[j] = (GRAD_T) p[j];
}
}
}
}
}
};
Expand Down Expand Up @@ -332,4 +384,3 @@ void fused_adam_cuda_mt(
}
THCudaCheck(cudaGetLastError());
}

93 changes: 66 additions & 27 deletions csrc/multi_tensor_axpby_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@
#define BLOCK_SIZE 512
#define ILP 4

template<typename T>
__device__ __forceinline__ bool is_aligned(T* p){
return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
}

template<typename T>
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}

template<typename x_t, typename y_t, typename out_t>
struct AxpbyFunctor
{
Expand Down Expand Up @@ -43,46 +54,74 @@ struct AxpbyFunctor

n -= chunk_idx*chunk_size;

// Non-divergent exit condition for __syncthreads, not necessary here
float xs[ILP];
float ys[ILP];
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP)
bool finite = true;
x_t r_x[ILP];
y_t r_y[ILP];
out_t r_out[ILP];

// to make things simple, we put aligned case in a different code path
if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x) && is_aligned(y) && is_aligned(out))
{
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
{
xs[ii] = 0;
ys[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
// load
load_store(r_x, x, 0 , i_start);
load_store(r_y, y, 0 , i_start);
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
xs[ii] = static_cast<float>(x[i]);
ys[ii] = static_cast<float>(y[i]);
r_out[ii] = a*static_cast<float>(r_x[ii]) + b*static_cast<float>(r_y[ii]);
if(arg_to_check == -1)
finite = finite && (isfinite(r_x[ii]) && isfinite(r_y[ii]));
if(arg_to_check == 0)
finite = finite && isfinite(r_x[ii]);
if(arg_to_check == 1)
finite = finite && isfinite(r_y[ii]);
}
// store
load_store(out, r_out, i_start , 0);
}

// see note in multi_tensor_scale_kernel.cu
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
}
else
{
// Non-divergent exit condition for __syncthreads, not necessary here
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
out[i] = static_cast<out_t>(a*xs[ii] + b*ys[ii]);
bool finite = true;
r_x[ii] = 0;
r_y[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
r_x[ii] = x[i];
r_y[ii] = y[i];
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
r_out[ii] = a*static_cast<float>(r_x[ii]) + b*static_cast<float>(r_y[ii]);
if(arg_to_check == -1)
finite = (isfinite(xs[ii]) && isfinite(ys[ii]));
finite = finite && (isfinite(r_x[ii]) && isfinite(r_y[ii]));
if(arg_to_check == 0)
finite = isfinite(xs[ii]);
finite = finite && isfinite(r_x[ii]);
if(arg_to_check == 1)
finite = isfinite(ys[ii]);
if(!finite)
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
finite = finite && isfinite(r_y[ii]);
}
// see note in multi_tensor_scale_kernel.cu
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
out[i] = r_out[ii];
}
}
}
if(!finite)
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
}
};

Expand Down
79 changes: 67 additions & 12 deletions csrc/multi_tensor_l2norm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@
#define BLOCK_SIZE 512
#define ILP 4

template<typename T>
__device__ __forceinline__ bool is_aligned(T* p){
return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
}

template<typename T>
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}

template<typename x_t>
struct L2NormFunctor
{
Expand Down Expand Up @@ -41,22 +52,44 @@ struct L2NormFunctor
__shared__ float s_vals[512];

float vals[ILP]; // = {0}; // this probably works too but I want to be sure...
x_t r_x[ILP];
for(int i = 0; i < ILP; i++)
{
vals[i] = 0.f;
r_x[i] = 0;
}

for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP)
// to make things simple, we put aligned case in a different code path
if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x))
{
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
// load
load_store(r_x, x, 0 , i_start);
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
float next = static_cast<float>(x[i]);
float next = static_cast<float>(r_x[ii]);
vals[ii] += next*next;
}
}
}
else
{
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP)
{
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
float next = static_cast<float>(x[i]);
vals[ii] += next*next;
}
}
}
}

float val = 0.f;
for(int i = 0; i < ILP; i++)
Expand Down Expand Up @@ -104,22 +137,44 @@ struct MaxNormFunctor
__shared__ float s_vals[512];

float vals[ILP]; // = {0}; // this probably works too but I want to be sure...
x_t r_x[ILP];
for(int i = 0; i < ILP; i++)
{
vals[i] = 0.f;
r_x[i] = 0;
}

for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP)
// to make things simple, we put aligned case in a different code path
if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x))
{
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
// load
load_store(r_x, x, 0 , i_start);
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
float next = static_cast<float>(x[i]);
float next = static_cast<float>(r_x[ii]);
vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next));
}
}
}
else
{
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP)
{
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
float next = static_cast<float>(x[i]);
vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next));
}
}
}
}

float val = 0.f;
for(int i = 0; i < ILP; i++)
Expand Down
Loading