diff --git a/clients/common/cblas_interface.cpp b/clients/common/cblas_interface.cpp index 978433869..269063276 100644 --- a/clients/common/cblas_interface.cpp +++ b/clients/common/cblas_interface.cpp @@ -2085,6 +2085,68 @@ void cblas_gemm(hipblasOperation_t transA, } } +template <> +void cblas_gemm(hipblasOperation_t transA, + hipblasOperation_t transB, + int m, + int n, + int k, + float alpha_float, + hipblasHalf* A, + int lda, + hipblasHalf* B, + int ldb, + float beta_float, + hipblasHalf* C, + int ldc) +{ + // cblas does not support hipblasHalf, so convert to higher precision float + // This will give more precise result which is acceptable for testing + + int sizeA = transA == HIPBLAS_OP_N ? k * lda : m * lda; + int sizeB = transB == HIPBLAS_OP_N ? n * ldb : k * ldb; + int sizeC = n * ldc; + + std::unique_ptr A_float(new float[sizeA]()); + std::unique_ptr B_float(new float[sizeB]()); + std::unique_ptr C_float(new float[sizeC]()); + + for(int i = 0; i < sizeA; i++) + { + A_float[i] = half_to_float(A[i]); + } + for(int i = 0; i < sizeB; i++) + { + B_float[i] = half_to_float(B[i]); + } + for(int i = 0; i < sizeC; i++) + { + C_float[i] = half_to_float(C[i]); + } + + // just directly cast, since transA, transB are integers in the enum + // printf("transA: rocblas =%d, cblas=%d\n", transA, (CBLAS_TRANSPOSE)transA ); + cblas_sgemm(CblasColMajor, + (CBLAS_TRANSPOSE)transA, + (CBLAS_TRANSPOSE)transB, + m, + n, + k, + alpha_float, + const_cast(A_float.get()), + lda, + const_cast(B_float.get()), + ldb, + beta_float, + static_cast(C_float.get()), + ldc); + + for(int i = 0; i < sizeC; i++) + { + C[i] = float_to_half(C_float[i]); + } +} + template <> void cblas_gemm(hipblasOperation_t transA, hipblasOperation_t transB, @@ -2212,6 +2274,64 @@ void cblas_gemm(hipblasOperation_t transA, ldc); } +template <> +void cblas_gemm(hipblasOperation_t transA, + hipblasOperation_t transB, + int m, + int n, + int k, + int32_t alpha, + int8_t* A, + int lda, + int8_t* B, + int ldb, + int32_t beta, + int32_t* C, + int ldc) +{ + double alpha_double = static_cast(alpha); + double beta_double = static_cast(beta); + + size_t const sizeA = ((transA == HIPBLAS_OP_N) ? k : m) * size_t(lda); + size_t const sizeB = ((transB == HIPBLAS_OP_N) ? n : k) * size_t(ldb); + size_t const sizeC = n * size_t(ldc); + + std::unique_ptr A_double(new double[sizeA]()); + std::unique_ptr B_double(new double[sizeB]()); + std::unique_ptr C_double(new double[sizeC]()); + + for(int i = 0; i < sizeA; i++) + { + A_double[i] = static_cast(A[i]); + } + for(int i = 0; i < sizeB; i++) + { + B_double[i] = static_cast(B[i]); + } + for(int i = 0; i < sizeC; i++) + { + C_double[i] = static_cast(C[i]); + } + + cblas_dgemm(CblasColMajor, + static_cast(transA), + static_cast(transB), + m, + n, + k, + alpha_double, + const_cast(A_double.get()), + lda, + const_cast(B_double.get()), + ldb, + beta_double, + static_cast(C_double.get()), + ldc); + + for(size_t i = 0; i < sizeC; i++) + C[i] = static_cast(C_double[i]); +} + // hemm template <> void cblas_hemm(hipblasSideMode_t side, diff --git a/clients/gtest/gemm_ex_gtest.cpp b/clients/gtest/gemm_ex_gtest.cpp index 42bf9f04d..1a89a9abe 100644 --- a/clients/gtest/gemm_ex_gtest.cpp +++ b/clients/gtest/gemm_ex_gtest.cpp @@ -33,6 +33,17 @@ typedef std::tuple, vector, vector, vector> int8_matrix_size_range = { + + { 4, 4, 4, 4, 4, 4}, + { 8, 8, 8, 8, 8, 8}, + {12, 12, 12, 12, 12, 12}, + {16, 16, 16, 16, 16, 16}, + {20, 20, 20, 20, 20, 20}, + { 8, 4, 4, 8, 8, 8}, + { 8, 12, 12, 12, 12, 12}, +}; + const vector> small_matrix_size_range = { { 1, 1, 1, 1, 1, 1}, { 1, 2, 3, 4, 5, 6}, @@ -169,6 +180,12 @@ const vector> small_alpha_beta_range = { const vector> full_alpha_beta_range = { {1.0, 0.0}, {-1.0, -1.0}, {2.0, 1.0}, {0.0, 1.0}}; +// For Cuda v < 10.0, only alpha and beta = 1 or = 0 are +// supported. +const vector> alpha_beta_range_int8 = { + {1.0, 1.0}, {1.0, 0.0}, +}; + // vector of vector, each pair is a {transA, transB}; // add/delete this list in pairs, like {'N', 'T'} // for single/double precision, 'C'(conjTranspose) will downgraded to 'T' (transpose) internally in @@ -201,6 +218,25 @@ const vector> precision_double = {{ HIPBLAS_R_64F, HIPBLAS_R_64F, HIPBLAS_R_64F }}; +const vector> precision_single_complex = {{ HIPBLAS_C_32F, + HIPBLAS_C_32F, + HIPBLAS_C_32F, + HIPBLAS_C_32F, + HIPBLAS_C_32F }}; + +const vector> precision_double_complex = {{ HIPBLAS_C_64F, + HIPBLAS_C_64F, + HIPBLAS_C_64F, + HIPBLAS_C_64F, + HIPBLAS_C_64F }}; + +const vector> precision_int8 = {{ HIPBLAS_R_8I, + HIPBLAS_R_8I, + HIPBLAS_R_32I, + HIPBLAS_R_32I, + HIPBLAS_R_32I }}; + + const vector> precision_type_range = {{HIPBLAS_R_16F, HIPBLAS_R_16F, HIPBLAS_R_16F, @@ -220,7 +256,22 @@ const vector> precision_type_range = {{HIPBLAS_R_16F, HIPBLAS_R_64F, HIPBLAS_R_64F, HIPBLAS_R_64F, - HIPBLAS_R_64F}}; + HIPBLAS_R_64F}, + {HIPBLAS_C_32F, + HIPBLAS_C_32F, + HIPBLAS_C_32F, + HIPBLAS_C_32F, + HIPBLAS_C_32F}, + {HIPBLAS_C_64F, + HIPBLAS_C_64F, + HIPBLAS_C_64F, + HIPBLAS_C_64F, + HIPBLAS_C_64F}, + {HIPBLAS_R_8I, + HIPBLAS_R_8I, + HIPBLAS_R_32I, + HIPBLAS_R_32I, + HIPBLAS_R_32I}}; const int batch_count_range[] = { -1, 1, 5 }; const int batch_count_range_small[] = { 1 }; @@ -364,13 +415,19 @@ TEST_P(parameterized_gemm_batched_ex, standard_batched) { EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); } - else + else if(status == HIPBLAS_STATUS_ARCH_MISMATCH) { // Only available in cuda cc >= 5.0. // If we want we can change this to call query_device_property() and // call this only if cc < 5.0 on a CUDA device, else fail. EXPECT_EQ(HIPBLAS_STATUS_ARCH_MISMATCH, status); } + else + { + // cublas/rocblas do not have identical support + // (i.e. cublas doesn't support i8/i32 here) + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); + } } } @@ -400,13 +457,19 @@ TEST_P(parameterized_gemm_batched_ex, standard_strided_batched) { EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); } - else + else if(status == HIPBLAS_STATUS_ARCH_MISMATCH) { // Only available in cuda cc >= 5.0. // If we want we can change this to call query_device_property() and // call this only if cc < 5.0 on a CUDA device, else fail. EXPECT_EQ(HIPBLAS_STATUS_ARCH_MISMATCH, status); } + else + { + // cublas/rocblas do not have identical support + // (i.e. cublas doesn't support i8/i32 here) + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); + } } } @@ -502,6 +565,34 @@ INSTANTIATE_TEST_CASE_P(quick_blas_ex_small_double, ValuesIn(precision_double), ValuesIn(batch_count_range_small), ValuesIn(is_fortran))); + +INSTANTIATE_TEST_CASE_P(quick_blas_ex_small_single_complex, + parameterized_gemm_ex, + Combine(ValuesIn(small_matrix_size_range), + ValuesIn(alpha_beta_range), + ValuesIn(transA_transB_range), + ValuesIn(precision_single_complex), + ValuesIn(batch_count_range_small), + ValuesIn(is_fortran))); + +INSTANTIATE_TEST_CASE_P(quick_blas_ex_small_double_complex, + parameterized_gemm_ex, + Combine(ValuesIn(small_matrix_size_range), + ValuesIn(alpha_beta_range), + ValuesIn(transA_transB_range), + ValuesIn(precision_double_complex), + ValuesIn(batch_count_range_small), + ValuesIn(is_fortran))); + +INSTANTIATE_TEST_CASE_P(quick_blas_ex_small_int8, + parameterized_gemm_ex, + Combine(ValuesIn(int8_matrix_size_range), + ValuesIn(alpha_beta_range_int8), + ValuesIn(transA_transB_range), + ValuesIn(precision_int8), + ValuesIn(batch_count_range_small), + ValuesIn(is_fortran))); + //----medium INSTANTIATE_TEST_CASE_P(pre_checkin_blas_ex_medium_hpa_half, parameterized_gemm_ex, @@ -575,3 +666,30 @@ INSTANTIATE_TEST_CASE_P(quick_blas_batched_ex_small_double, ValuesIn(precision_double), ValuesIn(batch_count_range), ValuesIn(is_fortran))); + +INSTANTIATE_TEST_CASE_P(quick_blas_batched_ex_small_single_complex, + parameterized_gemm_batched_ex, + Combine(ValuesIn(small_matrix_size_range), + ValuesIn(alpha_beta_range), + ValuesIn(transA_transB_range), + ValuesIn(precision_single_complex), + ValuesIn(batch_count_range), + ValuesIn(is_fortran))); + +INSTANTIATE_TEST_CASE_P(quick_blas_batched_ex_small_double_complex, + parameterized_gemm_batched_ex, + Combine(ValuesIn(small_matrix_size_range), + ValuesIn(alpha_beta_range), + ValuesIn(transA_transB_range), + ValuesIn(precision_double_complex), + ValuesIn(batch_count_range), + ValuesIn(is_fortran))); + +INSTANTIATE_TEST_CASE_P(quick_blas_batched_ex_small_int8, + parameterized_gemm_batched_ex, + Combine(ValuesIn(int8_matrix_size_range), + ValuesIn(alpha_beta_range_int8), + ValuesIn(transA_transB_range), + ValuesIn(precision_int8), + ValuesIn(batch_count_range), + ValuesIn(is_fortran))); diff --git a/clients/include/cblas_interface.h b/clients/include/cblas_interface.h index 9382491e6..bb7cd7021 100644 --- a/clients/include/cblas_interface.h +++ b/clients/include/cblas_interface.h @@ -312,19 +312,19 @@ void cblas_geam(hipblasOperation_t transa, int ldc); // gemm -template +template void cblas_gemm(hipblasOperation_t transA, hipblasOperation_t transB, int m, int n, int k, - T alpha, - T* A, + Tc alpha, + Ti* A, int lda, - T* B, + Ti* B, int ldb, - T beta, - T* C, + Tc beta, + To* C, int ldc); // hemm diff --git a/clients/include/testing_gemm_batched.hpp b/clients/include/testing_gemm_batched.hpp index 51b042307..3ee6b8272 100644 --- a/clients/include/testing_gemm_batched.hpp +++ b/clients/include/testing_gemm_batched.hpp @@ -220,12 +220,12 @@ hipblasStatus_t testing_GemmBatched(Arguments argus) N, K, h_alpha, - hA_array[i], + (T*)hA_array[i], lda, - hB_array[i], + (T*)hB_array[i], ldb, h_beta, - hC_copy_array[i], + (T*)hC_copy_array[i], ldc); } diff --git a/clients/include/testing_gemm_batched_ex.hpp b/clients/include/testing_gemm_batched_ex.hpp index b1c696e98..5482a1122 100644 --- a/clients/include/testing_gemm_batched_ex.hpp +++ b/clients/include/testing_gemm_batched_ex.hpp @@ -24,7 +24,7 @@ using namespace std; /* ============================================================================================ */ -template +template hipblasStatus_t testing_gemm_batched_ex_template(hipblasOperation_t transA, hipblasOperation_t transB, int M, @@ -50,46 +50,38 @@ hipblasStatus_t testing_gemm_batched_ex_template(hipblasOperation_t transA, uint32_t solution_index = 0; uint32_t flags = 0; - Td h_alpha_Td; - Td h_beta_Td; + Tex h_alpha_Tc; + Tex h_beta_Tc; - if(is_same::value) + if(is_same::value) { - h_alpha_Td = float_to_half(alpha_float); - h_beta_Td = float_to_half(beta_float); + h_alpha_Tc = float_to_half(alpha_float); + h_beta_Tc = float_to_half(beta_float); } - else if(is_same::value) + else if(is_same::value) { - h_alpha_Td = static_cast(alpha_float); - h_beta_Td = static_cast(beta_float); + h_alpha_Tc = static_cast(alpha_float); + h_beta_Tc = static_cast(beta_float); } - else if(is_same::value) + else if(is_same::value) { - h_alpha_Td = static_cast(alpha_float); - h_beta_Td = static_cast(beta_float); + h_alpha_Tc = static_cast(alpha_float); + h_beta_Tc = static_cast(beta_float); } - else + else if(is_same::value) { - return HIPBLAS_STATUS_NOT_SUPPORTED; + h_alpha_Tc = static_cast(alpha_float); + h_beta_Tc = static_cast(beta_float); } - - Tc h_alpha_Tc; - Tc h_beta_Tc; - - if(is_same::value) + else if(is_same::value) { - h_alpha_Tc = float_to_half(alpha_float); - h_beta_Tc = float_to_half(beta_float); + h_alpha_Tc = static_cast(alpha_float); + h_beta_Tc = static_cast(beta_float); } - else if(is_same::value) + else if(is_same::value) { - h_alpha_Tc = static_cast(alpha_float); - h_beta_Tc = static_cast(beta_float); - } - else if(is_same::value) - { - h_alpha_Tc = static_cast(alpha_float); - h_beta_Tc = static_cast(beta_float); + h_alpha_Tc = static_cast(alpha_float); + h_beta_Tc = static_cast(beta_float); } else { @@ -111,15 +103,15 @@ hipblasStatus_t testing_gemm_batched_ex_template(hipblasOperation_t transA, const size_t size_B = static_cast(ldb) * static_cast(B_col); const size_t size_C = static_cast(ldc) * static_cast(N); - device_vector dA(batch_count); - device_vector dB(batch_count); - device_vector dC(batch_count); - device_vector d_alpha_Tc(1); - device_vector d_beta_Tc(1); + device_vector dA(batch_count); + device_vector dB(batch_count); + device_vector dC(batch_count); + device_vector d_alpha_Tc(1); + device_vector d_beta_Tc(1); - device_batch_vector bA(batch_count, size_A); - device_batch_vector bB(batch_count, size_B); - device_batch_vector bC(batch_count, size_C); + device_batch_vector bA(batch_count, size_A); + device_batch_vector bB(batch_count, size_B); + device_batch_vector bC(batch_count, size_C); int last = batch_count - 1; if(!dA || !dB || !dC || !bA[last] || !bB[last] || !bC[last]) @@ -133,34 +125,61 @@ hipblasStatus_t testing_gemm_batched_ex_template(hipblasOperation_t transA, hipblasCreate(&handle); // Naming: dX is in GPU (device) memory. hK is in CPU (host) memory - host_vector hA[batch_count]; - host_vector hB[batch_count]; - host_vector hC[batch_count]; - host_vector hC_gold[batch_count]; + host_vector hA[batch_count]; + host_vector hB[batch_count]; + host_vector hC[batch_count]; + host_vector hC_gold[batch_count]; // Initial Data on CPU srand(1); for(int b = 0; b < batch_count; b++) { - hA[b] = host_vector(size_A); - hB[b] = host_vector(size_B); - hC[b] = host_vector(size_C); - hC_gold[b] = host_vector(size_C); + hA[b] = host_vector(size_A); + hB[b] = host_vector(size_B); + hC[b] = host_vector(size_C); + hC_gold[b] = host_vector(size_C); - hipblas_init(hA[b], A_row, A_col, lda); - hipblas_init_alternating_sign(hB[b], B_row, B_col, ldb); - hipblas_init(hC[b], M, N, ldc); + hipblas_init(hA[b], A_row, A_col, lda); + hipblas_init_alternating_sign(hB[b], B_row, B_col, ldb); + hipblas_init(hC[b], M, N, ldc); hC_gold[b] = hC[b]; - - CHECK_HIP_ERROR(hipMemcpy(bA[b], hA[b].data(), sizeof(Td) * size_A, hipMemcpyHostToDevice)); - CHECK_HIP_ERROR(hipMemcpy(bB[b], hB[b].data(), sizeof(Td) * size_B, hipMemcpyHostToDevice)); - CHECK_HIP_ERROR(hipMemcpy(bC[b], hC[b].data(), sizeof(Td) * size_C, hipMemcpyHostToDevice)); +#ifdef __HIP_PLATFORM_NVCC__ + CHECK_HIP_ERROR(hipMemcpy(bA[b], hA[b].data(), sizeof(Ta) * size_A, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(bB[b], hB[b].data(), sizeof(Tb) * size_B, hipMemcpyHostToDevice)); +#else + if(std::is_same{} && transA == HIPBLAS_OP_N) + { + vector hA_packed(hA[b]); + hipblas_packInt8(hA_packed, M, K, lda); + CHECK_HIP_ERROR( + hipMemcpy(bA[b], hA_packed.data(), sizeof(Ta) * size_A, hipMemcpyHostToDevice)); + } + else + { + CHECK_HIP_ERROR( + hipMemcpy(bA[b], hA[b].data(), sizeof(Ta) * size_A, hipMemcpyHostToDevice)); + } + + if(std::is_same{} && transB != HIPBLAS_OP_N) + { + vector hB_packed(hB[b]); + hipblas_packInt8(hB_packed, N, K, ldb); + CHECK_HIP_ERROR( + hipMemcpy(bB[b], hB_packed.data(), sizeof(Tb) * size_B, hipMemcpyHostToDevice)); + } + else + { + CHECK_HIP_ERROR( + hipMemcpy(bB[b], hB[b].data(), sizeof(Tb) * size_B, hipMemcpyHostToDevice)); + } +#endif + CHECK_HIP_ERROR(hipMemcpy(bC[b], hC[b].data(), sizeof(Tc) * size_C, hipMemcpyHostToDevice)); } - CHECK_HIP_ERROR(hipMemcpy(dA, bA, sizeof(Td*) * batch_count, hipMemcpyHostToDevice)); - CHECK_HIP_ERROR(hipMemcpy(dB, bB, sizeof(Td*) * batch_count, hipMemcpyHostToDevice)); - CHECK_HIP_ERROR(hipMemcpy(dC, bC, sizeof(Td*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dA, bA, sizeof(Ta*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dB, bB, sizeof(Tb*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dC, bC, sizeof(Tc*) * batch_count, hipMemcpyHostToDevice)); status = hipblasGemmBatchedExFn(handle, transA, @@ -169,14 +188,14 @@ hipblasStatus_t testing_gemm_batched_ex_template(hipblasOperation_t transA, N, K, &h_alpha_Tc, - (const void**)(Td**)dA, + (const void**)(Ta**)dA, a_type, lda, - (const void**)(Td**)dB, + (const void**)(Tb**)dB, b_type, ldb, &h_beta_Tc, - (void**)(Td**)dC, + (void**)(Tc**)dC, c_type, ldc, batch_count, @@ -192,25 +211,25 @@ hipblasStatus_t testing_gemm_batched_ex_template(hipblasOperation_t transA, for(int b = 0; b < batch_count; b++) { - CHECK_HIP_ERROR(hipMemcpy(hC[b].data(), bC[b], sizeof(Td) * size_C, hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(hC[b].data(), bC[b], sizeof(Tc) * size_C, hipMemcpyDeviceToHost)); } // CPU BLAS for(int b = 0; b < batch_count; b++) { - cblas_gemm(transA, - transB, - M, - N, - K, - h_alpha_Td, - hA[b].data(), - lda, - hB[b].data(), - ldb, - h_beta_Td, - hC_gold[b].data(), - ldc); + cblas_gemm(transA, + transB, + M, + N, + K, + h_alpha_Tc, + hA[b].data(), + lda, + hB[b].data(), + ldb, + h_beta_Tc, + hC_gold[b].data(), + ldc); } // enable unit check, notice unit check is not invasive, but norm check is, @@ -218,7 +237,7 @@ hipblasStatus_t testing_gemm_batched_ex_template(hipblasOperation_t transA, if(unit_check) { for(int b = 0; b < batch_count; b++) - unit_check_general(M, N, ldc, hC_gold[b].data(), hC[b].data()); + unit_check_general(M, N, ldc, hC_gold[b].data(), hC[b].data()); } hipblasDestroy(handle); @@ -257,73 +276,96 @@ hipblasStatus_t testing_gemm_batched_ex(Arguments argus) if(a_type == HIPBLAS_R_16F && b_type == HIPBLAS_R_16F && c_type == HIPBLAS_R_16F && c_type == HIPBLAS_R_16F && compute_type == HIPBLAS_R_16F) { - status = testing_gemm_batched_ex_template(transA, - transB, - M, - N, - K, - alpha, - lda, - ldb, - beta, - ldc, - norm_check, - unit_check, - a_type, - b_type, - c_type, - batch_count, - compute_type, - argus.fortran); + status = testing_gemm_batched_ex_template(transA, + transB, + M, + N, + K, + alpha, + lda, + ldb, + beta, + ldc, + norm_check, + unit_check, + a_type, + b_type, + c_type, + batch_count, + compute_type, + argus.fortran); } else if(a_type == HIPBLAS_R_16F && b_type == HIPBLAS_R_16F && c_type == HIPBLAS_R_16F && c_type == HIPBLAS_R_16F && compute_type == HIPBLAS_R_32F) { - status = testing_gemm_batched_ex_template(transA, - transB, - M, - N, - K, - alpha, - lda, - ldb, - beta, - ldc, - norm_check, - unit_check, - a_type, - b_type, - c_type, - batch_count, - compute_type, - argus.fortran); + status = testing_gemm_batched_ex_template( + transA, + transB, + M, + N, + K, + alpha, + lda, + ldb, + beta, + ldc, + norm_check, + unit_check, + a_type, + b_type, + c_type, + batch_count, + compute_type, + argus.fortran); } else if(a_type == HIPBLAS_R_32F && b_type == HIPBLAS_R_32F && c_type == HIPBLAS_R_32F && c_type == HIPBLAS_R_32F && compute_type == HIPBLAS_R_32F) { - status = testing_gemm_batched_ex_template(transA, - transB, - M, - N, - K, - alpha, - lda, - ldb, - beta, - ldc, - norm_check, - unit_check, - a_type, - b_type, - c_type, - batch_count, - compute_type, - argus.fortran); + status = testing_gemm_batched_ex_template(transA, + transB, + M, + N, + K, + alpha, + lda, + ldb, + beta, + ldc, + norm_check, + unit_check, + a_type, + b_type, + c_type, + batch_count, + compute_type, + argus.fortran); } else if(a_type == HIPBLAS_R_64F && b_type == HIPBLAS_R_64F && c_type == HIPBLAS_R_64F && c_type == HIPBLAS_R_64F && compute_type == HIPBLAS_R_64F) { - status = testing_gemm_batched_ex_template(transA, + status = testing_gemm_batched_ex_template(transA, + transB, + M, + N, + K, + alpha, + lda, + ldb, + beta, + ldc, + norm_check, + unit_check, + a_type, + b_type, + c_type, + batch_count, + compute_type, + argus.fortran); + } + else if(a_type == HIPBLAS_C_32F && b_type == HIPBLAS_C_32F && c_type == HIPBLAS_C_32F + && c_type == HIPBLAS_C_32F && compute_type == HIPBLAS_C_32F) + { + status = testing_gemm_batched_ex_template(transA, transB, M, N, @@ -342,6 +384,50 @@ hipblasStatus_t testing_gemm_batched_ex(Arguments argus) compute_type, argus.fortran); } + else if(a_type == HIPBLAS_C_64F && b_type == HIPBLAS_C_64F && c_type == HIPBLAS_C_64F + && c_type == HIPBLAS_C_64F && compute_type == HIPBLAS_C_64F) + { + status = testing_gemm_batched_ex_template(transA, + transB, + M, + N, + K, + alpha, + lda, + ldb, + beta, + ldc, + norm_check, + unit_check, + a_type, + b_type, + c_type, + batch_count, + compute_type, + argus.fortran); + } + else if(a_type == HIPBLAS_R_8I && b_type == HIPBLAS_R_8I && c_type == HIPBLAS_R_32I + && c_type == HIPBLAS_R_32I && compute_type == HIPBLAS_R_32I) + { + status = testing_gemm_batched_ex_template(transA, + transB, + M, + N, + K, + alpha, + lda, + ldb, + beta, + ldc, + norm_check, + unit_check, + a_type, + b_type, + c_type, + batch_count, + compute_type, + argus.fortran); + } else { status = HIPBLAS_STATUS_NOT_SUPPORTED; diff --git a/clients/include/testing_gemm_ex.hpp b/clients/include/testing_gemm_ex.hpp index f9353c7ad..c98848935 100644 --- a/clients/include/testing_gemm_ex.hpp +++ b/clients/include/testing_gemm_ex.hpp @@ -24,7 +24,7 @@ using namespace std; /* ============================================================================================ */ -template +template hipblasStatus_t testing_gemm_ex_template(hipblasOperation_t transA, hipblasOperation_t transB, int M, @@ -51,46 +51,38 @@ hipblasStatus_t testing_gemm_ex_template(hipblasOperation_t transA, size_t* workspace_size = 0; void* workspace = 0; - Td h_alpha_Td; - Td h_beta_Td; + Tex h_alpha_Tc; + Tex h_beta_Tc; - if(is_same::value) + if(is_same::value) { - h_alpha_Td = float_to_half(alpha_float); - h_beta_Td = float_to_half(beta_float); + h_alpha_Tc = float_to_half(alpha_float); + h_beta_Tc = float_to_half(beta_float); } - else if(is_same::value) + else if(is_same::value) { - h_alpha_Td = static_cast(alpha_float); - h_beta_Td = static_cast(beta_float); + h_alpha_Tc = static_cast(alpha_float); + h_beta_Tc = static_cast(beta_float); } - else if(is_same::value) + else if(is_same::value) { - h_alpha_Td = static_cast(alpha_float); - h_beta_Td = static_cast(beta_float); + h_alpha_Tc = static_cast(alpha_float); + h_beta_Tc = static_cast(beta_float); } - else + else if(is_same::value) { - return HIPBLAS_STATUS_NOT_SUPPORTED; + h_alpha_Tc = static_cast(alpha_float); + h_beta_Tc = static_cast(beta_float); } - - Tc h_alpha_Tc; - Tc h_beta_Tc; - - if(is_same::value) + else if(is_same::value) { - h_alpha_Tc = float_to_half(alpha_float); - h_beta_Tc = float_to_half(beta_float); + h_alpha_Tc = static_cast(alpha_float); + h_beta_Tc = static_cast(beta_float); } - else if(is_same::value) + else if(is_same::value) { - h_alpha_Tc = static_cast(alpha_float); - h_beta_Tc = static_cast(beta_float); - } - else if(is_same::value) - { - h_alpha_Tc = static_cast(alpha_float); - h_beta_Tc = static_cast(beta_float); + h_alpha_Tc = static_cast(alpha_float); + h_beta_Tc = static_cast(beta_float); } else { @@ -132,15 +124,17 @@ hipblasStatus_t testing_gemm_ex_template(hipblasOperation_t transA, // Tc* d_alpha_Tc = (Tc*)d_alpha_Tc_managed.get(); // Tc* d_beta_Tc = (Tc*)d_beta_Tc_managed.get(); - Td *dA, *dB, *dC; - Tc *d_alpha_Tc, *d_beta_Tc; + Ta* dA; + Tb* dB; + Tc* dC; + Tex *d_alpha_Tc, *d_beta_Tc; - CHECK_HIP_ERROR(hipMalloc(&dA, size_A * sizeof(Td))); - CHECK_HIP_ERROR(hipMalloc(&dB, size_B * sizeof(Td))); - CHECK_HIP_ERROR(hipMalloc(&dC, size_C * sizeof(Td))); + CHECK_HIP_ERROR(hipMalloc(&dA, size_A * sizeof(Ta))); + CHECK_HIP_ERROR(hipMalloc(&dB, size_B * sizeof(Tb))); + CHECK_HIP_ERROR(hipMalloc(&dC, size_C * sizeof(Tc))); - CHECK_HIP_ERROR(hipMalloc(&d_alpha_Tc, sizeof(Td))); - CHECK_HIP_ERROR(hipMalloc(&d_beta_Tc, sizeof(Td))); + CHECK_HIP_ERROR(hipMalloc(&d_alpha_Tc, sizeof(Tex))); + CHECK_HIP_ERROR(hipMalloc(&d_beta_Tc, sizeof(Tex))); if(!dA || !dB || !dC || !d_alpha_Tc || !d_beta_Tc) { @@ -153,16 +147,16 @@ hipblasStatus_t testing_gemm_ex_template(hipblasOperation_t transA, hipblasCreate(&handle); // Naming: dX is in GPU (device) memory. hK is in CPU (host) memory - vector hA(size_A); - vector hB(size_B); - vector hC(size_C); - vector hC_gold(size_C); + vector hA(size_A); + vector hB(size_B); + vector hC(size_C); + vector hC_gold(size_C); // Initial Data on CPU srand(1); - hipblas_init(hA, A_row, A_col, lda); - hipblas_init_alternating_sign(hB, B_row, B_col, ldb); - hipblas_init(hC, M, N, ldc); + hipblas_init(hA, A_row, A_col, lda); + hipblas_init_alternating_sign(hB, B_row, B_col, ldb); + hipblas_init(hC, M, N, ldc); // if(is_same::value) // { @@ -192,9 +186,37 @@ hipblasStatus_t testing_gemm_ex_template(hipblasOperation_t transA, hC_gold = hC; // copy data from CPU to device - CHECK_HIP_ERROR(hipMemcpy(dA, hA.data(), sizeof(Td) * size_A, hipMemcpyHostToDevice)); - CHECK_HIP_ERROR(hipMemcpy(dB, hB.data(), sizeof(Td) * size_B, hipMemcpyHostToDevice)); - CHECK_HIP_ERROR(hipMemcpy(dC, hC.data(), sizeof(Td) * size_C, hipMemcpyHostToDevice)); + + // CUDA doesn't do packing +#ifdef __HIP_PLATFORM_NVCC__ + CHECK_HIP_ERROR(hipMemcpy(dA, hA.data(), sizeof(Ta) * size_A, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dB, hB.data(), sizeof(Tb) * size_B, hipMemcpyHostToDevice)); +#else + if(std::is_same{} && transA == HIPBLAS_OP_N) + { + vector hA_packed(hA); + hipblas_packInt8(hA_packed, M, K, lda); + CHECK_HIP_ERROR( + hipMemcpy(dA, hA_packed.data(), sizeof(Ta) * size_A, hipMemcpyHostToDevice)); + } + else + { + CHECK_HIP_ERROR(hipMemcpy(dA, hA.data(), sizeof(Ta) * size_A, hipMemcpyHostToDevice)); + } + + if(std::is_same{} && transB != HIPBLAS_OP_N) + { + vector hB_packed(hB); + hipblas_packInt8(hB_packed, N, K, ldb); + CHECK_HIP_ERROR( + hipMemcpy(dB, hB_packed.data(), sizeof(Tb) * size_B, hipMemcpyHostToDevice)); + } + else + { + CHECK_HIP_ERROR(hipMemcpy(dB, hB.data(), sizeof(Tb) * size_B, hipMemcpyHostToDevice)); + } +#endif + CHECK_HIP_ERROR(hipMemcpy(dC, hC.data(), sizeof(Tc) * size_C, hipMemcpyHostToDevice)); status = hipblasGemmExFn(handle, transA, @@ -229,7 +251,7 @@ hipblasStatus_t testing_gemm_ex_template(hipblasOperation_t transA, return status; } - CHECK_HIP_ERROR(hipMemcpy(hC.data(), dC, sizeof(Td) * size_C, hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(hC.data(), dC, sizeof(Tc) * size_C, hipMemcpyDeviceToHost)); // std::cout << std::endl << "-----hD_1---------------------------------------" << // std::endl; @@ -247,19 +269,19 @@ hipblasStatus_t testing_gemm_ex_template(hipblasOperation_t transA, // CPU BLAS - cblas_gemm(transA, - transB, - M, - N, - K, - h_alpha_Td, - hA.data(), - lda, - hB.data(), - ldb, - h_beta_Td, - hC_gold.data(), - ldc); + cblas_gemm(transA, + transB, + M, + N, + K, + h_alpha_Tc, + hA.data(), + lda, + hB.data(), + ldb, + h_beta_Tc, + hC_gold.data(), + ldc); // std::cout << std::endl << "---gold---gold---gold---------------------" << std::endl; // if(is_same::value) @@ -280,7 +302,7 @@ hipblasStatus_t testing_gemm_ex_template(hipblasOperation_t transA, // unit check and norm check can not be interchanged their order if(unit_check) { - unit_check_general(M, N, ldc, hC_gold.data(), hC.data()); + unit_check_general(M, N, ldc, hC_gold.data(), hC.data()); } hipblasDestroy(handle); @@ -323,70 +345,92 @@ hipblasStatus_t testing_gemm_ex(Arguments argus) if(a_type == HIPBLAS_R_16F && b_type == HIPBLAS_R_16F && c_type == HIPBLAS_R_16F && c_type == HIPBLAS_R_16F && compute_type == HIPBLAS_R_16F) { - status = testing_gemm_ex_template(transA, - transB, - M, - N, - K, - alpha, - lda, - ldb, - beta, - ldc, - norm_check, - unit_check, - a_type, - b_type, - c_type, - compute_type, - argus.fortran); + status = testing_gemm_ex_template(transA, + transB, + M, + N, + K, + alpha, + lda, + ldb, + beta, + ldc, + norm_check, + unit_check, + a_type, + b_type, + c_type, + compute_type, + argus.fortran); } else if(a_type == HIPBLAS_R_16F && b_type == HIPBLAS_R_16F && c_type == HIPBLAS_R_16F && c_type == HIPBLAS_R_16F && compute_type == HIPBLAS_R_32F) { - status = testing_gemm_ex_template(transA, - transB, - M, - N, - K, - alpha, - lda, - ldb, - beta, - ldc, - norm_check, - unit_check, - a_type, - b_type, - c_type, - compute_type, - argus.fortran); + status + = testing_gemm_ex_template(transA, + transB, + M, + N, + K, + alpha, + lda, + ldb, + beta, + ldc, + norm_check, + unit_check, + a_type, + b_type, + c_type, + compute_type, + argus.fortran); } else if(a_type == HIPBLAS_R_32F && b_type == HIPBLAS_R_32F && c_type == HIPBLAS_R_32F && c_type == HIPBLAS_R_32F && compute_type == HIPBLAS_R_32F) { - status = testing_gemm_ex_template(transA, - transB, - M, - N, - K, - alpha, - lda, - ldb, - beta, - ldc, - norm_check, - unit_check, - a_type, - b_type, - c_type, - compute_type, - argus.fortran); + status = testing_gemm_ex_template(transA, + transB, + M, + N, + K, + alpha, + lda, + ldb, + beta, + ldc, + norm_check, + unit_check, + a_type, + b_type, + c_type, + compute_type, + argus.fortran); } else if(a_type == HIPBLAS_R_64F && b_type == HIPBLAS_R_64F && c_type == HIPBLAS_R_64F && c_type == HIPBLAS_R_64F && compute_type == HIPBLAS_R_64F) { - status = testing_gemm_ex_template(transA, + status = testing_gemm_ex_template(transA, + transB, + M, + N, + K, + alpha, + lda, + ldb, + beta, + ldc, + norm_check, + unit_check, + a_type, + b_type, + c_type, + compute_type, + argus.fortran); + } + else if(a_type == HIPBLAS_C_32F && b_type == HIPBLAS_C_32F && c_type == HIPBLAS_C_32F + && c_type == HIPBLAS_C_32F && compute_type == HIPBLAS_C_32F) + { + status = testing_gemm_ex_template(transA, transB, M, N, @@ -404,6 +448,48 @@ hipblasStatus_t testing_gemm_ex(Arguments argus) compute_type, argus.fortran); } + else if(a_type == HIPBLAS_C_64F && b_type == HIPBLAS_C_64F && c_type == HIPBLAS_C_64F + && c_type == HIPBLAS_C_64F && compute_type == HIPBLAS_C_64F) + { + status = testing_gemm_ex_template(transA, + transB, + M, + N, + K, + alpha, + lda, + ldb, + beta, + ldc, + norm_check, + unit_check, + a_type, + b_type, + c_type, + compute_type, + argus.fortran); + } + else if(a_type == HIPBLAS_R_8I && b_type == HIPBLAS_R_8I && c_type == HIPBLAS_R_32I + && c_type == HIPBLAS_R_32I && compute_type == HIPBLAS_R_32I) + { + status = testing_gemm_ex_template(transA, + transB, + M, + N, + K, + alpha, + lda, + ldb, + beta, + ldc, + norm_check, + unit_check, + a_type, + b_type, + c_type, + compute_type, + argus.fortran); + } else { status = HIPBLAS_STATUS_NOT_SUPPORTED; diff --git a/clients/include/testing_gemm_strided_batched_ex.hpp b/clients/include/testing_gemm_strided_batched_ex.hpp index 3998c1c7c..467526190 100644 --- a/clients/include/testing_gemm_strided_batched_ex.hpp +++ b/clients/include/testing_gemm_strided_batched_ex.hpp @@ -24,7 +24,7 @@ using namespace std; /* ============================================================================================ */ -template +template hipblasStatus_t testing_gemm_strided_batched_ex_template(hipblasOperation_t transA, hipblasOperation_t transB, int M, @@ -51,46 +51,38 @@ hipblasStatus_t testing_gemm_strided_batched_ex_template(hipblasOperation_t tran uint32_t solution_index = 0; uint32_t flags = 0; - Td h_alpha_Td; - Td h_beta_Td; + Tex h_alpha_Tc; + Tex h_beta_Tc; - if(is_same::value) + if(is_same::value) { - h_alpha_Td = float_to_half(alpha_float); - h_beta_Td = float_to_half(beta_float); + h_alpha_Tc = float_to_half(alpha_float); + h_beta_Tc = float_to_half(beta_float); } - else if(is_same::value) + else if(is_same::value) { - h_alpha_Td = static_cast(alpha_float); - h_beta_Td = static_cast(beta_float); + h_alpha_Tc = static_cast(alpha_float); + h_beta_Tc = static_cast(beta_float); } - else if(is_same::value) + else if(is_same::value) { - h_alpha_Td = static_cast(alpha_float); - h_beta_Td = static_cast(beta_float); + h_alpha_Tc = static_cast(alpha_float); + h_beta_Tc = static_cast(beta_float); } - else - { - return HIPBLAS_STATUS_NOT_SUPPORTED; - } - - Tc h_alpha_Tc; - Tc h_beta_Tc; - - if(is_same::value) + else if(is_same::value) { - h_alpha_Tc = float_to_half(alpha_float); - h_beta_Tc = float_to_half(beta_float); + h_alpha_Tc = static_cast(alpha_float); + h_beta_Tc = static_cast(beta_float); } - else if(is_same::value) + else if(is_same::value) { - h_alpha_Tc = static_cast(alpha_float); - h_beta_Tc = static_cast(beta_float); + h_alpha_Tc = static_cast(alpha_float); + h_beta_Tc = static_cast(beta_float); } - else if(is_same::value) + else if(is_same::value) { - h_alpha_Tc = static_cast(alpha_float); - h_beta_Tc = static_cast(beta_float); + h_alpha_Tc = static_cast(alpha_float); + h_beta_Tc = static_cast(beta_float); } else { @@ -116,9 +108,9 @@ hipblasStatus_t testing_gemm_strided_batched_ex_template(hipblasOperation_t tran const size_t size_B = stride_B * batch_count; const size_t size_C = stride_C * batch_count; - device_vector dA(size_A); - device_vector dB(size_B); - device_vector dC(size_C); + device_vector dA(size_A); + device_vector dB(size_B); + device_vector dC(size_C); device_vector d_alpha_Tc(1); device_vector d_beta_Tc(1); @@ -134,22 +126,48 @@ hipblasStatus_t testing_gemm_strided_batched_ex_template(hipblasOperation_t tran hipblasCreate(&handle); // Naming: dX is in GPU (device) memory. hK is in CPU (host) memory - host_vector hA(size_A); - host_vector hB(size_B); - host_vector hC(size_C); - host_vector hC_gold(size_C); + host_vector hA(size_A); + host_vector hB(size_B); + host_vector hC(size_C); + host_vector hC_gold(size_C); // Initial Data on CPU srand(1); - hipblas_init(hA, A_row, A_col, lda, stride_A, batch_count); - hipblas_init_alternating_sign(hB, B_row, B_col, ldb, stride_B, batch_count); - hipblas_init(hC, M, N, ldc, stride_C, batch_count); + hipblas_init(hA, A_row, A_col, lda, stride_A, batch_count); + hipblas_init_alternating_sign(hB, B_row, B_col, ldb, stride_B, batch_count); + hipblas_init(hC, M, N, ldc, stride_C, batch_count); hC_gold = hC; // copy data from CPU to device - CHECK_HIP_ERROR(hipMemcpy(dA, hA.data(), sizeof(Td) * size_A, hipMemcpyHostToDevice)); - CHECK_HIP_ERROR(hipMemcpy(dB, hB.data(), sizeof(Td) * size_B, hipMemcpyHostToDevice)); - CHECK_HIP_ERROR(hipMemcpy(dC, hC.data(), sizeof(Td) * size_C, hipMemcpyHostToDevice)); +#ifdef __HIP_PLATFORM_NVCC__ + CHECK_HIP_ERROR(hipMemcpy(dA, hA.data(), sizeof(Ta) * size_A, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dB, hB.data(), sizeof(Tb) * size_B, hipMemcpyHostToDevice)); +#else + if(std::is_same{} && transA == HIPBLAS_OP_N) + { + vector hA_packed(hA); + hipblas_packInt8(hA_packed, M, K, lda, batch_count, stride_A); + CHECK_HIP_ERROR( + hipMemcpy(dA, hA_packed.data(), sizeof(Ta) * size_A, hipMemcpyHostToDevice)); + } + else + { + CHECK_HIP_ERROR(hipMemcpy(dA, hA.data(), sizeof(Ta) * size_A, hipMemcpyHostToDevice)); + } + + if(std::is_same{} && transB != HIPBLAS_OP_N) + { + vector hB_packed(hB); + hipblas_packInt8(hB_packed, N, K, ldb, batch_count, stride_B); + CHECK_HIP_ERROR( + hipMemcpy(dB, hB_packed.data(), sizeof(Tb) * size_B, hipMemcpyHostToDevice)); + } + else + { + CHECK_HIP_ERROR(hipMemcpy(dB, hB.data(), sizeof(Tb) * size_B, hipMemcpyHostToDevice)); + } +#endif + CHECK_HIP_ERROR(hipMemcpy(dC, hC.data(), sizeof(Tc) * size_C, hipMemcpyHostToDevice)); status = hipblasGemmStridedBatchedExFn(handle, transA, @@ -181,24 +199,24 @@ hipblasStatus_t testing_gemm_strided_batched_ex_template(hipblasOperation_t tran return status; } - CHECK_HIP_ERROR(hipMemcpy(hC.data(), dC, sizeof(Td) * size_C, hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(hC.data(), dC, sizeof(Tc) * size_C, hipMemcpyDeviceToHost)); // CPU BLAS for(int b = 0; b < batch_count; b++) { - cblas_gemm(transA, - transB, - M, - N, - K, - h_alpha_Td, - hA.data() + b * stride_A, - lda, - hB.data() + b * stride_B, - ldb, - h_beta_Td, - hC_gold.data() + b * stride_C, - ldc); + cblas_gemm(transA, + transB, + M, + N, + K, + h_alpha_Tc, + hA.data() + b * stride_A, + lda, + hB.data() + b * stride_B, + ldb, + h_beta_Tc, + hC_gold.data() + b * stride_C, + ldc); } // enable unit check, notice unit check is not invasive, but norm check is, @@ -207,7 +225,7 @@ hipblasStatus_t testing_gemm_strided_batched_ex_template(hipblasOperation_t tran { for(int b = 0; b < batch_count; b++) { - unit_check_general( + unit_check_general( M, N, ldc, hC_gold.data() + b * stride_C, hC.data() + b * stride_C); } } @@ -247,73 +265,98 @@ hipblasStatus_t testing_gemm_strided_batched_ex(Arguments argus) if(a_type == HIPBLAS_R_16F && b_type == HIPBLAS_R_16F && c_type == HIPBLAS_R_16F && c_type == HIPBLAS_R_16F && compute_type == HIPBLAS_R_16F) { - status = testing_gemm_strided_batched_ex_template(transA, - transB, - M, - N, - K, - alpha, - lda, - ldb, - beta, - ldc, - norm_check, - unit_check, - a_type, - b_type, - c_type, - batch_count, - compute_type, - argus.fortran); + status = testing_gemm_strided_batched_ex_template(transA, + transB, + M, + N, + K, + alpha, + lda, + ldb, + beta, + ldc, + norm_check, + unit_check, + a_type, + b_type, + c_type, + batch_count, + compute_type, + argus.fortran); } else if(a_type == HIPBLAS_R_16F && b_type == HIPBLAS_R_16F && c_type == HIPBLAS_R_16F && c_type == HIPBLAS_R_16F && compute_type == HIPBLAS_R_32F) { - status = testing_gemm_strided_batched_ex_template(transA, - transB, - M, - N, - K, - alpha, - lda, - ldb, - beta, - ldc, - norm_check, - unit_check, - a_type, - b_type, - c_type, - batch_count, - compute_type, - argus.fortran); + status = testing_gemm_strided_batched_ex_template(transA, + transB, + M, + N, + K, + alpha, + lda, + ldb, + beta, + ldc, + norm_check, + unit_check, + a_type, + b_type, + c_type, + batch_count, + compute_type, + argus.fortran); } else if(a_type == HIPBLAS_R_32F && b_type == HIPBLAS_R_32F && c_type == HIPBLAS_R_32F && c_type == HIPBLAS_R_32F && compute_type == HIPBLAS_R_32F) { - status = testing_gemm_strided_batched_ex_template(transA, - transB, - M, - N, - K, - alpha, - lda, - ldb, - beta, - ldc, - norm_check, - unit_check, - a_type, - b_type, - c_type, - batch_count, - compute_type, - argus.fortran); + status = testing_gemm_strided_batched_ex_template(transA, + transB, + M, + N, + K, + alpha, + lda, + ldb, + beta, + ldc, + norm_check, + unit_check, + a_type, + b_type, + c_type, + batch_count, + compute_type, + argus.fortran); } else if(a_type == HIPBLAS_R_64F && b_type == HIPBLAS_R_64F && c_type == HIPBLAS_R_64F && c_type == HIPBLAS_R_64F && compute_type == HIPBLAS_R_64F) { - status = testing_gemm_strided_batched_ex_template(transA, + status = testing_gemm_strided_batched_ex_template(transA, + transB, + M, + N, + K, + alpha, + lda, + ldb, + beta, + ldc, + norm_check, + unit_check, + a_type, + b_type, + c_type, + batch_count, + compute_type, + argus.fortran); + } + else if(a_type == HIPBLAS_C_32F && b_type == HIPBLAS_C_32F && c_type == HIPBLAS_C_32F + && c_type == HIPBLAS_C_32F && compute_type == HIPBLAS_C_32F) + { + status = testing_gemm_strided_batched_ex_template(transA, transB, M, N, @@ -332,6 +375,51 @@ hipblasStatus_t testing_gemm_strided_batched_ex(Arguments argus) compute_type, argus.fortran); } + else if(a_type == HIPBLAS_C_64F && b_type == HIPBLAS_C_64F && c_type == HIPBLAS_C_64F + && c_type == HIPBLAS_C_64F && compute_type == HIPBLAS_C_64F) + { + status = testing_gemm_strided_batched_ex_template(transA, + transB, + M, + N, + K, + alpha, + lda, + ldb, + beta, + ldc, + norm_check, + unit_check, + a_type, + b_type, + c_type, + batch_count, + compute_type, + argus.fortran); + } + else if(a_type == HIPBLAS_R_8I && b_type == HIPBLAS_R_8I && c_type == HIPBLAS_R_32I + && c_type == HIPBLAS_R_32I && compute_type == HIPBLAS_R_32I) + { + status = testing_gemm_strided_batched_ex_template( + transA, + transB, + M, + N, + K, + alpha, + lda, + ldb, + beta, + ldc, + norm_check, + unit_check, + a_type, + b_type, + c_type, + batch_count, + compute_type, + argus.fortran); + } else { status = HIPBLAS_STATUS_NOT_SUPPORTED; diff --git a/clients/include/testing_getrs.hpp b/clients/include/testing_getrs.hpp index 82949f6a6..2ca73b8b0 100644 --- a/clients/include/testing_getrs.hpp +++ b/clients/include/testing_getrs.hpp @@ -78,7 +78,7 @@ hipblasStatus_t testing_getrs(Arguments argus) // Calculate hB = hA*hX; hipblasOperation_t op = HIPBLAS_OP_N; - cblas_gemm(op, op, N, 1, N, 1, hA.data(), lda, hX.data(), ldb, 0, hB.data(), ldb); + cblas_gemm(op, op, N, 1, N, (T)1, hA.data(), lda, hX.data(), ldb, (T)0, hB.data(), ldb); // LU factorize hA on the CPU info = cblas_getrf(N, N, hA.data(), lda, hIpiv.data()); diff --git a/clients/include/testing_getrs_batched.hpp b/clients/include/testing_getrs_batched.hpp index 8e09b3b90..7085e0c67 100644 --- a/clients/include/testing_getrs_batched.hpp +++ b/clients/include/testing_getrs_batched.hpp @@ -96,7 +96,7 @@ hipblasStatus_t testing_getrs_batched(Arguments argus) // Calculate hB = hA*hX; cblas_gemm( - op, op, N, 1, N, 1, hA[b].data(), lda, hX[b].data(), ldb, 0, hB[b].data(), ldb); + op, op, N, 1, N, (T)1, hA[b].data(), lda, hX[b].data(), ldb, (T)0, hB[b].data(), ldb); // LU factorize hA on the CPU info = cblas_getrf(N, N, hA[b].data(), lda, hIpiv.data() + b * strideP); diff --git a/clients/include/testing_getrs_strided_batched.hpp b/clients/include/testing_getrs_strided_batched.hpp index 539aab084..c1625571e 100644 --- a/clients/include/testing_getrs_strided_batched.hpp +++ b/clients/include/testing_getrs_strided_batched.hpp @@ -95,7 +95,7 @@ hipblasStatus_t testing_getrs_strided_batched(Arguments argus) } // Calculate hB = hA*hX; - cblas_gemm(op, op, N, 1, N, 1, hAb, lda, hXb, ldb, 0, hBb, ldb); + cblas_gemm(op, op, N, 1, N, (T)1, hAb, lda, hXb, ldb, (T)0, hBb, ldb); // LU factorize hA on the CPU info = cblas_getrf(N, N, hAb, lda, hIpivb); diff --git a/clients/include/testing_tpsv.hpp b/clients/include/testing_tpsv.hpp index 8254617fe..81f4959f9 100644 --- a/clients/include/testing_tpsv.hpp +++ b/clients/include/testing_tpsv.hpp @@ -72,8 +72,19 @@ hipblasStatus_t testing_tpsv(Arguments argus) hipblas_init(hA, N, N, 1); // calculate AAT = hA * hA ^ T - cblas_gemm( - HIPBLAS_OP_N, HIPBLAS_OP_T, N, N, N, 1.0, hA.data(), N, hA.data(), N, 0.0, AAT.data(), N); + cblas_gemm(HIPBLAS_OP_N, + HIPBLAS_OP_T, + N, + N, + N, + (T)1.0, + hA.data(), + N, + hA.data(), + N, + (T)0.0, + AAT.data(), + N); // copy AAT into hA, make hA strictly diagonal dominant, and therefore SPD for(int i = 0; i < N; i++) diff --git a/clients/include/testing_tpsv_batched.hpp b/clients/include/testing_tpsv_batched.hpp index 735c4d9d8..682a6b73c 100644 --- a/clients/include/testing_tpsv_batched.hpp +++ b/clients/include/testing_tpsv_batched.hpp @@ -99,7 +99,19 @@ hipblasStatus_t testing_tpsv_batched(Arguments argus) hipblas_init(hA[b], N, N, N); // calculate AAT = hA * hA ^ T - cblas_gemm(HIPBLAS_OP_N, HIPBLAS_OP_T, N, N, N, 1.0, hA[b], N, hA[b], N, 0.0, AAT[b], N); + cblas_gemm(HIPBLAS_OP_N, + HIPBLAS_OP_T, + N, + N, + N, + (T)1.0, + (T*)hA[b], + N, + (T*)hA[b], + N, + (T)0.0, + (T*)AAT[b], + N); // copy AAT into hA, make hA strictly diagonal dominant, and therefore SPD for(int i = 0; i < N; i++) diff --git a/clients/include/testing_tpsv_strided_batched.hpp b/clients/include/testing_tpsv_strided_batched.hpp index 157de0d8d..9a6add631 100644 --- a/clients/include/testing_tpsv_strided_batched.hpp +++ b/clients/include/testing_tpsv_strided_batched.hpp @@ -87,7 +87,7 @@ hipblasStatus_t testing_tpsv_strided_batched(Arguments argus) T* AATb = AAT.data() + b * strideA; T* hbb = hb.data() + b * stridex; // calculate AAT = hA * hA ^ T - cblas_gemm(HIPBLAS_OP_N, HIPBLAS_OP_T, N, N, N, 1.0, hAb, N, hAb, N, 0.0, AATb, N); + cblas_gemm(HIPBLAS_OP_N, HIPBLAS_OP_T, N, N, N, (T)1.0, hAb, N, hAb, N, (T)0.0, AATb, N); // copy AAT into hA, make hA strictly diagonal dominant, and therefore SPD for(int i = 0; i < N; i++) diff --git a/clients/include/testing_trsv.hpp b/clients/include/testing_trsv.hpp index 09c4daf5b..986982682 100644 --- a/clients/include/testing_trsv.hpp +++ b/clients/include/testing_trsv.hpp @@ -87,12 +87,12 @@ hipblasStatus_t testing_trsv(Arguments argus) M, M, M, - 1.0, + (T)1.0, hA.data(), lda, hA.data(), lda, - 0.0, + (T)0.0, AAT.data(), lda); diff --git a/clients/include/testing_trsv_batched.hpp b/clients/include/testing_trsv_batched.hpp index 023809417..dabca1d31 100644 --- a/clients/include/testing_trsv_batched.hpp +++ b/clients/include/testing_trsv_batched.hpp @@ -97,8 +97,19 @@ hipblasStatus_t testing_trsv_batched(Arguments argus) hipblas_init(hA[b], M, M, lda); // calculate AAT = hA * hA ^ T - cblas_gemm( - HIPBLAS_OP_N, HIPBLAS_OP_T, M, M, M, 1.0, hA[b], lda, hA[b], lda, 0.0, AAT[b], lda); + cblas_gemm(HIPBLAS_OP_N, + HIPBLAS_OP_T, + M, + M, + M, + (T)1.0, + (T*)hA[b], + lda, + (T*)hA[b], + lda, + (T)0.0, + (T*)AAT[b], + lda); // copy AAT into hA, make hA strictly diagonal dominant, and therefore SPD for(int i = 0; i < M; i++) diff --git a/clients/include/testing_trsv_strided_batched.hpp b/clients/include/testing_trsv_strided_batched.hpp index b11d5b2a2..dbe15ea57 100644 --- a/clients/include/testing_trsv_strided_batched.hpp +++ b/clients/include/testing_trsv_strided_batched.hpp @@ -87,7 +87,8 @@ hipblasStatus_t testing_trsv_strided_batched(Arguments argus) T* AATb = AAT.data() + b * strideA; T* hbb = hb.data() + b * stridex; // calculate AAT = hA * hA ^ T - cblas_gemm(HIPBLAS_OP_N, HIPBLAS_OP_T, M, M, M, 1.0, hAb, lda, hAb, lda, 0.0, AATb, lda); + cblas_gemm( + HIPBLAS_OP_N, HIPBLAS_OP_T, M, M, M, (T)1.0, hAb, lda, hAb, lda, (T)0.0, AATb, lda); // copy AAT into hA, make hA strictly diagonal dominant, and therefore SPD for(int i = 0; i < M; i++) diff --git a/clients/include/utility.h b/clients/include/utility.h index a605ecf35..f0ecc437f 100644 --- a/clients/include/utility.h +++ b/clients/include/utility.h @@ -328,6 +328,23 @@ inline hipblasDoubleComplex random_generator_negative() /* ============================================================================================ */ +/* ============================================================================================ */ +/*! \brief Packs strided_batched matricies into groups of 4 in N */ +template +void hipblas_packInt8( + std::vector& A, size_t M, size_t N, size_t lda, size_t batch_count = 1, size_t stride_a = 0) +{ + std::vector temp(A); + for(size_t b = 0; b < batch_count; b++) + for(size_t colBase = 0; colBase < N; colBase += 4) + for(size_t row = 0; row < lda; row++) + for(size_t colOffset = 0; colOffset < 4; colOffset++) + A[(colBase * lda + 4 * row) + colOffset + (stride_a * b)] + = temp[(colBase + colOffset) * lda + row + (stride_a * b)]; +} + +/* ============================================================================================ */ + /* ============================================================================================ */ /*! \brief matrix/vector initialization: */ // for vector x (M=1, N=lengthX, lda=incx); diff --git a/library/src/hcc_detail/hipblas.cpp b/library/src/hcc_detail/hipblas.cpp index 6fe3b883b..945035cd6 100644 --- a/library/src/hcc_detail/hipblas.cpp +++ b/library/src/hcc_detail/hipblas.cpp @@ -199,6 +199,12 @@ rocblas_datatype HIPDatatypeToRocblasDatatype(hipblasDatatype_t type) case HIPBLAS_R_64F: return rocblas_datatype_f64_r; + case HIPBLAS_R_8I: + return rocblas_datatype_i8_r; + + case HIPBLAS_R_32I: + return rocblas_datatype_i32_r; + case HIPBLAS_C_16F: return rocblas_datatype_f16_c; @@ -224,6 +230,12 @@ hipblasDatatype_t RocblasDatatypeToHIPDatatype(rocblas_datatype type) case rocblas_datatype_f64_r: return HIPBLAS_R_64F; + case rocblas_datatype_i8_r: + return HIPBLAS_R_8I; + + case rocblas_datatype_i32_r: + return HIPBLAS_R_32I; + case rocblas_datatype_f16_c: return HIPBLAS_C_16F; diff --git a/library/src/nvcc_detail/hipblas.cpp b/library/src/nvcc_detail/hipblas.cpp index e3dfd58f0..4832358ab 100644 --- a/library/src/nvcc_detail/hipblas.cpp +++ b/library/src/nvcc_detail/hipblas.cpp @@ -169,6 +169,12 @@ cudaDataType_t HIPDatatypeToCudaDatatype(hipblasDatatype_t type) case HIPBLAS_R_64F: return CUDA_R_64F; + case HIPBLAS_R_8I: + return CUDA_R_8I; + + case HIPBLAS_R_32I: + return CUDA_R_32I; + case HIPBLAS_C_16F: return CUDA_C_16F;