Skip to content

Commit

Permalink
Missing types master (#258)
Browse files Browse the repository at this point in the history
  • Loading branch information
daineAMD authored Sep 18, 2020
1 parent 22084f0 commit e4d9e7b
Show file tree
Hide file tree
Showing 19 changed files with 945 additions and 377 deletions.
120 changes: 120 additions & 0 deletions clients/common/cblas_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2085,6 +2085,68 @@ void cblas_gemm<hipblasHalf>(hipblasOperation_t transA,
}
}

template <>
void cblas_gemm<hipblasHalf, hipblasHalf, float>(hipblasOperation_t transA,
hipblasOperation_t transB,
int m,
int n,
int k,
float alpha_float,
hipblasHalf* A,
int lda,
hipblasHalf* B,
int ldb,
float beta_float,
hipblasHalf* C,
int ldc)
{
// cblas does not support hipblasHalf, so convert to higher precision float
// This will give more precise result which is acceptable for testing

int sizeA = transA == HIPBLAS_OP_N ? k * lda : m * lda;
int sizeB = transB == HIPBLAS_OP_N ? n * ldb : k * ldb;
int sizeC = n * ldc;

std::unique_ptr<float[]> A_float(new float[sizeA]());
std::unique_ptr<float[]> B_float(new float[sizeB]());
std::unique_ptr<float[]> C_float(new float[sizeC]());

for(int i = 0; i < sizeA; i++)
{
A_float[i] = half_to_float(A[i]);
}
for(int i = 0; i < sizeB; i++)
{
B_float[i] = half_to_float(B[i]);
}
for(int i = 0; i < sizeC; i++)
{
C_float[i] = half_to_float(C[i]);
}

// just directly cast, since transA, transB are integers in the enum
// printf("transA: rocblas =%d, cblas=%d\n", transA, (CBLAS_TRANSPOSE)transA );
cblas_sgemm(CblasColMajor,
(CBLAS_TRANSPOSE)transA,
(CBLAS_TRANSPOSE)transB,
m,
n,
k,
alpha_float,
const_cast<const float*>(A_float.get()),
lda,
const_cast<const float*>(B_float.get()),
ldb,
beta_float,
static_cast<float*>(C_float.get()),
ldc);

for(int i = 0; i < sizeC; i++)
{
C[i] = float_to_half(C_float[i]);
}
}

template <>
void cblas_gemm<float>(hipblasOperation_t transA,
hipblasOperation_t transB,
Expand Down Expand Up @@ -2212,6 +2274,64 @@ void cblas_gemm<hipblasDoubleComplex>(hipblasOperation_t transA,
ldc);
}

template <>
void cblas_gemm<int8_t, int32_t, int32_t>(hipblasOperation_t transA,
hipblasOperation_t transB,
int m,
int n,
int k,
int32_t alpha,
int8_t* A,
int lda,
int8_t* B,
int ldb,
int32_t beta,
int32_t* C,
int ldc)
{
double alpha_double = static_cast<double>(alpha);
double beta_double = static_cast<double>(beta);

size_t const sizeA = ((transA == HIPBLAS_OP_N) ? k : m) * size_t(lda);
size_t const sizeB = ((transB == HIPBLAS_OP_N) ? n : k) * size_t(ldb);
size_t const sizeC = n * size_t(ldc);

std::unique_ptr<double[]> A_double(new double[sizeA]());
std::unique_ptr<double[]> B_double(new double[sizeB]());
std::unique_ptr<double[]> C_double(new double[sizeC]());

for(int i = 0; i < sizeA; i++)
{
A_double[i] = static_cast<double>(A[i]);
}
for(int i = 0; i < sizeB; i++)
{
B_double[i] = static_cast<double>(B[i]);
}
for(int i = 0; i < sizeC; i++)
{
C_double[i] = static_cast<double>(C[i]);
}

cblas_dgemm(CblasColMajor,
static_cast<CBLAS_TRANSPOSE>(transA),
static_cast<CBLAS_TRANSPOSE>(transB),
m,
n,
k,
alpha_double,
const_cast<const double*>(A_double.get()),
lda,
const_cast<const double*>(B_double.get()),
ldb,
beta_double,
static_cast<double*>(C_double.get()),
ldc);

for(size_t i = 0; i < sizeC; i++)
C[i] = static_cast<int32_t>(C_double[i]);
}

