Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove cutlass usage in row major input for euclidean exp/unexp, cosine and L1 distance matrix #3589

Merged
merged 10 commits into from
Mar 26, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ else(DEFINED ENV{RAFT_PATH})

ExternalProject_Add(raft
GIT_REPOSITORY https://github.com/rapidsai/raft.git
GIT_TAG 4a79adcb0c0e87964dcdc9b9122f242b5235b702
GIT_TAG 479b99bfae19747bc4159ee6267bdc6ec959af95
PREFIX ${RAFT_DIR}
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
Expand Down
189 changes: 180 additions & 9 deletions cpp/src_prims/distance/cosine.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2020, NVIDIA CORPORATION.
* Copyright (c) 2018-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,15 +16,105 @@

#pragma once
#include <linalg/eltwise2d.cuh>
#include "algo1.cuh"
#include "distance_fragment_multiply_add.cuh"
#include "pairwise_distance_base.cuh"

#include <linalg/cutlass_gemm.cuh>
#include <raft/linalg/norm.cuh>
#include "distance_epilogue.cuh"
#include "distance_epilogue_functor.cuh"
#include "distance_epilogue_traits.h"

#include <cutlass/gemm/gemm_epilogue_traits.h>
#include <cutlass/gemm/thread_multiply_add.h>
#include <cutlass/shape.h>

#include <type_traits>

namespace MLCommon {
namespace Distance {

/**
* @brief the cosine distance matrix calculation implementer
* It computes the following equation:
* C = 1 - op(A * B / sqrt(A^2) * sqrt(B^2)))
* @tparam DataT input data-type (for A and B matrices)
* @tparam AccT accumulation data-type
* @tparam OutT output data-type (for C and D matrices)
* @tparam IdxT index data-type
* @tparam Veclen number of k-elements loaded by each thread for every LDG call
* it makes. check contractions.cuh for details.
* @tparam FinalLambda the final lambda called on final distance value
* @param[in] x input matrix
* @param[in] y input matrix
* @param[in] xn row norms of input matrix A.
* @param[in] yn row norms of input matrix B.
* @param[in] m number of rows of A and C/D
* @param[in] n number of columns of B and C/D
* @param[in] k number of cols of A and rows of B
* @param[output] pD output matrix
* @param fin_op the final gemm epilogue lambda
* @param stream cuda stream to launch cuda operations.
*/
template <typename DataT, typename AccT, typename OutT, typename IdxT,
int VecLen, typename FinalLambda>
void cosineImpl(const DataT *x, const DataT *y, const DataT *xn,
const DataT *yn, IdxT m, IdxT n, IdxT k, OutT *dOutput,
FinalLambda fin_op, cudaStream_t stream) {
typedef typename raft::linalg::Policy4x4<DataT, VecLen>::Policy Policy;
dim3 grid(raft::ceildiv<int>(m, Policy::Mblk),
raft::ceildiv<int>(n, Policy::Nblk));
dim3 blk(Policy::Nthreads);

// Accumulation operation lambda
auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) {
acc += x * y;
};

// epilogue operation lambda for final value calculation
auto epilog_lambda = [] __device__(
AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh],
DataT * regxn, DataT * regyn) {
#pragma unroll
for (int i = 0; i < Policy::AccRowsPerTh; ++i) {
#pragma unroll
for (int j = 0; j < Policy::AccColsPerTh; ++j) {
acc[i][j] = acc[i][j] / (regxn[i] * regyn[j]);
}
}
};

pairwiseDistanceMatKernel<raft::distance::DistanceType::CosineExpanded, DataT,
AccT, OutT, IdxT, Policy, decltype(core_lambda),
decltype(epilog_lambda), FinalLambda>
<<<grid, blk, Policy::SmemSize, stream>>>(
x, y, xn, yn, m, n, k, dOutput, core_lambda, epilog_lambda, fin_op);

CUDA_CHECK(cudaGetLastError());
}

template <typename DataT, typename AccT, typename OutT, typename IdxT,
typename FinalLambda>
void cosine(IdxT m, IdxT n, IdxT k, const DataT *x, const DataT *y,
const DataT *xn, const DataT *yn, OutT *dOutput, FinalLambda fin_op,
cudaStream_t stream) {
size_t bytes = sizeof(DataT) * k;
if (16 % sizeof(DataT) == 0 && bytes % 16 == 0) {
cosineImpl<DataT, AccT, OutT, IdxT, 16 / sizeof(DataT), FinalLambda>(
x, y, xn, yn, m, n, k, dOutput, fin_op, stream);
} else if (8 % sizeof(DataT) == 0 && bytes % 8 == 0) {
cosineImpl<DataT, AccT, OutT, IdxT, 8 / sizeof(DataT), FinalLambda>(
x, y, xn, yn, m, n, k, dOutput, fin_op, stream);
} else {
cosineImpl<DataT, AccT, OutT, IdxT, 1, FinalLambda>(
x, y, xn, yn, m, n, k, dOutput, fin_op, stream);
}
}

/**
* @brief the expanded cosine distance matrix calculation
* It computes the following equation: C = op(A^2 + B^2 - 2AB)
* It computes the following equation:
* C = 1 - op(A * B / sqrt(A^2) * sqrt(B^2)))
* @tparam IType input data-type (for A and B matrices)
* @tparam AccType accumulation data-type
* @tparam OType output data-type (for C and D matrices)
Expand All @@ -50,19 +140,100 @@ void cosineAlgo1(Index_ m, Index_ n, Index_ k, const InType *pA,
const InType *pB, OutType *pD, AccType *workspace,
size_t worksize, FinalLambda fin_op, cudaStream_t stream,
bool isRowMajor) {
typedef ExpandedDistanceFragmentMultiplyAdd<CosFusedDistance>
FragmentMultiplyAdd_;
auto norm_op = [] __device__(AccType in) { return raft::mySqrt(in); };

// Wrap fin_op to allow computing 1 - pA before calling fin_op
auto wrapped_fin_op = [fin_op] __device__(AccType d_val, Index_ g_d_idx) {
return fin_op(static_cast<AccType>(1.0) - d_val, g_d_idx);
};

distanceAlgo1<InType, AccType, OutType, OutputTile_, FragmentMultiplyAdd_,
decltype(wrapped_fin_op), decltype(norm_op), Index_>(
m, n, k, pA, pB, pD, false, workspace, worksize, wrapped_fin_op, norm_op,
stream, isRowMajor);
typedef std::is_same<OutType, bool> is_bool;

ASSERT(!(((pA != pB) && (worksize < (m + n) * sizeof(AccType))) ||
(worksize < m * sizeof(AccType))),
"workspace size error");
ASSERT(workspace != nullptr, "workspace is null");

InType *col_vec = workspace;
InType *row_vec = workspace;
if (pA != pB) {
row_vec += m;
raft::linalg::rowNorm(col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor,
stream, norm_op);
raft::linalg::rowNorm(row_vec, pB, k, n, raft::linalg::L2Norm, isRowMajor,
stream, norm_op);
} else {
raft::linalg::rowNorm(col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor,
stream, norm_op);
}

if (isRowMajor) {
typedef typename std::conditional<is_bool::value, OutType, AccType>::type
CosOutType;

cosine<InType, AccType, CosOutType, Index_, decltype(wrapped_fin_op)>(
m, n, k, pA, pB, col_vec, row_vec, reinterpret_cast<CosOutType *>(pD),
wrapped_fin_op, stream);
} else {
typedef ExpandedDistanceFragmentMultiplyAdd<CosFusedDistance>
FragmentMultiplyAdd_;
typedef typename std::conditional<is_bool::value, AccType, OutType>::type
EffOutType;
EffOutType *pDCast =
reinterpret_cast<EffOutType *>(pD); // Pretend to be EffOutType;
typedef typename cutlass::Shape<8, 8, 8> AccumulatorsPerThread_;
typedef cutlass::gemm::ThreadMultiplyAdd<
AccumulatorsPerThread_, cutlass::Shape<1, 4, 8>, InType, InType, AccType>
MainLoopFunctor_;
typedef LinAlg::CustomGemmConfig<InType, AccType, EffOutType, OutputTile_,
AccumulatorsPerThread_, MainLoopFunctor_>
GemmConfig_;

typedef ExpandedDistanceEpilogueFunctor<InType, AccType, GemmConfig_,
FragmentMultiplyAdd_>
EpilogueFunctor_;

typedef typename std::conditional<
is_bool::value,
BoolEpilogueTraitsHelper<GemmConfig_, EpilogueFunctor_, Index_>,
cutlass::gemm::GemmEpilogueTraitsHelper<
GemmConfig_, EpilogueFunctor_, Index_>>::type EpilogueTraitsHelper_;

typedef typename cutlass::gemm::SimplifiedGemmEpilogueTraits<
GemmConfig_, EpilogueFunctor_, Index_, EpilogueTraitsHelper_>
GemmEpilogueTraits_;
typedef ExpandedDistanceGemmEpilogue<GemmEpilogueTraits_> GemmEpilogue_;
typedef typename EpilogueFunctor_::Params EpiParams;

cublasOperation_t transa, transb;
const InType *aPtr, *bPtr;
Index_ lda, ldb, ldd;
Index_ gemm_m, gemm_n;
InType *rvec, *cvec;

transa = CUBLAS_OP_N;
transb = CUBLAS_OP_T;
aPtr = pA;
bPtr = pB;
lda = m;
ldb = n;
ldd = m;
gemm_m = m;
gemm_n = n;
cvec = row_vec;
rvec = col_vec;

LinAlg::gemm<InType, AccType, EffOutType, OutputTile_,
AccumulatorsPerThread_, MainLoopFunctor_, Index_, GemmConfig_,
EpilogueFunctor_, GemmEpilogueTraits_, GemmEpilogue_>(
transa, transb, gemm_m, gemm_n, k, (EffOutType)1, aPtr, lda, bPtr, ldb,
(EffOutType)0, nullptr, ldd, pDCast,
[cvec, rvec] HD(EpiParams & p) {
int err = p.initializeExtra(cvec, rvec, false);
return err;
},
wrapped_fin_op, stream);
}
}

}; // end namespace Distance
Expand Down
Loading