// hemm
template <>
void cblas_hemm(hipblasSideMode_t side,
Expand Down
124 changes: 121 additions & 3 deletions clients/gtest/gemm_ex_gtest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,17 @@ typedef std::tuple<vector<int>, vector<double>, vector<char>, vector<hipblasData
// clang-format off
// vector of vector, each vector is a {M, N, K, lda, ldb, ldc};
// add/delete as a group
const vector<vector<int>> int8_matrix_size_range = {

{ 4, 4, 4, 4, 4, 4},
{ 8, 8, 8, 8, 8, 8},
{12, 12, 12, 12, 12, 12},
{16, 16, 16, 16, 16, 16},
{20, 20, 20, 20, 20, 20},
{ 8, 4, 4, 8, 8, 8},
{ 8, 12, 12, 12, 12, 12},
};

const vector<vector<int>> small_matrix_size_range = {
{ 1, 1, 1, 1, 1, 1},
{ 1, 2, 3, 4, 5, 6},
Expand Down Expand Up @@ -169,6 +180,12 @@ const vector<vector<double>> small_alpha_beta_range = {
const vector<vector<double>> full_alpha_beta_range = {
{1.0, 0.0}, {-1.0, -1.0}, {2.0, 1.0}, {0.0, 1.0}};

// For Cuda v < 10.0, only alpha and beta = 1 or = 0 are
// supported.
const vector<vector<double>> alpha_beta_range_int8 = {
{1.0, 1.0}, {1.0, 0.0},
};

// vector of vector, each pair is a {transA, transB};
// add/delete this list in pairs, like {'N', 'T'}
// for single/double precision, 'C'(conjTranspose) will downgraded to 'T' (transpose) internally in
Expand Down Expand Up @@ -201,6 +218,25 @@ const vector<vector<hipblasDatatype_t>> precision_double = {{ HIPBLAS_R_64F,
HIPBLAS_R_64F,
HIPBLAS_R_64F }};

const vector<vector<hipblasDatatype_t>> precision_single_complex = {{ HIPBLAS_C_32F,
HIPBLAS_C_32F,
HIPBLAS_C_32F,
HIPBLAS_C_32F,
HIPBLAS_C_32F }};

const vector<vector<hipblasDatatype_t>> precision_double_complex = {{ HIPBLAS_C_64F,
HIPBLAS_C_64F,
HIPBLAS_C_64F,
HIPBLAS_C_64F,
HIPBLAS_C_64F }};

const vector<vector<hipblasDatatype_t>> precision_int8 = {{ HIPBLAS_R_8I,
HIPBLAS_R_8I,
HIPBLAS_R_32I,
HIPBLAS_R_32I,
HIPBLAS_R_32I }};


const vector<vector<hipblasDatatype_t>> precision_type_range = {{HIPBLAS_R_16F,
HIPBLAS_R_16F,
HIPBLAS_R_16F,
Expand All @@ -220,7 +256,22 @@ const vector<vector<hipblasDatatype_t>> precision_type_range = {{HIPBLAS_R_16F,
HIPBLAS_R_64F,
HIPBLAS_R_64F,
HIPBLAS_R_64F,
HIPBLAS_R_64F}};
HIPBLAS_R_64F},
{HIPBLAS_C_32F,
HIPBLAS_C_32F,
HIPBLAS_C_32F,
HIPBLAS_C_32F,
HIPBLAS_C_32F},
{HIPBLAS_C_64F,
HIPBLAS_C_64F,
HIPBLAS_C_64F,
HIPBLAS_C_64F,
HIPBLAS_C_64F},
{HIPBLAS_R_8I,
HIPBLAS_R_8I,
HIPBLAS_R_32I,
HIPBLAS_R_32I,
HIPBLAS_R_32I}};

const int batch_count_range[] = { -1, 1, 5 };
const int batch_count_range_small[] = { 1 };
Expand Down Expand Up @@ -364,13 +415,19 @@ TEST_P(parameterized_gemm_batched_ex, standard_batched)
{
EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status);
}
else
else if(status == HIPBLAS_STATUS_ARCH_MISMATCH)
{
// Only available in cuda cc >= 5.0.
// If we want we can change this to call query_device_property() and
// call this only if cc < 5.0 on a CUDA device, else fail.
EXPECT_EQ(HIPBLAS_STATUS_ARCH_MISMATCH, status);
}
else
{
// cublas/rocblas do not have identical support
// (i.e. cublas doesn't support i8/i32 here)
EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status);
}
}
}

Expand Down Expand Up @@ -400,13 +457,19 @@ TEST_P(parameterized_gemm_batched_ex, standard_strided_batched)
{
EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status);
}
else
else if(status == HIPBLAS_STATUS_ARCH_MISMATCH)
{
// Only available in cuda cc >= 5.0.
// If we want we can change this to call query_device_property() and
// call this only if cc < 5.0 on a CUDA device, else fail.
EXPECT_EQ(HIPBLAS_STATUS_ARCH_MISMATCH, status);
}
else
{
// cublas/rocblas do not have identical support
// (i.e. cublas doesn't support i8/i32 here)
EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status);
}
}
}

Expand Down Expand Up @@ -502,6 +565,34 @@ INSTANTIATE_TEST_CASE_P(quick_blas_ex_small_double,
ValuesIn(precision_double),
ValuesIn(batch_count_range_small),
ValuesIn(is_fortran)));

INSTANTIATE_TEST_CASE_P(quick_blas_ex_small_single_complex,
parameterized_gemm_ex,
Combine(ValuesIn(small_matrix_size_range),
ValuesIn(alpha_beta_range),
ValuesIn(transA_transB_range),
ValuesIn(precision_single_complex),
ValuesIn(batch_count_range_small),
ValuesIn(is_fortran)));

INSTANTIATE_TEST_CASE_P(quick_blas_ex_small_double_complex,
parameterized_gemm_ex,
Combine(ValuesIn(small_matrix_size_range),
ValuesIn(alpha_beta_range),
ValuesIn(transA_transB_range),
ValuesIn(precision_double_complex),
ValuesIn(batch_count_range_small),
ValuesIn(is_fortran)));

INSTANTIATE_TEST_CASE_P(quick_blas_ex_small_int8,
parameterized_gemm_ex,
Combine(ValuesIn(int8_matrix_size_range),
ValuesIn(alpha_beta_range_int8),
ValuesIn(transA_transB_range),
ValuesIn(precision_int8),
ValuesIn(batch_count_range_small),
ValuesIn(is_fortran)));

//----medium
INSTANTIATE_TEST_CASE_P(pre_checkin_blas_ex_medium_hpa_half,
parameterized_gemm_ex,
Expand Down Expand Up @@ -575,3 +666,30 @@ INSTANTIATE_TEST_CASE_P(quick_blas_batched_ex_small_double,
ValuesIn(precision_double),
ValuesIn(batch_count_range),
ValuesIn(is_fortran)));

INSTANTIATE_TEST_CASE_P(quick_blas_batched_ex_small_single_complex,
parameterized_gemm_batched_ex,
Combine(ValuesIn(small_matrix_size_range),
ValuesIn(alpha_beta_range),
ValuesIn(transA_transB_range),
ValuesIn(precision_single_complex),
ValuesIn(batch_count_range),
ValuesIn(is_fortran)));

INSTANTIATE_TEST_CASE_P(quick_blas_batched_ex_small_double_complex,
parameterized_gemm_batched_ex,
Combine(ValuesIn(small_matrix_size_range),
ValuesIn(alpha_beta_range),
ValuesIn(transA_transB_range),
ValuesIn(precision_double_complex),
ValuesIn(batch_count_range),
ValuesIn(is_fortran)));

INSTANTIATE_TEST_CASE_P(quick_blas_batched_ex_small_int8,
parameterized_gemm_batched_ex,
Combine(ValuesIn(int8_matrix_size_range),
ValuesIn(alpha_beta_range_int8),
ValuesIn(transA_transB_range),
ValuesIn(precision_int8),
ValuesIn(batch_count_range),
ValuesIn(is_fortran)));
12 changes: 6 additions & 6 deletions clients/include/cblas_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -312,19 +312,19 @@ void cblas_geam(hipblasOperation_t transa,
int ldc);

// gemm
template <typename T>
template <typename Ti, typename To = Ti, typename Tc = To>
void cblas_gemm(hipblasOperation_t transA,
hipblasOperation_t transB,
int m,
int n,
int k,
T alpha,
T* A,
Tc alpha,
Ti* A,
int lda,
T* B,
Ti* B,
int ldb,
T beta,
T* C,
Tc beta,
To* C,
int ldc);

// hemm
Expand Down
6 changes: 3 additions & 3 deletions clients/include/testing_gemm_batched.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,12 +220,12 @@ hipblasStatus_t testing_GemmBatched(Arguments argus)
N,
K,
h_alpha,
hA_array[i],
(T*)hA_array[i],
lda,
hB_array[i],
(T*)hB_array[i],
ldb,
h_beta,
hC_copy_array[i],
(T*)hC_copy_array[i],
ldc);
}

Expand Down
Loading

0 comments on commit e4d9e7b

Please sign in to comment.