From 6d7982024afaf227a78b6dae225bcbb4e935b72a Mon Sep 17 00:00:00 2001 From: daineAMD Date: Wed, 3 Jun 2020 09:40:39 -0600 Subject: [PATCH 1/9] Version for develop branch release 3.6. --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index dca786e4a..8f01e0d76 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -45,7 +45,7 @@ include( ROCMInstallTargets ) include( ROCMPackageConfigHelpers ) include( ROCMInstallSymlinks ) -set ( VERSION_STRING "0.30.0" ) +set ( VERSION_STRING "0.31.0" ) rocm_setup_version( VERSION ${VERSION_STRING} ) # Append our library helper cmake path and the cmake path for hip (for convenience) From a44f42169b29fb913f6bf971f382cb603cc16b07 Mon Sep 17 00:00:00 2001 From: tfalders <58866654+tfalders@users.noreply.github.com> Date: Mon, 8 Jun 2020 11:23:54 -0600 Subject: [PATCH 2/9] rocSOLVER support 2 (#222) * Removed USE_DEVICE_MODE macro and added complex versions of geqrf * Call batched geqrf through geqrf_ptr_batched * Added test cases for complex functions * Preventing singularities in test matrices --- clients/common/utility.cpp | 24 + clients/gtest/geqrf_batched_gtest.cpp | 48 +- clients/gtest/geqrf_gtest.cpp | 48 +- clients/gtest/geqrf_strided_batched_gtest.cpp | 48 +- clients/gtest/getrf_batched_gtest.cpp | 48 +- clients/gtest/getrf_gtest.cpp | 48 +- clients/gtest/getrf_strided_batched_gtest.cpp | 48 +- clients/gtest/getrs_batched_gtest.cpp | 48 +- clients/gtest/getrs_gtest.cpp | 48 +- clients/gtest/getrs_strided_batched_gtest.cpp | 48 +- clients/include/testing_geqrf.hpp | 22 +- clients/include/testing_geqrf_batched.hpp | 22 +- .../include/testing_geqrf_strided_batched.hpp | 22 +- clients/include/testing_getrf.hpp | 18 +- clients/include/testing_getrf_batched.hpp | 18 +- .../include/testing_getrf_strided_batched.hpp | 18 +- clients/include/testing_getrs.hpp | 14 +- clients/include/testing_getrs_batched.hpp | 14 +- .../include/testing_getrs_strided_batched.hpp | 14 +- clients/include/utility.h | 6 + library/src/hcc_detail/hipblas.cpp | 721 +++++++----------- 21 files changed, 822 insertions(+), 523 deletions(-) diff --git a/clients/common/utility.cpp b/clients/common/utility.cpp index a82d35fff..d6b2a768f 100644 --- a/clients/common/utility.cpp +++ b/clients/common/utility.cpp @@ -32,6 +32,30 @@ char type2char() // return 'z'; // } +template <> +int type2int(float val) +{ + return (int)val; +} + +template <> +int type2int(double val) +{ + return (int)val; +} + +template <> +int type2int(hipblasComplex val) +{ + return (int)val.real(); +} + +template <> +int type2int(hipblasDoubleComplex val) +{ + return (int)val.real(); +} + #ifdef __cplusplus extern "C" { #endif diff --git a/clients/gtest/geqrf_batched_gtest.cpp b/clients/gtest/geqrf_batched_gtest.cpp index 79885a6fb..0880d0854 100644 --- a/clients/gtest/geqrf_batched_gtest.cpp +++ b/clients/gtest/geqrf_batched_gtest.cpp @@ -60,7 +60,7 @@ TEST_P(geqrf_batched_gtest, geqrf_batched_gtest_float) Arguments arg = setup_geqrf_batched_arguments(GetParam()); - hipblasStatus_t status = testing_geqrf_batched(arg); + hipblasStatus_t status = testing_geqrf_batched(arg); if(status != HIPBLAS_STATUS_SUCCESS) { @@ -82,7 +82,51 @@ TEST_P(geqrf_batched_gtest, geqrf_batched_gtest_double) Arguments arg = setup_geqrf_batched_arguments(GetParam()); - hipblasStatus_t status = testing_geqrf_batched(arg); + hipblasStatus_t status = testing_geqrf_batched(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.M < 0 || arg.N < 0 || arg.lda < arg.M || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_SUCCESS, status); + } + } +} + +TEST_P(geqrf_batched_gtest, geqrf_batched_gtest_float_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_geqrf_batched_arguments(GetParam()); + + hipblasStatus_t status = testing_geqrf_batched(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.M < 0 || arg.N < 0 || arg.lda < arg.M || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_SUCCESS, status); + } + } +} + +TEST_P(geqrf_batched_gtest, geqrf_batched_gtest_double_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_geqrf_batched_arguments(GetParam()); + + hipblasStatus_t status = testing_geqrf_batched(arg); if(status != HIPBLAS_STATUS_SUCCESS) { diff --git a/clients/gtest/geqrf_gtest.cpp b/clients/gtest/geqrf_gtest.cpp index 4b71597f4..e794aadf3 100644 --- a/clients/gtest/geqrf_gtest.cpp +++ b/clients/gtest/geqrf_gtest.cpp @@ -60,7 +60,7 @@ TEST_P(geqrf_gtest, geqrf_gtest_float) Arguments arg = setup_geqrf_arguments(GetParam()); - hipblasStatus_t status = testing_geqrf(arg); + hipblasStatus_t status = testing_geqrf(arg); if(status != HIPBLAS_STATUS_SUCCESS) { @@ -82,7 +82,51 @@ TEST_P(geqrf_gtest, geqrf_gtest_double) Arguments arg = setup_geqrf_arguments(GetParam()); - hipblasStatus_t status = testing_geqrf(arg); + hipblasStatus_t status = testing_geqrf(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.M < 0 || arg.N < 0 || arg.lda < arg.M) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); // for cuda + } + } +} + +TEST_P(geqrf_gtest, geqrf_gtest_float_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_geqrf_arguments(GetParam()); + + hipblasStatus_t status = testing_geqrf(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.M < 0 || arg.N < 0 || arg.lda < arg.M) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); // for cuda + } + } +} + +TEST_P(geqrf_gtest, geqrf_gtest_double_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_geqrf_arguments(GetParam()); + + hipblasStatus_t status = testing_geqrf(arg); if(status != HIPBLAS_STATUS_SUCCESS) { diff --git a/clients/gtest/geqrf_strided_batched_gtest.cpp b/clients/gtest/geqrf_strided_batched_gtest.cpp index 2985af85a..8dfe1c0d4 100644 --- a/clients/gtest/geqrf_strided_batched_gtest.cpp +++ b/clients/gtest/geqrf_strided_batched_gtest.cpp @@ -60,7 +60,7 @@ TEST_P(geqrf_strided_batched_gtest, geqrf_strided_batched_gtest_float) Arguments arg = setup_geqrf_strided_batched_arguments(GetParam()); - hipblasStatus_t status = testing_geqrf_strided_batched(arg); + hipblasStatus_t status = testing_geqrf_strided_batched(arg); if(status != HIPBLAS_STATUS_SUCCESS) { @@ -82,7 +82,51 @@ TEST_P(geqrf_strided_batched_gtest, geqrf_strided_batched_gtest_double) Arguments arg = setup_geqrf_strided_batched_arguments(GetParam()); - hipblasStatus_t status = testing_geqrf_strided_batched(arg); + hipblasStatus_t status = testing_geqrf_strided_batched(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.M < 0 || arg.N < 0 || arg.lda < arg.M || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); // for cuda + } + } +} + +TEST_P(geqrf_strided_batched_gtest, geqrf_strided_batched_gtest_float_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_geqrf_strided_batched_arguments(GetParam()); + + hipblasStatus_t status = testing_geqrf_strided_batched(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.M < 0 || arg.N < 0 || arg.lda < arg.M || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); // for cuda + } + } +} + +TEST_P(geqrf_strided_batched_gtest, geqrf_strided_batched_gtest_double_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_geqrf_strided_batched_arguments(GetParam()); + + hipblasStatus_t status = testing_geqrf_strided_batched(arg); if(status != HIPBLAS_STATUS_SUCCESS) { diff --git a/clients/gtest/getrf_batched_gtest.cpp b/clients/gtest/getrf_batched_gtest.cpp index a386aee4f..3e844e9e5 100644 --- a/clients/gtest/getrf_batched_gtest.cpp +++ b/clients/gtest/getrf_batched_gtest.cpp @@ -63,7 +63,7 @@ TEST_P(getrf_batched_gtest, getrf_batched_gtest_float) Arguments arg = setup_getrf_batched_arguments(GetParam()); - hipblasStatus_t status = testing_getrf_batched(arg); + hipblasStatus_t status = testing_getrf_batched(arg); if(status != HIPBLAS_STATUS_SUCCESS) { @@ -85,7 +85,51 @@ TEST_P(getrf_batched_gtest, getrf_batched_gtest_double) Arguments arg = setup_getrf_batched_arguments(GetParam()); - hipblasStatus_t status = testing_getrf_batched(arg); + hipblasStatus_t status = testing_getrf_batched(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.N < 0 || arg.lda < arg.N || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_SUCCESS, status); + } + } +} + +TEST_P(getrf_batched_gtest, getrf_batched_gtest_float_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_getrf_batched_arguments(GetParam()); + + hipblasStatus_t status = testing_getrf_batched(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.N < 0 || arg.lda < arg.N || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_SUCCESS, status); + } + } +} + +TEST_P(getrf_batched_gtest, getrf_batched_gtest_double_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_getrf_batched_arguments(GetParam()); + + hipblasStatus_t status = testing_getrf_batched(arg); if(status != HIPBLAS_STATUS_SUCCESS) { diff --git a/clients/gtest/getrf_gtest.cpp b/clients/gtest/getrf_gtest.cpp index 701a904bb..3957815cd 100644 --- a/clients/gtest/getrf_gtest.cpp +++ b/clients/gtest/getrf_gtest.cpp @@ -63,7 +63,7 @@ TEST_P(getrf_gtest, getrf_gtest_float) Arguments arg = setup_getrf_arguments(GetParam()); - hipblasStatus_t status = testing_getrf(arg); + hipblasStatus_t status = testing_getrf(arg); if(status != HIPBLAS_STATUS_SUCCESS) { @@ -85,7 +85,51 @@ TEST_P(getrf_gtest, getrf_gtest_double) Arguments arg = setup_getrf_arguments(GetParam()); - hipblasStatus_t status = testing_getrf(arg); + hipblasStatus_t status = testing_getrf(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.N < 0 || arg.lda < arg.N) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); // for cuda + } + } +} + +TEST_P(getrf_gtest, getrf_gtest_float_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_getrf_arguments(GetParam()); + + hipblasStatus_t status = testing_getrf(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.N < 0 || arg.lda < arg.N) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); // for cuda + } + } +} + +TEST_P(getrf_gtest, getrf_gtest_double_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_getrf_arguments(GetParam()); + + hipblasStatus_t status = testing_getrf(arg); if(status != HIPBLAS_STATUS_SUCCESS) { diff --git a/clients/gtest/getrf_strided_batched_gtest.cpp b/clients/gtest/getrf_strided_batched_gtest.cpp index e5129c43c..2cda93f97 100644 --- a/clients/gtest/getrf_strided_batched_gtest.cpp +++ b/clients/gtest/getrf_strided_batched_gtest.cpp @@ -63,7 +63,7 @@ TEST_P(getrf_strided_batched_gtest, getrf_strided_batched_gtest_float) Arguments arg = setup_getrf_strided_batched_arguments(GetParam()); - hipblasStatus_t status = testing_getrf_strided_batched(arg); + hipblasStatus_t status = testing_getrf_strided_batched(arg); if(status != HIPBLAS_STATUS_SUCCESS) { @@ -85,7 +85,51 @@ TEST_P(getrf_strided_batched_gtest, getrf_strided_batched_gtest_double) Arguments arg = setup_getrf_strided_batched_arguments(GetParam()); - hipblasStatus_t status = testing_getrf_strided_batched(arg); + hipblasStatus_t status = testing_getrf_strided_batched(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.N < 0 || arg.lda < arg.N || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); // for cuda + } + } +} + +TEST_P(getrf_strided_batched_gtest, getrf_strided_batched_gtest_float_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_getrf_strided_batched_arguments(GetParam()); + + hipblasStatus_t status = testing_getrf_strided_batched(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.N < 0 || arg.lda < arg.N || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); // for cuda + } + } +} + +TEST_P(getrf_strided_batched_gtest, getrf_strided_batched_gtest_double_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_getrf_strided_batched_arguments(GetParam()); + + hipblasStatus_t status = testing_getrf_strided_batched(arg); if(status != HIPBLAS_STATUS_SUCCESS) { diff --git a/clients/gtest/getrs_batched_gtest.cpp b/clients/gtest/getrs_batched_gtest.cpp index 6f79e55d8..c6da21065 100644 --- a/clients/gtest/getrs_batched_gtest.cpp +++ b/clients/gtest/getrs_batched_gtest.cpp @@ -59,7 +59,7 @@ TEST_P(getrs_batched_gtest, getrs_batched_gtest_float) Arguments arg = setup_getrs_batched_arguments(GetParam()); - hipblasStatus_t status = testing_getrs_batched(arg); + hipblasStatus_t status = testing_getrs_batched(arg); if(status != HIPBLAS_STATUS_SUCCESS) { @@ -81,7 +81,51 @@ TEST_P(getrs_batched_gtest, getrs_batched_gtest_double) Arguments arg = setup_getrs_batched_arguments(GetParam()); - hipblasStatus_t status = testing_getrs_batched(arg); + hipblasStatus_t status = testing_getrs_batched(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.N < 0 || arg.lda < arg.N || arg.ldb < arg.N || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_SUCCESS, status); + } + } +} + +TEST_P(getrs_batched_gtest, getrs_batched_gtest_float_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_getrs_batched_arguments(GetParam()); + + hipblasStatus_t status = testing_getrs_batched(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.N < 0 || arg.lda < arg.N || arg.ldb < arg.N || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_SUCCESS, status); + } + } +} + +TEST_P(getrs_batched_gtest, getrs_batched_gtest_double_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_getrs_batched_arguments(GetParam()); + + hipblasStatus_t status = testing_getrs_batched(arg); if(status != HIPBLAS_STATUS_SUCCESS) { diff --git a/clients/gtest/getrs_gtest.cpp b/clients/gtest/getrs_gtest.cpp index 214c78b07..dae2d2fe4 100644 --- a/clients/gtest/getrs_gtest.cpp +++ b/clients/gtest/getrs_gtest.cpp @@ -59,7 +59,7 @@ TEST_P(getrs_gtest, getrs_gtest_float) Arguments arg = setup_getrs_arguments(GetParam()); - hipblasStatus_t status = testing_getrs(arg); + hipblasStatus_t status = testing_getrs(arg); if(status != HIPBLAS_STATUS_SUCCESS) { @@ -81,7 +81,51 @@ TEST_P(getrs_gtest, getrs_gtest_double) Arguments arg = setup_getrs_arguments(GetParam()); - hipblasStatus_t status = testing_getrs(arg); + hipblasStatus_t status = testing_getrs(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.N < 0 || arg.lda < arg.N || arg.ldb < arg.N) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); // for cuda + } + } +} + +TEST_P(getrs_gtest, getrs_gtest_float_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_getrs_arguments(GetParam()); + + hipblasStatus_t status = testing_getrs(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.N < 0 || arg.lda < arg.N || arg.ldb < arg.N) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); // for cuda + } + } +} + +TEST_P(getrs_gtest, getrs_gtest_double_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_getrs_arguments(GetParam()); + + hipblasStatus_t status = testing_getrs(arg); if(status != HIPBLAS_STATUS_SUCCESS) { diff --git a/clients/gtest/getrs_strided_batched_gtest.cpp b/clients/gtest/getrs_strided_batched_gtest.cpp index f4b1b2761..bb6f075bd 100644 --- a/clients/gtest/getrs_strided_batched_gtest.cpp +++ b/clients/gtest/getrs_strided_batched_gtest.cpp @@ -59,7 +59,7 @@ TEST_P(getrs_strided_batched_gtest, getrs_strided_batched_gtest_float) Arguments arg = setup_getrs_strided_batched_arguments(GetParam()); - hipblasStatus_t status = testing_getrs_strided_batched(arg); + hipblasStatus_t status = testing_getrs_strided_batched(arg); if(status != HIPBLAS_STATUS_SUCCESS) { @@ -81,7 +81,51 @@ TEST_P(getrs_strided_batched_gtest, getrs_strided_batched_gtest_double) Arguments arg = setup_getrs_strided_batched_arguments(GetParam()); - hipblasStatus_t status = testing_getrs_strided_batched(arg); + hipblasStatus_t status = testing_getrs_strided_batched(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.N < 0 || arg.lda < arg.N || arg.ldb < arg.N || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); // for cuda + } + } +} + +TEST_P(getrs_strided_batched_gtest, getrs_strided_batched_gtest_float_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_getrs_strided_batched_arguments(GetParam()); + + hipblasStatus_t status = testing_getrs_strided_batched(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.N < 0 || arg.lda < arg.N || arg.ldb < arg.N || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); // for cuda + } + } +} + +TEST_P(getrs_strided_batched_gtest, getrs_strided_batched_gtest_double_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_getrs_strided_batched_arguments(GetParam()); + + hipblasStatus_t status = testing_getrs_strided_batched(arg); if(status != HIPBLAS_STATUS_SUCCESS) { diff --git a/clients/include/testing_geqrf.hpp b/clients/include/testing_geqrf.hpp index bd67c625f..c2c13644e 100644 --- a/clients/include/testing_geqrf.hpp +++ b/clients/include/testing_geqrf.hpp @@ -17,7 +17,7 @@ using namespace std; -template +template hipblasStatus_t testing_geqrf(Arguments argus) { int M = argus.M; @@ -56,6 +56,18 @@ hipblasStatus_t testing_geqrf(Arguments argus) srand(1); hipblas_init(hA, M, N, lda); + // scale A to avoid singularities + for(int i = 0; i < M; i++) + { + for(int j = 0; j < N; j++) + { + if(i == j) + hA[i + j * lda] += 400; + else + hA[i + j * lda] -= 4; + } + } + // Copy data from CPU to device CHECK_HIP_ERROR(hipMemcpy(dA, hA.data(), A_size * sizeof(T), hipMemcpyHostToDevice)); CHECK_HIP_ERROR(hipMemset(dIpiv, 0, Ipiv_size * sizeof(T))); @@ -85,7 +97,7 @@ hipblasStatus_t testing_geqrf(Arguments argus) // Workspace query host_vector work(1); cblas_geqrf(M, N, hA.data(), lda, hIpiv.data(), work.data(), -1); - int lwork = (int)work[0]; + int lwork = type2int(work[0]); // Perform factorization work = host_vector(lwork); @@ -93,14 +105,14 @@ hipblasStatus_t testing_geqrf(Arguments argus) if(argus.unit_check) { - T eps = std::numeric_limits::epsilon(); + U eps = std::numeric_limits::epsilon(); double tolerance = eps * 2000; - double e1 = norm_check_general('M', M, N, lda, hA.data(), hA1.data()); + double e1 = norm_check_general('F', M, N, lda, hA.data(), hA1.data()); unit_check_error(e1, tolerance); double e2 - = norm_check_general('M', min(M, N), 1, min(M, N), hIpiv.data(), hIpiv1.data()); + = norm_check_general('F', min(M, N), 1, min(M, N), hIpiv.data(), hIpiv1.data()); unit_check_error(e2, tolerance); } } diff --git a/clients/include/testing_geqrf_batched.hpp b/clients/include/testing_geqrf_batched.hpp index f24944665..b3ab323e7 100644 --- a/clients/include/testing_geqrf_batched.hpp +++ b/clients/include/testing_geqrf_batched.hpp @@ -17,7 +17,7 @@ using namespace std; -template +template hipblasStatus_t testing_geqrf_batched(Arguments argus) { int M = argus.M; @@ -71,6 +71,18 @@ hipblasStatus_t testing_geqrf_batched(Arguments argus) hipblas_init(hA[b], M, N, lda); + // scale A to avoid singularities + for(int i = 0; i < M; i++) + { + for(int j = 0; j < N; j++) + { + if(i == j) + hA[b][i + j * lda] += 400; + else + hA[b][i + j * lda] -= 4; + } + } + // Copy data from CPU to device CHECK_HIP_ERROR(hipMemcpy(bA[b], hA[b].data(), A_size * sizeof(T), hipMemcpyHostToDevice)); CHECK_HIP_ERROR( @@ -109,7 +121,7 @@ hipblasStatus_t testing_geqrf_batched(Arguments argus) // Workspace query host_vector work(1); cblas_geqrf(M, N, hA[0].data(), lda, hIpiv[0].data(), work.data(), -1); - int lwork = (int)work[0]; + int lwork = type2int(work[0]); // Perform factorization work = host_vector(lwork); @@ -119,14 +131,14 @@ hipblasStatus_t testing_geqrf_batched(Arguments argus) if(argus.unit_check) { - T eps = std::numeric_limits::epsilon(); + U eps = std::numeric_limits::epsilon(); double tolerance = eps * 2000; - double e1 = norm_check_general('M', M, N, lda, hA[b].data(), hA1[b].data()); + double e1 = norm_check_general('F', M, N, lda, hA[b].data(), hA1[b].data()); unit_check_error(e1, tolerance); double e2 = norm_check_general( - 'M', min(M, N), 1, min(M, N), hIpiv[b].data(), hIpiv1[b].data()); + 'F', min(M, N), 1, min(M, N), hIpiv[b].data(), hIpiv1[b].data()); unit_check_error(e2, tolerance); } } diff --git a/clients/include/testing_geqrf_strided_batched.hpp b/clients/include/testing_geqrf_strided_batched.hpp index 447476f38..243923aa9 100644 --- a/clients/include/testing_geqrf_strided_batched.hpp +++ b/clients/include/testing_geqrf_strided_batched.hpp @@ -17,7 +17,7 @@ using namespace std; -template +template hipblasStatus_t testing_geqrf_strided_batched(Arguments argus) { int M = argus.M; @@ -67,6 +67,18 @@ hipblasStatus_t testing_geqrf_strided_batched(Arguments argus) T* hAb = hA.data() + b * strideA; hipblas_init(hAb, M, N, lda); + + // scale A to avoid singularities + for(int i = 0; i < M; i++) + { + for(int j = 0; j < N; j++) + { + if(i == j) + hAb[i + j * lda] += 400; + else + hAb[i + j * lda] -= 4; + } + } } // Copy data from CPU to device @@ -99,7 +111,7 @@ hipblasStatus_t testing_geqrf_strided_batched(Arguments argus) // Workspace query host_vector work(1); cblas_geqrf(M, N, hA.data(), lda, hIpiv.data(), work.data(), -1); - int lwork = (int)work[0]; + int lwork = type2int(work[0]); // Perform factorization work = host_vector(lwork); @@ -110,14 +122,14 @@ hipblasStatus_t testing_geqrf_strided_batched(Arguments argus) if(argus.unit_check) { - T eps = std::numeric_limits::epsilon(); + U eps = std::numeric_limits::epsilon(); double tolerance = eps * 2000; double e1 = norm_check_general( - 'M', M, N, lda, hA.data() + b * strideA, hA1.data() + b * strideA); + 'F', M, N, lda, hA.data() + b * strideA, hA1.data() + b * strideA); unit_check_error(e1, tolerance); - double e2 = norm_check_general('M', + double e2 = norm_check_general('F', min(M, N), 1, strideP, diff --git a/clients/include/testing_getrf.hpp b/clients/include/testing_getrf.hpp index b075b84b6..3ef2a61c7 100644 --- a/clients/include/testing_getrf.hpp +++ b/clients/include/testing_getrf.hpp @@ -17,7 +17,7 @@ using namespace std; -template +template hipblasStatus_t testing_getrf(Arguments argus) { int M = argus.N; @@ -58,6 +58,18 @@ hipblasStatus_t testing_getrf(Arguments argus) srand(1); hipblas_init(hA, M, N, lda); + // scale A to avoid singularities + for(int i = 0; i < M; i++) + { + for(int j = 0; j < N; j++) + { + if(i == j) + hA[i + j * lda] += 400; + else + hA[i + j * lda] -= 4; + } + } + // Copy data from CPU to device CHECK_HIP_ERROR(hipMemcpy(dA, hA.data(), A_size * sizeof(T), hipMemcpyHostToDevice)); CHECK_HIP_ERROR(hipMemset(dIpiv, 0, Ipiv_size * sizeof(int))); @@ -91,10 +103,10 @@ hipblasStatus_t testing_getrf(Arguments argus) if(argus.unit_check) { - T eps = std::numeric_limits::epsilon(); + U eps = std::numeric_limits::epsilon(); double tolerance = eps * 2000; - double e = norm_check_general('M', M, N, lda, hA.data(), hA1.data()); + double e = norm_check_general('F', M, N, lda, hA.data(), hA1.data()); unit_check_error(e, tolerance); } } diff --git a/clients/include/testing_getrf_batched.hpp b/clients/include/testing_getrf_batched.hpp index 060ccd76a..017c11d7a 100644 --- a/clients/include/testing_getrf_batched.hpp +++ b/clients/include/testing_getrf_batched.hpp @@ -17,7 +17,7 @@ using namespace std; -template +template hipblasStatus_t testing_getrf_batched(Arguments argus) { int M = argus.N; @@ -71,6 +71,18 @@ hipblasStatus_t testing_getrf_batched(Arguments argus) hipblas_init(hA[b], M, N, lda); + // scale A to avoid singularities + for(int i = 0; i < M; i++) + { + for(int j = 0; j < N; j++) + { + if(i == j) + hA[b][i + j * lda] += 400; + else + hA[b][i + j * lda] -= 4; + } + } + // Copy data from CPU to device CHECK_HIP_ERROR(hipMemcpy(bA[b], hA[b].data(), A_size * sizeof(T), hipMemcpyHostToDevice)); } @@ -111,10 +123,10 @@ hipblasStatus_t testing_getrf_batched(Arguments argus) if(argus.unit_check) { - T eps = std::numeric_limits::epsilon(); + U eps = std::numeric_limits::epsilon(); double tolerance = eps * 2000; - double e = norm_check_general('M', M, N, lda, hA[b].data(), hA1[b].data()); + double e = norm_check_general('F', M, N, lda, hA[b].data(), hA1[b].data()); unit_check_error(e, tolerance); } } diff --git a/clients/include/testing_getrf_strided_batched.hpp b/clients/include/testing_getrf_strided_batched.hpp index 013da1a0e..f916c2939 100644 --- a/clients/include/testing_getrf_strided_batched.hpp +++ b/clients/include/testing_getrf_strided_batched.hpp @@ -17,7 +17,7 @@ using namespace std; -template +template hipblasStatus_t testing_getrf_strided_batched(Arguments argus) { int M = argus.N; @@ -69,6 +69,18 @@ hipblasStatus_t testing_getrf_strided_batched(Arguments argus) T* hAb = hA.data() + b * strideA; hipblas_init(hAb, M, N, lda); + + // scale A to avoid singularities + for(int i = 0; i < M; i++) + { + for(int j = 0; j < N; j++) + { + if(i == j) + hAb[i + j * lda] += 400; + else + hAb[i + j * lda] -= 4; + } + } } // Copy data from CPU to device @@ -108,11 +120,11 @@ hipblasStatus_t testing_getrf_strided_batched(Arguments argus) if(argus.unit_check) { - T eps = std::numeric_limits::epsilon(); + U eps = std::numeric_limits::epsilon(); double tolerance = eps * 2000; double e = norm_check_general( - 'M', M, N, lda, hA.data() + b * strideA, hA1.data() + b * strideA); + 'F', M, N, lda, hA.data() + b * strideA, hA1.data() + b * strideA); unit_check_error(e, tolerance); } } diff --git a/clients/include/testing_getrs.hpp b/clients/include/testing_getrs.hpp index 5aa13e48e..7da1f6607 100644 --- a/clients/include/testing_getrs.hpp +++ b/clients/include/testing_getrs.hpp @@ -17,7 +17,7 @@ using namespace std; -template +template hipblasStatus_t testing_getrs(Arguments argus) { int N = argus.N; @@ -61,15 +61,15 @@ hipblasStatus_t testing_getrs(Arguments argus) hipblas_init(hA, N, N, lda); hipblas_init(hX, N, 1, ldb); - // Put hA entries into range [0, 1], make diagonally dominant + // scale A to avoid singularities for(int i = 0; i < N; i++) { for(int j = 0; j < N; j++) { - hA[i + j * lda] = (hA[i + j * lda] - 1.0) / 10.0; - if(i == j) - hA[i + j * lda] *= 100; + hA[i + j * lda] += 400; + else + hA[i + j * lda] -= 4; } } @@ -122,10 +122,10 @@ hipblasStatus_t testing_getrs(Arguments argus) if(argus.unit_check) { - T eps = std::numeric_limits::epsilon(); + U eps = std::numeric_limits::epsilon(); double tolerance = N * eps * 100; - double e = norm_check_general('M', N, 1, ldb, hB.data(), hB1.data()); + double e = norm_check_general('F', N, 1, ldb, hB.data(), hB1.data()); unit_check_error(e, tolerance); } } diff --git a/clients/include/testing_getrs_batched.hpp b/clients/include/testing_getrs_batched.hpp index 5cabffc32..dd4e1139f 100644 --- a/clients/include/testing_getrs_batched.hpp +++ b/clients/include/testing_getrs_batched.hpp @@ -17,7 +17,7 @@ using namespace std; -template +template hipblasStatus_t testing_getrs_batched(Arguments argus) { int N = argus.N; @@ -78,15 +78,15 @@ hipblasStatus_t testing_getrs_batched(Arguments argus) hipblas_init(hA[b], N, N, lda); hipblas_init(hX[b], N, 1, ldb); - // Put hA entries into range [0, 1], make diagonally dominant + // scale A to avoid singularities for(int i = 0; i < N; i++) { for(int j = 0; j < N; j++) { - hA[b][i + j * lda] = (hA[b][i + j * lda] - 1.0) / 10.0; - if(i == j) - hA[b][i + j * lda] *= 100; + hA[b][i + j * lda] += 400; + else + hA[b][i + j * lda] -= 4; } } @@ -143,10 +143,10 @@ hipblasStatus_t testing_getrs_batched(Arguments argus) if(argus.unit_check) { - T eps = std::numeric_limits::epsilon(); + U eps = std::numeric_limits::epsilon(); double tolerance = N * eps * 100; - double e = norm_check_general('M', N, 1, ldb, hB[b].data(), hB1[b].data()); + double e = norm_check_general('F', N, 1, ldb, hB[b].data(), hB1[b].data()); unit_check_error(e, tolerance); } } diff --git a/clients/include/testing_getrs_strided_batched.hpp b/clients/include/testing_getrs_strided_batched.hpp index 50416eb81..95277e362 100644 --- a/clients/include/testing_getrs_strided_batched.hpp +++ b/clients/include/testing_getrs_strided_batched.hpp @@ -17,7 +17,7 @@ using namespace std; -template +template hipblasStatus_t testing_getrs_strided_batched(Arguments argus) { int N = argus.N; @@ -78,15 +78,15 @@ hipblasStatus_t testing_getrs_strided_batched(Arguments argus) hipblas_init(hAb, N, N, lda); hipblas_init(hXb, N, 1, ldb); - // Put hA entries into range [0, 1], make diagonally dominant + // scale A to avoid singularities for(int i = 0; i < N; i++) { for(int j = 0; j < N; j++) { - hAb[i + j * lda] = (hAb[i + j * lda] - 1.0) / 10.0; - if(i == j) - hAb[i + j * lda] *= 100; + hAb[i + j * lda] += 400; + else + hAb[i + j * lda] -= 4; } } @@ -144,11 +144,11 @@ hipblasStatus_t testing_getrs_strided_batched(Arguments argus) if(argus.unit_check) { - T eps = std::numeric_limits::epsilon(); + U eps = std::numeric_limits::epsilon(); double tolerance = N * eps * 100; double e = norm_check_general( - 'M', N, 1, ldb, hB.data() + b * strideB, hB1.data() + b * strideB); + 'F', N, 1, ldb, hB.data() + b * strideB, hB1.data() + b * strideB); unit_check_error(e, tolerance); } } diff --git a/clients/include/utility.h b/clients/include/utility.h index 954ee03c0..a605ecf35 100644 --- a/clients/include/utility.h +++ b/clients/include/utility.h @@ -581,6 +581,12 @@ void prepare_triangular_solve(T* hA, int lda, T* AAT, int N, char char_uplo) template char type2char(); +/* ============================================================================================ */ +/*! \brief turn float -> int, double -> int, hipblas_float_complex.real() -> int, + * hipblas_double_complex.real() -> int */ +template +int type2int(T val); + /* ============================================================================================ */ /*! \brief Debugging purpose, print out CPU and GPU result matrix, not valid in complex number */ template , int> = 0> diff --git a/library/src/hcc_detail/hipblas.cpp b/library/src/hcc_detail/hipblas.cpp index 928f98dd3..c7b01ed30 100644 --- a/library/src/hcc_detail/hipblas.cpp +++ b/library/src/hcc_detail/hipblas.cpp @@ -10,18 +10,6 @@ #include #include -#define USE_DEVICE_POINTER_MODE(handle, cmd) \ - do \ - { \ - hipblasPointerMode_t mode; \ - hipblasGetPointerMode(handle, &mode); \ - hipblasSetPointerMode(handle, HIPBLAS_POINTER_MODE_DEVICE); \ - \ - cmd; \ - \ - hipblasSetPointerMode(handle, mode); \ - } while(0); - #ifdef __cplusplus extern "C" { #endif @@ -12266,34 +12254,68 @@ hipblasStatus_t hipblasZdgmmStridedBatched(hipblasHandle_t handle, //rocSOLVER functions //-------------------------------------------------------------------------------------- +// The following functions are not included in the public API and must be declared + +#ifdef __cplusplus +extern "C" { +#endif + +rocblas_status rocsolver_sgeqrf_ptr_batched(rocblas_handle handle, + const rocblas_int m, + const rocblas_int n, + float* const A[], + const rocblas_int lda, + float* const ipiv[], + const rocblas_int batch_count); + +rocblas_status rocsolver_dgeqrf_ptr_batched(rocblas_handle handle, + const rocblas_int m, + const rocblas_int n, + double* const A[], + const rocblas_int lda, + double* const ipiv[], + const rocblas_int batch_count); + +rocblas_status rocsolver_cgeqrf_ptr_batched(rocblas_handle handle, + const rocblas_int m, + const rocblas_int n, + rocblas_float_complex* const A[], + const rocblas_int lda, + rocblas_float_complex* const ipiv[], + const rocblas_int batch_count); + +rocblas_status rocsolver_zgeqrf_ptr_batched(rocblas_handle handle, + const rocblas_int m, + const rocblas_int n, + rocblas_double_complex* const A[], + const rocblas_int lda, + rocblas_double_complex* const ipiv[], + const rocblas_int batch_count); + +#ifdef __cplusplus +} +#endif + // getrf hipblasStatus_t hipblasSgetrf( hipblasHandle_t handle, const int n, float* A, const int lda, int* ipiv, int* info) { - rocsolver_status status; - USE_DEVICE_POINTER_MODE( - handle, status = rocsolver_sgetrf((rocblas_handle)handle, n, n, A, lda, ipiv, info)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus( + rocsolver_sgetrf((rocblas_handle)handle, n, n, A, lda, ipiv, info)); } hipblasStatus_t hipblasDgetrf( hipblasHandle_t handle, const int n, double* A, const int lda, int* ipiv, int* info) { - rocsolver_status status; - USE_DEVICE_POINTER_MODE( - handle, status = rocsolver_dgetrf((rocblas_handle)handle, n, n, A, lda, ipiv, info)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus( + rocsolver_dgetrf((rocblas_handle)handle, n, n, A, lda, ipiv, info)); } hipblasStatus_t hipblasCgetrf( hipblasHandle_t handle, const int n, hipblasComplex* A, const int lda, int* ipiv, int* info) { - rocsolver_status status; - USE_DEVICE_POINTER_MODE( - handle, - status = rocsolver_cgetrf( - (rocblas_handle)handle, n, n, (rocblas_float_complex*)A, lda, ipiv, info)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus( + rocsolver_cgetrf((rocblas_handle)handle, n, n, (rocblas_float_complex*)A, lda, ipiv, info)); } hipblasStatus_t hipblasZgetrf(hipblasHandle_t handle, @@ -12303,12 +12325,8 @@ hipblasStatus_t hipblasZgetrf(hipblasHandle_t handle, int* ipiv, int* info) { - rocsolver_status status; - USE_DEVICE_POINTER_MODE( - handle, - status = rocsolver_zgetrf( - (rocblas_handle)handle, n, n, (rocblas_double_complex*)A, lda, ipiv, info)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_zgetrf( + (rocblas_handle)handle, n, n, (rocblas_double_complex*)A, lda, ipiv, info)); } // getrf_batched @@ -12320,11 +12338,8 @@ hipblasStatus_t hipblasSgetrfBatched(hipblasHandle_t handle, int* info, const int batch_count) { - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status = rocsolver_sgetrf_batched( - (rocblas_handle)handle, n, n, A, lda, ipiv, n, info, batch_count)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus( + rocsolver_sgetrf_batched((rocblas_handle)handle, n, n, A, lda, ipiv, n, info, batch_count)); } hipblasStatus_t hipblasDgetrfBatched(hipblasHandle_t handle, @@ -12335,11 +12350,8 @@ hipblasStatus_t hipblasDgetrfBatched(hipblasHandle_t handle, int* info, const int batch_count) { - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status = rocsolver_dgetrf_batched( - (rocblas_handle)handle, n, n, A, lda, ipiv, n, info, batch_count)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus( + rocsolver_dgetrf_batched((rocblas_handle)handle, n, n, A, lda, ipiv, n, info, batch_count)); } hipblasStatus_t hipblasCgetrfBatched(hipblasHandle_t handle, @@ -12350,18 +12362,8 @@ hipblasStatus_t hipblasCgetrfBatched(hipblasHandle_t handle, int* info, const int batch_count) { - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status = rocsolver_cgetrf_batched((rocblas_handle)handle, - n, - n, - (rocblas_float_complex**)A, - lda, - ipiv, - n, - info, - batch_count)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_cgetrf_batched( + (rocblas_handle)handle, n, n, (rocblas_float_complex**)A, lda, ipiv, n, info, batch_count)); } hipblasStatus_t hipblasZgetrfBatched(hipblasHandle_t handle, @@ -12372,18 +12374,15 @@ hipblasStatus_t hipblasZgetrfBatched(hipblasHandle_t handle, int* info, const int batch_count) { - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status = rocsolver_zgetrf_batched((rocblas_handle)handle, - n, - n, - (rocblas_double_complex**)A, - lda, - ipiv, - n, - info, - batch_count)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_zgetrf_batched((rocblas_handle)handle, + n, + n, + (rocblas_double_complex**)A, + lda, + ipiv, + n, + info, + batch_count)); } // getrf_strided_batched @@ -12397,12 +12396,8 @@ hipblasStatus_t hipblasSgetrfStridedBatched(hipblasHandle_t handle, int* info, const int batch_count) { - rocsolver_status status; - USE_DEVICE_POINTER_MODE( - handle, - status = rocsolver_sgetrf_strided_batched( - (rocblas_handle)handle, n, n, A, lda, strideA, ipiv, strideP, info, batch_count)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_sgetrf_strided_batched( + (rocblas_handle)handle, n, n, A, lda, strideA, ipiv, strideP, info, batch_count)); } hipblasStatus_t hipblasDgetrfStridedBatched(hipblasHandle_t handle, @@ -12415,12 +12410,8 @@ hipblasStatus_t hipblasDgetrfStridedBatched(hipblasHandle_t handle, int* info, const int batch_count) { - rocsolver_status status; - USE_DEVICE_POINTER_MODE( - handle, - status = rocsolver_dgetrf_strided_batched( - (rocblas_handle)handle, n, n, A, lda, strideA, ipiv, strideP, info, batch_count)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_dgetrf_strided_batched( + (rocblas_handle)handle, n, n, A, lda, strideA, ipiv, strideP, info, batch_count)); } hipblasStatus_t hipblasCgetrfStridedBatched(hipblasHandle_t handle, @@ -12433,19 +12424,16 @@ hipblasStatus_t hipblasCgetrfStridedBatched(hipblasHandle_t handle, int* info, const int batch_count) { - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status = rocsolver_cgetrf_strided_batched((rocblas_handle)handle, - n, - n, - (rocblas_float_complex*)A, - lda, - strideA, - ipiv, - strideP, - info, - batch_count)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_cgetrf_strided_batched((rocblas_handle)handle, + n, + n, + (rocblas_float_complex*)A, + lda, + strideA, + ipiv, + strideP, + info, + batch_count)); } hipblasStatus_t hipblasZgetrfStridedBatched(hipblasHandle_t handle, @@ -12458,19 +12446,16 @@ hipblasStatus_t hipblasZgetrfStridedBatched(hipblasHandle_t handle, int* info, const int batch_count) { - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status = rocsolver_zgetrf_strided_batched((rocblas_handle)handle, - n, - n, - (rocblas_double_complex*)A, - lda, - strideA, - ipiv, - strideP, - info, - batch_count)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_zgetrf_strided_batched((rocblas_handle)handle, + n, + n, + (rocblas_double_complex*)A, + lda, + strideA, + ipiv, + strideP, + info, + batch_count)); } // getrs @@ -12504,18 +12489,8 @@ hipblasStatus_t hipblasSgetrs(hipblasHandle_t handle, else *info = 0; - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status = rocsolver_sgetrs((rocblas_handle)handle, - hipOperationToHCCOperation(trans), - n, - nrhs, - A, - lda, - ipiv, - B, - ldb)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_sgetrs( + (rocblas_handle)handle, hipOperationToHCCOperation(trans), n, nrhs, A, lda, ipiv, B, ldb)); } hipblasStatus_t hipblasDgetrs(hipblasHandle_t handle, @@ -12548,18 +12523,8 @@ hipblasStatus_t hipblasDgetrs(hipblasHandle_t handle, else *info = 0; - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status = rocsolver_dgetrs((rocblas_handle)handle, - hipOperationToHCCOperation(trans), - n, - nrhs, - A, - lda, - ipiv, - B, - ldb)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_dgetrs( + (rocblas_handle)handle, hipOperationToHCCOperation(trans), n, nrhs, A, lda, ipiv, B, ldb)); } hipblasStatus_t hipblasCgetrs(hipblasHandle_t handle, @@ -12592,18 +12557,15 @@ hipblasStatus_t hipblasCgetrs(hipblasHandle_t handle, else *info = 0; - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status = rocsolver_cgetrs((rocblas_handle)handle, - hipOperationToHCCOperation(trans), - n, - nrhs, - (rocblas_float_complex*)A, - lda, - ipiv, - (rocblas_float_complex*)B, - ldb)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_cgetrs((rocblas_handle)handle, + hipOperationToHCCOperation(trans), + n, + nrhs, + (rocblas_float_complex*)A, + lda, + ipiv, + (rocblas_float_complex*)B, + ldb)); } hipblasStatus_t hipblasZgetrs(hipblasHandle_t handle, @@ -12636,18 +12598,15 @@ hipblasStatus_t hipblasZgetrs(hipblasHandle_t handle, else *info = 0; - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status = rocsolver_zgetrs((rocblas_handle)handle, - hipOperationToHCCOperation(trans), - n, - nrhs, - (rocblas_double_complex*)A, - lda, - ipiv, - (rocblas_double_complex*)B, - ldb)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_zgetrs((rocblas_handle)handle, + hipOperationToHCCOperation(trans), + n, + nrhs, + (rocblas_double_complex*)A, + lda, + ipiv, + (rocblas_double_complex*)B, + ldb)); } // getrs_batched @@ -12684,20 +12643,17 @@ hipblasStatus_t hipblasSgetrsBatched(hipblasHandle_t handle, else *info = 0; - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status = rocsolver_sgetrs_batched((rocblas_handle)handle, - hipOperationToHCCOperation(trans), - n, - nrhs, - A, - lda, - ipiv, - n, - B, - ldb, - batch_count)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_sgetrs_batched((rocblas_handle)handle, + hipOperationToHCCOperation(trans), + n, + nrhs, + A, + lda, + ipiv, + n, + B, + ldb, + batch_count)); } hipblasStatus_t hipblasDgetrsBatched(hipblasHandle_t handle, @@ -12733,20 +12689,17 @@ hipblasStatus_t hipblasDgetrsBatched(hipblasHandle_t handle, else *info = 0; - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status = rocsolver_dgetrs_batched((rocblas_handle)handle, - hipOperationToHCCOperation(trans), - n, - nrhs, - A, - lda, - ipiv, - n, - B, - ldb, - batch_count)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_dgetrs_batched((rocblas_handle)handle, + hipOperationToHCCOperation(trans), + n, + nrhs, + A, + lda, + ipiv, + n, + B, + ldb, + batch_count)); } hipblasStatus_t hipblasCgetrsBatched(hipblasHandle_t handle, @@ -12782,20 +12735,17 @@ hipblasStatus_t hipblasCgetrsBatched(hipblasHandle_t handle, else *info = 0; - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status = rocsolver_cgetrs_batched((rocblas_handle)handle, - hipOperationToHCCOperation(trans), - n, - nrhs, - (rocblas_float_complex**)A, - lda, - ipiv, - n, - (rocblas_float_complex**)B, - ldb, - batch_count)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_cgetrs_batched((rocblas_handle)handle, + hipOperationToHCCOperation(trans), + n, + nrhs, + (rocblas_float_complex**)A, + lda, + ipiv, + n, + (rocblas_float_complex**)B, + ldb, + batch_count)); } hipblasStatus_t hipblasZgetrsBatched(hipblasHandle_t handle, @@ -12831,20 +12781,17 @@ hipblasStatus_t hipblasZgetrsBatched(hipblasHandle_t handle, else *info = 0; - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status = rocsolver_zgetrs_batched((rocblas_handle)handle, - hipOperationToHCCOperation(trans), - n, - nrhs, - (rocblas_double_complex**)A, - lda, - ipiv, - n, - (rocblas_double_complex**)B, - ldb, - batch_count)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_zgetrs_batched((rocblas_handle)handle, + hipOperationToHCCOperation(trans), + n, + nrhs, + (rocblas_double_complex**)A, + lda, + ipiv, + n, + (rocblas_double_complex**)B, + ldb, + batch_count)); } // getrs_strided_batched @@ -12884,23 +12831,20 @@ hipblasStatus_t hipblasSgetrsStridedBatched(hipblasHandle_t handle, else *info = 0; - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status - = rocsolver_sgetrs_strided_batched((rocblas_handle)handle, - hipOperationToHCCOperation(trans), - n, - nrhs, - A, - lda, - strideA, - ipiv, - strideP, - B, - ldb, - strideB, - batch_count)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus( + rocsolver_sgetrs_strided_batched((rocblas_handle)handle, + hipOperationToHCCOperation(trans), + n, + nrhs, + A, + lda, + strideA, + ipiv, + strideP, + B, + ldb, + strideB, + batch_count)); } hipblasStatus_t hipblasDgetrsStridedBatched(hipblasHandle_t handle, @@ -12939,23 +12883,20 @@ hipblasStatus_t hipblasDgetrsStridedBatched(hipblasHandle_t handle, else *info = 0; - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status - = rocsolver_dgetrs_strided_batched((rocblas_handle)handle, - hipOperationToHCCOperation(trans), - n, - nrhs, - A, - lda, - strideA, - ipiv, - strideP, - B, - ldb, - strideB, - batch_count)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus( + rocsolver_dgetrs_strided_batched((rocblas_handle)handle, + hipOperationToHCCOperation(trans), + n, + nrhs, + A, + lda, + strideA, + ipiv, + strideP, + B, + ldb, + strideB, + batch_count)); } hipblasStatus_t hipblasCgetrsStridedBatched(hipblasHandle_t handle, @@ -12994,23 +12935,20 @@ hipblasStatus_t hipblasCgetrsStridedBatched(hipblasHandle_t handle, else *info = 0; - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status - = rocsolver_cgetrs_strided_batched((rocblas_handle)handle, - hipOperationToHCCOperation(trans), - n, - nrhs, - (rocblas_float_complex*)A, - lda, - strideA, - ipiv, - strideP, - (rocblas_float_complex*)B, - ldb, - strideB, - batch_count)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus( + rocsolver_cgetrs_strided_batched((rocblas_handle)handle, + hipOperationToHCCOperation(trans), + n, + nrhs, + (rocblas_float_complex*)A, + lda, + strideA, + ipiv, + strideP, + (rocblas_float_complex*)B, + ldb, + strideB, + batch_count)); } hipblasStatus_t hipblasZgetrsStridedBatched(hipblasHandle_t handle, @@ -13049,23 +12987,20 @@ hipblasStatus_t hipblasZgetrsStridedBatched(hipblasHandle_t handle, else *info = 0; - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status - = rocsolver_zgetrs_strided_batched((rocblas_handle)handle, - hipOperationToHCCOperation(trans), - n, - nrhs, - (rocblas_double_complex*)A, - lda, - strideA, - ipiv, - strideP, - (rocblas_double_complex*)B, - ldb, - strideB, - batch_count)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus( + rocsolver_zgetrs_strided_batched((rocblas_handle)handle, + hipOperationToHCCOperation(trans), + n, + nrhs, + (rocblas_double_complex*)A, + lda, + strideA, + ipiv, + strideP, + (rocblas_double_complex*)B, + ldb, + strideB, + batch_count)); } // geqrf @@ -13092,10 +13027,7 @@ hipblasStatus_t hipblasSgeqrf(hipblasHandle_t handle, else *info = 0; - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status = rocsolver_sgeqrf((rocblas_handle)handle, m, n, A, lda, tau)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_sgeqrf((rocblas_handle)handle, m, n, A, lda, tau)); } hipblasStatus_t hipblasDgeqrf(hipblasHandle_t handle, @@ -13121,10 +13053,7 @@ hipblasStatus_t hipblasDgeqrf(hipblasHandle_t handle, else *info = 0; - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status = rocsolver_dgeqrf((rocblas_handle)handle, m, n, A, lda, tau)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_dgeqrf((rocblas_handle)handle, m, n, A, lda, tau)); } hipblasStatus_t hipblasCgeqrf(hipblasHandle_t handle, @@ -13150,17 +13079,8 @@ hipblasStatus_t hipblasCgeqrf(hipblasHandle_t handle, else *info = 0; - // rocsolver_status status; - // USE_DEVICE_POINTER_MODE(handle, status = rocsolver_cgeqrf( - // (rocblas_handle)handle, - // m, - // n, - // (rocblas_float_complex*)A, - // lda, - // (rocblas_float_complex*)tau)); - // return rocBLASStatusToHIPStatus(status); - - return HIPBLAS_STATUS_NOT_SUPPORTED; + return rocBLASStatusToHIPStatus(rocsolver_cgeqrf( + (rocblas_handle)handle, m, n, (rocblas_float_complex*)A, lda, (rocblas_float_complex*)tau)); } hipblasStatus_t hipblasZgeqrf(hipblasHandle_t handle, @@ -13186,17 +13106,12 @@ hipblasStatus_t hipblasZgeqrf(hipblasHandle_t handle, else *info = 0; - // rocsolver_status status; - // USE_DEVICE_POINTER_MODE(handle, status = rocsolver_zgeqrf( - // (rocblas_handle)handle, - // m, - // n, - // (rocblas_double_complex*)A, - // lda, - // (rocblas_double_complex*)tau)); - // return rocBLASStatusToHIPStatus(status); - - return HIPBLAS_STATUS_NOT_SUPPORTED; + return rocBLASStatusToHIPStatus(rocsolver_zgeqrf((rocblas_handle)handle, + m, + n, + (rocblas_double_complex*)A, + lda, + (rocblas_double_complex*)tau)); } // geqrf_batched @@ -13226,31 +13141,8 @@ hipblasStatus_t hipblasSgeqrfBatched(hipblasHandle_t handle, else *info = 0; - int perArray = std::min(m, n); - float* ipiv = NULL; - if(perArray > 0) - hipMalloc(&ipiv, batch_count * perArray * sizeof(float)); - - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status = rocsolver_sgeqrf_batched( - (rocblas_handle)handle, m, n, A, lda, ipiv, perArray, batch_count)); - - if(status == rocblas_status_success) - { - // TO DO: Copy ipiv into tau using a fast kernel call - float* hTau[batch_count]; - hipMemcpy(hTau, tau, sizeof(float*) * batch_count, hipMemcpyDeviceToHost); - - for(int b = 0; b < batch_count; b++) - hipMemcpy( - hTau[b], ipiv + b * perArray, sizeof(float) * perArray, hipMemcpyDeviceToDevice); - } - - if(perArray > 0) - hipFree(ipiv); - - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus( + rocsolver_sgeqrf_ptr_batched((rocblas_handle)handle, m, n, A, lda, tau, batch_count)); } hipblasStatus_t hipblasDgeqrfBatched(hipblasHandle_t handle, @@ -13279,31 +13171,8 @@ hipblasStatus_t hipblasDgeqrfBatched(hipblasHandle_t handle, else *info = 0; - int perArray = std::min(m, n); - double* ipiv = NULL; - if(perArray > 0) - hipMalloc(&ipiv, batch_count * perArray * sizeof(double)); - - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status = rocsolver_dgeqrf_batched( - (rocblas_handle)handle, m, n, A, lda, ipiv, perArray, batch_count)); - - if(status == rocblas_status_success) - { - // TO DO: Copy ipiv into tau using a fast kernel call - double* hTau[batch_count]; - hipMemcpy(hTau, tau, sizeof(double*) * batch_count, hipMemcpyDeviceToHost); - - for(int b = 0; b < batch_count; b++) - hipMemcpy( - hTau[b], ipiv + b * perArray, sizeof(double) * perArray, hipMemcpyDeviceToDevice); - } - - if(perArray > 0) - hipFree(ipiv); - - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus( + rocsolver_dgeqrf_ptr_batched((rocblas_handle)handle, m, n, A, lda, tau, batch_count)); } hipblasStatus_t hipblasCgeqrfBatched(hipblasHandle_t handle, @@ -13332,38 +13201,13 @@ hipblasStatus_t hipblasCgeqrfBatched(hipblasHandle_t handle, else *info = 0; - // int perArray = std::min(m, n); - // rocblas_float_complex* ipiv = NULL; - // if (perArray > 0) - // hipMalloc(&ipiv, batch_count * perArray * sizeof(rocblas_float_complex)); - - // rocsolver_status status; - // USE_DEVICE_POINTER_MODE(handle, status = rocsolver_cgeqrf_batched( - // (rocblas_handle)handle, - // m, - // n, - // (rocblas_float_complex**)A, - // lda, - // (rocblas_float_complex*)ipiv, - // perArray, - // batch_count)); - - // if (status == rocblas_status_success) - // { - // // TO DO: Copy ipiv into tau using a fast kernel call - // rocblas_float_complex* hTau[batch_count]; - // hipMemcpy(hTau, tau, sizeof(rocblas_float_complex*) * batch_count, hipMemcpyDeviceToHost); - - // for(int b = 0; b < batch_count; b++) - // hipMemcpy(hTau[b], ipiv + b * perArray, sizeof(rocblas_float_complex) * perArray, hipMemcpyDeviceToDevice); - // } - - // if (perArray > 0) - // hipFree(ipiv); - - // return rocBLASStatusToHIPStatus(status); - - return HIPBLAS_STATUS_NOT_SUPPORTED; + return rocBLASStatusToHIPStatus(rocsolver_cgeqrf_ptr_batched((rocblas_handle)handle, + m, + n, + (rocblas_float_complex**)A, + lda, + (rocblas_float_complex**)tau, + batch_count)); } hipblasStatus_t hipblasZgeqrfBatched(hipblasHandle_t handle, @@ -13392,38 +13236,13 @@ hipblasStatus_t hipblasZgeqrfBatched(hipblasHandle_t handle, else *info = 0; - // int perArray = std::min(m, n); - // rocblas_double_complex* ipiv = NULL; - // if (perArray > 0) - // hipMalloc(&ipiv, batch_count * perArray * sizeof(rocblas_double_complex)); - - // rocsolver_status status; - // USE_DEVICE_POINTER_MODE(handle, status = rocsolver_zgeqrf_batched( - // (rocblas_handle)handle, - // m, - // n, - // (rocblas_double_complex**)A, - // lda, - // (rocblas_double_complex*)ipiv, - // perArray, - // batch_count)); - - // if (status == rocblas_status_success) - // { - // // TO DO: Copy ipiv into tau using a fast kernel call - // rocblas_double_complex* hTau[batch_count]; - // hipMemcpy(hTau, tau, sizeof(rocblas_double_complex*) * batch_count, hipMemcpyDeviceToHost); - - // for(int b = 0; b < batch_count; b++) - // hipMemcpy(hTau[b], ipiv + b * perArray, sizeof(rocblas_double_complex) * perArray, hipMemcpyDeviceToDevice); - // } - - // if (perArray > 0) - // hipFree(ipiv); - - // return rocBLASStatusToHIPStatus(status); - - return HIPBLAS_STATUS_NOT_SUPPORTED; + return rocBLASStatusToHIPStatus(rocsolver_zgeqrf_ptr_batched((rocblas_handle)handle, + m, + n, + (rocblas_double_complex**)A, + lda, + (rocblas_double_complex**)tau, + batch_count)); } // geqrf_strided_batched @@ -13455,12 +13274,8 @@ hipblasStatus_t hipblasSgeqrfStridedBatched(hipblasHandle_t handle, else *info = 0; - rocsolver_status status; - USE_DEVICE_POINTER_MODE( - handle, - status = rocsolver_sgeqrf_strided_batched( - (rocblas_handle)handle, m, n, A, lda, strideA, tau, strideT, batch_count)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_sgeqrf_strided_batched( + (rocblas_handle)handle, m, n, A, lda, strideA, tau, strideT, batch_count)); } hipblasStatus_t hipblasDgeqrfStridedBatched(hipblasHandle_t handle, @@ -13491,12 +13306,8 @@ hipblasStatus_t hipblasDgeqrfStridedBatched(hipblasHandle_t handle, else *info = 0; - rocsolver_status status; - USE_DEVICE_POINTER_MODE( - handle, - status = rocsolver_dgeqrf_strided_batched( - (rocblas_handle)handle, m, n, A, lda, strideA, tau, strideT, batch_count)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_dgeqrf_strided_batched( + (rocblas_handle)handle, m, n, A, lda, strideA, tau, strideT, batch_count)); } hipblasStatus_t hipblasCgeqrfStridedBatched(hipblasHandle_t handle, @@ -13527,20 +13338,15 @@ hipblasStatus_t hipblasCgeqrfStridedBatched(hipblasHandle_t handle, else *info = 0; - // rocsolver_status status; - // USE_DEVICE_POINTER_MODE(handle, status = rocsolver_cgeqrf_strided_batched( - // (rocblas_handle)handle, - // m, - // n, - // (rocblas_float_complex*)A, - // lda, - // strideA, - // (rocblas_float_complex*)tau, - // strideT, - // batch_count)); - // return rocBLASStatusToHIPStatus(status); - - return HIPBLAS_STATUS_NOT_SUPPORTED; + return rocBLASStatusToHIPStatus(rocsolver_cgeqrf_strided_batched((rocblas_handle)handle, + m, + n, + (rocblas_float_complex*)A, + lda, + strideA, + (rocblas_float_complex*)tau, + strideT, + batch_count)); } hipblasStatus_t hipblasZgeqrfStridedBatched(hipblasHandle_t handle, @@ -13571,20 +13377,15 @@ hipblasStatus_t hipblasZgeqrfStridedBatched(hipblasHandle_t handle, else *info = 0; - // rocsolver_status status; - // USE_DEVICE_POINTER_MODE(handle, status = rocsolver_zgeqrf_strided_batched( - // (rocblas_handle)handle, - // m, - // n, - // (rocblas_double_complex*)A, - // lda, - // strideA, - // (rocblas_double_complex*)tau, - // strideT, - // batch_count)); - // return rocBLASStatusToHIPStatus(status); - - return HIPBLAS_STATUS_NOT_SUPPORTED; + return rocBLASStatusToHIPStatus(rocsolver_zgeqrf_strided_batched((rocblas_handle)handle, + m, + n, + (rocblas_double_complex*)A, + lda, + strideA, + (rocblas_double_complex*)tau, + strideT, + batch_count)); } #endif From 4142e4b32294513c754258f0c0e602365c0434cf Mon Sep 17 00:00:00 2001 From: daineAMD <51674140+daineAMD@users.noreply.github.com> Date: Mon, 8 Jun 2020 11:54:52 -0600 Subject: [PATCH 3/9] Adding rocblas_path build variable for custom rocblas path. (#230) --- install.sh | 39 +++++++++++++++++++++++++------------- library/src/CMakeLists.txt | 7 ++++++- 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/install.sh b/install.sh index 0d59634ae..dac6b27c4 100755 --- a/install.sh +++ b/install.sh @@ -24,6 +24,7 @@ function display_help() echo " [--custom-target] link against custom target (e.g. host, device)" echo " [-v|--rocm-dev] Set specific rocm-dev version" echo " [-b|--rocblas] Set specific rocblas version" + echo " [--rocblas-path] Set specific path to custom built rocblas" } # This function is helpful for dockerfiles that do not have sudo installed, but the default user is root @@ -147,18 +148,22 @@ install_packages( ) fi # Custom rocblas installation - if [[ -z ${custom_rocblas+foo} ]]; then - # Install base rocblas package unless -b/--rocblas flag is passed - library_dependencies_ubuntu+=( "rocblas" ) - library_dependencies_centos+=( "rocblas" ) - library_dependencies_fedora+=( "rocblas" ) - library_dependencies_sles+=( "rocblas" ) - else - # Install rocm-specific rocblas package - library_dependencies_ubuntu+=( "${custom_rocblas}" ) - library_dependencies_centos+=( "${custom_rocblas}" ) - library_dependencies_fedora+=( "${custom_rocblas}" ) - library_dependencies_sles+=( "${custom_rocblas}" ) + # Do not install rocblas if --rocblas_path flag is set, + # as we will be building against our own rocblas intead. + if [[ -z ${rocblas_path+foo} ]]; then + if [[ -z ${custom_rocblas+foo} ]]; then + # Install base rocblas package unless -b/--rocblas flag is passed + library_dependencies_ubuntu+=( "rocblas" ) + library_dependencies_centos+=( "rocblas" ) + library_dependencies_fedora+=( "rocblas" ) + library_dependencies_sles+=( "rocblas" ) + else + # Install rocm-specific rocblas package + library_dependencies_ubuntu+=( "${custom_rocblas}" ) + library_dependencies_centos+=( "${custom_rocblas}" ) + library_dependencies_fedora+=( "${custom_rocblas}" ) + library_dependencies_sles+=( "${custom_rocblas}" ) + fi fi if [[ "${build_solver}" == true ]]; then @@ -286,7 +291,7 @@ compiler=g++ # check if we have a modern version of getopt that can handle whitespace and long parameters getopt -T if [[ $? -eq 4 ]]; then - GETOPT_PARSE=$(getopt --name "${0}" --longoptions help,install,clients,no-solver,dependencies,debug,hip-clang,compiler:,cuda,cmakepp,relocatable:,rocm-dev:,rocblas:,custom-target: --options rhicndgp:v:b: -- "$@") + GETOPT_PARSE=$(getopt --name "${0}" --longoptions help,install,clients,no-solver,dependencies,debug,hip-clang,compiler:,cuda,cmakepp,relocatable:,rocm-dev:,rocblas:,rocblas-path:,custom-target: --options rhicndgp:v:b: -- "$@") else echo "Need a new version of getopt" exit 1 @@ -344,6 +349,9 @@ while true; do -b|--rocblas) custom_rocblas=${2} shift 2;; + --rocblas-path) + rocblas_path=${2} + shift 2 ;; --prefix) install_prefix=${2} shift 2 ;; @@ -438,6 +446,11 @@ pushd . cmake_common_options="${cmake_common_options} -DCUSTOM_TARGET=${custom_target}" fi + # custom rocblas + if [[ ${rocblas_path+foo} ]]; then + cmake_common_options="${cmake_common_options} -DCUSTOM_ROCBLAS=${rocblas_path}" + fi + # Build library if [[ "${build_relocatable}" == true ]]; then CXX=${compiler} ${cmake_executable} ${cmake_common_options} ${cmake_client_options} -DCPACK_SET_DESTDIR=OFF -DCMAKE_INSTALL_PREFIX="${rocm_path}" \ diff --git a/library/src/CMakeLists.txt b/library/src/CMakeLists.txt index f21a7046c..258d742d3 100755 --- a/library/src/CMakeLists.txt +++ b/library/src/CMakeLists.txt @@ -49,7 +49,12 @@ target_include_directories( hipblas # Build hipblas from source on AMD platform if( NOT CUDA_FOUND ) if( NOT TARGET rocblas ) - find_package( rocblas REQUIRED CONFIG PATHS /opt/rocm /opt/rocm/rocblas ) + if( CUSTOM_ROCBLAS ) + set ( ENV{rocblas_DIR} ${CUSTOM_ROCBLAS}) + find_package( rocblas REQUIRED CONFIG NO_CMAKE_PATH ) + else( ) + find_package( rocblas REQUIRED CONFIG PATHS /opt/rocm /opt/rocm/rocblas ) + endif( ) endif( ) target_compile_definitions( hipblas PRIVATE __HIP_PLATFORM_HCC__ ) From d4c962b353d32539c0a1792ef0618b9abc40a343 Mon Sep 17 00:00:00 2001 From: Lee Killough Date: Fri, 5 Jun 2020 21:43:39 -0400 Subject: [PATCH 4/9] Use PATH and robustly test for clang-format --- .githooks/pre-commit | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/.githooks/pre-commit b/.githooks/pre-commit index 13f070903..87fee17b1 100755 --- a/.githooks/pre-commit +++ b/.githooks/pre-commit @@ -4,9 +4,7 @@ # are installed, and if so, uses the installed version to format # the staged changes. -set -x - -format=/opt/rocm/hcc/bin/clang-format +export PATH=/usr/bin:/bin:/opt/rocm/llvm/bin:/opt/rocm/hcc/bin # Redirect stdout to stderr. exec >&2 @@ -48,11 +46,11 @@ for file in $files; do done # if clang-format exists, run it on C/C++ files -if [[ -x $format ]]; then +if command -v clang-format >/dev/null; then for file in $files; do if [[ -e $file ]] && echo $file | grep -Eq '\.c$|\.h$|\.hpp$|\.cpp$|\.cl$|\.h\.in$|\.hpp\.in$|\.cpp\.in$'; then - echo "$format $file" - "$format" -i -style=file "$file" + echo "clang-format $file" + clang-format -i -style=file "$file" git add -u "$file" fi done From 3049ecf404e3ab7e6fd03d532debe8d2d9bb253a Mon Sep 17 00:00:00 2001 From: daineAMD <51674140+daineAMD@users.noreply.github.com> Date: Mon, 8 Jun 2020 12:07:19 -0600 Subject: [PATCH 5/9] Add trsm_ex (#231) --- clients/gtest/CMakeLists.txt | 1 + clients/gtest/trsm_ex_gtest.cpp | 325 ++++++++++++++++++ clients/include/hipblas_fortran.f90 | 172 ++++----- clients/include/hipblas_fortran.hpp | 54 +++ clients/include/testing_trsm_batched_ex.hpp | 271 +++++++++++++++ clients/include/testing_trsm_ex.hpp | 221 ++++++++++++ .../testing_trsm_strided_batched_ex.hpp | 258 ++++++++++++++ library/include/hipblas.h | 54 +++ library/src/hcc_detail/hipblas.cpp | 113 ++++++ library/src/nvcc_detail/hipblas.cpp | 64 ++++ 10 files changed, 1447 insertions(+), 86 deletions(-) create mode 100644 clients/gtest/trsm_ex_gtest.cpp create mode 100644 clients/include/testing_trsm_batched_ex.hpp create mode 100644 clients/include/testing_trsm_ex.hpp create mode 100644 clients/include/testing_trsm_strided_batched_ex.hpp diff --git a/clients/gtest/CMakeLists.txt b/clients/gtest/CMakeLists.txt index 956d6cd5d..3cb23d56f 100644 --- a/clients/gtest/CMakeLists.txt +++ b/clients/gtest/CMakeLists.txt @@ -86,6 +86,7 @@ set(hipblas_test_source syrk_gtest.cpp syr2k_gtest.cpp trsm_gtest.cpp + trsm_ex_gtest.cpp trmm_gtest.cpp trtri_gtest.cpp ) diff --git a/clients/gtest/trsm_ex_gtest.cpp b/clients/gtest/trsm_ex_gtest.cpp new file mode 100644 index 000000000..2b27c2615 --- /dev/null +++ b/clients/gtest/trsm_ex_gtest.cpp @@ -0,0 +1,325 @@ +/* ************************************************************************ + * Copyright 2016-2020 Advanced Micro Devices, Inc. + * + * ************************************************************************ */ + +#include "testing_trsm_batched_ex.hpp" +#include "testing_trsm_ex.hpp" +#include "testing_trsm_strided_batched_ex.hpp" +#include "utility.h" +#include +#include +#include +#include + +using ::testing::Combine; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::ValuesIn; +using namespace std; + +// only GCC/VS 2010 comes with std::tr1::tuple, but it is unnecessary, std::tuple is good enough; + +typedef std::tuple, double, vector, double, int, bool> trsm_ex_tuple; + +/* ===================================================================== +README: This file contains testers to verify the correctness of + BLAS routines with google test + + It is supposed to be played/used by advance / expert users + Normal users only need to get the library routines without testers + =================================================================== */ + +/* ===================================================================== +Advance users only: BrainStorm the parameters but do not make artificial one which invalidates the +matrix. +like lda pairs with M, and "lda must >= M". case "lda < M" will be guarded by argument-checkers +inside API of course. +Yet, the goal of this file is to verify result correctness not argument-checkers. + +Representative sampling is sufficient, endless brute-force sampling is not necessary +=================================================================== */ + +// vector of vector, each vector is a {M, N, lda, ldb}; +// add/delete as a group +const vector> full_matrix_size_range = { + {192, 192, 192, 192}, {640, 640, 960, 960}, + // {1000, 1000, 1000, 1000}, + // {2000, 2000, 2000, 2000}, +}; + +const vector alpha_range = {1.0, -5.0}; + +// vector of vector, each pair is a {side, uplo, transA, diag}; +// side has two option "Lefe (L), Right (R)" +// uplo has two "Lower (L), Upper (U)" +// transA has three ("Nontranspose (N), conjTranspose(C), transpose (T)") +// for single/double precision, 'C'(conjTranspose) will downgraded to 'T' (transpose) automatically +// in strsm/dtrsm, +// so we use 'C' +// Diag has two options ("Non-unit (N), Unit (U)") + +// Each letter is capitalizied, e.g. do not use 'l', but use 'L' instead. + +const vector> side_uplo_transA_diag_range = { + {'L', 'L', 'N', 'N'}, + {'R', 'L', 'N', 'N'}, + {'L', 'U', 'C', 'N'}, +}; + +const vector stride_scale_range = {2.5}; + +const vector batch_count_range = {-1, 0, 1, 2}; + +const bool is_fortran[] = {false, true}; + +/* ===============Google Unit Test==================================================== */ + +/* ===================================================================== + BLAS-3 trsm: +=================================================================== */ + +/* ============================Setup Arguments======================================= */ + +// Please use "class Arguments" (see utility.hpp) to pass parameters to templated testers; +// Some routines may not touch/use certain "members" of objects "argus". +// like BLAS-1 Scal does not have lda, BLAS-2 GEMV does not have ldb, ldc; +// That is fine. These testers & routines will leave untouched members alone. +// Do not use std::tuple to directly pass parameters to testers +// by std:tuple, you have unpack it with extreme care for each one by like "std::get<0>" which is +// not intuitive and error-prone + +Arguments setup_trsm_ex_arguments(trsm_ex_tuple tup) +{ + + vector matrix_size = std::get<0>(tup); + double alpha = std::get<1>(tup); + vector side_uplo_transA_diag = std::get<2>(tup); + double stride_scale = std::get<3>(tup); + int batch_count = std::get<4>(tup); + bool fortran = std::get<5>(tup); + + Arguments arg; + + // see the comments about matrix_size_range above + arg.M = matrix_size[0]; + arg.N = matrix_size[1]; + arg.lda = matrix_size[2]; + arg.ldb = matrix_size[3]; + + arg.alpha = alpha; + + arg.side_option = side_uplo_transA_diag[0]; + arg.uplo_option = side_uplo_transA_diag[1]; + arg.transA_option = side_uplo_transA_diag[2]; + arg.diag_option = side_uplo_transA_diag[3]; + + arg.timing = 1; + + arg.stride_scale = stride_scale; + arg.batch_count = batch_count; + + arg.fortran = fortran; + + return arg; +} + +class trsm_ex_gtest : public ::TestWithParam +{ +protected: + trsm_ex_gtest() {} + virtual ~trsm_ex_gtest() {} + virtual void SetUp() {} + virtual void TearDown() {} +}; + +TEST_P(trsm_ex_gtest, trsm_ex_gtest_float) +{ + // GetParam return a tuple. Tee setup routine unpack the tuple + // and initializes arg(Arguments) which will be passed to testing routine + // The Arguments data struture have physical meaning associated. + // while the tuple is non-intuitive. + + Arguments arg = setup_trsm_ex_arguments(GetParam()); + arg.compute_type = HIPBLAS_R_32F; + + hipblasStatus_t status = testing_trsm_ex(arg); + + // if not success, then the input argument is problematic, so detect the error message + if(status != HIPBLAS_STATUS_SUCCESS) + { + + if(arg.M < 0 || arg.N < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else if(arg.side_option == 'L' ? arg.lda < arg.M : arg.lda < arg.N) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else if(arg.ldb < arg.M) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); // for cuda + } + } +} + +TEST_P(trsm_ex_gtest, trsm_gtest_ex_double_complex) +{ + // GetParam return a tuple. Tee setup routine unpack the tuple + // and initializes arg(Arguments) which will be passed to testing routine + // The Arguments data struture have physical meaning associated. + // while the tuple is non-intuitive. + + Arguments arg = setup_trsm_ex_arguments(GetParam()); + arg.compute_type = HIPBLAS_C_64F; + + hipblasStatus_t status = testing_trsm_ex(arg); + + // if not success, then the input argument is problematic, so detect the error message + if(status != HIPBLAS_STATUS_SUCCESS) + { + + if(arg.M < 0 || arg.N < 0 || arg.ldb < arg.M + || (arg.side_option == 'L' ? arg.lda < arg.M : arg.lda < arg.N)) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); // for cuda + } + } +} + +TEST_P(trsm_ex_gtest, trsm_batched_ex_gtest_float) +{ + // GetParam return a tuple. Tee setup routine unpack the tuple + // and initializes arg(Arguments) which will be passed to testing routine + // The Arguments data struture have physical meaning associated. + // while the tuple is non-intuitive. + + Arguments arg = setup_trsm_ex_arguments(GetParam()); + arg.compute_type = HIPBLAS_R_32F; + + hipblasStatus_t status = testing_trsm_batched_ex(arg); + + // if not success, then the input argument is problematic, so detect the error message + if(status != HIPBLAS_STATUS_SUCCESS) + { + + if(arg.M < 0 || arg.N < 0 || arg.lda < arg.K || arg.ldb < arg.M + || (arg.side_option == 'L' ? arg.lda < arg.M : arg.lda < arg.N) || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); // for cuda + } + } +} + +TEST_P(trsm_ex_gtest, trsm_batched_ex_gtest_double_complex) +{ + // GetParam return a tuple. Tee setup routine unpack the tuple + // and initializes arg(Arguments) which will be passed to testing routine + // The Arguments data struture have physical meaning associated. + // while the tuple is non-intuitive. + + Arguments arg = setup_trsm_ex_arguments(GetParam()); + arg.compute_type = HIPBLAS_C_64F; + + hipblasStatus_t status = testing_trsm_batched_ex(arg); + + // if not success, then the input argument is problematic, so detect the error message + if(status != HIPBLAS_STATUS_SUCCESS) + { + + if(arg.M < 0 || arg.N < 0 || arg.lda < arg.K || arg.ldb < arg.M + || (arg.side_option == 'L' ? arg.lda < arg.M : arg.lda < arg.N) || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); // for cuda + } + } +} + +TEST_P(trsm_ex_gtest, trsm_strided_batched_ex_gtest_float) +{ + // GetParam return a tuple. Tee setup routine unpack the tuple + // and initializes arg(Arguments) which will be passed to testing routine + // The Arguments data struture have physical meaning associated. + // while the tuple is non-intuitive. + + Arguments arg = setup_trsm_ex_arguments(GetParam()); + arg.compute_type = HIPBLAS_R_32F; + + hipblasStatus_t status = testing_trsm_strided_batched_ex(arg); + + // if not success, then the input argument is problematic, so detect the error message + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.M < 0 || arg.N < 0 || arg.ldb < arg.M + || (arg.side_option == 'L' ? arg.lda < arg.M : arg.lda < arg.N) || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); // for cuda + } + } +} + +TEST_P(trsm_ex_gtest, trsm_strided_batched_ex_gtest_double_complex) +{ + // GetParam return a tuple. Tee setup routine unpack the tuple + // and initializes arg(Arguments) which will be passed to testing routine + // The Arguments data struture have physical meaning associated. + // while the tuple is non-intuitive. + + Arguments arg = setup_trsm_ex_arguments(GetParam()); + arg.compute_type = HIPBLAS_C_64F; + + hipblasStatus_t status = testing_trsm_strided_batched_ex(arg); + + // if not success, then the input argument is problematic, so detect the error message + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.M < 0 || arg.N < 0 || arg.ldb < arg.M + || (arg.side_option == 'L' ? arg.lda < arg.M : arg.lda < arg.N) || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); // for cuda + } + } +} + +// notice we are using vector of vector +// so each elment in xxx_range is a avector, +// ValuesIn take each element (a vector) and combine them and feed them to test_p +// The combinations are { {M, N, lda, ldb}, alpha, {side, uplo, transA, diag} } + +// THis function mainly test the scope of matrix_size. the scope of side_uplo_transA_diag_range is +// small +// Testing order: side_uplo_transA_xx first, alpha_range second, full_matrix_size last +// i.e fix the matrix size and alpha, test all the side_uplo_transA_xx first. +INSTANTIATE_TEST_CASE_P(hipblasTrsm_matrix_size, + trsm_ex_gtest, + Combine(ValuesIn(full_matrix_size_range), + ValuesIn(alpha_range), + ValuesIn(side_uplo_transA_diag_range), + ValuesIn(stride_scale_range), + ValuesIn(batch_count_range), + ValuesIn(is_fortran))); diff --git a/clients/include/hipblas_fortran.f90 b/clients/include/hipblas_fortran.f90 index 3fc5c4f87..cd1db28b2 100644 --- a/clients/include/hipblas_fortran.f90 +++ b/clients/include/hipblas_fortran.f90 @@ -12113,91 +12113,91 @@ function hipblasGemmStridedBatchedExFortran(handle, transA, transB, m, n, k, alp batch_count, compute_type, algo, solution_index, flags) end function hipblasGemmStridedBatchedExFortran - ! ! trsmEx - ! function hipblasTrsmExFortran(handle, side, uplo, transA, diag, m, n, alpha, A, lda, & - ! B, ldb, invA, invA_size, compute_type) & - ! result(res) & - ! bind(c, name = 'hipblasTrsmExFortran') - ! use iso_c_binding - ! use hipblas_enums - ! implicit none - ! type(c_ptr), value :: handle - ! integer(kind(HIPBLAS_SIDE_LEFT)), value :: side - ! integer(kind(HIPBLAS_FILL_MODE_UPPER)), value :: uplo - ! integer(kind(HIPBLAS_OP_N)), value :: transA - ! integer(kind(HIPBLAS_DIAG_UNIT)), value :: diag - ! integer(c_int), value :: m - ! integer(c_int), value :: n - ! type(c_ptr), value :: alpha - ! type(c_ptr), value :: A - ! integer(c_int), value :: lda - ! type(c_ptr), value :: B - ! integer(c_int), value :: ldb - ! type(c_ptr), value :: invA - ! integer(c_int), value :: invA_size - ! integer(kind(HIPBLAS_R_16F)), value :: compute_type - ! integer(c_int) :: res - ! res = hipblasTrsmEx(handle, side, uplo, transA, diag, m, n, alpha,& - ! A, lda, B, ldb, invA, invA_size, compute_type) - ! end function hipblasTrsmExFortran - - ! function hipblasTrsmBatchedExFortran(handle, side, uplo, transA, diag, m, n, alpha, A, lda, & - ! B, ldb, batch_count, invA, invA_size, compute_type) & - ! result(res) & - ! bind(c, name = 'hipblasTrsmBatchedExFortran') - ! use iso_c_binding - ! use hipblas_enums - ! implicit none - ! type(c_ptr), value :: handle - ! integer(kind(HIPBLAS_SIDE_LEFT)), value :: side - ! integer(kind(HIPBLAS_FILL_MODE_UPPER)), value :: uplo - ! integer(kind(HIPBLAS_OP_N)), value :: transA - ! integer(kind(HIPBLAS_DIAG_UNIT)), value :: diag - ! integer(c_int), value :: m - ! integer(c_int), value :: n - ! type(c_ptr), value :: alpha - ! type(c_ptr), value :: A - ! integer(c_int), value :: lda - ! type(c_ptr), value :: B - ! integer(c_int), value :: ldb - ! integer(c_int), value :: batch_count - ! type(c_ptr), value :: invA - ! integer(c_int), value :: invA_size - ! integer(kind(HIPBLAS_R_16F)), value :: compute_type - ! integer(c_int) :: res - ! res = hipblasTrsmBatchedEx(handle, side, uplo, transA, diag, m, n, alpha,& - ! A, lda, B, ldb, batch_count, invA, invA_size, compute_type) - ! end function hipblasTrsmBatchedExFortran - - ! function hipblasTrsmStridedBatchedExFortran(handle, side, uplo, transA, diag, m, n, alpha, A, lda, stride_A, & - ! B, ldb, stride_B, batch_count, invA, invA_size, stride_invA, compute_type) & - ! result(res) & - ! bind(c, name = 'hipblasTrsmStridedBatchedExFortran') - ! use iso_c_binding - ! use hipblas_enums - ! implicit none - ! type(c_ptr), value :: handle - ! integer(kind(HIPBLAS_SIDE_LEFT)), value :: side - ! integer(kind(HIPBLAS_FILL_MODE_UPPER)), value :: uplo - ! integer(kind(HIPBLAS_OP_N)), value :: transA - ! integer(kind(HIPBLAS_DIAG_UNIT)), value :: diag - ! integer(c_int), value :: m - ! integer(c_int), value :: n - ! type(c_ptr), value :: alpha - ! type(c_ptr), value :: A - ! integer(c_int), value :: lda - ! integer(c_int64_t), value :: stride_A - ! type(c_ptr), value :: B - ! integer(c_int), value :: ldb - ! integer(c_int64_t), value :: stride_B - ! integer(c_int), value :: batch_count - ! type(c_ptr), value :: invA - ! integer(c_int), value :: invA_size - ! integer(c_int64_t), value :: stride_invA - ! integer(kind(HIPBLAS_R_16F)), value :: compute_type - ! integer(c_int) :: res - ! res = hipblasTrsmStridedBatchedEx(handle, side, uplo, transA, diag, m, n, alpha,& - ! A, lda, stride_A, B, ldb, stride_B, batch_count, invA, invA_size, stride_invA, compute_type) - ! end function hipblasTrsmStridedBatchedExFortran + ! trsmEx + function hipblasTrsmExFortran(handle, side, uplo, transA, diag, m, n, alpha, A, lda, & + B, ldb, invA, invA_size, compute_type) & + result(res) & + bind(c, name = 'hipblasTrsmExFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_SIDE_LEFT)), value :: side + integer(kind(HIPBLAS_FILL_MODE_UPPER)), value :: uplo + integer(kind(HIPBLAS_OP_N)), value :: transA + integer(kind(HIPBLAS_DIAG_UNIT)), value :: diag + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: alpha + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: invA + integer(c_int), value :: invA_size + integer(kind(HIPBLAS_R_16F)), value :: compute_type + integer(c_int) :: res + res = hipblasTrsmEx(handle, side, uplo, transA, diag, m, n, alpha,& + A, lda, B, ldb, invA, invA_size, compute_type) + end function hipblasTrsmExFortran + + function hipblasTrsmBatchedExFortran(handle, side, uplo, transA, diag, m, n, alpha, A, lda, & + B, ldb, batch_count, invA, invA_size, compute_type) & + result(res) & + bind(c, name = 'hipblasTrsmBatchedExFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_SIDE_LEFT)), value :: side + integer(kind(HIPBLAS_FILL_MODE_UPPER)), value :: uplo + integer(kind(HIPBLAS_OP_N)), value :: transA + integer(kind(HIPBLAS_DIAG_UNIT)), value :: diag + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: alpha + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: B + integer(c_int), value :: ldb + integer(c_int), value :: batch_count + type(c_ptr), value :: invA + integer(c_int), value :: invA_size + integer(kind(HIPBLAS_R_16F)), value :: compute_type + integer(c_int) :: res + res = hipblasTrsmBatchedEx(handle, side, uplo, transA, diag, m, n, alpha,& + A, lda, B, ldb, batch_count, invA, invA_size, compute_type) + end function hipblasTrsmBatchedExFortran + + function hipblasTrsmStridedBatchedExFortran(handle, side, uplo, transA, diag, m, n, alpha, A, lda, stride_A, & + B, ldb, stride_B, batch_count, invA, invA_size, stride_invA, compute_type) & + result(res) & + bind(c, name = 'hipblasTrsmStridedBatchedExFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_SIDE_LEFT)), value :: side + integer(kind(HIPBLAS_FILL_MODE_UPPER)), value :: uplo + integer(kind(HIPBLAS_OP_N)), value :: transA + integer(kind(HIPBLAS_DIAG_UNIT)), value :: diag + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: alpha + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int64_t), value :: stride_A + type(c_ptr), value :: B + integer(c_int), value :: ldb + integer(c_int64_t), value :: stride_B + integer(c_int), value :: batch_count + type(c_ptr), value :: invA + integer(c_int), value :: invA_size + integer(c_int64_t), value :: stride_invA + integer(kind(HIPBLAS_R_16F)), value :: compute_type + integer(c_int) :: res + res = hipblasTrsmStridedBatchedEx(handle, side, uplo, transA, diag, m, n, alpha,& + A, lda, stride_A, B, ldb, stride_B, batch_count, invA, invA_size, stride_invA, compute_type) + end function hipblasTrsmStridedBatchedExFortran end module hipblas_interface diff --git a/clients/include/hipblas_fortran.hpp b/clients/include/hipblas_fortran.hpp index 70d66319a..e940c93c5 100644 --- a/clients/include/hipblas_fortran.hpp +++ b/clients/include/hipblas_fortran.hpp @@ -6745,6 +6745,60 @@ hipblasStatus_t hipblasGemmStridedBatchedExFortran(hipblasHandle_t handle, int batch_count, hipblasDatatype_t compute_type, hipblasGemmAlgo_t algo); + +// trsm_ex +hipblasStatus_t hipblasTrsmExFortran(hipblasHandle_t handle, + hipblasSideMode_t side, + hipblasFillMode_t uplo, + hipblasOperation_t transA, + hipblasDiagType_t diag, + int m, + int n, + const void* alpha, + void* A, + int lda, + void* B, + int ldb, + const void* invA, + int invA_size, + hipblasDatatype_t compute_type); + +hipblasStatus_t hipblasTrsmBatchedExFortran(hipblasHandle_t handle, + hipblasSideMode_t side, + hipblasFillMode_t uplo, + hipblasOperation_t transA, + hipblasDiagType_t diag, + int m, + int n, + const void* alpha, + void* A, + int lda, + void* B, + int ldb, + int batch_count, + const void* invA, + int invA_size, + hipblasDatatype_t compute_type); + +hipblasStatus_t hipblasTrsmStridedBatchedExFortran(hipblasHandle_t handle, + hipblasSideMode_t side, + hipblasFillMode_t uplo, + hipblasOperation_t transA, + hipblasDiagType_t diag, + int m, + int n, + const void* alpha, + void* A, + int lda, + int stride_A, + void* B, + int ldb, + int stride_B, + int batch_count, + const void* invA, + int invA_size, + int stride_invA, + hipblasDatatype_t compute_type); } #endif diff --git a/clients/include/testing_trsm_batched_ex.hpp b/clients/include/testing_trsm_batched_ex.hpp new file mode 100644 index 000000000..58c485596 --- /dev/null +++ b/clients/include/testing_trsm_batched_ex.hpp @@ -0,0 +1,271 @@ +/* ************************************************************************ + * Copyright 2016-2020 Advanced Micro Devices, Inc. + * + * ************************************************************************ */ + +#include +#include +#include +#include + +#include "cblas_interface.h" +#include "flops.h" +#include "hipblas.hpp" +#include "hipblas_fortran.hpp" +#include "norm.h" +#include "unit.h" +#include "utility.h" + +using namespace std; + +#define TRSM_BLOCK 128 + +/* ============================================================================================ */ + +template +hipblasStatus_t testing_trsm_batched_ex(Arguments argus) +{ + bool FORTRAN = argus.fortran; + auto hipblasTrsmBatchedExFn = FORTRAN ? hipblasTrsmBatchedExFortran : hipblasTrsmBatchedEx; + + int M = argus.M; + int N = argus.N; + int lda = argus.lda; + int ldb = argus.ldb; + + char char_side = argus.side_option; + char char_uplo = argus.uplo_option; + char char_transA = argus.transA_option; + char char_diag = argus.diag_option; + T alpha = argus.alpha; + int batch_count = argus.batch_count; + + hipblasSideMode_t side = char2hipblas_side(char_side); + hipblasFillMode_t uplo = char2hipblas_fill(char_uplo); + hipblasOperation_t transA = char2hipblas_operation(char_transA); + hipblasDiagType_t diag = char2hipblas_diagonal(char_diag); + + int K = (side == HIPBLAS_SIDE_LEFT ? M : N); + int A_size = lda * K; + int B_size = ldb * N; + + hipblasStatus_t status = HIPBLAS_STATUS_SUCCESS; + + // check here to prevent undefined memory allocation error + // TODO: Workaround for cuda tests, not actually testing return values + if(M < 0 || N < 0 || lda < K || ldb < M || batch_count < 0) + { + return HIPBLAS_STATUS_INVALID_VALUE; + } + if(!M || !N || !lda || !ldb || !batch_count) + { + return HIPBLAS_STATUS_SUCCESS; + } + // Naming: dK is in GPU (device) memory. hK is in CPU (host) memory + host_vector hA[batch_count]; + host_vector hB[batch_count]; + host_vector hB_copy[batch_count]; + host_vector hX[batch_count]; + + device_batch_vector bA(batch_count, A_size); + device_batch_vector bB(batch_count, B_size); + device_batch_vector binvA(batch_count, TRSM_BLOCK * K); + + device_vector dA(batch_count); + device_vector dB(batch_count); + device_vector dinvA(batch_count); + + int last = batch_count - 1; + if(!dA || !dB || !dinvA || (!bA[last] && A_size) || (!bB[last] && B_size) || !binvA[last]) + { + return HIPBLAS_STATUS_ALLOC_FAILED; + } + + double gpu_time_used, cpu_time_used; + double hipblasGflops, cblas_gflops; + + hipblasHandle_t handle; + hipblasCreate(&handle); + + // Initial hA on CPU + srand(1); + for(int b = 0; b < batch_count; b++) + { + hA[b] = host_vector(A_size); + hB[b] = host_vector(B_size); + hB_copy[b] = host_vector(B_size); + hX[b] = host_vector(B_size); + + hipblas_init_symmetric(hA[b], K, lda); + // pad untouched area into zero + for(int i = K; i < lda; i++) + { + for(int j = 0; j < K; j++) + { + hA[b][i + j * lda] = 0.0; + } + } + + // proprocess the matrix to avoid ill-conditioned matrix + vector ipiv(K); + cblas_getrf(K, K, hA[b].data(), lda, ipiv.data()); + for(int i = 0; i < K; i++) + { + for(int j = i; j < K; j++) + { + hA[b][i + j * lda] = hA[b][j + i * lda]; + if(diag == HIPBLAS_DIAG_UNIT) + { + if(i == j) + hA[b][i + j * lda] = 1.0; + } + } + } + + // Initial hB, hX on CPU + hipblas_init(hB[b], M, N, ldb); + // pad untouched area into zero + for(int i = M; i < ldb; i++) + { + for(int j = 0; j < N; j++) + { + hB[b][i + j * ldb] = 0.0; + } + } + hX[b] = hB[b]; // original solution hX + + // Calculate hB = hA*hX; + cblas_trmm(side, + uplo, + transA, + diag, + M, + N, + T(1.0) / alpha, + (const T*)hA[b].data(), + lda, + hB[b].data(), + ldb); + + hB_copy[b] = hB[b]; + + // copy data from CPU to device + CHECK_HIP_ERROR(hipMemcpy(bA[b], hA[b], sizeof(T) * A_size, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(bB[b], hB[b], sizeof(T) * B_size, hipMemcpyHostToDevice)); + } + CHECK_HIP_ERROR(hipMemcpy(dA, bA, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dB, bB, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dinvA, binvA, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + + /* ===================================================================== + HIPBLAS + =================================================================== */ + + int stride_A = TRSM_BLOCK * lda + TRSM_BLOCK; + int stride_invA = TRSM_BLOCK * TRSM_BLOCK; + int blocks = K / TRSM_BLOCK; + + for(int b = 0; b < batch_count; b++) + { + if(blocks > 0) + { + status = hipblasTrtriStridedBatched(handle, + uplo, + diag, + TRSM_BLOCK, + bA[b], + lda, + stride_A, + binvA[b], + TRSM_BLOCK, + stride_invA, + blocks); + if(status != HIPBLAS_STATUS_SUCCESS) + { + hipblasDestroy(handle); + return status; + } + } + + if(K % TRSM_BLOCK != 0 || blocks == 0) + { + status = hipblasTrtriStridedBatched(handle, + uplo, + diag, + K - TRSM_BLOCK * blocks, + bA[b] + stride_A * blocks, + lda, + stride_A, + binvA[b] + stride_invA * blocks, + TRSM_BLOCK, + stride_invA, + 1); + if(status != HIPBLAS_STATUS_SUCCESS) + { + hipblasDestroy(handle); + return status; + } + } + } + + status = hipblasTrsmBatchedExFn(handle, + side, + uplo, + transA, + diag, + M, + N, + &alpha, + dA, + lda, + dB, + ldb, + batch_count, + dinvA, + TRSM_BLOCK * K, + argus.compute_type); + if(status != HIPBLAS_STATUS_SUCCESS) + { + hipblasDestroy(handle); + return status; + } + + // copy output from device to CPU + for(int b = 0; b < batch_count; b++) + CHECK_HIP_ERROR(hipMemcpy(hB[b], bB[b], sizeof(T) * B_size, hipMemcpyDeviceToHost)); + + if(argus.unit_check) + { + /* ===================================================================== + CPU BLAS + =================================================================== */ + + for(int b = 0; b < batch_count; b++) + { + cblas_trsm(side, + uplo, + transA, + diag, + M, + N, + alpha, + (const T*)hA[b].data(), + lda, + hB_copy[b].data(), + ldb); + } + + // if enable norm check, norm check is invasive + real_t eps = std::numeric_limits>::epsilon(); + double tolerance = eps * 40 * M; + + for(int b = 0; b < batch_count; b++) + { + double error = norm_check_general('F', M, N, ldb, hB_copy[b].data(), hB[b].data()); + unit_check_error(error, tolerance); + } + } + + hipblasDestroy(handle); + return HIPBLAS_STATUS_SUCCESS; +} diff --git a/clients/include/testing_trsm_ex.hpp b/clients/include/testing_trsm_ex.hpp new file mode 100644 index 000000000..5c6823d74 --- /dev/null +++ b/clients/include/testing_trsm_ex.hpp @@ -0,0 +1,221 @@ +/* ************************************************************************ + * Copyright 2016-2020 Advanced Micro Devices, Inc. + * + * ************************************************************************ */ + +#include +#include +#include +#include + +#include "cblas_interface.h" +#include "flops.h" +#include "hipblas.hpp" +#include "hipblas_fortran.hpp" +#include "norm.h" +#include "unit.h" +#include "utility.h" + +using namespace std; + +#define TRSM_BLOCK 128 + +/* ============================================================================================ */ + +template +hipblasStatus_t testing_trsm_ex(Arguments argus) +{ + bool FORTRAN = argus.fortran; + auto hipblasTrsmExFn = FORTRAN ? hipblasTrsmExFortran : hipblasTrsmEx; + + int M = argus.M; + int N = argus.N; + int lda = argus.lda; + int ldb = argus.ldb; + + char char_side = argus.side_option; + char char_uplo = argus.uplo_option; + char char_transA = argus.transA_option; + char char_diag = argus.diag_option; + T alpha = argus.alpha; + + hipblasSideMode_t side = char2hipblas_side(char_side); + hipblasFillMode_t uplo = char2hipblas_fill(char_uplo); + hipblasOperation_t transA = char2hipblas_operation(char_transA); + hipblasDiagType_t diag = char2hipblas_diagonal(char_diag); + + int K = (side == HIPBLAS_SIDE_LEFT ? M : N); + int A_size = lda * K; + int B_size = ldb * N; + + hipblasStatus_t status = HIPBLAS_STATUS_SUCCESS; + + // check here to prevent undefined memory allocation error + if(M < 0 || N < 0 || lda < K || ldb < M) + { + return HIPBLAS_STATUS_INVALID_VALUE; + } + // Naming: dK is in GPU (device) memory. hK is in CPU (host) memory + host_vector hA(A_size); + host_vector hB(B_size); + host_vector hB_copy(B_size); + host_vector hX(B_size); + + device_vector dA(A_size); + device_vector dB(B_size); + device_vector dinvA(TRSM_BLOCK * K); + + double gpu_time_used, cpu_time_used; + double hipblasGflops, cblas_gflops; + double rocblas_error; + + hipblasHandle_t handle; + hipblasCreate(&handle); + + // Initial hA on CPU + srand(1); + hipblas_init_symmetric(hA, K, lda); + // pad untouched area into zero + for(int i = K; i < lda; i++) + { + for(int j = 0; j < K; j++) + { + hA[i + j * lda] = 0.0; + } + } + // proprocess the matrix to avoid ill-conditioned matrix + vector ipiv(K); + cblas_getrf(K, K, hA.data(), lda, ipiv.data()); + for(int i = 0; i < K; i++) + { + for(int j = i; j < K; j++) + { + hA[i + j * lda] = hA[j + i * lda]; + if(diag == HIPBLAS_DIAG_UNIT) + { + if(i == j) + hA[i + j * lda] = 1.0; + } + } + } + + // Initial hB, hX on CPU + hipblas_init(hB, M, N, ldb); + // pad untouched area into zero + for(int i = M; i < ldb; i++) + { + for(int j = 0; j < N; j++) + { + hB[i + j * ldb] = 0.0; + } + } + hX = hB; // original solution hX + + // Calculate hB = hA*hX; + cblas_trmm( + side, uplo, transA, diag, M, N, T(1.0) / alpha, (const T*)hA.data(), lda, hB.data(), ldb); + + hB_copy = hB; + + // copy data from CPU to device + CHECK_HIP_ERROR(hipMemcpy(dA, hA.data(), sizeof(T) * A_size, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dB, hB.data(), sizeof(T) * B_size, hipMemcpyHostToDevice)); + + int stride_A = TRSM_BLOCK * lda + TRSM_BLOCK; + int stride_invA = TRSM_BLOCK * TRSM_BLOCK; + int blocks = K / TRSM_BLOCK; + + /* ===================================================================== + HIPBLAS + =================================================================== */ + // Calculate invA + if(blocks > 0) + { + status = hipblasTrtriStridedBatched(handle, + uplo, + diag, + TRSM_BLOCK, + dA, + lda, + stride_A, + dinvA, + TRSM_BLOCK, + stride_invA, + blocks); + if(status != HIPBLAS_STATUS_SUCCESS) + { + hipblasDestroy(handle); + return status; + } + } + + if(K % TRSM_BLOCK != 0 || blocks == 0) + { + status = hipblasTrtriStridedBatched(handle, + uplo, + diag, + K - TRSM_BLOCK * blocks, + dA + stride_A * blocks, + lda, + stride_A, + dinvA + stride_invA * blocks, + TRSM_BLOCK, + stride_invA, + 1); + + if(blocks > 0) + { + if(status != HIPBLAS_STATUS_SUCCESS) + { + hipblasDestroy(handle); + return status; + } + } + } + + status = hipblasTrsmExFn(handle, + side, + uplo, + transA, + diag, + M, + N, + &alpha, + dA, + lda, + dB, + ldb, + dinvA, + TRSM_BLOCK * K, + argus.compute_type); + if(status != HIPBLAS_STATUS_SUCCESS) + { + hipblasDestroy(handle); + return status; + } + + // copy output from device to CPU + CHECK_HIP_ERROR(hipMemcpy(hB.data(), dB, sizeof(T) * B_size, hipMemcpyDeviceToHost)); + + if(argus.unit_check) + { + /* ===================================================================== + CPU BLAS + =================================================================== */ + + cblas_trsm( + side, uplo, transA, diag, M, N, alpha, (const T*)hA.data(), lda, hB_copy.data(), ldb); + + // print_matrix(hB_copy, hB, min(M, 3), min(N,3), ldb); + + // if enable norm check, norm check is invasive + real_t eps = std::numeric_limits>::epsilon(); + double tolerance = eps * 40 * M; + + double error = norm_check_general('F', M, N, ldb, hB_copy.data(), hB.data()); + unit_check_error(error, tolerance); + } + + hipblasDestroy(handle); + return HIPBLAS_STATUS_SUCCESS; +} diff --git a/clients/include/testing_trsm_strided_batched_ex.hpp b/clients/include/testing_trsm_strided_batched_ex.hpp new file mode 100644 index 000000000..2e37a4dc7 --- /dev/null +++ b/clients/include/testing_trsm_strided_batched_ex.hpp @@ -0,0 +1,258 @@ +/* ************************************************************************ + * Copyright 2016-2020 Advanced Micro Devices, Inc. + * + * ************************************************************************ */ + +#include +#include +#include +#include + +#include "cblas_interface.h" +#include "flops.h" +#include "hipblas.hpp" +#include "hipblas_fortran.hpp" +#include "norm.h" +#include "unit.h" +#include "utility.h" + +using namespace std; + +#define TRSM_BLOCK 128 + +/* ============================================================================================ */ + +template +hipblasStatus_t testing_trsm_strided_batched_ex(Arguments argus) +{ + bool FORTRAN = argus.fortran; + auto hipblasTrsmStridedBatchedExFn + = FORTRAN ? hipblasTrsmStridedBatchedEx : hipblasTrsmStridedBatchedEx; + + int M = argus.M; + int N = argus.N; + int lda = argus.lda; + int ldb = argus.ldb; + + char char_side = argus.side_option; + char char_uplo = argus.uplo_option; + char char_transA = argus.transA_option; + char char_diag = argus.diag_option; + T alpha = argus.alpha; + double stride_scale = argus.stride_scale; + int batch_count = argus.batch_count; + + hipblasSideMode_t side = char2hipblas_side(char_side); + hipblasFillMode_t uplo = char2hipblas_fill(char_uplo); + hipblasOperation_t transA = char2hipblas_operation(char_transA); + hipblasDiagType_t diag = char2hipblas_diagonal(char_diag); + + int K = (side == HIPBLAS_SIDE_LEFT ? M : N); + + int strideA = lda * K * stride_scale; + int strideB = ldb * N * stride_scale; + int stride_invA = TRSM_BLOCK * K; + int A_size = strideA * batch_count; + int B_size = strideB * batch_count; + int invA_size = stride_invA * batch_count; + + hipblasStatus_t status = HIPBLAS_STATUS_SUCCESS; + + // check here to prevent undefined memory allocation error + // TODO: Workaround for cuda tests, not actually testing return values + if(M < 0 || N < 0 || lda < K || ldb < M || batch_count < 0) + { + return HIPBLAS_STATUS_INVALID_VALUE; + } + if(!batch_count) + { + return HIPBLAS_STATUS_SUCCESS; + } + // Naming: dK is in GPU (device) memory. hK is in CPU (host) memory + host_vector hA(A_size); + host_vector hB(B_size); + host_vector hB_copy(B_size); + host_vector hX(B_size); + + device_vector dA(A_size); + device_vector dB(B_size); + device_vector dinvA(invA_size); + + double gpu_time_used, cpu_time_used; + double hipblasGflops, cblas_gflops; + + hipblasHandle_t handle; + hipblasCreate(&handle); + + // Initial hA on CPU + srand(1); + hipblas_init_symmetric(hA, K, lda, strideA, batch_count); + for(int b = 0; b < batch_count; b++) + { + T* hAb = hA.data() + b * strideA; + T* hBb = hB.data() + b * strideB; + + // pad ountouched area into zero + for(int i = K; i < lda; i++) + { + for(int j = 0; j < K; j++) + { + hAb[i + j * lda] = 0.0; + } + } + + // proprocess the matrix to avoid ill-conditioned matrix + vector ipiv(K); + cblas_getrf(K, K, hAb, lda, ipiv.data()); + for(int i = 0; i < K; i++) + { + for(int j = i; j < K; j++) + { + hAb[i + j * lda] = hAb[j + i * lda]; + if(diag == HIPBLAS_DIAG_UNIT) + { + if(i == j) + hAb[i + j * lda] = 1.0; + } + } + } + + // Initial hB, hX on CPU + hipblas_init(hBb, M, N, ldb); + // pad untouched area into zero + for(int i = M; i < ldb; i++) + { + for(int j = 0; j < N; j++) + { + hBb[i + j * ldb] = 0.0; + } + } + + // Calculate hB = hA*hX; + cblas_trmm(side, uplo, transA, diag, M, N, T(1.0) / alpha, (const T*)hAb, lda, hBb, ldb); + } + hX = hB; // original solutions hX + hB_copy = hB; + + // copy data from CPU to device + CHECK_HIP_ERROR(hipMemcpy(dA, hA.data(), sizeof(T) * A_size, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dB, hB.data(), sizeof(T) * B_size, hipMemcpyHostToDevice)); + + /* ===================================================================== + HIPBLAS + =================================================================== */ + int sub_stride_A = TRSM_BLOCK * lda + TRSM_BLOCK; + int sub_stride_invA = TRSM_BLOCK * TRSM_BLOCK; + int blocks = K / TRSM_BLOCK; + + for(int b = 0; b < batch_count; b++) + { + if(blocks > 0) + { + status = hipblasTrtriStridedBatched(handle, + uplo, + diag, + TRSM_BLOCK, + dA + b * strideA, + lda, + sub_stride_A, + dinvA + b * stride_invA, + TRSM_BLOCK, + sub_stride_invA, + blocks); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + hipblasDestroy(handle); + return status; + } + } + + if(K % TRSM_BLOCK != 0 || blocks == 0) + { + status + = hipblasTrtriStridedBatched(handle, + uplo, + diag, + K - TRSM_BLOCK * blocks, + dA + sub_stride_A * blocks + b * strideA, + lda, + sub_stride_A, + dinvA + sub_stride_invA * blocks + b * stride_invA, + TRSM_BLOCK, + sub_stride_invA, + 1); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + hipblasDestroy(handle); + return status; + } + } + } + + status = hipblasTrsmStridedBatchedExFn(handle, + side, + uplo, + transA, + diag, + M, + N, + &alpha, + dA, + lda, + strideA, + dB, + ldb, + strideB, + batch_count, + dinvA, + invA_size, + stride_invA, + argus.compute_type); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + hipblasDestroy(handle); + return status; + } + + // copy output from device to CPU + CHECK_HIP_ERROR(hipMemcpy(hB.data(), dB, sizeof(T) * B_size, hipMemcpyDeviceToHost)); + + if(argus.unit_check) + { + /* ===================================================================== + CPU BLAS + =================================================================== */ + + for(int b = 0; b < batch_count; b++) + { + cblas_trsm(side, + uplo, + transA, + diag, + M, + N, + alpha, + (const T*)hA.data() + b * strideA, + lda, + hB_copy.data() + b * strideB, + ldb); + } + + // if enable norm check, norm check is invasive + real_t eps = std::numeric_limits>::epsilon(); + double tolerance = eps * 40 * M; + + for(int b = 0; b < batch_count; b++) + { + double error = norm_check_general( + 'F', M, N, ldb, hB_copy.data() + b * strideB, hB.data() + b * strideB); + unit_check_error(error, tolerance); + } + } + + hipblasDestroy(handle); + return HIPBLAS_STATUS_SUCCESS; +} diff --git a/library/include/hipblas.h b/library/include/hipblas.h index 4631e6650..c2eb8fe46 100644 --- a/library/include/hipblas.h +++ b/library/include/hipblas.h @@ -6972,6 +6972,60 @@ HIPBLAS_EXPORT hipblasStatus_t hipblasGemmStridedBatchedEx(hipblasHandle_t ha hipblasDatatype_t compute_type, hipblasGemmAlgo_t algo); +// trsm_ex +HIPBLAS_EXPORT hipblasStatus_t hipblasTrsmEx(hipblasHandle_t handle, + hipblasSideMode_t side, + hipblasFillMode_t uplo, + hipblasOperation_t transA, + hipblasDiagType_t diag, + int m, + int n, + const void* alpha, + void* A, + int lda, + void* B, + int ldb, + const void* invA, + int invA_size, + hipblasDatatype_t compute_type); + +HIPBLAS_EXPORT hipblasStatus_t hipblasTrsmBatchedEx(hipblasHandle_t handle, + hipblasSideMode_t side, + hipblasFillMode_t uplo, + hipblasOperation_t transA, + hipblasDiagType_t diag, + int m, + int n, + const void* alpha, + void* A, + int lda, + void* B, + int ldb, + int batch_count, + const void* invA, + int invA_size, + hipblasDatatype_t compute_type); + +HIPBLAS_EXPORT hipblasStatus_t hipblasTrsmStridedBatchedEx(hipblasHandle_t handle, + hipblasSideMode_t side, + hipblasFillMode_t uplo, + hipblasOperation_t transA, + hipblasDiagType_t diag, + int m, + int n, + const void* alpha, + void* A, + int lda, + int stride_A, + void* B, + int ldb, + int stride_B, + int batch_count, + const void* invA, + int invA_size, + int stride_invA, + hipblasDatatype_t compute_type); + #ifdef __cplusplus } #endif diff --git a/library/src/hcc_detail/hipblas.cpp b/library/src/hcc_detail/hipblas.cpp index c7b01ed30..3bc39d908 100644 --- a/library/src/hcc_detail/hipblas.cpp +++ b/library/src/hcc_detail/hipblas.cpp @@ -13977,6 +13977,7 @@ hipblasStatus_t hipblasZgemmStridedBatched(hipblasHandle_t handle, } #endif +// gemm_ex extern "C" hipblasStatus_t hipblasGemmEx(hipblasHandle_t handle, hipblasOperation_t transa, hipblasOperation_t transb, @@ -14136,3 +14137,115 @@ extern "C" hipblasStatus_t hipblasGemmStridedBatchedEx(hipblasHandle_t handle solution_index, flags)); } + +// trsm_ex +extern "C" hipblasStatus_t hipblasTrsmEx(hipblasHandle_t handle, + hipblasSideMode_t side, + hipblasFillMode_t uplo, + hipblasOperation_t transA, + hipblasDiagType_t diag, + int m, + int n, + const void* alpha, + void* A, + int lda, + void* B, + int ldb, + const void* invA, + int invA_size, + hipblasDatatype_t compute_type) +{ + return rocBLASStatusToHIPStatus(rocblas_trsm_ex((rocblas_handle)handle, + hipSideToHCCSide(side), + hipFillToHCCFill(uplo), + hipOperationToHCCOperation(transA), + hipDiagonalToHCCDiagonal(diag), + m, + n, + alpha, + A, + lda, + B, + ldb, + invA, + invA_size, + HIPDatatypeToRocblasDatatype(compute_type))); +} + +extern "C" hipblasStatus_t hipblasTrsmBatchedEx(hipblasHandle_t handle, + hipblasSideMode_t side, + hipblasFillMode_t uplo, + hipblasOperation_t transA, + hipblasDiagType_t diag, + int m, + int n, + const void* alpha, + void* A, + int lda, + void* B, + int ldb, + int batch_count, + const void* invA, + int invA_size, + hipblasDatatype_t compute_type) +{ + return rocBLASStatusToHIPStatus( + rocblas_trsm_batched_ex((rocblas_handle)handle, + hipSideToHCCSide(side), + hipFillToHCCFill(uplo), + hipOperationToHCCOperation(transA), + hipDiagonalToHCCDiagonal(diag), + m, + n, + alpha, + A, + lda, + B, + ldb, + batch_count, + invA, + invA_size, + HIPDatatypeToRocblasDatatype(compute_type))); +} + +extern "C" hipblasStatus_t hipblasTrsmStridedBatchedEx(hipblasHandle_t handle, + hipblasSideMode_t side, + hipblasFillMode_t uplo, + hipblasOperation_t transA, + hipblasDiagType_t diag, + int m, + int n, + const void* alpha, + void* A, + int lda, + int stride_A, + void* B, + int ldb, + int stride_B, + int batch_count, + const void* invA, + int invA_size, + int stride_invA, + hipblasDatatype_t compute_type) +{ + return rocBLASStatusToHIPStatus( + rocblas_trsm_strided_batched_ex((rocblas_handle)handle, + hipSideToHCCSide(side), + hipFillToHCCFill(uplo), + hipOperationToHCCOperation(transA), + hipDiagonalToHCCDiagonal(diag), + m, + n, + alpha, + A, + lda, + stride_A, + B, + ldb, + stride_B, + batch_count, + invA, + invA_size, + stride_invA, + HIPDatatypeToRocblasDatatype(compute_type))); +} diff --git a/library/src/nvcc_detail/hipblas.cpp b/library/src/nvcc_detail/hipblas.cpp index d7f26e84f..29a41d6f4 100644 --- a/library/src/nvcc_detail/hipblas.cpp +++ b/library/src/nvcc_detail/hipblas.cpp @@ -10064,6 +10064,7 @@ hipblasStatus_t hipblasZgemmStridedBatched(hipblasHandle_t handle, } #endif +// gemm_ex extern "C" hipblasStatus_t hipblasGemmEx(hipblasHandle_t handle, hipblasOperation_t transa, hipblasOperation_t transb, @@ -10197,3 +10198,66 @@ extern "C" hipblasStatus_t hipblasGemmStridedBatchedEx(hipblasHandle_t handle HIPDatatypeToCudaDatatype(compute_type), HIPGemmAlgoToCudaGemmAlgo(algo))); } + +// trsm_ex +extern "C" hipblasStatus_t hipblasTrsmEx(hipblasHandle_t handle, + hipblasSideMode_t side, + hipblasFillMode_t uplo, + hipblasOperation_t transA, + hipblasDiagType_t diag, + int m, + int n, + const void* alpha, + void* A, + int lda, + void* B, + int ldb, + const void* invA, + int invA_size, + hipblasDatatype_t compute_type) +{ + return HIPBLAS_STATUS_NOT_SUPPORTED; +} + +extern "C" hipblasStatus_t hipblasTrsmBatchedEx(hipblasHandle_t handle, + hipblasSideMode_t side, + hipblasFillMode_t uplo, + hipblasOperation_t transA, + hipblasDiagType_t diag, + int m, + int n, + const void* alpha, + void* A, + int lda, + void* B, + int ldb, + int batch_count, + const void* invA, + int invA_size, + hipblasDatatype_t compute_type) +{ + return HIPBLAS_STATUS_NOT_SUPPORTED; +} + +extern "C" hipblasStatus_t hipblasTrsmStridedBatchedEx(hipblasHandle_t handle, + hipblasSideMode_t side, + hipblasFillMode_t uplo, + hipblasOperation_t transA, + hipblasDiagType_t diag, + int m, + int n, + const void* alpha, + void* A, + int lda, + int stride_A, + void* B, + int ldb, + int stride_B, + int batch_count, + const void* invA, + int invA_size, + int stride_invA, + hipblasDatatype_t compute_type) +{ + return HIPBLAS_STATUS_NOT_SUPPORTED; +} From d38ad4b6d5842c921daf3529b4097e1858b696d7 Mon Sep 17 00:00:00 2001 From: daineAMD <51674140+daineAMD@users.noreply.github.com> Date: Wed, 10 Jun 2020 10:34:05 -0600 Subject: [PATCH 6/9] Fortran for solver functions (#233) --- clients/CMakeLists.txt | 11 +- .../hipblas_template_specialization.cpp | 1052 ++++++++--------- clients/gtest/geqrf_batched_gtest.cpp | 10 +- clients/gtest/geqrf_gtest.cpp | 10 +- clients/gtest/geqrf_strided_batched_gtest.cpp | 10 +- clients/gtest/getrf_batched_gtest.cpp | 10 +- clients/gtest/getrf_gtest.cpp | 10 +- clients/gtest/getrf_strided_batched_gtest.cpp | 10 +- clients/gtest/getrs_batched_gtest.cpp | 10 +- clients/gtest/getrs_gtest.cpp | 10 +- clients/gtest/getrs_strided_batched_gtest.cpp | 10 +- clients/include/hipblas_fortran.hpp | 739 ++++++------ clients/include/hipblas_fortran_solver.f90 | 748 ++++++++++++ clients/include/testing_geqrf.hpp | 5 +- clients/include/testing_geqrf_batched.hpp | 5 +- .../include/testing_geqrf_strided_batched.hpp | 5 +- clients/include/testing_getrf.hpp | 5 +- clients/include/testing_getrf_batched.hpp | 5 +- .../include/testing_getrf_strided_batched.hpp | 5 +- clients/include/testing_getrs.hpp | 5 +- clients/include/testing_getrs_batched.hpp | 5 +- .../include/testing_getrs_strided_batched.hpp | 5 +- library/src/hipblas_module.f90 | 721 +++++++++++ 23 files changed, 2481 insertions(+), 925 deletions(-) create mode 100644 clients/include/hipblas_fortran_solver.f90 diff --git a/clients/CMakeLists.txt b/clients/CMakeLists.txt index e639b1831..296ca17e2 100644 --- a/clients/CMakeLists.txt +++ b/clients/CMakeLists.txt @@ -30,8 +30,17 @@ set(hipblas_f90_source_clients include/hipblas_fortran.f90 ) +set(hipblas_f90_source_clients_solver + include/hipblas_fortran_solver.f90 +) + if( BUILD_CLIENTS_TESTS OR BUILD_CLIENTS_SAMPLES ) - add_library(hipblas_fortran_client ${hipblas_f90_source_clients}) + if( BUILD_WITH_SOLVER) + add_library(hipblas_fortran_client ${hipblas_f90_source_clients} ${hipblas_f90_source_clients_solver}) + else() + add_library(hipblas_fortran_client ${hipblas_f90_source_clients}) + endif() + add_dependencies(hipblas_fortran_client hipblas_fortran) include_directories(${CMAKE_BINARY_DIR}/include) endif( ) diff --git a/clients/common/hipblas_template_specialization.cpp b/clients/common/hipblas_template_specialization.cpp index 57144d4e6..0d59b9ae6 100644 --- a/clients/common/hipblas_template_specialization.cpp +++ b/clients/common/hipblas_template_specialization.cpp @@ -12,26 +12,6 @@ * \brief provide template functions interfaces to ROCBLAS C89 interfaces */ -/* - * =========================================================================== - * level 1 BLAS - * =========================================================================== - */ - -/* ************************************************************************ - * Copyright 2016-2020 Advanced Micro Devices, Inc. - * - * ************************************************************************/ - -#include "hipblas.h" -#include "hipblas.hpp" -#include "hipblas_fortran.hpp" -#include - -/*!\file - * \brief provide template functions interfaces to ROCBLAS C89 interfaces -*/ - /* * =========================================================================== * level 1 BLAS @@ -19654,519 +19634,519 @@ hipblasStatus_t batchCount); } -// #ifdef __HIP_PLATFORM_SOLVER__ - -// // getrf -// template <> -// hipblasStatus_t hipblasGetrf( -// hipblasHandle_t handle, const int n, float* A, const int lda, int* ipiv, int* info) -// { -// return hipblasSgetrfFortran(handle, n, A, lda, ipiv, info); -// } - -// template <> -// hipblasStatus_t hipblasGetrf( -// hipblasHandle_t handle, const int n, double* A, const int lda, int* ipiv, int* info) -// { -// return hipblasDgetrfFortran(handle, n, A, lda, ipiv, info); -// } - -// template <> -// hipblasStatus_t hipblasGetrf( -// hipblasHandle_t handle, const int n, hipblasComplex* A, const int lda, int* ipiv, int* info) -// { -// return hipblasCgetrfFortran(handle, n, A, lda, ipiv, info); -// } - -// template <> -// hipblasStatus_t hipblasGetrf(hipblasHandle_t handle, -// const int n, -// hipblasDoubleComplex* A, -// const int lda, -// int* ipiv, -// int* info) -// { -// return hipblasZgetrfFortran(handle, n, A, lda, ipiv, info); -// } - -// // getrf_batched -// template <> -// hipblasStatus_t hipblasGetrfBatched(hipblasHandle_t handle, -// const int n, -// float* const A[], -// const int lda, -// int* ipiv, -// int* info, -// const int batchCount) -// { -// return hipblasSgetrfBatchedFortran(handle, n, A, lda, ipiv, info, batchCount); -// } - -// template <> -// hipblasStatus_t hipblasGetrfBatched(hipblasHandle_t handle, -// const int n, -// double* const A[], -// const int lda, -// int* ipiv, -// int* info, -// const int batchCount) -// { -// return hipblasDgetrfBatchedFortran(handle, n, A, lda, ipiv, info, batchCount); -// } - -// template <> -// hipblasStatus_t hipblasGetrfBatched(hipblasHandle_t handle, -// const int n, -// hipblasComplex* const A[], -// const int lda, -// int* ipiv, -// int* info, -// const int batchCount) -// { -// return hipblasCgetrfBatchedFortran(handle, n, A, lda, ipiv, info, batchCount); -// } - -// template <> -// hipblasStatus_t hipblasGetrfBatched(hipblasHandle_t handle, -// const int n, -// hipblasDoubleComplex* const A[], -// const int lda, -// int* ipiv, -// int* info, -// const int batchCount) -// { -// return hipblasZgetrfBatchedFortran(handle, n, A, lda, ipiv, info, batchCount); -// } - -// // getrf_strided_batched -// template <> -// hipblasStatus_t hipblasGetrfStridedBatched(hipblasHandle_t handle, -// const int n, -// float* A, -// const int lda, -// const int strideA, -// int* ipiv, -// const int strideP, -// int* info, -// const int batchCount) -// { -// return hipblasSgetrfStridedBatchedFortran(handle, n, A, lda, strideA, ipiv, strideP, info, batchCount); -// } - -// template <> -// hipblasStatus_t hipblasGetrfStridedBatched(hipblasHandle_t handle, -// const int n, -// double* A, -// const int lda, -// const int strideA, -// int* ipiv, -// const int strideP, -// int* info, -// const int batchCount) -// { -// return hipblasDgetrfStridedBatchedFortran(handle, n, A, lda, strideA, ipiv, strideP, info, batchCount); -// } - -// template <> -// hipblasStatus_t hipblasGetrfStridedBatched(hipblasHandle_t handle, -// const int n, -// hipblasComplex* A, -// const int lda, -// const int strideA, -// int* ipiv, -// const int strideP, -// int* info, -// const int batchCount) -// { -// return hipblasCgetrfStridedBatchedFortran(handle, n, A, lda, strideA, ipiv, strideP, info, batchCount); -// } - -// template <> -// hipblasStatus_t hipblasGetrfStridedBatched(hipblasHandle_t handle, -// const int n, -// hipblasDoubleComplex* A, -// const int lda, -// const int strideA, -// int* ipiv, -// const int strideP, -// int* info, -// const int batchCount) -// { -// return hipblasZgetrfStridedBatchedFortran(handle, n, A, lda, strideA, ipiv, strideP, info, batchCount); -// } - -// // getrs -// template <> -// hipblasStatus_t hipblasGetrs(hipblasHandle_t handle, -// const hipblasOperation_t trans, -// const int n, -// const int nrhs, -// float* A, -// const int lda, -// const int* ipiv, -// float* B, -// const int ldb, -// int* info) -// { -// return hipblasSgetrsFortran(handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info); -// } - -// template <> -// hipblasStatus_t hipblasGetrs(hipblasHandle_t handle, -// const hipblasOperation_t trans, -// const int n, -// const int nrhs, -// double* A, -// const int lda, -// const int* ipiv, -// double* B, -// const int ldb, -// int* info) -// { -// return hipblasDgetrsFortran(handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info); -// } - -// template <> -// hipblasStatus_t hipblasGetrs(hipblasHandle_t handle, -// const hipblasOperation_t trans, -// const int n, -// const int nrhs, -// hipblasComplex* A, -// const int lda, -// const int* ipiv, -// hipblasComplex* B, -// const int ldb, -// int* info) -// { -// return hipblasCgetrsFortran(handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info); -// } - -// template <> -// hipblasStatus_t hipblasGetrs(hipblasHandle_t handle, -// const hipblasOperation_t trans, -// const int n, -// const int nrhs, -// hipblasDoubleComplex* A, -// const int lda, -// const int* ipiv, -// hipblasDoubleComplex* B, -// const int ldb, -// int* info) -// { -// return hipblasZgetrsFortran(handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info); -// } - -// // getrs_batched -// template <> -// hipblasStatus_t hipblasGetrsBatched(hipblasHandle_t handle, -// const hipblasOperation_t trans, -// const int n, -// const int nrhs, -// float* const A[], -// const int lda, -// const int* ipiv, -// float* const B[], -// const int ldb, -// int* info, -// const int batchCount) -// { -// return hipblasSgetrsBatchedFortran(handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info, batchCount); -// } - -// template <> -// hipblasStatus_t hipblasGetrsBatched(hipblasHandle_t handle, -// const hipblasOperation_t trans, -// const int n, -// const int nrhs, -// double* const A[], -// const int lda, -// const int* ipiv, -// double* const B[], -// const int ldb, -// int* info, -// const int batchCount) -// { -// return hipblasDgetrsBatchedFortran(handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info, batchCount); -// } - -// template <> -// hipblasStatus_t hipblasGetrsBatched(hipblasHandle_t handle, -// const hipblasOperation_t trans, -// const int n, -// const int nrhs, -// hipblasComplex* const A[], -// const int lda, -// const int* ipiv, -// hipblasComplex* const B[], -// const int ldb, -// int* info, -// const int batchCount) -// { -// return hipblasCgetrsBatchedFortran(handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info, batchCount); -// } - -// template <> -// hipblasStatus_t hipblasGetrsBatched(hipblasHandle_t handle, -// const hipblasOperation_t trans, -// const int n, -// const int nrhs, -// hipblasDoubleComplex* const A[], -// const int lda, -// const int* ipiv, -// hipblasDoubleComplex* const B[], -// const int ldb, -// int* info, -// const int batchCount) -// { -// return hipblasZgetrsBatchedFortran(handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info, batchCount); -// } - -// // getrs_strided_batched -// template <> -// hipblasStatus_t hipblasGetrsStridedBatched(hipblasHandle_t handle, -// const hipblasOperation_t trans, -// const int n, -// const int nrhs, -// float* A, -// const int lda, -// const int strideA, -// const int* ipiv, -// const int strideP, -// float* B, -// const int ldb, -// const int strideB, -// int* info, -// const int batchCount) -// { -// return hipblasSgetrsStridedBatchedFortran( -// handle, trans, n, nrhs, A, lda, strideA, ipiv, strideP, B, ldb, strideB, info, batchCount); -// } - -// template <> -// hipblasStatus_t hipblasGetrsStridedBatched(hipblasHandle_t handle, -// const hipblasOperation_t trans, -// const int n, -// const int nrhs, -// double* A, -// const int lda, -// const int strideA, -// const int* ipiv, -// const int strideP, -// double* B, -// const int ldb, -// const int strideB, -// int* info, -// const int batchCount) -// { -// return hipblasDgetrsStridedBatchedFortran( -// handle, trans, n, nrhs, A, lda, strideA, ipiv, strideP, B, ldb, strideB, info, batchCount); -// } - -// template <> -// hipblasStatus_t hipblasGetrsStridedBatched(hipblasHandle_t handle, -// const hipblasOperation_t trans, -// const int n, -// const int nrhs, -// hipblasComplex* A, -// const int lda, -// const int strideA, -// const int* ipiv, -// const int strideP, -// hipblasComplex* B, -// const int ldb, -// const int strideB, -// int* info, -// const int batchCount) -// { -// return hipblasCgetrsStridedBatchedFortran( -// handle, trans, n, nrhs, A, lda, strideA, ipiv, strideP, B, ldb, strideB, info, batchCount); -// } - -// template <> -// hipblasStatus_t hipblasGetrsStridedBatched(hipblasHandle_t handle, -// const hipblasOperation_t trans, -// const int n, -// const int nrhs, -// hipblasDoubleComplex* A, -// const int lda, -// const int strideA, -// const int* ipiv, -// const int strideP, -// hipblasDoubleComplex* B, -// const int ldb, -// const int strideB, -// int* info, -// const int batchCount) -// { -// return hipblasZgetrsStridedBatchedFortran( -// handle, trans, n, nrhs, A, lda, strideA, ipiv, strideP, B, ldb, strideB, info, batchCount); -// } - -// // geqrf -// template <> -// hipblasStatus_t hipblasGeqrf(hipblasHandle_t handle, -// const int m, -// const int n, -// float* A, -// const int lda, -// float* ipiv, -// int* info) -// { -// return hipblasSgeqrfFortran(handle, m, n, A, lda, ipiv, info); -// } - -// template <> -// hipblasStatus_t hipblasGeqrf(hipblasHandle_t handle, -// const int m, -// const int n, -// double* A, -// const int lda, -// double* ipiv, -// int* info) -// { -// return hipblasDgeqrfFortran(handle, m, n, A, lda, ipiv, info); -// } - -// template <> -// hipblasStatus_t hipblasGeqrf(hipblasHandle_t handle, -// const int m, -// const int n, -// hipblasComplex* A, -// const int lda, -// hipblasComplex* ipiv, -// int* info) -// { -// return hipblasCgeqrfFortran(handle, m, n, A, lda, ipiv, info); -// } - -// template <> -// hipblasStatus_t hipblasGeqrf(hipblasHandle_t handle, -// const int m, -// const int n, -// hipblasDoubleComplex* A, -// const int lda, -// hipblasDoubleComplex* ipiv, -// int* info) -// { -// return hipblasZgeqrfFortran(handle, m, n, A, lda, ipiv, info); -// } - -// // geqrf_batched -// template <> -// hipblasStatus_t hipblasGeqrfBatched(hipblasHandle_t handle, -// const int m, -// const int n, -// float* const A[], -// const int lda, -// float* const ipiv[], -// int* info, -// const int batchCount) -// { -// return hipblasSgeqrfBatchedFortran(handle, m, n, A, lda, ipiv, info, batchCount); -// } - -// template <> -// hipblasStatus_t hipblasGeqrfBatched(hipblasHandle_t handle, -// const int m, -// const int n, -// double* const A[], -// const int lda, -// double* const ipiv[], -// int* info, -// const int batchCount) -// { -// return hipblasDgeqrfBatchedFortran(handle, m, n, A, lda, ipiv, info, batchCount); -// } - -// template <> -// hipblasStatus_t hipblasGeqrfBatched(hipblasHandle_t handle, -// const int m, -// const int n, -// hipblasComplex* const A[], -// const int lda, -// hipblasComplex* const ipiv[], -// int* info, -// const int batchCount) -// { -// return hipblasCgeqrfBatchedFortran(handle, m, n, A, lda, ipiv, info, batchCount); -// } - -// template <> -// hipblasStatus_t hipblasGeqrfBatched(hipblasHandle_t handle, -// const int m, -// const int n, -// hipblasDoubleComplex* const A[], -// const int lda, -// hipblasDoubleComplex* const ipiv[], -// int* info, -// const int batchCount) -// { -// return hipblasZgeqrfBatchedFortran(handle, m, n, A, lda, ipiv, info, batchCount); -// } - -// // geqrf_strided_batched -// template <> -// hipblasStatus_t hipblasGeqrfStridedBatched(hipblasHandle_t handle, -// const int m, -// const int n, -// float* A, -// const int lda, -// const int strideA, -// float* ipiv, -// const int strideP, -// int* info, -// const int batchCount) -// { -// return hipblasSgeqrfStridedBatchedFortran( -// handle, m, n, A, lda, strideA, ipiv, strideP, info, batchCount); -// } - -// template <> -// hipblasStatus_t hipblasGeqrfStridedBatched(hipblasHandle_t handle, -// const int m, -// const int n, -// double* A, -// const int lda, -// const int strideA, -// double* ipiv, -// const int strideP, -// int* info, -// const int batchCount) -// { -// return hipblasDgeqrfStridedBatchedFortran( -// handle, m, n, A, lda, strideA, ipiv, strideP, info, batchCount); -// } - -// template <> -// hipblasStatus_t hipblasGeqrfStridedBatched(hipblasHandle_t handle, -// const int m, -// const int n, -// hipblasComplex* A, -// const int lda, -// const int strideA, -// hipblasComplex* ipiv, -// const int strideP, -// int* info, -// const int batchCount) -// { -// return hipblasCgeqrfStridedBatchedFortran( -// handle, m, n, A, lda, strideA, ipiv, strideP, info, batchCount); -// } - -// template <> -// hipblasStatus_t hipblasGeqrfStridedBatched(hipblasHandle_t handle, -// const int m, -// const int n, -// hipblasDoubleComplex* A, -// const int lda, -// const int strideA, -// hipblasDoubleComplex* ipiv, -// const int strideP, -// int* info, -// const int batchCount) -// { -// return hipblasZgeqrfStridedBatchedFortran( -// handle, m, n, A, lda, strideA, ipiv, strideP, info, batchCount); -// } - -// #endif +#ifdef __HIP_PLATFORM_SOLVER__ + +// getrf +template <> +hipblasStatus_t hipblasGetrf( + hipblasHandle_t handle, const int n, float* A, const int lda, int* ipiv, int* info) +{ + return hipblasSgetrfFortran(handle, n, A, lda, ipiv, info); +} + +template <> +hipblasStatus_t hipblasGetrf( + hipblasHandle_t handle, const int n, double* A, const int lda, int* ipiv, int* info) +{ + return hipblasDgetrfFortran(handle, n, A, lda, ipiv, info); +} + +template <> +hipblasStatus_t hipblasGetrf( + hipblasHandle_t handle, const int n, hipblasComplex* A, const int lda, int* ipiv, int* info) +{ + return hipblasCgetrfFortran(handle, n, A, lda, ipiv, info); +} + +template <> +hipblasStatus_t hipblasGetrf(hipblasHandle_t handle, + const int n, + hipblasDoubleComplex* A, + const int lda, + int* ipiv, + int* info) +{ + return hipblasZgetrfFortran(handle, n, A, lda, ipiv, info); +} + +// getrf_batched +template <> +hipblasStatus_t hipblasGetrfBatched(hipblasHandle_t handle, + const int n, + float* const A[], + const int lda, + int* ipiv, + int* info, + const int batchCount) +{ + return hipblasSgetrfBatchedFortran(handle, n, A, lda, ipiv, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGetrfBatched(hipblasHandle_t handle, + const int n, + double* const A[], + const int lda, + int* ipiv, + int* info, + const int batchCount) +{ + return hipblasDgetrfBatchedFortran(handle, n, A, lda, ipiv, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGetrfBatched(hipblasHandle_t handle, + const int n, + hipblasComplex* const A[], + const int lda, + int* ipiv, + int* info, + const int batchCount) +{ + return hipblasCgetrfBatchedFortran(handle, n, A, lda, ipiv, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGetrfBatched(hipblasHandle_t handle, + const int n, + hipblasDoubleComplex* const A[], + const int lda, + int* ipiv, + int* info, + const int batchCount) +{ + return hipblasZgetrfBatchedFortran(handle, n, A, lda, ipiv, info, batchCount); +} + +// getrf_strided_batched +template <> +hipblasStatus_t hipblasGetrfStridedBatched(hipblasHandle_t handle, + const int n, + float* A, + const int lda, + const int strideA, + int* ipiv, + const int strideP, + int* info, + const int batchCount) +{ + return hipblasSgetrfStridedBatchedFortran(handle, n, A, lda, strideA, ipiv, strideP, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGetrfStridedBatched(hipblasHandle_t handle, + const int n, + double* A, + const int lda, + const int strideA, + int* ipiv, + const int strideP, + int* info, + const int batchCount) +{ + return hipblasDgetrfStridedBatchedFortran(handle, n, A, lda, strideA, ipiv, strideP, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGetrfStridedBatched(hipblasHandle_t handle, + const int n, + hipblasComplex* A, + const int lda, + const int strideA, + int* ipiv, + const int strideP, + int* info, + const int batchCount) +{ + return hipblasCgetrfStridedBatchedFortran(handle, n, A, lda, strideA, ipiv, strideP, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGetrfStridedBatched(hipblasHandle_t handle, + const int n, + hipblasDoubleComplex* A, + const int lda, + const int strideA, + int* ipiv, + const int strideP, + int* info, + const int batchCount) +{ + return hipblasZgetrfStridedBatchedFortran(handle, n, A, lda, strideA, ipiv, strideP, info, batchCount); +} + +// getrs +template <> +hipblasStatus_t hipblasGetrs(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + float* A, + const int lda, + const int* ipiv, + float* B, + const int ldb, + int* info) +{ + return hipblasSgetrsFortran(handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info); +} + +template <> +hipblasStatus_t hipblasGetrs(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + double* A, + const int lda, + const int* ipiv, + double* B, + const int ldb, + int* info) +{ + return hipblasDgetrsFortran(handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info); +} + +template <> +hipblasStatus_t hipblasGetrs(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + hipblasComplex* A, + const int lda, + const int* ipiv, + hipblasComplex* B, + const int ldb, + int* info) +{ + return hipblasCgetrsFortran(handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info); +} + +template <> +hipblasStatus_t hipblasGetrs(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + hipblasDoubleComplex* A, + const int lda, + const int* ipiv, + hipblasDoubleComplex* B, + const int ldb, + int* info) +{ + return hipblasZgetrsFortran(handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info); +} + +// getrs_batched +template <> +hipblasStatus_t hipblasGetrsBatched(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + float* const A[], + const int lda, + const int* ipiv, + float* const B[], + const int ldb, + int* info, + const int batchCount) +{ + return hipblasSgetrsBatchedFortran(handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGetrsBatched(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + double* const A[], + const int lda, + const int* ipiv, + double* const B[], + const int ldb, + int* info, + const int batchCount) +{ + return hipblasDgetrsBatchedFortran(handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGetrsBatched(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + hipblasComplex* const A[], + const int lda, + const int* ipiv, + hipblasComplex* const B[], + const int ldb, + int* info, + const int batchCount) +{ + return hipblasCgetrsBatchedFortran(handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGetrsBatched(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + hipblasDoubleComplex* const A[], + const int lda, + const int* ipiv, + hipblasDoubleComplex* const B[], + const int ldb, + int* info, + const int batchCount) +{ + return hipblasZgetrsBatchedFortran(handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info, batchCount); +} + +// getrs_strided_batched +template <> +hipblasStatus_t hipblasGetrsStridedBatched(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + float* A, + const int lda, + const int strideA, + const int* ipiv, + const int strideP, + float* B, + const int ldb, + const int strideB, + int* info, + const int batchCount) +{ + return hipblasSgetrsStridedBatchedFortran( + handle, trans, n, nrhs, A, lda, strideA, ipiv, strideP, B, ldb, strideB, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGetrsStridedBatched(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + double* A, + const int lda, + const int strideA, + const int* ipiv, + const int strideP, + double* B, + const int ldb, + const int strideB, + int* info, + const int batchCount) +{ + return hipblasDgetrsStridedBatchedFortran( + handle, trans, n, nrhs, A, lda, strideA, ipiv, strideP, B, ldb, strideB, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGetrsStridedBatched(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + hipblasComplex* A, + const int lda, + const int strideA, + const int* ipiv, + const int strideP, + hipblasComplex* B, + const int ldb, + const int strideB, + int* info, + const int batchCount) +{ + return hipblasCgetrsStridedBatchedFortran( + handle, trans, n, nrhs, A, lda, strideA, ipiv, strideP, B, ldb, strideB, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGetrsStridedBatched(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + hipblasDoubleComplex* A, + const int lda, + const int strideA, + const int* ipiv, + const int strideP, + hipblasDoubleComplex* B, + const int ldb, + const int strideB, + int* info, + const int batchCount) +{ + return hipblasZgetrsStridedBatchedFortran( + handle, trans, n, nrhs, A, lda, strideA, ipiv, strideP, B, ldb, strideB, info, batchCount); +} + +// geqrf +template <> +hipblasStatus_t hipblasGeqrf(hipblasHandle_t handle, + const int m, + const int n, + float* A, + const int lda, + float* ipiv, + int* info) +{ + return hipblasSgeqrfFortran(handle, m, n, A, lda, ipiv, info); +} + +template <> +hipblasStatus_t hipblasGeqrf(hipblasHandle_t handle, + const int m, + const int n, + double* A, + const int lda, + double* ipiv, + int* info) +{ + return hipblasDgeqrfFortran(handle, m, n, A, lda, ipiv, info); +} + +template <> +hipblasStatus_t hipblasGeqrf(hipblasHandle_t handle, + const int m, + const int n, + hipblasComplex* A, + const int lda, + hipblasComplex* ipiv, + int* info) +{ + return hipblasCgeqrfFortran(handle, m, n, A, lda, ipiv, info); +} + +template <> +hipblasStatus_t hipblasGeqrf(hipblasHandle_t handle, + const int m, + const int n, + hipblasDoubleComplex* A, + const int lda, + hipblasDoubleComplex* ipiv, + int* info) +{ + return hipblasZgeqrfFortran(handle, m, n, A, lda, ipiv, info); +} + +// geqrf_batched +template <> +hipblasStatus_t hipblasGeqrfBatched(hipblasHandle_t handle, + const int m, + const int n, + float* const A[], + const int lda, + float* const ipiv[], + int* info, + const int batchCount) +{ + return hipblasSgeqrfBatchedFortran(handle, m, n, A, lda, ipiv, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGeqrfBatched(hipblasHandle_t handle, + const int m, + const int n, + double* const A[], + const int lda, + double* const ipiv[], + int* info, + const int batchCount) +{ + return hipblasDgeqrfBatchedFortran(handle, m, n, A, lda, ipiv, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGeqrfBatched(hipblasHandle_t handle, + const int m, + const int n, + hipblasComplex* const A[], + const int lda, + hipblasComplex* const ipiv[], + int* info, + const int batchCount) +{ + return hipblasCgeqrfBatchedFortran(handle, m, n, A, lda, ipiv, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGeqrfBatched(hipblasHandle_t handle, + const int m, + const int n, + hipblasDoubleComplex* const A[], + const int lda, + hipblasDoubleComplex* const ipiv[], + int* info, + const int batchCount) +{ + return hipblasZgeqrfBatchedFortran(handle, m, n, A, lda, ipiv, info, batchCount); +} + +// geqrf_strided_batched +template <> +hipblasStatus_t hipblasGeqrfStridedBatched(hipblasHandle_t handle, + const int m, + const int n, + float* A, + const int lda, + const int strideA, + float* ipiv, + const int strideP, + int* info, + const int batchCount) +{ + return hipblasSgeqrfStridedBatchedFortran( + handle, m, n, A, lda, strideA, ipiv, strideP, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGeqrfStridedBatched(hipblasHandle_t handle, + const int m, + const int n, + double* A, + const int lda, + const int strideA, + double* ipiv, + const int strideP, + int* info, + const int batchCount) +{ + return hipblasDgeqrfStridedBatchedFortran( + handle, m, n, A, lda, strideA, ipiv, strideP, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGeqrfStridedBatched(hipblasHandle_t handle, + const int m, + const int n, + hipblasComplex* A, + const int lda, + const int strideA, + hipblasComplex* ipiv, + const int strideP, + int* info, + const int batchCount) +{ + return hipblasCgeqrfStridedBatchedFortran( + handle, m, n, A, lda, strideA, ipiv, strideP, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGeqrfStridedBatched(hipblasHandle_t handle, + const int m, + const int n, + hipblasDoubleComplex* A, + const int lda, + const int strideA, + hipblasDoubleComplex* ipiv, + const int strideP, + int* info, + const int batchCount) +{ + return hipblasZgeqrfStridedBatchedFortran( + handle, m, n, A, lda, strideA, ipiv, strideP, info, batchCount); +} + +#endif diff --git a/clients/gtest/geqrf_batched_gtest.cpp b/clients/gtest/geqrf_batched_gtest.cpp index 0880d0854..cd2bca749 100644 --- a/clients/gtest/geqrf_batched_gtest.cpp +++ b/clients/gtest/geqrf_batched_gtest.cpp @@ -16,7 +16,7 @@ using ::testing::Values; using ::testing::ValuesIn; using namespace std; -typedef std::tuple, double, int> geqrf_batched_tuple; +typedef std::tuple, double, int, bool> geqrf_batched_tuple; const vector> matrix_size_range = {{-1, -1, 1, 1}, {10, 10, 10, 10}, {10, 10, 20, 100}, {600, 500, 600, 600}}; @@ -25,11 +25,14 @@ const vector stride_scale_range = {2.5}; const vector batch_count_range = {-1, 0, 1, 2}; +const vector is_fortran = {false, true}; + Arguments setup_geqrf_batched_arguments(geqrf_batched_tuple tup) { vector matrix_size = std::get<0>(tup); double stride_scale = std::get<1>(tup); int batch_count = std::get<2>(tup); + bool fortran = std::get<3>(tup); Arguments arg; @@ -41,6 +44,8 @@ Arguments setup_geqrf_batched_arguments(geqrf_batched_tuple tup) arg.stride_scale = stride_scale; arg.batch_count = batch_count; + arg.fortran = fortran; + return arg; } @@ -150,4 +155,5 @@ INSTANTIATE_TEST_CASE_P(hipblasGeqrfBatched, geqrf_batched_gtest, Combine(ValuesIn(matrix_size_range), ValuesIn(stride_scale_range), - ValuesIn(batch_count_range))); + ValuesIn(batch_count_range), + ValuesIn(is_fortran))); diff --git a/clients/gtest/geqrf_gtest.cpp b/clients/gtest/geqrf_gtest.cpp index e794aadf3..bbcc9d741 100644 --- a/clients/gtest/geqrf_gtest.cpp +++ b/clients/gtest/geqrf_gtest.cpp @@ -16,7 +16,7 @@ using ::testing::Values; using ::testing::ValuesIn; using namespace std; -typedef std::tuple, double, int> geqrf_tuple; +typedef std::tuple, double, int, bool> geqrf_tuple; const vector> matrix_size_range = {{-1, -1, 1, 1}, {10, 10, 10, 10}, {10, 10, 20, 100}, {600, 500, 600, 600}}; @@ -25,11 +25,14 @@ const vector stride_scale_range = {2.5}; const vector batch_count_range = {1}; +const vector is_fortran = {false, true}; + Arguments setup_geqrf_arguments(geqrf_tuple tup) { vector matrix_size = std::get<0>(tup); double stride_scale = std::get<1>(tup); int batch_count = std::get<2>(tup); + bool fortran = std::get<3>(tup); Arguments arg; @@ -41,6 +44,8 @@ Arguments setup_geqrf_arguments(geqrf_tuple tup) arg.stride_scale = stride_scale; arg.batch_count = batch_count; + arg.fortran = fortran; + return arg; } @@ -150,4 +155,5 @@ INSTANTIATE_TEST_CASE_P(hipblasGeqrf, geqrf_gtest, Combine(ValuesIn(matrix_size_range), ValuesIn(stride_scale_range), - ValuesIn(batch_count_range))); + ValuesIn(batch_count_range), + ValuesIn(is_fortran))); diff --git a/clients/gtest/geqrf_strided_batched_gtest.cpp b/clients/gtest/geqrf_strided_batched_gtest.cpp index 8dfe1c0d4..6a8fc0e44 100644 --- a/clients/gtest/geqrf_strided_batched_gtest.cpp +++ b/clients/gtest/geqrf_strided_batched_gtest.cpp @@ -16,7 +16,7 @@ using ::testing::Values; using ::testing::ValuesIn; using namespace std; -typedef std::tuple, double, int> geqrf_strided_batched_tuple; +typedef std::tuple, double, int, bool> geqrf_strided_batched_tuple; const vector> matrix_size_range = {{-1, -1, 1, 1}, {10, 10, 10, 10}, {10, 10, 20, 100}, {600, 500, 600, 600}}; @@ -25,11 +25,14 @@ const vector stride_scale_range = {2.5}; const vector batch_count_range = {-1, 0, 1, 2}; +const vector is_fortran = {false, true}; + Arguments setup_geqrf_strided_batched_arguments(geqrf_strided_batched_tuple tup) { vector matrix_size = std::get<0>(tup); double stride_scale = std::get<1>(tup); int batch_count = std::get<2>(tup); + bool fortran = std::get<3>(tup); Arguments arg; @@ -41,6 +44,8 @@ Arguments setup_geqrf_strided_batched_arguments(geqrf_strided_batched_tuple tup) arg.stride_scale = stride_scale; arg.batch_count = batch_count; + arg.fortran = fortran; + return arg; } @@ -150,4 +155,5 @@ INSTANTIATE_TEST_CASE_P(hipblasGeqrfStridedBatched, geqrf_strided_batched_gtest, Combine(ValuesIn(matrix_size_range), ValuesIn(stride_scale_range), - ValuesIn(batch_count_range))); + ValuesIn(batch_count_range), + ValuesIn(is_fortran))); diff --git a/clients/gtest/getrf_batched_gtest.cpp b/clients/gtest/getrf_batched_gtest.cpp index 3e844e9e5..284a4ca50 100644 --- a/clients/gtest/getrf_batched_gtest.cpp +++ b/clients/gtest/getrf_batched_gtest.cpp @@ -16,7 +16,7 @@ using ::testing::Values; using ::testing::ValuesIn; using namespace std; -typedef std::tuple, double, int> getrf_batched_tuple; +typedef std::tuple, double, int, bool> getrf_batched_tuple; const vector> matrix_size_range = {{-1, -1, 1, 1}, {10, 10, 10, 10}, @@ -28,11 +28,14 @@ const vector stride_scale_range = {2.5}; const vector batch_count_range = {-1, 0, 1, 2}; +const vector is_fortran = {false, true}; + Arguments setup_getrf_batched_arguments(getrf_batched_tuple tup) { vector matrix_size = std::get<0>(tup); double stride_scale = std::get<1>(tup); int batch_count = std::get<2>(tup); + bool fortran = std::get<3>(tup); Arguments arg; @@ -44,6 +47,8 @@ Arguments setup_getrf_batched_arguments(getrf_batched_tuple tup) arg.stride_scale = stride_scale; arg.batch_count = batch_count; + arg.fortran = fortran; + return arg; } @@ -153,4 +158,5 @@ INSTANTIATE_TEST_CASE_P(hipblasGetrfBatched, getrf_batched_gtest, Combine(ValuesIn(matrix_size_range), ValuesIn(stride_scale_range), - ValuesIn(batch_count_range))); + ValuesIn(batch_count_range), + ValuesIn(is_fortran))); diff --git a/clients/gtest/getrf_gtest.cpp b/clients/gtest/getrf_gtest.cpp index 3957815cd..fe99ca5e6 100644 --- a/clients/gtest/getrf_gtest.cpp +++ b/clients/gtest/getrf_gtest.cpp @@ -16,7 +16,7 @@ using ::testing::Values; using ::testing::ValuesIn; using namespace std; -typedef std::tuple, double, int> getrf_tuple; +typedef std::tuple, double, int, bool> getrf_tuple; const vector> matrix_size_range = {{-1, -1, 1, 1}, {10, 10, 10, 10}, @@ -28,11 +28,14 @@ const vector stride_scale_range = {2.5}; const vector batch_count_range = {1}; +const vector is_fortran = {false, true}; + Arguments setup_getrf_arguments(getrf_tuple tup) { vector matrix_size = std::get<0>(tup); double stride_scale = std::get<1>(tup); int batch_count = std::get<2>(tup); + bool fortran = std::get<3>(tup); Arguments arg; @@ -44,6 +47,8 @@ Arguments setup_getrf_arguments(getrf_tuple tup) arg.stride_scale = stride_scale; arg.batch_count = batch_count; + arg.fortran = fortran; + return arg; } @@ -153,4 +158,5 @@ INSTANTIATE_TEST_CASE_P(hipblasGetrf, getrf_gtest, Combine(ValuesIn(matrix_size_range), ValuesIn(stride_scale_range), - ValuesIn(batch_count_range))); + ValuesIn(batch_count_range), + ValuesIn(is_fortran))); diff --git a/clients/gtest/getrf_strided_batched_gtest.cpp b/clients/gtest/getrf_strided_batched_gtest.cpp index 2cda93f97..30105f277 100644 --- a/clients/gtest/getrf_strided_batched_gtest.cpp +++ b/clients/gtest/getrf_strided_batched_gtest.cpp @@ -16,7 +16,7 @@ using ::testing::Values; using ::testing::ValuesIn; using namespace std; -typedef std::tuple, double, int> getrf_strided_batched_tuple; +typedef std::tuple, double, int, bool> getrf_strided_batched_tuple; const vector> matrix_size_range = {{-1, -1, 1, 1}, {10, 10, 10, 10}, @@ -28,11 +28,14 @@ const vector stride_scale_range = {2.5}; const vector batch_count_range = {-1, 0, 1, 2}; +const vector is_fortran = {false, true}; + Arguments setup_getrf_strided_batched_arguments(getrf_strided_batched_tuple tup) { vector matrix_size = std::get<0>(tup); double stride_scale = std::get<1>(tup); int batch_count = std::get<2>(tup); + bool fortran = std::get<3>(tup); Arguments arg; @@ -44,6 +47,8 @@ Arguments setup_getrf_strided_batched_arguments(getrf_strided_batched_tuple tup) arg.stride_scale = stride_scale; arg.batch_count = batch_count; + arg.fortran = fortran; + return arg; } @@ -153,4 +158,5 @@ INSTANTIATE_TEST_CASE_P(hipblasGetrfStridedBatched, getrf_strided_batched_gtest, Combine(ValuesIn(matrix_size_range), ValuesIn(stride_scale_range), - ValuesIn(batch_count_range))); + ValuesIn(batch_count_range), + ValuesIn(is_fortran))); diff --git a/clients/gtest/getrs_batched_gtest.cpp b/clients/gtest/getrs_batched_gtest.cpp index c6da21065..e2910ec34 100644 --- a/clients/gtest/getrs_batched_gtest.cpp +++ b/clients/gtest/getrs_batched_gtest.cpp @@ -16,7 +16,7 @@ using ::testing::Values; using ::testing::ValuesIn; using namespace std; -typedef std::tuple, double, int> getrs_batched_tuple; +typedef std::tuple, double, int, bool> getrs_batched_tuple; const vector> matrix_size_range = {{-1, 1, 1}, {10, 20, 100}, {500, 600, 600}, {1024, 1024, 1024}}; @@ -25,11 +25,14 @@ const vector stride_scale_range = {2.5}; const vector batch_count_range = {-1, 0, 1, 2}; +const vector is_fortran = {false, true}; + Arguments setup_getrs_batched_arguments(getrs_batched_tuple tup) { vector matrix_size = std::get<0>(tup); double stride_scale = std::get<1>(tup); int batch_count = std::get<2>(tup); + bool fortran = std::get<3>(tup); Arguments arg; @@ -40,6 +43,8 @@ Arguments setup_getrs_batched_arguments(getrs_batched_tuple tup) arg.stride_scale = stride_scale; arg.batch_count = batch_count; + arg.fortran = fortran; + return arg; } @@ -149,4 +154,5 @@ INSTANTIATE_TEST_CASE_P(hipblasGetrsBatched, getrs_batched_gtest, Combine(ValuesIn(matrix_size_range), ValuesIn(stride_scale_range), - ValuesIn(batch_count_range))); + ValuesIn(batch_count_range), + ValuesIn(is_fortran))); diff --git a/clients/gtest/getrs_gtest.cpp b/clients/gtest/getrs_gtest.cpp index dae2d2fe4..5bc30725a 100644 --- a/clients/gtest/getrs_gtest.cpp +++ b/clients/gtest/getrs_gtest.cpp @@ -16,7 +16,7 @@ using ::testing::Values; using ::testing::ValuesIn; using namespace std; -typedef std::tuple, double, int> getrs_tuple; +typedef std::tuple, double, int, bool> getrs_tuple; const vector> matrix_size_range = {{-1, 1, 1}, {10, 20, 100}, {500, 600, 600}, {1024, 1024, 1024}}; @@ -25,11 +25,14 @@ const vector stride_scale_range = {2.5}; const vector batch_count_range = {1}; +const vector is_fortran = {false, true}; + Arguments setup_getrs_arguments(getrs_tuple tup) { vector matrix_size = std::get<0>(tup); double stride_scale = std::get<1>(tup); int batch_count = std::get<2>(tup); + bool fortran = std::get<3>(tup); Arguments arg; @@ -40,6 +43,8 @@ Arguments setup_getrs_arguments(getrs_tuple tup) arg.stride_scale = stride_scale; arg.batch_count = batch_count; + arg.fortran = fortran; + return arg; } @@ -149,4 +154,5 @@ INSTANTIATE_TEST_CASE_P(hipblasGetrs, getrs_gtest, Combine(ValuesIn(matrix_size_range), ValuesIn(stride_scale_range), - ValuesIn(batch_count_range))); + ValuesIn(batch_count_range), + ValuesIn(is_fortran))); diff --git a/clients/gtest/getrs_strided_batched_gtest.cpp b/clients/gtest/getrs_strided_batched_gtest.cpp index bb6f075bd..d1b13c535 100644 --- a/clients/gtest/getrs_strided_batched_gtest.cpp +++ b/clients/gtest/getrs_strided_batched_gtest.cpp @@ -16,7 +16,7 @@ using ::testing::Values; using ::testing::ValuesIn; using namespace std; -typedef std::tuple, double, int> getrs_strided_batched_tuple; +typedef std::tuple, double, int, bool> getrs_strided_batched_tuple; const vector> matrix_size_range = {{-1, 1, 1}, {10, 20, 100}, {500, 600, 600}, {1024, 1024, 1024}}; @@ -25,11 +25,14 @@ const vector stride_scale_range = {2.5}; const vector batch_count_range = {-1, 0, 1, 2}; +const vector is_fortran = {false, true}; + Arguments setup_getrs_strided_batched_arguments(getrs_strided_batched_tuple tup) { vector matrix_size = std::get<0>(tup); double stride_scale = std::get<1>(tup); int batch_count = std::get<2>(tup); + bool fortran = std::get<3>(tup); Arguments arg; @@ -40,6 +43,8 @@ Arguments setup_getrs_strided_batched_arguments(getrs_strided_batched_tuple tup) arg.stride_scale = stride_scale; arg.batch_count = batch_count; + arg.fortran = fortran; + return arg; } @@ -149,4 +154,5 @@ INSTANTIATE_TEST_CASE_P(hipblasGetrsStridedBatched, getrs_strided_batched_gtest, Combine(ValuesIn(matrix_size_range), ValuesIn(stride_scale_range), - ValuesIn(batch_count_range))); + ValuesIn(batch_count_range), + ValuesIn(is_fortran))); diff --git a/clients/include/hipblas_fortran.hpp b/clients/include/hipblas_fortran.hpp index e940c93c5..e1c02e770 100644 --- a/clients/include/hipblas_fortran.hpp +++ b/clients/include/hipblas_fortran.hpp @@ -6066,367 +6066,6 @@ hipblasStatus_t hipblasZtrsmStridedBatchedFortran(hipblasHandle_t ha int strideB, int batch_count); -// getrf -hipblasStatus_t hipblasSgetrf( - hipblasHandle_t handle, const int n, float* A, const int lda, int* ipiv, int* info); - -hipblasStatus_t hipblasDgetrf( - hipblasHandle_t handle, const int n, double* A, const int lda, int* ipiv, int* info); - -hipblasStatus_t hipblasCgetrf( - hipblasHandle_t handle, const int n, hipblasComplex* A, const int lda, int* ipiv, int* info); - -hipblasStatus_t hipblasZgetrfFortran(hipblasHandle_t handle, - const int n, - hipblasDoubleComplex* A, - const int lda, - int* ipiv, - int* info); - -// getrf_batched -hipblasStatus_t hipblasSgetrfBatchedFortran(hipblasHandle_t handle, - const int n, - float* const A[], - const int lda, - int* ipiv, - int* info, - const int batch_count); - -hipblasStatus_t hipblasDgetrfBatchedFortran(hipblasHandle_t handle, - const int n, - double* const A[], - const int lda, - int* ipiv, - int* info, - const int batch_count); - -hipblasStatus_t hipblasCgetrfBatchedFortran(hipblasHandle_t handle, - const int n, - hipblasComplex* const A[], - const int lda, - int* ipiv, - int* info, - const int batch_count); - -hipblasStatus_t hipblasZgetrfBatchedFortran(hipblasHandle_t handle, - const int n, - hipblasDoubleComplex* const A[], - const int lda, - int* ipiv, - int* info, - const int batch_count); - -// getrf_strided_batched -hipblasStatus_t hipblasSgetrfStridedBatchedFortran(hipblasHandle_t handle, - const int n, - float* A, - const int lda, - const int strideA, - int* ipiv, - const int strideP, - int* info, - const int batch_count); - -hipblasStatus_t hipblasDgetrfStridedBatchedFortran(hipblasHandle_t handle, - const int n, - double* A, - const int lda, - const int strideA, - int* ipiv, - const int strideP, - int* info, - const int batch_count); - -hipblasStatus_t hipblasCgetrfStridedBatchedFortran(hipblasHandle_t handle, - const int n, - hipblasComplex* A, - const int lda, - const int strideA, - int* ipiv, - const int strideP, - int* info, - const int batch_count); - -hipblasStatus_t hipblasZgetrfStridedBatchedFortran(hipblasHandle_t handle, - const int n, - hipblasDoubleComplex* A, - const int lda, - const int strideA, - int* ipiv, - const int strideP, - int* info, - const int batch_count); - -// getrs -hipblasStatus_t hipblasSgetrsFortran(hipblasHandle_t handle, - const hipblasOperation_t trans, - const int n, - const int nrhs, - float* A, - const int lda, - const int* ipiv, - float* B, - const int ldb, - int* info); - -hipblasStatus_t hipblasDgetrsFortran(hipblasHandle_t handle, - const hipblasOperation_t trans, - const int n, - const int nrhs, - double* A, - const int lda, - const int* ipiv, - double* B, - const int ldb, - int* info); - -hipblasStatus_t hipblasCgetrsFortran(hipblasHandle_t handle, - const hipblasOperation_t trans, - const int n, - const int nrhs, - hipblasComplex* A, - const int lda, - const int* ipiv, - hipblasComplex* B, - const int ldb, - int* info); - -hipblasStatus_t hipblasZgetrsFortran(hipblasHandle_t handle, - const hipblasOperation_t trans, - const int n, - const int nrhs, - hipblasDoubleComplex* A, - const int lda, - const int* ipiv, - hipblasDoubleComplex* B, - const int ldb, - int* info); - -// getrs_batched -hipblasStatus_t hipblasSgetrsBatchedFortran(hipblasHandle_t handle, - const hipblasOperation_t trans, - const int n, - const int nrhs, - float* const A[], - const int lda, - const int* ipiv, - float* const B[], - const int ldb, - int* info, - const int batch_count); - -hipblasStatus_t hipblasDgetrsBatchedFortran(hipblasHandle_t handle, - const hipblasOperation_t trans, - const int n, - const int nrhs, - double* const A[], - const int lda, - const int* ipiv, - double* const B[], - const int ldb, - int* info, - const int batch_count); - -hipblasStatus_t hipblasCgetrsBatchedFortran(hipblasHandle_t handle, - const hipblasOperation_t trans, - const int n, - const int nrhs, - hipblasComplex* const A[], - const int lda, - const int* ipiv, - hipblasComplex* const B[], - const int ldb, - int* info, - const int batch_count); - -hipblasStatus_t hipblasZgetrsBatchedFortran(hipblasHandle_t handle, - const hipblasOperation_t trans, - const int n, - const int nrhs, - hipblasDoubleComplex* const A[], - const int lda, - const int* ipiv, - hipblasDoubleComplex* const B[], - const int ldb, - int* info, - const int batch_count); - -// getrs_strided_batched -hipblasStatus_t hipblasSgetrsStridedBatchedFortran(hipblasHandle_t handle, - const hipblasOperation_t trans, - const int n, - const int nrhs, - float* A, - const int lda, - const int strideA, - const int* ipiv, - const int strideP, - float* B, - const int ldb, - const int strideB, - int* info, - const int batch_count); - -hipblasStatus_t hipblasDgetrsStridedBatchedFortran(hipblasHandle_t handle, - const hipblasOperation_t trans, - const int n, - const int nrhs, - double* A, - const int lda, - const int strideA, - const int* ipiv, - const int strideP, - double* B, - const int ldb, - const int strideB, - int* info, - const int batch_count); - -hipblasStatus_t hipblasCgetrsStridedBatchedFortran(hipblasHandle_t handle, - const hipblasOperation_t trans, - const int n, - const int nrhs, - hipblasComplex* A, - const int lda, - const int strideA, - const int* ipiv, - const int strideP, - hipblasComplex* B, - const int ldb, - const int strideB, - int* info, - const int batch_count); - -hipblasStatus_t hipblasZgetrsStridedBatchedFortran(hipblasHandle_t handle, - const hipblasOperation_t trans, - const int n, - const int nrhs, - hipblasDoubleComplex* A, - const int lda, - const int strideA, - const int* ipiv, - const int strideP, - hipblasDoubleComplex* B, - const int ldb, - const int strideB, - int* info, - const int batch_count); - -// geqrf -hipblasStatus_t hipblasSgeqrfFortran(hipblasHandle_t handle, - const int m, - const int n, - float* A, - const int lda, - float* ipiv, - int* info); - -hipblasStatus_t hipblasDgeqrfFortran(hipblasHandle_t handle, - const int m, - const int n, - double* A, - const int lda, - double* ipiv, - int* info); - -hipblasStatus_t hipblasCgeqrfFortran(hipblasHandle_t handle, - const int m, - const int n, - hipblasComplex* A, - const int lda, - hipblasComplex* ipiv, - int* info); - -hipblasStatus_t hipblasZgeqrfFortran(hipblasHandle_t handle, - const int m, - const int n, - hipblasDoubleComplex* A, - const int lda, - hipblasDoubleComplex* ipiv, - int* info); - -// geqrf_batched -hipblasStatus_t hipblasSgeqrfBatchedFortran(hipblasHandle_t handle, - const int m, - const int n, - float* const A[], - const int lda, - float* const ipiv[], - int* info, - const int batch_count); - -hipblasStatus_t hipblasDgeqrfBatchedFortran(hipblasHandle_t handle, - const int m, - const int n, - double* const A[], - const int lda, - double* const ipiv[], - int* info, - const int batch_count); - -hipblasStatus_t hipblasCgeqrfBatchedFortran(hipblasHandle_t handle, - const int m, - const int n, - hipblasComplex* const A[], - const int lda, - hipblasComplex* const ipiv[], - int* info, - const int batch_count); - -hipblasStatus_t hipblasZgeqrfBatchedFortran(hipblasHandle_t handle, - const int m, - const int n, - hipblasDoubleComplex* const A[], - const int lda, - hipblasDoubleComplex* const ipiv[], - int* info, - const int batch_count); - -// geqrf_strided_batched -hipblasStatus_t hipblasSgeqrfStridedBatchedFortran(hipblasHandle_t handle, - const int m, - const int n, - float* A, - const int lda, - const int strideA, - float* ipiv, - const int strideP, - int* info, - const int batch_count); - -hipblasStatus_t hipblasDgeqrfStridedBatchedFortran(hipblasHandle_t handle, - const int m, - const int n, - double* A, - const int lda, - const int strideA, - double* ipiv, - const int strideP, - int* info, - const int batch_count); - -hipblasStatus_t hipblasCgeqrfStridedBatchedFortran(hipblasHandle_t handle, - const int m, - const int n, - hipblasComplex* A, - const int lda, - const int strideA, - hipblasComplex* ipiv, - const int strideP, - int* info, - const int batch_count); - -hipblasStatus_t hipblasZgeqrfStridedBatchedFortran(hipblasHandle_t handle, - const int m, - const int n, - hipblasDoubleComplex* A, - const int lda, - const int strideA, - hipblasDoubleComplex* ipiv, - const int strideP, - int* info, - const int batch_count); - // gemm hipblasStatus_t hipblasHgemmFortran(hipblasHandle_t handle, hipblasOperation_t transa, @@ -6799,6 +6438,384 @@ hipblasStatus_t hipblasTrsmStridedBatchedExFortran(hipblasHandle_t handle, int invA_size, int stride_invA, hipblasDatatype_t compute_type); + +/* ========== + * Solver + * ========== */ + +// getrf +hipblasStatus_t hipblasSgetrfFortran(hipblasHandle_t handle, + const int n, + float* A, + const int lda, + int* ipiv, + int* info); + +hipblasStatus_t hipblasDgetrfFortran(hipblasHandle_t handle, + const int n, + double* A, + const int lda, + int* ipiv, + int* info); + +hipblasStatus_t hipblasCgetrfFortran(hipblasHandle_t handle, + const int n, + hipblasComplex* A, + const int lda, + int* ipiv, + int* info); + +hipblasStatus_t hipblasZgetrfFortran(hipblasHandle_t handle, + const int n, + hipblasDoubleComplex* A, + const int lda, + int* ipiv, + int* info); + +// getrf_batched +hipblasStatus_t hipblasSgetrfBatchedFortran(hipblasHandle_t handle, + const int n, + float* const A[], + const int lda, + int* ipiv, + int* info, + const int batch_count); + +hipblasStatus_t hipblasDgetrfBatchedFortran(hipblasHandle_t handle, + const int n, + double* const A[], + const int lda, + int* ipiv, + int* info, + const int batch_count); + +hipblasStatus_t hipblasCgetrfBatchedFortran(hipblasHandle_t handle, + const int n, + hipblasComplex* const A[], + const int lda, + int* ipiv, + int* info, + const int batch_count); + +hipblasStatus_t hipblasZgetrfBatchedFortran(hipblasHandle_t handle, + const int n, + hipblasDoubleComplex* const A[], + const int lda, + int* ipiv, + int* info, + const int batch_count); + +// getrf_strided_batched +hipblasStatus_t hipblasSgetrfStridedBatchedFortran(hipblasHandle_t handle, + const int n, + float* A, + const int lda, + const int stride_A, + int* ipiv, + const int stride_P, + int* info, + const int batch_count); + +hipblasStatus_t hipblasDgetrfStridedBatchedFortran(hipblasHandle_t handle, + const int n, + double* A, + const int lda, + const int stride_A, + int* ipiv, + const int stride_P, + int* info, + const int batch_count); + +hipblasStatus_t hipblasCgetrfStridedBatchedFortran(hipblasHandle_t handle, + const int n, + hipblasComplex* A, + const int lda, + const int stride_A, + int* ipiv, + const int stride_P, + int* info, + const int batch_count); + +hipblasStatus_t hipblasZgetrfStridedBatchedFortran(hipblasHandle_t handle, + const int n, + hipblasDoubleComplex* A, + const int lda, + const int stride_A, + int* ipiv, + const int stride_P, + int* info, + const int batch_count); + +// getrs +hipblasStatus_t hipblasSgetrsFortran(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + float* A, + const int lda, + const int* ipiv, + float* B, + const int ldb, + int* info); + +hipblasStatus_t hipblasDgetrsFortran(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + double* A, + const int lda, + const int* ipiv, + double* B, + const int ldb, + int* info); + +hipblasStatus_t hipblasCgetrsFortran(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + hipblasComplex* A, + const int lda, + const int* ipiv, + hipblasComplex* B, + const int ldb, + int* info); + +hipblasStatus_t hipblasZgetrsFortran(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + hipblasDoubleComplex* A, + const int lda, + const int* ipiv, + hipblasDoubleComplex* B, + const int ldb, + int* info); + +// getrs_batched +hipblasStatus_t hipblasSgetrsBatchedFortran(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + float* const A[], + const int lda, + const int* ipiv, + float* const B[], + const int ldb, + int* info, + const int batch_count); + +hipblasStatus_t hipblasDgetrsBatchedFortran(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + double* const A[], + const int lda, + const int* ipiv, + double* const B[], + const int ldb, + int* info, + const int batch_count); + +hipblasStatus_t hipblasCgetrsBatchedFortran(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + hipblasComplex* const A[], + const int lda, + const int* ipiv, + hipblasComplex* const B[], + const int ldb, + int* info, + const int batch_count); + +hipblasStatus_t hipblasZgetrsBatchedFortran(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + hipblasDoubleComplex* const A[], + const int lda, + const int* ipiv, + hipblasDoubleComplex* const B[], + const int ldb, + int* info, + const int batch_count); + +// getrs_strided_batched +hipblasStatus_t hipblasSgetrsStridedBatchedFortran(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + float* A, + const int lda, + const int stride_A, + const int* ipiv, + const int stride_P, + float* B, + const int ldb, + const int stride_B, + int* info, + const int batch_count); + +hipblasStatus_t hipblasDgetrsStridedBatchedFortran(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + double* A, + const int lda, + const int stride_A, + const int* ipiv, + const int stride_P, + double* B, + const int ldb, + const int stride_B, + int* info, + const int batch_count); + +hipblasStatus_t hipblasCgetrsStridedBatchedFortran(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + hipblasComplex* A, + const int lda, + const int stride_A, + const int* ipiv, + const int stride_P, + hipblasComplex* B, + const int ldb, + const int stride_B, + int* info, + const int batch_count); + +hipblasStatus_t hipblasZgetrsStridedBatchedFortran(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + hipblasDoubleComplex* A, + const int lda, + const int stride_A, + const int* ipiv, + const int stride_P, + hipblasDoubleComplex* B, + const int ldb, + const int stride_B, + int* info, + const int batch_count); + +// geqrf +hipblasStatus_t hipblasSgeqrfFortran(hipblasHandle_t handle, + const int m, + const int n, + float* A, + const int lda, + float* tau, + int* info); + +hipblasStatus_t hipblasDgeqrfFortran(hipblasHandle_t handle, + const int m, + const int n, + double* A, + const int lda, + double* tau, + int* info); + +hipblasStatus_t hipblasCgeqrfFortran(hipblasHandle_t handle, + const int m, + const int n, + hipblasComplex* A, + const int lda, + hipblasComplex* tau, + int* info); + +hipblasStatus_t hipblasZgeqrfFortran(hipblasHandle_t handle, + const int m, + const int n, + hipblasDoubleComplex* A, + const int lda, + hipblasDoubleComplex* tau, + int* info); + +// geqrf_batched +hipblasStatus_t hipblasSgeqrfBatchedFortran(hipblasHandle_t handle, + const int m, + const int n, + float* const A[], + const int lda, + float* const tau[], + int* info, + const int batch_count); + +hipblasStatus_t hipblasDgeqrfBatchedFortran(hipblasHandle_t handle, + const int m, + const int n, + double* const A[], + const int lda, + double* const tau[], + int* info, + const int batch_count); + +hipblasStatus_t hipblasCgeqrfBatchedFortran(hipblasHandle_t handle, + const int m, + const int n, + hipblasComplex* const A[], + const int lda, + hipblasComplex* const tau[], + int* info, + const int batch_count); + +hipblasStatus_t hipblasZgeqrfBatchedFortran(hipblasHandle_t handle, + const int m, + const int n, + hipblasDoubleComplex* const A[], + const int lda, + hipblasDoubleComplex* const tau[], + int* info, + const int batch_count); + +// geqrf_strided_batched +hipblasStatus_t hipblasSgeqrfStridedBatchedFortran(hipblasHandle_t handle, + const int m, + const int n, + float* A, + const int lda, + const int stride_A, + float* tau, + const int stride_T, + int* info, + const int batch_count); + +hipblasStatus_t hipblasDgeqrfStridedBatchedFortran(hipblasHandle_t handle, + const int m, + const int n, + double* A, + const int lda, + const int stride_A, + double* tau, + const int stride_T, + int* info, + const int batch_count); + +hipblasStatus_t hipblasCgeqrfStridedBatchedFortran(hipblasHandle_t handle, + const int m, + const int n, + hipblasComplex* A, + const int lda, + const int stride_A, + hipblasComplex* tau, + const int stride_T, + int* info, + const int batch_count); + +hipblasStatus_t hipblasZgeqrfStridedBatchedFortran(hipblasHandle_t handle, + const int m, + const int n, + hipblasDoubleComplex* A, + const int lda, + const int stride_A, + hipblasDoubleComplex* tau, + const int stride_T, + int* info, + const int batch_count); + } #endif diff --git a/clients/include/hipblas_fortran_solver.f90 b/clients/include/hipblas_fortran_solver.f90 new file mode 100644 index 000000000..f6d5958f3 --- /dev/null +++ b/clients/include/hipblas_fortran_solver.f90 @@ -0,0 +1,748 @@ +module hipblas_interface + use iso_c_binding + use hipblas + + contains + + !--------! + ! Solver ! + !--------! + + ! getrf + function hipblasSgetrfFortran(handle, n, A, lda, ipiv, info) & + result(res) & + bind(c, name = 'hipblasSgetrfFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: info + integer(c_int) :: res + res = hipblasSgetrf(handle, n, A, lda, ipiv, info) + end function hipblasSgetrfFortran + + function hipblasDgetrfFortran(handle, n, A, lda, ipiv, info) & + result(res) & + bind(c, name = 'hipblasDgetrfFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: info + integer(c_int) :: res + res = hipblasDgetrf(handle, n, A, lda, ipiv, info) + end function hipblasDgetrfFortran + + function hipblasCgetrfFortran(handle, n, A, lda, ipiv, info) & + result(res) & + bind(c, name = 'hipblasCgetrfFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: info + integer(c_int) :: res + res = hipblasCgetrf(handle, n, A, lda, ipiv, info) + end function hipblasCgetrfFortran + + function hipblasZgetrfFortran(handle, n, A, lda, ipiv, info) & + result(res) & + bind(c, name = 'hipblasZgetrfFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: info + integer(c_int) :: res + res = hipblasZgetrf(handle, n, A, lda, ipiv, info) + end function hipblasZgetrfFortran + + ! getrf_batched + function hipblasSgetrfBatchedFortran(handle, n, A, lda, ipiv, info, batch_count) & + result(res) & + bind(c, name = 'hipblasSgetrfBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasSgetrfBatched(handle, n, A, lda, ipiv, info, batch_count) + end function hipblasSgetrfBatchedFortran + + function hipblasDgetrfBatchedFortran(handle, n, A, lda, ipiv, info, batch_count) & + result(res) & + bind(c, name = 'hipblasDgetrfBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasDgetrfBatched(handle, n, A, lda, ipiv, info, batch_count) + end function hipblasDgetrfBatchedFortran + + function hipblasCgetrfBatchedFortran(handle, n, A, lda, ipiv, info, batch_count) & + result(res) & + bind(c, name = 'hipblasCgetrfBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasCgetrfBatched(handle, n, A, lda, ipiv, info, batch_count) + end function hipblasCgetrfBatchedFortran + + function hipblasZgetrfBatchedFortran(handle, n, A, lda, ipiv, info, batch_count) & + result(res) & + bind(c, name = 'hipblasZgetrfBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasZgetrfBatched(handle, n, A, lda, ipiv, info, batch_count) + end function hipblasZgetrfBatchedFortran + + ! getrf_strided_batched + function hipblasSgetrfStridedBatchedFortran(handle, n, A, lda, stride_A,& + ipiv, stride_P, info, batch_count) & + result(res) & + bind(c, name = 'hipblasSgetrfStridedBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: ipiv + integer(c_int), value :: stride_P + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasSgetrfStridedBatched(handle, n, A, lda, stride_A,& + ipiv, stride_P, info, batch_count) + end function hipblasSgetrfStridedBatchedFortran + + function hipblasDgetrfStridedBatchedFortran(handle, n, A, lda, stride_A,& + ipiv, stride_P, info, batch_count) & + result(res) & + bind(c, name = 'hipblasDgetrfStridedBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: ipiv + integer(c_int), value :: stride_P + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasDgetrfStridedBatched(handle, n, A, lda, stride_A,& + ipiv, stride_P, info, batch_count) + end function hipblasDgetrfStridedBatchedFortran + + function hipblasCgetrfStridedBatchedFortran(handle, n, A, lda, stride_A,& + ipiv, stride_P, info, batch_count) & + result(res) & + bind(c, name = 'hipblasCgetrfStridedBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: ipiv + integer(c_int), value :: stride_P + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasCgetrfStridedBatched(handle, n, A, lda, stride_A,& + ipiv, stride_P, info, batch_count) + end function hipblasCgetrfStridedBatchedFortran + + function hipblasZgetrfStridedBatchedFortran(handle, n, A, lda, stride_A,& + ipiv, stride_P, info, batch_count) & + result(res) & + bind(c, name = 'hipblasZgetrfStridedBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: ipiv + integer(c_int), value :: stride_P + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasZgetrfStridedBatched(handle, n, A, lda, stride_A,& + ipiv, stride_P, info, batch_count) + end function hipblasZgetrfStridedBatchedFortran + + ! getrs + function hipblasSgetrsFortran(handle, trans, n, nrhs, A, lda, ipiv,& + B, ldb, info) & + result(res) & + bind(c, name = 'hipblasSgetrsFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: info + integer(c_int) :: res + res = hipblasSgetrs(handle, trans, n, nrhs, A, lda,& + ipiv, B, ldb, info) + end function hipblasSgetrsFortran + + function hipblasDgetrsFortran(handle, trans, n, nrhs, A, lda, ipiv,& + B, ldb, info) & + result(res) & + bind(c, name = 'hipblasDgetrsFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: info + integer(c_int) :: res + res = hipblasDgetrs(handle, trans, n, nrhs, A, lda,& + ipiv, B, ldb, info) + end function hipblasDgetrsFortran + + function hipblasCgetrsFortran(handle, trans, n, nrhs, A, lda, ipiv,& + B, ldb, info) & + result(res) & + bind(c, name = 'hipblasCgetrsFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: info + integer(c_int) :: res + res = hipblasCgetrs(handle, trans, n, nrhs, A, lda,& + ipiv, B, ldb, info) + end function hipblasCgetrsFortran + + function hipblasZgetrsFortran(handle, trans, n, nrhs, A, lda, ipiv,& + B, ldb, info) & + result(res) & + bind(c, name = 'hipblasZgetrsFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: info + integer(c_int) :: res + res = hipblasZgetrs(handle, trans, n, nrhs, A, lda,& + ipiv, B, ldb, info) + end function hipblasZgetrsFortran + + ! getrs_batched + function hipblasSgetrsBatchedFortran(handle, trans, n, nrhs, A, lda, ipiv,& + B, ldb, info, batch_count) & + result(res) & + bind(c, name = 'hipblasSgetrsBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasSgetrsBatched(handle, trans, n, nrhs, A, lda,& + ipiv, B, ldb, info, batch_count) + end function hipblasSgetrsBatchedFortran + + function hipblasDgetrsBatchedFortran(handle, trans, n, nrhs, A, lda, ipiv,& + B, ldb, info, batch_count) & + result(res) & + bind(c, name = 'hipblasDgetrsBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasDgetrsBatched(handle, trans, n, nrhs, A, lda,& + ipiv, B, ldb, info, batch_count) + end function hipblasDgetrsBatchedFortran + + function hipblasCgetrsBatchedFortran(handle, trans, n, nrhs, A, lda, ipiv,& + B, ldb, info, batch_count) & + result(res) & + bind(c, name = 'hipblasCgetrsBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasCgetrsBatched(handle, trans, n, nrhs, A, lda,& + ipiv, B, ldb, info, batch_count) + end function hipblasCgetrsBatchedFortran + + function hipblasZgetrsBatchedFortran(handle, trans, n, nrhs, A, lda, ipiv,& + B, ldb, info, batch_count) & + result(res) & + bind(c, name = 'hipblasZgetrsBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasZgetrsBatched(handle, trans, n, nrhs, A, lda,& + ipiv, B, ldb, info, batch_count) + end function hipblasZgetrsBatchedFortran + + ! getrs_strided_batched + function hipblasSgetrsStridedBatchedFortran(handle, trans, n, nrhs, A, lda, stride_A, ipiv,& + stride_P, B, ldb, stride_B, info, batch_count) & + result(res) & + bind(c, name = 'hipblasSgetrsStridedBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: ipiv + integer(c_int), value :: stride_P + type(c_ptr), value :: B + integer(c_int), value :: ldb + integer(c_int), value :: stride_B + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasSgetrsStridedBatched(handle, trans, n, nrhs, A, lda, stride_A,& + ipiv, stride_P, B, ldb, stride_B, info, batch_count) + end function hipblasSgetrsStridedBatchedFortran + + function hipblasDgetrsStridedBatchedFortran(handle, trans, n, nrhs, A, lda, stride_A, ipiv,& + stride_P, B, ldb, stride_B, info, batch_count) & + result(res) & + bind(c, name = 'hipblasDgetrsStridedBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: ipiv + integer(c_int), value :: stride_P + type(c_ptr), value :: B + integer(c_int), value :: ldb + integer(c_int), value :: stride_B + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasDgetrsStridedBatched(handle, trans, n, nrhs, A, lda, stride_A,& + ipiv, stride_P, B, ldb, stride_B, info, batch_count) + end function hipblasDgetrsStridedBatchedFortran + + function hipblasCgetrsStridedBatchedFortran(handle, trans, n, nrhs, A, lda, stride_A, ipiv,& + stride_P, B, ldb, stride_B, info, batch_count) & + result(res) & + bind(c, name = 'hipblasCgetrsStridedBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: ipiv + integer(c_int), value :: stride_P + type(c_ptr), value :: B + integer(c_int), value :: ldb + integer(c_int), value :: stride_B + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasCgetrsStridedBatched(handle, trans, n, nrhs, A, lda, stride_A,& + ipiv, stride_P, B, ldb, stride_B, info, batch_count) + end function hipblasCgetrsStridedBatchedFortran + + function hipblasZgetrsStridedBatchedFortran(handle, trans, n, nrhs, A, lda, stride_A, ipiv,& + stride_P, B, ldb, stride_B, info, batch_count) & + result(res) & + bind(c, name = 'hipblasZgetrsStridedBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: ipiv + integer(c_int), value :: stride_P + type(c_ptr), value :: B + integer(c_int), value :: ldb + integer(c_int), value :: stride_B + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasZgetrsStridedBatched(handle, trans, n, nrhs, A, lda, stride_A,& + ipiv, stride_P, B, ldb, stride_B, info, batch_count) + end function hipblasZgetrsStridedBatchedFortran + + ! geqrf + function hipblasSgeqrfFortran(handle, m, n, A, lda, tau, info) & + result(res) & + bind(c, name = 'hipblasSgeqrfFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: tau + type(c_ptr), value :: info + integer(c_int) :: res + res = hipblasSgeqrf(handle, m, n, A, lda, tau, info) + end function hipblasSgeqrfFortran + + function hipblasDgeqrfFortran(handle, m, n, A, lda, tau, info) & + result(res) & + bind(c, name = 'hipblasDgeqrfFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: tau + type(c_ptr), value :: info + integer(c_int) :: res + res = hipblasDgeqrf(handle, m, n, A, lda, tau, info) + end function hipblasDgeqrfFortran + + function hipblasCgeqrfFortran(handle, m, n, A, lda, tau, info) & + result(res) & + bind(c, name = 'hipblasCgeqrfFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: tau + type(c_ptr), value :: info + integer(c_int) :: res + res = hipblasCgeqrf(handle, m, n, A, lda, tau, info) + end function hipblasCgeqrfFortran + + function hipblasZgeqrfFortran(handle, m, n, A, lda, tau, info) & + result(res) & + bind(c, name = 'hipblasZgeqrfFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: tau + type(c_ptr), value :: info + integer(c_int) :: res + res = hipblasZgeqrf(handle, m, n, A, lda, tau, info) + end function hipblasZgeqrfFortran + + ! geqrf_batched + function hipblasSgeqrfBatchedFortran(handle, m, n, A, lda, tau, info, batch_count) & + result(res) & + bind(c, name = 'hipblasSgeqrfBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: tau + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasSgeqrfBatched(handle, m, n, A, lda, tau, info, batch_count) + end function hipblasSgeqrfBatchedFortran + + function hipblasDgeqrfBatchedFortran(handle, m, n, A, lda, tau, info, batch_count) & + result(res) & + bind(c, name = 'hipblasDgeqrfBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: tau + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasDgeqrfBatched(handle, m, n, A, lda, tau, info, batch_count) + end function hipblasDgeqrfBatchedFortran + + function hipblasCgeqrfBatchedFortran(handle, m, n, A, lda, tau, info, batch_count) & + result(res) & + bind(c, name = 'hipblasCgeqrfBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: tau + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasCgeqrfBatched(handle, m, n, A, lda, tau, info, batch_count) + end function hipblasCgeqrfBatchedFortran + + function hipblasZgeqrfBatchedFortran(handle, m, n, A, lda, tau, info, batch_count) & + result(res) & + bind(c, name = 'hipblasZgeqrfBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: tau + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasZgeqrfBatched(handle, m, n, A, lda, tau, info, batch_count) + end function hipblasZgeqrfBatchedFortran + + ! geqrf_strided_batched + function hipblasSgeqrfStridedBatchedFortran(handle, m, n, A, lda, stride_A,& + tau, stride_T, info, batch_count) & + result(res) & + bind(c, name = 'hipblasSgeqrfStridedBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: tau + integer(c_int), value :: stride_T + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasSgeqrfStridedBatched(handle, m, n, A, lda, stride_A,& + tau, stride_T, info, batch_count) + end function hipblasSgeqrfStridedBatchedFortran + +function hipblasDgeqrfStridedBatchedFortran(handle, m, n, A, lda, stride_A,& + tau, stride_T, info, batch_count) & + result(res) & + bind(c, name = 'hipblasDgeqrfStridedBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: tau + integer(c_int), value :: stride_T + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasDgeqrfStridedBatched(handle, m, n, A, lda, stride_A,& + tau, stride_T, info, batch_count) +end function hipblasDgeqrfStridedBatchedFortran + + function hipblasCgeqrfStridedBatchedFortran(handle, m, n, A, lda, stride_A,& + tau, stride_T, info, batch_count) & + result(res) & + bind(c, name = 'hipblasCgeqrfStridedBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: tau + integer(c_int), value :: stride_T + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasCgeqrfStridedBatched(handle, m, n, A, lda, stride_A,& + tau, stride_T, info, batch_count) + end function hipblasCgeqrfStridedBatchedFortran + + function hipblasZgeqrfStridedBatchedFortran(handle, m, n, A, lda, stride_A,& + tau, stride_T, info, batch_count) & + result(res) & + bind(c, name = 'hipblasZgeqrfStridedBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: tau + integer(c_int), value :: stride_T + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasZgeqrfStridedBatched(handle, m, n, A, lda, stride_A,& + tau, stride_T, info, batch_count) + end function hipblasZgeqrfStridedBatchedFortran + +end module hipblas_interface \ No newline at end of file diff --git a/clients/include/testing_geqrf.hpp b/clients/include/testing_geqrf.hpp index c2c13644e..e5b49e505 100644 --- a/clients/include/testing_geqrf.hpp +++ b/clients/include/testing_geqrf.hpp @@ -20,6 +20,9 @@ using namespace std; template hipblasStatus_t testing_geqrf(Arguments argus) { + bool FORTRAN = argus.fortran; + auto hipblasGeqrfFn = FORTRAN ? hipblasGeqrf : hipblasGeqrf; + int M = argus.M; int N = argus.N; int lda = argus.lda; @@ -76,7 +79,7 @@ hipblasStatus_t testing_geqrf(Arguments argus) HIPBLAS =================================================================== */ - status = hipblasGeqrf(handle, M, N, dA, lda, dIpiv, &info); + status = hipblasGeqrfFn(handle, M, N, dA, lda, dIpiv, &info); if(status != HIPBLAS_STATUS_SUCCESS) { diff --git a/clients/include/testing_geqrf_batched.hpp b/clients/include/testing_geqrf_batched.hpp index b3ab323e7..5d9b8b5c8 100644 --- a/clients/include/testing_geqrf_batched.hpp +++ b/clients/include/testing_geqrf_batched.hpp @@ -20,6 +20,9 @@ using namespace std; template hipblasStatus_t testing_geqrf_batched(Arguments argus) { + bool FORTRAN = argus.fortran; + auto hipblasGeqrfBatchedFn = FORTRAN ? hipblasGeqrfBatched : hipblasGeqrfBatched; + int M = argus.M; int N = argus.N; int lda = argus.lda; @@ -96,7 +99,7 @@ hipblasStatus_t testing_geqrf_batched(Arguments argus) HIPBLAS =================================================================== */ - status = hipblasGeqrfBatched(handle, M, N, dA, lda, dIpiv, &info, batch_count); + status = hipblasGeqrfBatchedFn(handle, M, N, dA, lda, dIpiv, &info, batch_count); if(status != HIPBLAS_STATUS_SUCCESS) { diff --git a/clients/include/testing_geqrf_strided_batched.hpp b/clients/include/testing_geqrf_strided_batched.hpp index 243923aa9..654db1c49 100644 --- a/clients/include/testing_geqrf_strided_batched.hpp +++ b/clients/include/testing_geqrf_strided_batched.hpp @@ -20,6 +20,9 @@ using namespace std; template hipblasStatus_t testing_geqrf_strided_batched(Arguments argus) { + bool FORTRAN = argus.fortran; + auto hipblasGeqrfStridedBatchedFn = FORTRAN ? hipblasGeqrfStridedBatched : hipblasGeqrfStridedBatched; + int M = argus.M; int N = argus.N; int lda = argus.lda; @@ -89,7 +92,7 @@ hipblasStatus_t testing_geqrf_strided_batched(Arguments argus) HIPBLAS =================================================================== */ - status = hipblasGeqrfStridedBatched( + status = hipblasGeqrfStridedBatchedFn( handle, M, N, dA, lda, strideA, dIpiv, strideP, &info, batch_count); if(status != HIPBLAS_STATUS_SUCCESS) diff --git a/clients/include/testing_getrf.hpp b/clients/include/testing_getrf.hpp index 3ef2a61c7..af61325af 100644 --- a/clients/include/testing_getrf.hpp +++ b/clients/include/testing_getrf.hpp @@ -20,6 +20,9 @@ using namespace std; template hipblasStatus_t testing_getrf(Arguments argus) { + bool FORTRAN = argus.fortran; + auto hipblasGetrfFn = FORTRAN ? hipblasGetrf : hipblasGetrf; + int M = argus.N; int N = argus.N; int lda = argus.lda; @@ -79,7 +82,7 @@ hipblasStatus_t testing_getrf(Arguments argus) HIPBLAS =================================================================== */ - status = hipblasGetrf(handle, N, dA, lda, dIpiv, dInfo); + status = hipblasGetrfFn(handle, N, dA, lda, dIpiv, dInfo); if(status != HIPBLAS_STATUS_SUCCESS) { diff --git a/clients/include/testing_getrf_batched.hpp b/clients/include/testing_getrf_batched.hpp index 017c11d7a..9518e29bd 100644 --- a/clients/include/testing_getrf_batched.hpp +++ b/clients/include/testing_getrf_batched.hpp @@ -20,6 +20,9 @@ using namespace std; template hipblasStatus_t testing_getrf_batched(Arguments argus) { + bool FORTRAN = argus.fortran; + auto hipblasGetrfBatchedFn = FORTRAN ? hipblasGetrfBatched : hipblasGetrfBatched; + int M = argus.N; int N = argus.N; int lda = argus.lda; @@ -95,7 +98,7 @@ hipblasStatus_t testing_getrf_batched(Arguments argus) HIPBLAS =================================================================== */ - status = hipblasGetrfBatched(handle, N, dA, lda, dIpiv, dInfo, batch_count); + status = hipblasGetrfBatchedFn(handle, N, dA, lda, dIpiv, dInfo, batch_count); if(status != HIPBLAS_STATUS_SUCCESS) { diff --git a/clients/include/testing_getrf_strided_batched.hpp b/clients/include/testing_getrf_strided_batched.hpp index f916c2939..30156266b 100644 --- a/clients/include/testing_getrf_strided_batched.hpp +++ b/clients/include/testing_getrf_strided_batched.hpp @@ -20,6 +20,9 @@ using namespace std; template hipblasStatus_t testing_getrf_strided_batched(Arguments argus) { + bool FORTRAN = argus.fortran; + auto hipblasGetrfStridedBatchedFn = FORTRAN ? hipblasGetrfStridedBatched : hipblasGetrfStridedBatched; + int M = argus.N; int N = argus.N; int lda = argus.lda; @@ -92,7 +95,7 @@ hipblasStatus_t testing_getrf_strided_batched(Arguments argus) HIPBLAS =================================================================== */ - status = hipblasGetrfStridedBatched( + status = hipblasGetrfStridedBatchedFn( handle, N, dA, lda, strideA, dIpiv, strideP, dInfo, batch_count); if(status != HIPBLAS_STATUS_SUCCESS) diff --git a/clients/include/testing_getrs.hpp b/clients/include/testing_getrs.hpp index 7da1f6607..78757657b 100644 --- a/clients/include/testing_getrs.hpp +++ b/clients/include/testing_getrs.hpp @@ -20,6 +20,9 @@ using namespace std; template hipblasStatus_t testing_getrs(Arguments argus) { + bool FORTRAN = argus.fortran; + auto hipblasGetrsFn = FORTRAN ? hipblasGetrs : hipblasGetrs; + int N = argus.N; int lda = argus.lda; int ldb = argus.ldb; @@ -94,7 +97,7 @@ hipblasStatus_t testing_getrs(Arguments argus) HIPBLAS =================================================================== */ - status = hipblasGetrs(handle, op, N, 1, dA, lda, dIpiv, dB, ldb, &info); + status = hipblasGetrsFn(handle, op, N, 1, dA, lda, dIpiv, dB, ldb, &info); if(status != HIPBLAS_STATUS_SUCCESS) { diff --git a/clients/include/testing_getrs_batched.hpp b/clients/include/testing_getrs_batched.hpp index dd4e1139f..0e50ca900 100644 --- a/clients/include/testing_getrs_batched.hpp +++ b/clients/include/testing_getrs_batched.hpp @@ -20,6 +20,9 @@ using namespace std; template hipblasStatus_t testing_getrs_batched(Arguments argus) { + bool FORTRAN = argus.fortran; + auto hipblasGetrsBatchedFn = FORTRAN ? hipblasGetrsBatched : hipblasGetrsBatched; + int N = argus.N; int lda = argus.lda; int ldb = argus.ldb; @@ -116,7 +119,7 @@ hipblasStatus_t testing_getrs_batched(Arguments argus) HIPBLAS =================================================================== */ - status = hipblasGetrsBatched(handle, op, N, 1, dA, lda, dIpiv, dB, ldb, &info, batch_count); + status = hipblasGetrsBatchedFn(handle, op, N, 1, dA, lda, dIpiv, dB, ldb, &info, batch_count); if(status != HIPBLAS_STATUS_SUCCESS) { diff --git a/clients/include/testing_getrs_strided_batched.hpp b/clients/include/testing_getrs_strided_batched.hpp index 95277e362..df760f0ac 100644 --- a/clients/include/testing_getrs_strided_batched.hpp +++ b/clients/include/testing_getrs_strided_batched.hpp @@ -20,6 +20,9 @@ using namespace std; template hipblasStatus_t testing_getrs_strided_batched(Arguments argus) { + bool FORTRAN = argus.fortran; + auto hipblasGetrsStridedBatchedFn = FORTRAN ? hipblasGetrsStridedBatched : hipblasGetrsStridedBatched; + int N = argus.N; int lda = argus.lda; int ldb = argus.ldb; @@ -111,7 +114,7 @@ hipblasStatus_t testing_getrs_strided_batched(Arguments argus) HIPBLAS =================================================================== */ - status = hipblasGetrsStridedBatched( + status = hipblasGetrsStridedBatchedFn( handle, op, N, 1, dA, lda, strideA, dIpiv, strideP, dB, ldb, strideB, &info, batch_count); if(status != HIPBLAS_STATUS_SUCCESS) diff --git a/library/src/hipblas_module.f90 b/library/src/hipblas_module.f90 index 25ebf7ee7..962aaf7f1 100644 --- a/library/src/hipblas_module.f90 +++ b/library/src/hipblas_module.f90 @@ -11871,4 +11871,725 @@ function hipblasTrsmStridedBatchedEx(handle, side, uplo, transA, diag, m, n, alp end function hipblasTrsmStridedBatchedEx end interface + !--------! + ! Solver ! + !--------! + + ! getrf + interface + function hipblasSgetrf(handle, n, A, lda, ipiv, info) & + result(c_int) & + bind(c, name = 'hipblasSgetrf') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: info + end function hipblasSgetrf + end interface + + interface + function hipblasDgetrf(handle, n, A, lda, ipiv, info) & + result(c_int) & + bind(c, name = 'hipblasDgetrf') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: info + end function hipblasDgetrf + end interface + + interface + function hipblasCgetrf(handle, n, A, lda, ipiv, info) & + result(c_int) & + bind(c, name = 'hipblasCgetrf') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: info + end function hipblasCgetrf + end interface + + interface + function hipblasZgetrf(handle, n, A, lda, ipiv, info) & + result(c_int) & + bind(c, name = 'hipblasZgetrf') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: info + end function hipblasZgetrf + end interface + + ! getrf_batched + interface + function hipblasSgetrfBatched(handle, n, A, lda, ipiv, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasSgetrfBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasSgetrfBatched + end interface + + interface + function hipblasDgetrfBatched(handle, n, A, lda, ipiv, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasDgetrfBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasDgetrfBatched + end interface + + interface + function hipblasCgetrfBatched(handle, n, A, lda, ipiv, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasCgetrfBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasCgetrfBatched + end interface + + interface + function hipblasZgetrfBatched(handle, n, A, lda, ipiv, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasZgetrfBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasZgetrfBatched + end interface + + ! getrf_strided_batched + interface + function hipblasSgetrfStridedBatched(handle, n, A, lda, stride_A,& + ipiv, stride_P, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasSgetrfStridedBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: ipiv + integer(c_int), value :: stride_P + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasSgetrfStridedBatched + end interface + + interface + function hipblasDgetrfStridedBatched(handle, n, A, lda, stride_A,& + ipiv, stride_P, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasDgetrfStridedBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: ipiv + integer(c_int), value :: stride_P + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasDgetrfStridedBatched + end interface + + interface + function hipblasCgetrfStridedBatched(handle, n, A, lda, stride_A,& + ipiv, stride_P, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasCgetrfStridedBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: ipiv + integer(c_int), value :: stride_P + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasCgetrfStridedBatched + end interface + + interface + function hipblasZgetrfStridedBatched(handle, n, A, lda, stride_A,& + ipiv, stride_P, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasZgetrfStridedBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: ipiv + integer(c_int), value :: stride_P + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasZgetrfStridedBatched + end interface + + ! getrs + interface + function hipblasSgetrs(handle, trans, n, nrhs, A, lda, ipiv,& + B, ldb, info) & + result(c_int) & + bind(c, name = 'hipblasSgetrs') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: info + end function hipblasSgetrs + end interface + + interface + function hipblasDgetrs(handle, trans, n, nrhs, A, lda, ipiv,& + B, ldb, info) & + result(c_int) & + bind(c, name = 'hipblasDgetrs') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: info + end function hipblasDgetrs + end interface + + interface + function hipblasCgetrs(handle, trans, n, nrhs, A, lda, ipiv,& + B, ldb, info) & + result(c_int) & + bind(c, name = 'hipblasCgetrs') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: info + end function hipblasCgetrs + end interface + + interface + function hipblasZgetrs(handle, trans, n, nrhs, A, lda, ipiv,& + B, ldb, info) & + result(c_int) & + bind(c, name = 'hipblasZgetrs') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: info + end function hipblasZgetrs + end interface + + ! getrs_batched + interface + function hipblasSgetrsBatched(handle, trans, n, nrhs, A, lda, ipiv,& + B, ldb, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasSgetrsBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasSgetrsBatched + end interface + + interface + function hipblasDgetrsBatched(handle, trans, n, nrhs, A, lda, ipiv,& + B, ldb, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasDgetrsBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasDgetrsBatched + end interface + + interface + function hipblasCgetrsBatched(handle, trans, n, nrhs, A, lda, ipiv,& + B, ldb, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasCgetrsBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasCgetrsBatched + end interface + + interface + function hipblasZgetrsBatched(handle, trans, n, nrhs, A, lda, ipiv,& + B, ldb, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasZgetrsBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasZgetrsBatched + end interface + + ! getrs_strided_batched + interface + function hipblasSgetrsStridedBatched(handle, trans, n, nrhs, A, lda, stride_A, ipiv,& + stride_P, B, ldb, stride_B, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasSgetrsStridedBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: ipiv + integer(c_int), value :: stride_P + type(c_ptr), value :: B + integer(c_int), value :: ldb + integer(c_int), value :: stride_B + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasSgetrsStridedBatched + end interface + + interface + function hipblasDgetrsStridedBatched(handle, trans, n, nrhs, A, lda, stride_A, ipiv,& + stride_P, B, ldb, stride_B, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasDgetrsStridedBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: ipiv + integer(c_int), value :: stride_P + type(c_ptr), value :: B + integer(c_int), value :: ldb + integer(c_int), value :: stride_B + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasDgetrsStridedBatched + end interface + + interface + function hipblasCgetrsStridedBatched(handle, trans, n, nrhs, A, lda, stride_A, ipiv,& + stride_P, B, ldb, stride_B, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasCgetrsStridedBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: ipiv + integer(c_int), value :: stride_P + type(c_ptr), value :: B + integer(c_int), value :: ldb + integer(c_int), value :: stride_B + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasCgetrsStridedBatched + end interface + + interface + function hipblasZgetrsStridedBatched(handle, trans, n, nrhs, A, lda, stride_A, ipiv,& + stride_P, B, ldb, stride_B, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasZgetrsStridedBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: ipiv + integer(c_int), value :: stride_P + type(c_ptr), value :: B + integer(c_int), value :: ldb + integer(c_int), value :: stride_B + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasZgetrsStridedBatched + end interface + + ! geqrf + interface + function hipblasSgeqrf(handle, m, n, A, lda, tau, info) & + result(c_int) & + bind(c, name = 'hipblasSgeqrf') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: tau + type(c_ptr), value :: info + end function hipblasSgeqrf + end interface + + interface + function hipblasDgeqrf(handle, m, n, A, lda, tau, info) & + result(c_int) & + bind(c, name = 'hipblasDgeqrf') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: tau + type(c_ptr), value :: info + end function hipblasDgeqrf + end interface + + interface + function hipblasCgeqrf(handle, m, n, A, lda, tau, info) & + result(c_int) & + bind(c, name = 'hipblasCgeqrf') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: tau + type(c_ptr), value :: info + end function hipblasCgeqrf + end interface + + interface + function hipblasZgeqrf(handle, m, n, A, lda, tau, info) & + result(c_int) & + bind(c, name = 'hipblasZgeqrf') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: tau + type(c_ptr), value :: info + end function hipblasZgeqrf + end interface + + ! geqrf_batched + interface + function hipblasSgeqrfBatched(handle, m, n, A, lda, tau, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasSgeqrfBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: tau + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasSgeqrfBatched + end interface + + interface + function hipblasDgeqrfBatched(handle, m, n, A, lda, tau, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasDgeqrfBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: tau + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasDgeqrfBatched + end interface + + interface + function hipblasCgeqrfBatched(handle, m, n, A, lda, tau, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasCgeqrfBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: tau + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasCgeqrfBatched + end interface + + interface + function hipblasZgeqrfBatched(handle, m, n, A, lda, tau, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasZgeqrfBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: tau + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasZgeqrfBatched + end interface + + ! geqrf_strided_batched + interface + function hipblasSgeqrfStridedBatched(handle, m, n, A, lda, stride_A,& + tau, stride_T, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasSgeqrfStridedBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: tau + integer(c_int), value :: stride_T + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasSgeqrfStridedBatched + end interface + + interface + function hipblasDgeqrfStridedBatched(handle, m, n, A, lda, stride_A,& + tau, stride_T, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasDgeqrfStridedBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: tau + integer(c_int), value :: stride_T + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasDgeqrfStridedBatched + end interface + + interface + function hipblasCgeqrfStridedBatched(handle, m, n, A, lda, stride_A,& + tau, stride_T, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasCgeqrfStridedBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: tau + integer(c_int), value :: stride_T + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasCgeqrfStridedBatched + end interface + + interface + function hipblasZgeqrfStridedBatched(handle, m, n, A, lda, stride_A,& + tau, stride_T, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasZgeqrfStridedBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: tau + integer(c_int), value :: stride_T + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasZgeqrfStridedBatched + end interface + end module hipblas \ No newline at end of file From b99abd5763aaacf5652765c3cc537c40196355ef Mon Sep 17 00:00:00 2001 From: daineAMD <51674140+daineAMD@users.noreply.github.com> Date: Fri, 12 Jun 2020 12:00:43 -0600 Subject: [PATCH 7/9] Adding async functions (#234) --- clients/gtest/set_get_matrix_gtest.cpp | 43 +- clients/gtest/set_get_vector_gtest.cpp | 38 +- clients/include/hipblas_fortran.f90 | 131 ++++ clients/include/hipblas_fortran.hpp | 666 +++++++++--------- clients/include/testing_set_get_matrix.hpp | 34 +- .../include/testing_set_get_matrix_async.hpp | 129 ++++ clients/include/testing_set_get_vector.hpp | 27 +- .../include/testing_set_get_vector_async.hpp | 109 +++ library/include/hipblas.h | 24 + library/src/hcc_detail/hipblas.cpp | 28 + library/src/hipblas_module.f90 | 68 +- library/src/nvcc_detail/hipblas.cpp | 26 + 12 files changed, 944 insertions(+), 379 deletions(-) create mode 100644 clients/include/testing_set_get_matrix_async.hpp create mode 100644 clients/include/testing_set_get_vector_async.hpp diff --git a/clients/gtest/set_get_matrix_gtest.cpp b/clients/gtest/set_get_matrix_gtest.cpp index a95b4ec3c..f120fe78b 100644 --- a/clients/gtest/set_get_matrix_gtest.cpp +++ b/clients/gtest/set_get_matrix_gtest.cpp @@ -4,6 +4,7 @@ * ************************************************************************ */ #include "testing_set_get_matrix.hpp" +#include "testing_set_get_matrix_async.hpp" #include "utility.h" #include #include @@ -20,7 +21,7 @@ using namespace std; // only GCC/VS 2010 comes with std::tr1::tuple, but it is unnecessary, std::tuple is good enough; -typedef std::tuple, vector> set_get_matrix_tuple; +typedef std::tuple, vector, bool> set_get_matrix_tuple; /* ===================================================================== README: This file contains testers to verify the correctness of @@ -66,6 +67,9 @@ const vector> lda_ldb_ldc_range = {{3, 3, 3}, {5, 5, 3}, {5, 5, 4}, {5, 5, 5}}; + +const bool is_fortran[] = {false, true}; + /* ===============Google Unit Test==================================================== */ /* ===================================================================== @@ -123,23 +127,32 @@ TEST_P(set_matrix_get_matrix_gtest, float) // if not success, then the input argument is problematic, so detect the error message if(status != HIPBLAS_STATUS_SUCCESS) { - if(arg.rows < 0) + if(arg.rows < 0 || arg.cols <= 0 || arg.lda <= 0 || arg.ldb <= 0 || arg.ldc <= 0) { EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); } - else if(arg.cols <= 0) - { - EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); - } - else if(arg.lda <= 0) - { - EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); - } - else if(arg.ldb <= 0) + else { - EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + EXPECT_EQ(HIPBLAS_STATUS_SUCCESS, status); // fail } - else if(arg.ldc <= 0) + } +} + +TEST_P(set_matrix_get_matrix_gtest, async_float) +{ + // GetParam return a tuple. Tee setup routine unpack the tuple + // and initializes arg(Arguments) which will be passed to testing routine + // The Arguments data struture have physical meaning associated. + // while the tuple is non-intuitive. + + Arguments arg = setup_set_get_matrix_arguments(GetParam()); + + hipblasStatus_t status = testing_set_get_matrix_async(arg); + + // if not success, then the input argument is problematic, so detect the error message + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.rows < 0 || arg.cols <= 0 || arg.lda <= 0 || arg.ldb <= 0 || arg.ldc <= 0) { EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); } @@ -157,4 +170,6 @@ TEST_P(set_matrix_get_matrix_gtest, float) INSTANTIATE_TEST_CASE_P(hipblasAuxiliary_small, set_matrix_get_matrix_gtest, - Combine(ValuesIn(rows_cols_range), ValuesIn(lda_ldb_ldc_range))); + Combine(ValuesIn(rows_cols_range), + ValuesIn(lda_ldb_ldc_range), + ValuesIn(is_fortran))); diff --git a/clients/gtest/set_get_vector_gtest.cpp b/clients/gtest/set_get_vector_gtest.cpp index 9580534f0..2fdda1ccf 100644 --- a/clients/gtest/set_get_vector_gtest.cpp +++ b/clients/gtest/set_get_vector_gtest.cpp @@ -4,6 +4,7 @@ * ************************************************************************ */ #include "testing_set_get_vector.hpp" +#include "testing_set_get_vector_async.hpp" #include "utility.h" #include #include @@ -18,7 +19,7 @@ using namespace std; // only GCC/VS 2010 comes with std::tr1::tuple, but it is unnecessary, std::tuple is good enough; -typedef std::tuple> set_get_vector_tuple; +typedef std::tuple, bool> set_get_vector_tuple; /* ===================================================================== README: This file contains testers to verify the correctness of @@ -55,6 +56,8 @@ const vector> incx_incy_incd_range = {{1, 1, 1}, {3, 3, 1}, {3, 3, 3}}; +const bool is_fortran[] = {false, true}; + /* ===============Google Unit Test==================================================== */ /* ===================================================================== @@ -114,19 +117,32 @@ TEST_P(set_vector_get_vector_gtest, float) // if not success, then the input argument is problematic, so detect the error message if(status != HIPBLAS_STATUS_SUCCESS) { - if(arg.M < 0) - { - EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); - } - else if(arg.incx <= 0) + if(arg.M < 0 || arg.incx <= 0 || arg.incy <= 0 || arg.incx <= 0) { EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); } - else if(arg.incy <= 0) + else { - EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + EXPECT_EQ(HIPBLAS_STATUS_SUCCESS, status); // fail } - else if(arg.incd <= 0) + } +} + +TEST_P(set_vector_get_vector_gtest, async_float) +{ + // GetParam return a tuple. Tee setup routine unpack the tuple + // and initializes arg(Arguments) which will be passed to testing routine + // The Arguments data struture have physical meaning associated. + // while the tuple is non-intuitive. + + Arguments arg = setup_set_get_vector_arguments(GetParam()); + + hipblasStatus_t status = testing_set_get_vector_async(arg); + + // if not success, then the input argument is problematic, so detect the error message + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.M < 0 || arg.incx <= 0 || arg.incy <= 0 || arg.incx <= 0) { EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); } @@ -144,4 +160,6 @@ TEST_P(set_vector_get_vector_gtest, float) INSTANTIATE_TEST_CASE_P(rocblas_auxiliary_small, set_vector_get_vector_gtest, - Combine(ValuesIn(M_range), ValuesIn(incx_incy_incd_range))); + Combine(ValuesIn(M_range), + ValuesIn(incx_incy_incd_range), + ValuesIn(is_fortran))); diff --git a/clients/include/hipblas_fortran.f90 b/clients/include/hipblas_fortran.f90 index cd1db28b2..722167d16 100644 --- a/clients/include/hipblas_fortran.f90 +++ b/clients/include/hipblas_fortran.f90 @@ -4,6 +4,137 @@ module hipblas_interface contains + !--------! + ! Aux ! + !--------! + function hipblasSetVectorFortran(n, elemSize, x, incx, y, incy) & + result(res) & + bind(c, name = 'hipblasSetVectorFortran') + use iso_c_binding + implicit none + integer(c_int), value :: n + integer(c_int), value :: elemSize + type(c_ptr), value :: x + integer(c_int), value :: incx + type(c_ptr), value :: y + integer(c_int), value :: incy + integer(c_int) :: res + res = hipblasSetVector(n, elemSize, x, incx, y, incy) + end function hipblasSetVectorFortran + + function hipblasGetVectorFortran(n, elemSize, x, incx, y, incy) & + result(res) & + bind(c, name = 'hipblasGetVectorFortran') + use iso_c_binding + implicit none + integer(c_int), value :: n + integer(c_int), value :: elemSize + type(c_ptr), value :: x + integer(c_int), value :: incx + type(c_ptr), value :: y + integer(c_int), value :: incy + integer(c_int) :: res + res = hipblasGetVector(n, elemSize, x, incx, y, incy) + end function hipblasGetVectorFortran + + function hipblasSetMatrixFortran(rows, cols, elemSize, A, lda, B, ldb) & + result(res) & + bind(c, name = 'hipblasSetMatrixFortran') + use iso_c_binding + implicit none + integer(c_int), value :: rows + integer(c_int), value :: cols + integer(c_int), value :: elemSize + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: B + integer(c_int), value :: ldb + integer(c_int) :: res + res = hipblasSetMatrix(rows, cols, elemSize, A, lda, B, ldb) + end function hipblasSetMatrixFortran + + function hipblasGetMatrixFortran(rows, cols, elemSize, A, lda, B, ldb) & + result(res) & + bind(c, name = 'hipblasGetMatrixFortran') + use iso_c_binding + implicit none + integer(c_int), value :: rows + integer(c_int), value :: cols + integer(c_int), value :: elemSize + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: B + integer(c_int), value :: ldb + integer(c_int) :: res + res = hipblasSetMatrix(rows, cols, elemSize, A, lda, B, ldb) + end function hipblasGetMatrixFortran + + function hipblasSetVectorAsyncFortran(n, elemSize, x, incx, y, incy, stream) & + result(res) & + bind(c, name = 'hipblasSetVectorAsyncFortran') + use iso_c_binding + implicit none + integer(c_int), value :: n + integer(c_int), value :: elemSize + type(c_ptr), value :: x + integer(c_int), value :: incx + type(c_ptr), value :: y + integer(c_int), value :: incy + type(c_ptr), value :: stream + integer(c_int) :: res + res = hipblasSetVectorAsync(n, elemSize, x, incx, y, incy, stream) + end function hipblasSetVectorAsyncFortran + + function hipblasGetVectorAsyncFortran(n, elemSize, x, incx, y, incy, stream) & + result(res) & + bind(c, name = 'hipblasGetVectorAsyncFortran') + use iso_c_binding + implicit none + integer(c_int), value :: n + integer(c_int), value :: elemSize + type(c_ptr), value :: x + integer(c_int), value :: incx + type(c_ptr), value :: y + integer(c_int), value :: incy + type(c_ptr), value :: stream + integer(c_int) :: res + res = hipblasGetVectorAsync(n, elemSize, x, incx, y, incy, stream) + end function hipblasGetVectorAsyncFortran + + function hipblasSetMatrixAsyncFortran(rows, cols, elemSize, A, lda, B, ldb, stream) & + result(res) & + bind(c, name = 'hipblasSetMatrixAsyncFortran') + use iso_c_binding + implicit none + integer(c_int), value :: rows + integer(c_int), value :: cols + integer(c_int), value :: elemSize + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: stream + integer(c_int) :: res + res = hipblasSetMatrixAsync(rows, cols, elemSize, A, lda, B, ldb, stream) + end function hipblasSetMatrixAsyncFortran + + function hipblasGetMatrixAsyncFortran(rows, cols, elemSize, A, lda, B, ldb, stream) & + result(res) & + bind(c, name = 'hipblasGetMatrixAsyncFortran') + use iso_c_binding + implicit none + integer(c_int), value :: rows + integer(c_int), value :: cols + integer(c_int), value :: elemSize + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: stream + integer(c_int) :: res + res = hipblasGetMatrixAsync(rows, cols, elemSize, A, lda, B, ldb, stream) + end function hipblasGetMatrixAsyncFortran + !--------! ! blas 1 ! !--------! diff --git a/clients/include/hipblas_fortran.hpp b/clients/include/hipblas_fortran.hpp index e1c02e770..50bf6f0c5 100644 --- a/clients/include/hipblas_fortran.hpp +++ b/clients/include/hipblas_fortran.hpp @@ -12,6 +12,33 @@ */ extern "C" { +/* ========== + * Aux + * ========== */ +hipblasStatus_t + hipblasSetVectorFortran(int n, int elemSize, const void* x, int incx, void* y, int incy); + +hipblasStatus_t + hipblasGetVectorFortran(int n, int elemSize, const void* x, int incx, void* y, int incy); + +hipblasStatus_t hipblasSetMatrixFortran( + int rows, int cols, int elemSize, const void* A, int lda, void* B, int ldb); + +hipblasStatus_t hipblasGetMatrixFortran( + int rows, int cols, int elemSize, const void* A, int lda, void* B, int ldb); + +hipblasStatus_t hipblasSetVectorAsyncFortran( + int n, int elemSize, const void* x, int incx, void* y, int incy, hipStream_t stream); + +hipblasStatus_t hipblasGetVectorAsyncFortran( + int n, int elemSize, const void* x, int incx, void* y, int incy, hipStream_t stream); + +hipblasStatus_t hipblasSetMatrixAsyncFortran( + int rows, int cols, int elemSize, const void* A, int lda, void* B, int ldb, hipStream_t stream); + +hipblasStatus_t hipblasGetMatrixAsyncFortran( + int rows, int cols, int elemSize, const void* A, int lda, void* B, int ldb, hipStream_t stream); + /* ========== * L1 * ========== */ @@ -6444,378 +6471,365 @@ hipblasStatus_t hipblasTrsmStridedBatchedExFortran(hipblasHandle_t handle, * ========== */ // getrf -hipblasStatus_t hipblasSgetrfFortran(hipblasHandle_t handle, - const int n, - float* A, - const int lda, - int* ipiv, - int* info); - -hipblasStatus_t hipblasDgetrfFortran(hipblasHandle_t handle, - const int n, - double* A, - const int lda, - int* ipiv, - int* info); - -hipblasStatus_t hipblasCgetrfFortran(hipblasHandle_t handle, - const int n, - hipblasComplex* A, - const int lda, - int* ipiv, - int* info); - -hipblasStatus_t hipblasZgetrfFortran(hipblasHandle_t handle, - const int n, - hipblasDoubleComplex* A, - const int lda, - int* ipiv, - int* info); +hipblasStatus_t hipblasSgetrfFortran( + hipblasHandle_t handle, const int n, float* A, const int lda, int* ipiv, int* info); + +hipblasStatus_t hipblasDgetrfFortran( + hipblasHandle_t handle, const int n, double* A, const int lda, int* ipiv, int* info); + +hipblasStatus_t hipblasCgetrfFortran( + hipblasHandle_t handle, const int n, hipblasComplex* A, const int lda, int* ipiv, int* info); + +hipblasStatus_t hipblasZgetrfFortran(hipblasHandle_t handle, + const int n, + hipblasDoubleComplex* A, + const int lda, + int* ipiv, + int* info); // getrf_batched hipblasStatus_t hipblasSgetrfBatchedFortran(hipblasHandle_t handle, - const int n, - float* const A[], - const int lda, - int* ipiv, - int* info, - const int batch_count); + const int n, + float* const A[], + const int lda, + int* ipiv, + int* info, + const int batch_count); hipblasStatus_t hipblasDgetrfBatchedFortran(hipblasHandle_t handle, - const int n, - double* const A[], - const int lda, - int* ipiv, - int* info, - const int batch_count); - -hipblasStatus_t hipblasCgetrfBatchedFortran(hipblasHandle_t handle, - const int n, - hipblasComplex* const A[], - const int lda, - int* ipiv, - int* info, - const int batch_count); - -hipblasStatus_t hipblasZgetrfBatchedFortran(hipblasHandle_t handle, - const int n, - hipblasDoubleComplex* const A[], - const int lda, - int* ipiv, - int* info, - const int batch_count); + const int n, + double* const A[], + const int lda, + int* ipiv, + int* info, + const int batch_count); + +hipblasStatus_t hipblasCgetrfBatchedFortran(hipblasHandle_t handle, + const int n, + hipblasComplex* const A[], + const int lda, + int* ipiv, + int* info, + const int batch_count); + +hipblasStatus_t hipblasZgetrfBatchedFortran(hipblasHandle_t handle, + const int n, + hipblasDoubleComplex* const A[], + const int lda, + int* ipiv, + int* info, + const int batch_count); // getrf_strided_batched hipblasStatus_t hipblasSgetrfStridedBatchedFortran(hipblasHandle_t handle, - const int n, - float* A, - const int lda, - const int stride_A, - int* ipiv, - const int stride_P, - int* info, - const int batch_count); + const int n, + float* A, + const int lda, + const int stride_A, + int* ipiv, + const int stride_P, + int* info, + const int batch_count); hipblasStatus_t hipblasDgetrfStridedBatchedFortran(hipblasHandle_t handle, - const int n, - double* A, - const int lda, - const int stride_A, - int* ipiv, - const int stride_P, - int* info, - const int batch_count); + const int n, + double* A, + const int lda, + const int stride_A, + int* ipiv, + const int stride_P, + int* info, + const int batch_count); hipblasStatus_t hipblasCgetrfStridedBatchedFortran(hipblasHandle_t handle, - const int n, - hipblasComplex* A, - const int lda, - const int stride_A, - int* ipiv, - const int stride_P, - int* info, - const int batch_count); - -hipblasStatus_t hipblasZgetrfStridedBatchedFortran(hipblasHandle_t handle, - const int n, - hipblasDoubleComplex* A, - const int lda, - const int stride_A, - int* ipiv, - const int stride_P, - int* info, - const int batch_count); + const int n, + hipblasComplex* A, + const int lda, + const int stride_A, + int* ipiv, + const int stride_P, + int* info, + const int batch_count); + +hipblasStatus_t hipblasZgetrfStridedBatchedFortran(hipblasHandle_t handle, + const int n, + hipblasDoubleComplex* A, + const int lda, + const int stride_A, + int* ipiv, + const int stride_P, + int* info, + const int batch_count); // getrs -hipblasStatus_t hipblasSgetrsFortran(hipblasHandle_t handle, +hipblasStatus_t hipblasSgetrsFortran(hipblasHandle_t handle, const hipblasOperation_t trans, - const int n, - const int nrhs, - float* A, - const int lda, - const int* ipiv, - float* B, - const int ldb, - int* info); - -hipblasStatus_t hipblasDgetrsFortran(hipblasHandle_t handle, + const int n, + const int nrhs, + float* A, + const int lda, + const int* ipiv, + float* B, + const int ldb, + int* info); + +hipblasStatus_t hipblasDgetrsFortran(hipblasHandle_t handle, const hipblasOperation_t trans, - const int n, - const int nrhs, - double* A, - const int lda, - const int* ipiv, - double* B, - const int ldb, - int* info); - -hipblasStatus_t hipblasCgetrsFortran(hipblasHandle_t handle, + const int n, + const int nrhs, + double* A, + const int lda, + const int* ipiv, + double* B, + const int ldb, + int* info); + +hipblasStatus_t hipblasCgetrsFortran(hipblasHandle_t handle, const hipblasOperation_t trans, - const int n, - const int nrhs, - hipblasComplex* A, - const int lda, - const int* ipiv, - hipblasComplex* B, - const int ldb, - int* info); - -hipblasStatus_t hipblasZgetrsFortran(hipblasHandle_t handle, + const int n, + const int nrhs, + hipblasComplex* A, + const int lda, + const int* ipiv, + hipblasComplex* B, + const int ldb, + int* info); + +hipblasStatus_t hipblasZgetrsFortran(hipblasHandle_t handle, const hipblasOperation_t trans, - const int n, - const int nrhs, - hipblasDoubleComplex* A, - const int lda, - const int* ipiv, - hipblasDoubleComplex* B, - const int ldb, - int* info); + const int n, + const int nrhs, + hipblasDoubleComplex* A, + const int lda, + const int* ipiv, + hipblasDoubleComplex* B, + const int ldb, + int* info); // getrs_batched -hipblasStatus_t hipblasSgetrsBatchedFortran(hipblasHandle_t handle, - const hipblasOperation_t trans, - const int n, - const int nrhs, - float* const A[], - const int lda, - const int* ipiv, - float* const B[], - const int ldb, - int* info, - const int batch_count); - -hipblasStatus_t hipblasDgetrsBatchedFortran(hipblasHandle_t handle, - const hipblasOperation_t trans, - const int n, - const int nrhs, - double* const A[], - const int lda, - const int* ipiv, - double* const B[], - const int ldb, - int* info, - const int batch_count); - -hipblasStatus_t hipblasCgetrsBatchedFortran(hipblasHandle_t handle, - const hipblasOperation_t trans, - const int n, - const int nrhs, - hipblasComplex* const A[], - const int lda, - const int* ipiv, - hipblasComplex* const B[], - const int ldb, - int* info, - const int batch_count); - -hipblasStatus_t hipblasZgetrsBatchedFortran(hipblasHandle_t handle, - const hipblasOperation_t trans, - const int n, - const int nrhs, - hipblasDoubleComplex* const A[], - const int lda, - const int* ipiv, - hipblasDoubleComplex* const B[], - const int ldb, - int* info, - const int batch_count); +hipblasStatus_t hipblasSgetrsBatchedFortran(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + float* const A[], + const int lda, + const int* ipiv, + float* const B[], + const int ldb, + int* info, + const int batch_count); + +hipblasStatus_t hipblasDgetrsBatchedFortran(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + double* const A[], + const int lda, + const int* ipiv, + double* const B[], + const int ldb, + int* info, + const int batch_count); + +hipblasStatus_t hipblasCgetrsBatchedFortran(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + hipblasComplex* const A[], + const int lda, + const int* ipiv, + hipblasComplex* const B[], + const int ldb, + int* info, + const int batch_count); + +hipblasStatus_t hipblasZgetrsBatchedFortran(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + hipblasDoubleComplex* const A[], + const int lda, + const int* ipiv, + hipblasDoubleComplex* const B[], + const int ldb, + int* info, + const int batch_count); // getrs_strided_batched -hipblasStatus_t hipblasSgetrsStridedBatchedFortran(hipblasHandle_t handle, - const hipblasOperation_t trans, - const int n, - const int nrhs, - float* A, - const int lda, - const int stride_A, - const int* ipiv, - const int stride_P, - float* B, - const int ldb, - const int stride_B, - int* info, - const int batch_count); - -hipblasStatus_t hipblasDgetrsStridedBatchedFortran(hipblasHandle_t handle, - const hipblasOperation_t trans, - const int n, - const int nrhs, - double* A, - const int lda, - const int stride_A, - const int* ipiv, - const int stride_P, - double* B, - const int ldb, - const int stride_B, - int* info, - const int batch_count); - -hipblasStatus_t hipblasCgetrsStridedBatchedFortran(hipblasHandle_t handle, - const hipblasOperation_t trans, - const int n, - const int nrhs, - hipblasComplex* A, - const int lda, - const int stride_A, - const int* ipiv, - const int stride_P, - hipblasComplex* B, - const int ldb, - const int stride_B, - int* info, - const int batch_count); - -hipblasStatus_t hipblasZgetrsStridedBatchedFortran(hipblasHandle_t handle, - const hipblasOperation_t trans, - const int n, - const int nrhs, - hipblasDoubleComplex* A, - const int lda, - const int stride_A, - const int* ipiv, - const int stride_P, - hipblasDoubleComplex* B, - const int ldb, - const int stride_B, - int* info, - const int batch_count); +hipblasStatus_t hipblasSgetrsStridedBatchedFortran(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + float* A, + const int lda, + const int stride_A, + const int* ipiv, + const int stride_P, + float* B, + const int ldb, + const int stride_B, + int* info, + const int batch_count); + +hipblasStatus_t hipblasDgetrsStridedBatchedFortran(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + double* A, + const int lda, + const int stride_A, + const int* ipiv, + const int stride_P, + double* B, + const int ldb, + const int stride_B, + int* info, + const int batch_count); + +hipblasStatus_t hipblasCgetrsStridedBatchedFortran(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + hipblasComplex* A, + const int lda, + const int stride_A, + const int* ipiv, + const int stride_P, + hipblasComplex* B, + const int ldb, + const int stride_B, + int* info, + const int batch_count); + +hipblasStatus_t hipblasZgetrsStridedBatchedFortran(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + hipblasDoubleComplex* A, + const int lda, + const int stride_A, + const int* ipiv, + const int stride_P, + hipblasDoubleComplex* B, + const int ldb, + const int stride_B, + int* info, + const int batch_count); // geqrf hipblasStatus_t hipblasSgeqrfFortran(hipblasHandle_t handle, - const int m, - const int n, - float* A, - const int lda, - float* tau, - int* info); + const int m, + const int n, + float* A, + const int lda, + float* tau, + int* info); hipblasStatus_t hipblasDgeqrfFortran(hipblasHandle_t handle, - const int m, - const int n, - double* A, - const int lda, - double* tau, - int* info); + const int m, + const int n, + double* A, + const int lda, + double* tau, + int* info); hipblasStatus_t hipblasCgeqrfFortran(hipblasHandle_t handle, - const int m, - const int n, + const int m, + const int n, hipblasComplex* A, - const int lda, + const int lda, hipblasComplex* tau, - int* info); + int* info); -hipblasStatus_t hipblasZgeqrfFortran(hipblasHandle_t handle, - const int m, - const int n, +hipblasStatus_t hipblasZgeqrfFortran(hipblasHandle_t handle, + const int m, + const int n, hipblasDoubleComplex* A, - const int lda, + const int lda, hipblasDoubleComplex* tau, - int* info); + int* info); // geqrf_batched hipblasStatus_t hipblasSgeqrfBatchedFortran(hipblasHandle_t handle, - const int m, - const int n, - float* const A[], - const int lda, - float* const tau[], - int* info, - const int batch_count); + const int m, + const int n, + float* const A[], + const int lda, + float* const tau[], + int* info, + const int batch_count); hipblasStatus_t hipblasDgeqrfBatchedFortran(hipblasHandle_t handle, - const int m, - const int n, - double* const A[], - const int lda, - double* const tau[], - int* info, - const int batch_count); - -hipblasStatus_t hipblasCgeqrfBatchedFortran(hipblasHandle_t handle, - const int m, - const int n, - hipblasComplex* const A[], - const int lda, - hipblasComplex* const tau[], - int* info, - const int batch_count); - -hipblasStatus_t hipblasZgeqrfBatchedFortran(hipblasHandle_t handle, - const int m, - const int n, - hipblasDoubleComplex* const A[], - const int lda, - hipblasDoubleComplex* const tau[], - int* info, - const int batch_count); + const int m, + const int n, + double* const A[], + const int lda, + double* const tau[], + int* info, + const int batch_count); + +hipblasStatus_t hipblasCgeqrfBatchedFortran(hipblasHandle_t handle, + const int m, + const int n, + hipblasComplex* const A[], + const int lda, + hipblasComplex* const tau[], + int* info, + const int batch_count); + +hipblasStatus_t hipblasZgeqrfBatchedFortran(hipblasHandle_t handle, + const int m, + const int n, + hipblasDoubleComplex* const A[], + const int lda, + hipblasDoubleComplex* const tau[], + int* info, + const int batch_count); // geqrf_strided_batched hipblasStatus_t hipblasSgeqrfStridedBatchedFortran(hipblasHandle_t handle, - const int m, - const int n, - float* A, - const int lda, - const int stride_A, - float* tau, - const int stride_T, - int* info, - const int batch_count); + const int m, + const int n, + float* A, + const int lda, + const int stride_A, + float* tau, + const int stride_T, + int* info, + const int batch_count); hipblasStatus_t hipblasDgeqrfStridedBatchedFortran(hipblasHandle_t handle, - const int m, - const int n, - double* A, - const int lda, - const int stride_A, - double* tau, - const int stride_T, - int* info, - const int batch_count); + const int m, + const int n, + double* A, + const int lda, + const int stride_A, + double* tau, + const int stride_T, + int* info, + const int batch_count); hipblasStatus_t hipblasCgeqrfStridedBatchedFortran(hipblasHandle_t handle, - const int m, - const int n, - hipblasComplex* A, - const int lda, - const int stride_A, - hipblasComplex* tau, - const int stride_T, - int* info, - const int batch_count); - -hipblasStatus_t hipblasZgeqrfStridedBatchedFortran(hipblasHandle_t handle, - const int m, - const int n, - hipblasDoubleComplex* A, - const int lda, - const int stride_A, - hipblasDoubleComplex* tau, - const int stride_T, - int* info, - const int batch_count); - + const int m, + const int n, + hipblasComplex* A, + const int lda, + const int stride_A, + hipblasComplex* tau, + const int stride_T, + int* info, + const int batch_count); + +hipblasStatus_t hipblasZgeqrfStridedBatchedFortran(hipblasHandle_t handle, + const int m, + const int n, + hipblasDoubleComplex* A, + const int lda, + const int stride_A, + hipblasDoubleComplex* tau, + const int stride_T, + int* info, + const int batch_count); } #endif diff --git a/clients/include/testing_set_get_matrix.hpp b/clients/include/testing_set_get_matrix.hpp index f7c07fff4..1e5c86aaf 100644 --- a/clients/include/testing_set_get_matrix.hpp +++ b/clients/include/testing_set_get_matrix.hpp @@ -11,6 +11,7 @@ #include "cblas_interface.h" #include "flops.h" #include "hipblas.hpp" +#include "hipblas_fortran.hpp" #include "norm.h" #include "unit.h" #include "utility.h" @@ -22,6 +23,10 @@ using namespace std; template hipblasStatus_t testing_set_get_matrix(Arguments argus) { + bool FORTRAN = argus.fortran; + auto hipblasSetMatrixFn = FORTRAN ? hipblasSetMatrixFortran : hipblasSetMatrix; + auto hipblasGetMatrixFn = FORTRAN ? hipblasGetMatrixFortran : hipblasGetMatrix; + int rows = argus.rows; int cols = argus.cols; int lda = argus.lda; @@ -61,12 +66,12 @@ hipblasStatus_t testing_set_get_matrix(Arguments argus) } // Naming: dK is in GPU (device) memory. hK is in CPU (host) memory - vector ha(cols * lda); - vector hb(cols * ldb); - vector hb_ref(cols * ldb); - vector hc(cols * ldc); + host_vector ha(cols * lda); + host_vector hb(cols * ldb); + host_vector hb_ref(cols * ldb); + host_vector hc(cols * ldc); - T* dc; + device_vector dc(cols * ldc); double gpu_time_used, cpu_time_used; double hipblasBandwidth, cpu_bandwidth; @@ -76,9 +81,6 @@ hipblasStatus_t testing_set_get_matrix(Arguments argus) hipblasCreate(&handle); - // allocate memory on device - CHECK_HIP_ERROR(hipMalloc(&dc, cols * ldc * sizeof(T))); - // Initial Data on CPU srand(1); hipblas_init(ha, rows, cols, lda); @@ -98,13 +100,18 @@ hipblasStatus_t testing_set_get_matrix(Arguments argus) ROCBLAS =================================================================== */ - status_set = hipblasSetMatrix(rows, cols, sizeof(T), (void*)ha.data(), lda, (void*)dc, ldc); - status_get = hipblasGetMatrix(rows, cols, sizeof(T), (void*)dc, ldc, (void*)hb.data(), ldb); - if(status_set != HIPBLAS_STATUS_SUCCESS || status_get != HIPBLAS_STATUS_SUCCESS) + status_set = hipblasSetMatrixFn(rows, cols, sizeof(T), (void*)ha.data(), lda, (void*)dc, ldc); + status_get = hipblasGetMatrixFn(rows, cols, sizeof(T), (void*)dc, ldc, (void*)hb.data(), ldb); + if(status_set != HIPBLAS_STATUS_SUCCESS) { - CHECK_HIP_ERROR(hipFree(dc)); hipblasDestroy(handle); - return status; + return status_set; + } + + if(status_get != HIPBLAS_STATUS_SUCCESS) + { + hipblasDestroy(handle); + return status_get; } if(argus.unit_check) @@ -130,7 +137,6 @@ hipblasStatus_t testing_set_get_matrix(Arguments argus) } } - CHECK_HIP_ERROR(hipFree(dc)); hipblasDestroy(handle); return HIPBLAS_STATUS_SUCCESS; } diff --git a/clients/include/testing_set_get_matrix_async.hpp b/clients/include/testing_set_get_matrix_async.hpp new file mode 100644 index 000000000..747f35c18 --- /dev/null +++ b/clients/include/testing_set_get_matrix_async.hpp @@ -0,0 +1,129 @@ +/* ************************************************************************ + * Copyright 2016-2020 Advanced Micro Devices, Inc. + * + * ************************************************************************ */ + +#include +#include +#include +#include + +#include "cblas_interface.h" +#include "flops.h" +#include "hipblas.hpp" +#include "hipblas_fortran.hpp" +#include "norm.h" +#include "unit.h" +#include "utility.h" + +using namespace std; + +/* ============================================================================================ */ + +template +hipblasStatus_t testing_set_get_matrix_async(Arguments argus) +{ + bool FORTRAN = argus.fortran; + auto hipblasSetMatrixAsyncFn = FORTRAN ? hipblasSetMatrixAsyncFortran : hipblasSetMatrixAsync; + auto hipblasGetMatrixAsyncFn = FORTRAN ? hipblasGetMatrixAsyncFortran : hipblasGetMatrixAsync; + + int rows = argus.rows; + int cols = argus.cols; + int lda = argus.lda; + int ldb = argus.ldb; + int ldc = argus.ldc; + + hipblasStatus_t status = HIPBLAS_STATUS_SUCCESS; + hipblasStatus_t status_set = HIPBLAS_STATUS_SUCCESS; + hipblasStatus_t status_get = HIPBLAS_STATUS_SUCCESS; + + // argument sanity check, quick return if input parameters are invalid before allocating invalid + // memory + if(rows < 0 || cols < 0 || lda <= 0 || ldb <= 0 || ldc <= 0) + { + status = HIPBLAS_STATUS_INVALID_VALUE; + return status; + } + + // Naming: dK is in GPU (device) memory. hK is in CPU (host) memory + host_vector ha(cols * lda); + host_vector hb(cols * ldb); + host_vector hb_ref(cols * ldb); + host_vector hc(cols * ldc); + + device_vector dc(cols * ldc); + + double gpu_time_used, cpu_time_used; + double hipblasBandwidth, cpu_bandwidth; + double rocblas_error = 0.0; + + hipblasHandle_t handle; + hipblasCreate(&handle); + + hipStream_t stream; + hipblasGetStream(handle, &stream); + + // Initial Data on CPU + srand(1); + hipblas_init(ha, rows, cols, lda); + hipblas_init(hb, rows, cols, ldb); + hb_ref = hb; + for(int i = 0; i < cols * ldc; i++) + { + hc[i] = 100 + i; + }; + CHECK_HIP_ERROR(hipMemcpy(dc, hc.data(), sizeof(T) * ldc * cols, hipMemcpyHostToDevice)); + for(int i = 0; i < cols * ldc; i++) + { + hc[i] = 99.0; + }; + + /* ===================================================================== + ROCBLAS + =================================================================== */ + + status_set = hipblasSetMatrixAsyncFn( + rows, cols, sizeof(T), (void*)ha.data(), lda, (void*)dc, ldc, stream); + status_get = hipblasGetMatrixAsyncFn( + rows, cols, sizeof(T), (void*)dc, ldc, (void*)hb.data(), ldb, stream); + + hipStreamSynchronize(stream); + + if(status_set != HIPBLAS_STATUS_SUCCESS) + { + hipblasDestroy(handle); + return status_set; + } + + if(status_get != HIPBLAS_STATUS_SUCCESS) + { + hipblasDestroy(handle); + return status_get; + } + + if(argus.unit_check) + { + /* ===================================================================== + CPU BLAS + =================================================================== */ + + // reference calculation + for(int i1 = 0; i1 < rows; i1++) + { + for(int i2 = 0; i2 < cols; i2++) + { + hb_ref[i1 + i2 * ldb] = ha[i1 + i2 * lda]; + } + } + + // enable unit check, notice unit check is not invasive, but norm check is, + // unit check and norm check can not be interchanged their order + if(argus.unit_check) + { + unit_check_general(rows, cols, ldb, hb.data(), hb_ref.data()); + } + } + + hipblasDestroy(handle); + return HIPBLAS_STATUS_SUCCESS; +} diff --git a/clients/include/testing_set_get_vector.hpp b/clients/include/testing_set_get_vector.hpp index ba81bad3a..35fab5388 100644 --- a/clients/include/testing_set_get_vector.hpp +++ b/clients/include/testing_set_get_vector.hpp @@ -10,6 +10,7 @@ #include "cblas_interface.h" #include "hipblas.hpp" +#include "hipblas_fortran.hpp" #include "norm.h" #include "unit.h" #include "utility.h" @@ -21,6 +22,10 @@ using namespace std; template hipblasStatus_t testing_set_get_vector(Arguments argus) { + bool FORTRAN = argus.fortran; + auto hipblasSetVectorFn = FORTRAN ? hipblasSetVectorFortran : hipblasSetVector; + auto hipblasGetVectorFn = FORTRAN ? hipblasGetVectorFortran : hipblasGetVector; + int M = argus.M; int incx = argus.incx; int incy = argus.incy; @@ -54,19 +59,16 @@ hipblasStatus_t testing_set_get_vector(Arguments argus) } // Naming: dK is in GPU (device) memory. hK is in CPU (host) memory - vector hx(M * incx); - vector hy(M * incy); - vector hy_ref(M * incy); + host_vector hx(M * incx); + host_vector hy(M * incy); + host_vector hy_ref(M * incy); - T* db; + device_vector db(M * incd); hipblasHandle_t handle; hipblasCreate(&handle); - // allocate memory on device - CHECK_HIP_ERROR(hipMalloc(&db, M * incd * sizeof(T))); - // Initial Data on CPU srand(1); hipblas_init(hx, 1, M, incx); @@ -76,22 +78,20 @@ hipblasStatus_t testing_set_get_vector(Arguments argus) /* ===================================================================== ROCBLAS =================================================================== */ - status_set = hipblasSetVector(M, sizeof(T), (void*)hx.data(), incx, (void*)db, incd); + status_set = hipblasSetVectorFn(M, sizeof(T), (void*)hx.data(), incx, (void*)db, incd); - status_get = hipblasGetVector(M, sizeof(T), (void*)db, incd, (void*)hy.data(), incy); + status_get = hipblasGetVectorFn(M, sizeof(T), (void*)db, incd, (void*)hy.data(), incy); if(status_set != HIPBLAS_STATUS_SUCCESS) { - CHECK_HIP_ERROR(hipFree(db)); hipblasDestroy(handle); - return status; + return status_set; } if(status_get != HIPBLAS_STATUS_SUCCESS) { - CHECK_HIP_ERROR(hipFree(db)); hipblasDestroy(handle); - return status; + return status_get; } if(argus.unit_check) @@ -114,7 +114,6 @@ hipblasStatus_t testing_set_get_vector(Arguments argus) } } - CHECK_HIP_ERROR(hipFree(db)); hipblasDestroy(handle); return HIPBLAS_STATUS_SUCCESS; } diff --git a/clients/include/testing_set_get_vector_async.hpp b/clients/include/testing_set_get_vector_async.hpp new file mode 100644 index 000000000..79b4bc7cb --- /dev/null +++ b/clients/include/testing_set_get_vector_async.hpp @@ -0,0 +1,109 @@ +/* ************************************************************************ + * Copyright 2016-2020 Advanced Micro Devices, Inc. + * + * ************************************************************************ */ + +#include +#include +#include +#include + +#include "cblas_interface.h" +#include "hipblas.hpp" +#include "hipblas_fortran.hpp" +#include "norm.h" +#include "unit.h" +#include "utility.h" + +using namespace std; + +/* ============================================================================================ */ + +template +hipblasStatus_t testing_set_get_vector_async(Arguments argus) +{ + bool FORTRAN = argus.fortran; + auto hipblasSetVectorAsyncFn = FORTRAN ? hipblasSetVectorAsyncFortran : hipblasSetVectorAsync; + auto hipblasGetVectorAsyncFn = FORTRAN ? hipblasGetVectorAsyncFortran : hipblasGetVectorAsync; + + int M = argus.M; + int incx = argus.incx; + int incy = argus.incy; + int incd = argus.incd; + + hipblasStatus_t status = HIPBLAS_STATUS_SUCCESS; + hipblasStatus_t status_set = HIPBLAS_STATUS_SUCCESS; + hipblasStatus_t status_get = HIPBLAS_STATUS_SUCCESS; + + // argument sanity check, quick return if input parameters are invalid before allocating invalid + // memory + if(M < 0 || incx <= 0 || incy <= 0 || incd <= 0) + { + status = HIPBLAS_STATUS_INVALID_VALUE; + return status; + } + + // Naming: dK is in GPU (device) memory. hK is in CPU (host) memory + host_vector hx(M * incx); + host_vector hy(M * incy); + host_vector hy_ref(M * incy); + + device_vector db(M * incd); + + hipblasHandle_t handle; + hipblasCreate(&handle); + + hipStream_t stream; + hipblasGetStream(handle, &stream); + + // Initial Data on CPU + srand(1); + hipblas_init(hx, 1, M, incx); + hipblas_init(hy, 1, M, incy); + hy_ref = hy; + + /* ===================================================================== + ROCBLAS + =================================================================== */ + status_set + = hipblasSetVectorAsyncFn(M, sizeof(T), (void*)hx.data(), incx, (void*)db, incd, stream); + status_get + = hipblasGetVectorAsyncFn(M, sizeof(T), (void*)db, incd, (void*)hy.data(), incy, stream); + + hipStreamSynchronize(stream); + + if(status_set != HIPBLAS_STATUS_SUCCESS) + { + hipblasDestroy(handle); + return status_set; + } + + if(status_get != HIPBLAS_STATUS_SUCCESS) + { + hipblasDestroy(handle); + return status_get; + } + + if(argus.unit_check) + { + /* ===================================================================== + CPU BLAS + =================================================================== */ + + // reference calculation + for(int i = 0; i < M; i++) + { + hy_ref[i * incy] = hx[i * incx]; + } + + // enable unit check, notice unit check is not invasive, but norm check is, + // unit check and norm check can not be interchanged their order + if(argus.unit_check) + { + unit_check_general(1, M, incy, hy.data(), hy_ref.data()); + } + } + + hipblasDestroy(handle); + return HIPBLAS_STATUS_SUCCESS; +} diff --git a/library/include/hipblas.h b/library/include/hipblas.h index c2eb8fe46..91ac6053f 100644 --- a/library/include/hipblas.h +++ b/library/include/hipblas.h @@ -256,6 +256,30 @@ HIPBLAS_EXPORT hipblasStatus_t HIPBLAS_EXPORT hipblasStatus_t hipblasGetMatrix(int rows, int cols, int elemSize, const void* A, int lda, void* B, int ldb); +HIPBLAS_EXPORT hipblasStatus_t hipblasSetVectorAsync( + int n, int elem_size, const void* x, int incx, void* y, int incy, hipStream_t stream); + +HIPBLAS_EXPORT hipblasStatus_t hipblasGetVectorAsync( + int n, int elem_size, const void* x, int incx, void* y, int incy, hipStream_t stream); + +HIPBLAS_EXPORT hipblasStatus_t hipblasSetMatrixAsync(int rows, + int cols, + int elem_size, + const void* A, + int lda, + void* B, + int ldb, + hipStream_t stream); + +HIPBLAS_EXPORT hipblasStatus_t hipblasGetMatrixAsync(int rows, + int cols, + int elem_size, + const void* A, + int lda, + void* B, + int ldb, + hipStream_t stream); + // amax HIPBLAS_EXPORT hipblasStatus_t hipblasIsamax(hipblasHandle_t handle, int n, const float* x, int incx, int* result); diff --git a/library/src/hcc_detail/hipblas.cpp b/library/src/hcc_detail/hipblas.cpp index 3bc39d908..3bc28e7da 100644 --- a/library/src/hcc_detail/hipblas.cpp +++ b/library/src/hcc_detail/hipblas.cpp @@ -339,6 +339,34 @@ hipblasStatus_t return rocBLASStatusToHIPStatus(rocblas_get_matrix(rows, cols, elemSize, A, lda, B, ldb)); } +hipblasStatus_t hipblasSetVectorAsync( + int n, int elemSize, const void* x, int incx, void* y, int incy, hipStream_t stream) +{ + return rocBLASStatusToHIPStatus( + rocblas_set_vector_async(n, elemSize, x, incx, y, incy, stream)); +} + +hipblasStatus_t hipblasGetVectorAsync( + int n, int elemSize, const void* x, int incx, void* y, int incy, hipStream_t stream) +{ + return rocBLASStatusToHIPStatus( + rocblas_get_vector_async(n, elemSize, x, incx, y, incy, stream)); +} + +hipblasStatus_t hipblasSetMatrixAsync( + int rows, int cols, int elemSize, const void* A, int lda, void* B, int ldb, hipStream_t stream) +{ + return rocBLASStatusToHIPStatus( + rocblas_set_matrix_async(rows, cols, elemSize, A, lda, B, ldb, stream)); +} + +hipblasStatus_t hipblasGetMatrixAsync( + int rows, int cols, int elemSize, const void* A, int lda, void* B, int ldb, hipStream_t stream) +{ + return rocBLASStatusToHIPStatus( + rocblas_get_matrix_async(rows, cols, elemSize, A, lda, B, ldb, stream)); +} + // amax hipblasStatus_t hipblasIsamax(hipblasHandle_t handle, int n, const float* x, int incx, int* result) { diff --git a/library/src/hipblas_module.f90 b/library/src/hipblas_module.f90 index 962aaf7f1..fc3bc1531 100644 --- a/library/src/hipblas_module.f90 +++ b/library/src/hipblas_module.f90 @@ -205,7 +205,73 @@ function hipblasGetMatrix(rows, cols, elemSize, A, lda, B, ldb) & integer(c_int), value :: ldb end function hipblasGetMatrix end interface - + + interface + function hipblasSetVectorAsync(n, elemSize, x, incx, y, incy, stream) & + result(c_int) & + bind(c, name = 'hipblasSetVectorAsync') + use iso_c_binding + implicit none + integer(c_int), value :: n + integer(c_int), value :: elemSize + type(c_ptr), value :: x + integer(c_int), value :: incx + type(c_ptr), value :: y + integer(c_int), value :: incy + type(c_ptr), value :: stream + end function hipblasSetVectorAsync + end interface + + interface + function hipblasGetVectorAsync(n, elemSize, x, incx, y, incy, stream) & + result(c_int) & + bind(c, name = 'hipblasGetVectorAsync') + use iso_c_binding + implicit none + integer(c_int), value :: n + integer(c_int), value :: elemSize + type(c_ptr), value :: x + integer(c_int), value :: incx + type(c_ptr), value :: y + integer(c_int), value :: incy + type(c_ptr), value :: stream + end function hipblasGetVectorAsync + end interface + + interface + function hipblasSetMatrixAsync(rows, cols, elemSize, A, lda, B, ldb, stream) & + result(c_int) & + bind(c, name = 'hipblasSetMatrixAsync') + use iso_c_binding + implicit none + integer(c_int), value :: rows + integer(c_int), value :: cols + integer(c_int), value :: elemSize + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: stream + end function hipblasSetMatrixAsync + end interface + + interface + function hipblasGetMatrixAsync(rows, cols, elemSize, A, lda, B, ldb, stream) & + result(c_int) & + bind(c, name = 'hipblasGetMatrixAsync') + use iso_c_binding + implicit none + integer(c_int), value :: rows + integer(c_int), value :: cols + integer(c_int), value :: elemSize + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: stream + end function hipblasGetMatrixAsync + end interface + !--------! ! blas 1 ! !--------! diff --git a/library/src/nvcc_detail/hipblas.cpp b/library/src/nvcc_detail/hipblas.cpp index 29a41d6f4..3873509a7 100644 --- a/library/src/nvcc_detail/hipblas.cpp +++ b/library/src/nvcc_detail/hipblas.cpp @@ -286,6 +286,32 @@ hipblasStatus_t return hipCUBLASStatusToHIPStatus(cublasGetMatrix(rows, cols, elemSize, A, lda, B, ldb)); } +hipblasStatus_t hipblasSetVectorAsync( + int n, int elemSize, const void* x, int incx, void* y, int incy, hipStream_t stream) +{ + return hipCUBLASStatusToHIPStatus(cublasSetVectorAsync(n, elemSize, x, incx, y, incy, stream)); +} + +hipblasStatus_t hipblasGetVectorAsync( + int n, int elemSize, const void* x, int incx, void* y, int incy, hipStream_t stream) +{ + return hipCUBLASStatusToHIPStatus(cublasGetVectorAsync(n, elemSize, x, incx, y, incy, stream)); +} + +hipblasStatus_t hipblasSetMatrixAsync( + int rows, int cols, int elemSize, const void* A, int lda, void* B, int ldb, hipStream_t stream) +{ + return hipCUBLASStatusToHIPStatus( + cublasSetMatrixAsync(rows, cols, elemSize, A, lda, B, ldb, stream)); +} + +hipblasStatus_t hipblasGetMatrixAsync( + int rows, int cols, int elemSize, const void* A, int lda, void* B, int ldb, hipStream_t stream) +{ + return hipCUBLASStatusToHIPStatus( + cublasGetMatrixAsync(rows, cols, elemSize, A, lda, B, ldb, stream)); +} + // amax hipblasStatus_t hipblasIsamax(hipblasHandle_t handle, int n, const float* x, int incx, int* result) { From 7f3a9edbed3ed3341772ccce8682a6a7c516c159 Mon Sep 17 00:00:00 2001 From: tfalders <58866654+tfalders@users.noreply.github.com> Date: Tue, 23 Jun 2020 16:51:26 -0600 Subject: [PATCH 8/9] rocSOLVER support 3 (#235) * Added getri * Added template specializations for getri * Added test cases for getri * Removed getri and getri_strided_batched --- clients/common/cblas_interface.cpp | 47 ++ .../hipblas_template_specialization.cpp | 677 +++++++++++------- clients/gtest/CMakeLists.txt | 1 + clients/gtest/getri_batched_gtest.cpp | 157 ++++ clients/include/cblas_interface.h | 3 + clients/include/hipblas.hpp | 12 + clients/include/hipblas_fortran.hpp | 41 ++ clients/include/hipblas_fortran_solver.f90 | 77 ++ clients/include/testing_getri_batched.hpp | 158 ++++ library/include/hipblas.h | 41 ++ library/src/hcc_detail/hipblas.cpp | 117 +++ library/src/hipblas_module.f90 | 77 ++ library/src/nvcc_detail/hipblas.cpp | 71 ++ 13 files changed, 1202 insertions(+), 277 deletions(-) create mode 100644 clients/gtest/getri_batched_gtest.cpp create mode 100644 clients/include/testing_getri_batched.hpp diff --git a/clients/common/cblas_interface.cpp b/clients/common/cblas_interface.cpp index 9a2fcdd4b..978433869 100644 --- a/clients/common/cblas_interface.cpp +++ b/clients/common/cblas_interface.cpp @@ -53,6 +53,18 @@ void zgetrs_(char* trans, int* ldb, int* info); +void sgetri_(int* n, float* A, int* lda, int* ipiv, float* work, int* lwork, int* info); +void dgetri_(int* n, double* A, int* lda, int* ipiv, double* work, int* lwork, int* info); +void cgetri_( + int* n, hipblasComplex* A, int* lda, int* ipiv, hipblasComplex* work, int* lwork, int* info); +void zgetri_(int* n, + hipblasDoubleComplex* A, + int* lda, + int* ipiv, + hipblasDoubleComplex* work, + int* lwork, + int* info); + void sgeqrf_(int* m, int* n, float* A, int* lda, float* tau, float* work, int* lwork, int* info); void dgeqrf_(int* m, int* n, double* A, int* lda, double* tau, double* work, int* lwork, int* info); void cgeqrf_(int* m, @@ -3163,6 +3175,41 @@ int cblas_getrs(char trans, return info; } +// getri +template <> +int cblas_getri(int n, float* A, int lda, int* ipiv, float* work, int lwork) +{ + int info; + sgetri_(&n, A, &lda, ipiv, work, &lwork, &info); + return info; +} + +template <> +int cblas_getri(int n, double* A, int lda, int* ipiv, double* work, int lwork) +{ + int info; + dgetri_(&n, A, &lda, ipiv, work, &lwork, &info); + return info; +} + +template <> +int cblas_getri( + int n, hipblasComplex* A, int lda, int* ipiv, hipblasComplex* work, int lwork) +{ + int info; + cgetri_(&n, A, &lda, ipiv, work, &lwork, &info); + return info; +} + +template <> +int cblas_getri( + int n, hipblasDoubleComplex* A, int lda, int* ipiv, hipblasDoubleComplex* work, int lwork) +{ + int info; + zgetri_(&n, A, &lda, ipiv, work, &lwork, &info); + return info; +} + // geqrf template <> int cblas_geqrf(int m, int n, float* A, int lda, float* tau, float* work, int lwork) diff --git a/clients/common/hipblas_template_specialization.cpp b/clients/common/hipblas_template_specialization.cpp index 0d59b9ae6..a0ff958ed 100644 --- a/clients/common/hipblas_template_specialization.cpp +++ b/clients/common/hipblas_template_specialization.cpp @@ -9861,6 +9861,63 @@ hipblasStatus_t hipblasGetrsStridedBatched(hipblasHandle_t handle, trans, n, nrhs, A, lda, strideA, ipiv, strideP, B, ldb, strideB, info, batchCount); } +// getri_batched +template <> +hipblasStatus_t hipblasGetriBatched(hipblasHandle_t handle, + const int n, + float* const A[], + const int lda, + int* ipiv, + float* const C[], + const int ldc, + int* info, + const int batchCount) +{ + return hipblasSgetriBatched(handle, n, A, lda, ipiv, C, ldc, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGetriBatched(hipblasHandle_t handle, + const int n, + double* const A[], + const int lda, + int* ipiv, + double* const C[], + const int ldc, + int* info, + const int batchCount) +{ + return hipblasDgetriBatched(handle, n, A, lda, ipiv, C, ldc, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGetriBatched(hipblasHandle_t handle, + const int n, + hipblasComplex* const A[], + const int lda, + int* ipiv, + hipblasComplex* const C[], + const int ldc, + int* info, + const int batchCount) +{ + return hipblasCgetriBatched(handle, n, A, lda, ipiv, C, ldc, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGetriBatched(hipblasHandle_t handle, + const int n, + hipblasDoubleComplex* const A[], + const int lda, + int* ipiv, + hipblasDoubleComplex* const C[], + const int ldc, + int* info, + const int batchCount) +{ + return hipblasZgetriBatched(handle, n, A, lda, ipiv, C, ldc, info, batchCount); +} + // geqrf template <> hipblasStatus_t hipblasGeqrf(hipblasHandle_t handle, @@ -19660,11 +19717,11 @@ hipblasStatus_t hipblasGetrf( template <> hipblasStatus_t hipblasGetrf(hipblasHandle_t handle, - const int n, - hipblasDoubleComplex* A, - const int lda, - int* ipiv, - int* info) + const int n, + hipblasDoubleComplex* A, + const int lda, + int* ipiv, + int* info) { return hipblasZgetrfFortran(handle, n, A, lda, ipiv, info); } @@ -19672,48 +19729,48 @@ hipblasStatus_t hipblasGetrf(hipblasHandle_t h // getrf_batched template <> hipblasStatus_t hipblasGetrfBatched(hipblasHandle_t handle, - const int n, - float* const A[], - const int lda, - int* ipiv, - int* info, - const int batchCount) + const int n, + float* const A[], + const int lda, + int* ipiv, + int* info, + const int batchCount) { return hipblasSgetrfBatchedFortran(handle, n, A, lda, ipiv, info, batchCount); } template <> hipblasStatus_t hipblasGetrfBatched(hipblasHandle_t handle, - const int n, - double* const A[], - const int lda, - int* ipiv, - int* info, - const int batchCount) + const int n, + double* const A[], + const int lda, + int* ipiv, + int* info, + const int batchCount) { return hipblasDgetrfBatchedFortran(handle, n, A, lda, ipiv, info, batchCount); } template <> hipblasStatus_t hipblasGetrfBatched(hipblasHandle_t handle, - const int n, - hipblasComplex* const A[], - const int lda, - int* ipiv, - int* info, - const int batchCount) + const int n, + hipblasComplex* const A[], + const int lda, + int* ipiv, + int* info, + const int batchCount) { return hipblasCgetrfBatchedFortran(handle, n, A, lda, ipiv, info, batchCount); } template <> hipblasStatus_t hipblasGetrfBatched(hipblasHandle_t handle, - const int n, - hipblasDoubleComplex* const A[], - const int lda, - int* ipiv, - int* info, - const int batchCount) + const int n, + hipblasDoubleComplex* const A[], + const int lda, + int* ipiv, + int* info, + const int batchCount) { return hipblasZgetrfBatchedFortran(handle, n, A, lda, ipiv, info, batchCount); } @@ -19721,117 +19778,121 @@ hipblasStatus_t hipblasGetrfBatched(hipblasHandle_t // getrf_strided_batched template <> hipblasStatus_t hipblasGetrfStridedBatched(hipblasHandle_t handle, - const int n, - float* A, - const int lda, - const int strideA, - int* ipiv, - const int strideP, - int* info, - const int batchCount) + const int n, + float* A, + const int lda, + const int strideA, + int* ipiv, + const int strideP, + int* info, + const int batchCount) { - return hipblasSgetrfStridedBatchedFortran(handle, n, A, lda, strideA, ipiv, strideP, info, batchCount); + return hipblasSgetrfStridedBatchedFortran( + handle, n, A, lda, strideA, ipiv, strideP, info, batchCount); } template <> hipblasStatus_t hipblasGetrfStridedBatched(hipblasHandle_t handle, - const int n, - double* A, - const int lda, - const int strideA, - int* ipiv, - const int strideP, - int* info, - const int batchCount) + const int n, + double* A, + const int lda, + const int strideA, + int* ipiv, + const int strideP, + int* info, + const int batchCount) { - return hipblasDgetrfStridedBatchedFortran(handle, n, A, lda, strideA, ipiv, strideP, info, batchCount); + return hipblasDgetrfStridedBatchedFortran( + handle, n, A, lda, strideA, ipiv, strideP, info, batchCount); } template <> hipblasStatus_t hipblasGetrfStridedBatched(hipblasHandle_t handle, - const int n, - hipblasComplex* A, - const int lda, - const int strideA, - int* ipiv, - const int strideP, - int* info, - const int batchCount) + const int n, + hipblasComplex* A, + const int lda, + const int strideA, + int* ipiv, + const int strideP, + int* info, + const int batchCount) { - return hipblasCgetrfStridedBatchedFortran(handle, n, A, lda, strideA, ipiv, strideP, info, batchCount); + return hipblasCgetrfStridedBatchedFortran( + handle, n, A, lda, strideA, ipiv, strideP, info, batchCount); } template <> hipblasStatus_t hipblasGetrfStridedBatched(hipblasHandle_t handle, - const int n, - hipblasDoubleComplex* A, - const int lda, - const int strideA, - int* ipiv, - const int strideP, - int* info, - const int batchCount) + const int n, + hipblasDoubleComplex* A, + const int lda, + const int strideA, + int* ipiv, + const int strideP, + int* info, + const int batchCount) { - return hipblasZgetrfStridedBatchedFortran(handle, n, A, lda, strideA, ipiv, strideP, info, batchCount); + return hipblasZgetrfStridedBatchedFortran( + handle, n, A, lda, strideA, ipiv, strideP, info, batchCount); } // getrs template <> hipblasStatus_t hipblasGetrs(hipblasHandle_t handle, - const hipblasOperation_t trans, - const int n, - const int nrhs, - float* A, - const int lda, - const int* ipiv, - float* B, - const int ldb, - int* info) + const hipblasOperation_t trans, + const int n, + const int nrhs, + float* A, + const int lda, + const int* ipiv, + float* B, + const int ldb, + int* info) { return hipblasSgetrsFortran(handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info); } template <> hipblasStatus_t hipblasGetrs(hipblasHandle_t handle, - const hipblasOperation_t trans, - const int n, - const int nrhs, - double* A, - const int lda, - const int* ipiv, - double* B, - const int ldb, - int* info) + const hipblasOperation_t trans, + const int n, + const int nrhs, + double* A, + const int lda, + const int* ipiv, + double* B, + const int ldb, + int* info) { return hipblasDgetrsFortran(handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info); } template <> hipblasStatus_t hipblasGetrs(hipblasHandle_t handle, - const hipblasOperation_t trans, - const int n, - const int nrhs, - hipblasComplex* A, - const int lda, - const int* ipiv, - hipblasComplex* B, - const int ldb, - int* info) -{ - return hipblasCgetrsFortran(handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info); -} - -template <> -hipblasStatus_t hipblasGetrs(hipblasHandle_t handle, const hipblasOperation_t trans, const int n, const int nrhs, - hipblasDoubleComplex* A, + hipblasComplex* A, const int lda, const int* ipiv, - hipblasDoubleComplex* B, + hipblasComplex* B, const int ldb, int* info) +{ + return hipblasCgetrsFortran(handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info); +} + +template <> +hipblasStatus_t hipblasGetrs(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + hipblasDoubleComplex* A, + const int lda, + const int* ipiv, + hipblasDoubleComplex* B, + const int ldb, + int* info) { return hipblasZgetrsFortran(handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info); } @@ -19839,84 +19900,88 @@ hipblasStatus_t hipblasGetrs(hipblasHandle_t // getrs_batched template <> hipblasStatus_t hipblasGetrsBatched(hipblasHandle_t handle, - const hipblasOperation_t trans, - const int n, - const int nrhs, - float* const A[], - const int lda, - const int* ipiv, - float* const B[], - const int ldb, - int* info, - const int batchCount) + const hipblasOperation_t trans, + const int n, + const int nrhs, + float* const A[], + const int lda, + const int* ipiv, + float* const B[], + const int ldb, + int* info, + const int batchCount) { - return hipblasSgetrsBatchedFortran(handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info, batchCount); + return hipblasSgetrsBatchedFortran( + handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info, batchCount); } template <> hipblasStatus_t hipblasGetrsBatched(hipblasHandle_t handle, - const hipblasOperation_t trans, - const int n, - const int nrhs, - double* const A[], - const int lda, - const int* ipiv, - double* const B[], - const int ldb, - int* info, - const int batchCount) + const hipblasOperation_t trans, + const int n, + const int nrhs, + double* const A[], + const int lda, + const int* ipiv, + double* const B[], + const int ldb, + int* info, + const int batchCount) { - return hipblasDgetrsBatchedFortran(handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info, batchCount); + return hipblasDgetrsBatchedFortran( + handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info, batchCount); } template <> hipblasStatus_t hipblasGetrsBatched(hipblasHandle_t handle, - const hipblasOperation_t trans, - const int n, - const int nrhs, - hipblasComplex* const A[], - const int lda, - const int* ipiv, - hipblasComplex* const B[], - const int ldb, - int* info, - const int batchCount) + const hipblasOperation_t trans, + const int n, + const int nrhs, + hipblasComplex* const A[], + const int lda, + const int* ipiv, + hipblasComplex* const B[], + const int ldb, + int* info, + const int batchCount) { - return hipblasCgetrsBatchedFortran(handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info, batchCount); + return hipblasCgetrsBatchedFortran( + handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info, batchCount); } template <> hipblasStatus_t hipblasGetrsBatched(hipblasHandle_t handle, - const hipblasOperation_t trans, - const int n, - const int nrhs, - hipblasDoubleComplex* const A[], - const int lda, - const int* ipiv, - hipblasDoubleComplex* const B[], - const int ldb, - int* info, - const int batchCount) + const hipblasOperation_t trans, + const int n, + const int nrhs, + hipblasDoubleComplex* const A[], + const int lda, + const int* ipiv, + hipblasDoubleComplex* const B[], + const int ldb, + int* info, + const int batchCount) { - return hipblasZgetrsBatchedFortran(handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info, batchCount); + return hipblasZgetrsBatchedFortran( + handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info, batchCount); } // getrs_strided_batched template <> hipblasStatus_t hipblasGetrsStridedBatched(hipblasHandle_t handle, - const hipblasOperation_t trans, - const int n, - const int nrhs, - float* A, - const int lda, - const int strideA, - const int* ipiv, - const int strideP, - float* B, - const int ldb, - const int strideB, - int* info, - const int batchCount) + const hipblasOperation_t trans, + const int n, + const int nrhs, + float* A, + const int lda, + const int strideA, + const int* ipiv, + const int strideP, + float* B, + const int ldb, + const int strideB, + int* info, + const int batchCount) { return hipblasSgetrsStridedBatchedFortran( handle, trans, n, nrhs, A, lda, strideA, ipiv, strideP, B, ldb, strideB, info, batchCount); @@ -19924,19 +19989,19 @@ hipblasStatus_t hipblasGetrsStridedBatched(hipblasHandle_t template <> hipblasStatus_t hipblasGetrsStridedBatched(hipblasHandle_t handle, - const hipblasOperation_t trans, - const int n, - const int nrhs, - double* A, - const int lda, - const int strideA, - const int* ipiv, - const int strideP, - double* B, - const int ldb, - const int strideB, - int* info, - const int batchCount) + const hipblasOperation_t trans, + const int n, + const int nrhs, + double* A, + const int lda, + const int strideA, + const int* ipiv, + const int strideP, + double* B, + const int ldb, + const int strideB, + int* info, + const int batchCount) { return hipblasDgetrsStridedBatchedFortran( handle, trans, n, nrhs, A, lda, strideA, ipiv, strideP, B, ldb, strideB, info, batchCount); @@ -19944,89 +20009,147 @@ hipblasStatus_t hipblasGetrsStridedBatched(hipblasHandle_t template <> hipblasStatus_t hipblasGetrsStridedBatched(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + hipblasComplex* A, + const int lda, + const int strideA, + const int* ipiv, + const int strideP, + hipblasComplex* B, + const int ldb, + const int strideB, + int* info, + const int batchCount) +{ + return hipblasCgetrsStridedBatchedFortran( + handle, trans, n, nrhs, A, lda, strideA, ipiv, strideP, B, ldb, strideB, info, batchCount); +} + +template <> +hipblasStatus_t + hipblasGetrsStridedBatched(hipblasHandle_t handle, const hipblasOperation_t trans, const int n, const int nrhs, - hipblasComplex* A, + hipblasDoubleComplex* A, const int lda, const int strideA, const int* ipiv, const int strideP, - hipblasComplex* B, + hipblasDoubleComplex* B, const int ldb, const int strideB, int* info, const int batchCount) { - return hipblasCgetrsStridedBatchedFortran( + return hipblasZgetrsStridedBatchedFortran( handle, trans, n, nrhs, A, lda, strideA, ipiv, strideP, B, ldb, strideB, info, batchCount); } +// getri_batched template <> -hipblasStatus_t hipblasGetrsStridedBatched(hipblasHandle_t handle, - const hipblasOperation_t trans, - const int n, - const int nrhs, - hipblasDoubleComplex* A, - const int lda, - const int strideA, - const int* ipiv, - const int strideP, - hipblasDoubleComplex* B, - const int ldb, - const int strideB, - int* info, - const int batchCount) +hipblasStatus_t hipblasGetriBatched(hipblasHandle_t handle, + const int n, + float* const A[], + const int lda, + int* ipiv, + float* const C[], + const int ldc, + int* info, + const int batchCount) { - return hipblasZgetrsStridedBatchedFortran( - handle, trans, n, nrhs, A, lda, strideA, ipiv, strideP, B, ldb, strideB, info, batchCount); + return hipblasSgetriBatchedFortran(handle, n, A, lda, ipiv, C, ldc, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGetriBatched(hipblasHandle_t handle, + const int n, + double* const A[], + const int lda, + int* ipiv, + double* const C[], + const int ldc, + int* info, + const int batchCount) +{ + return hipblasDgetriBatchedFortran(handle, n, A, lda, ipiv, C, ldc, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGetriBatched(hipblasHandle_t handle, + const int n, + hipblasComplex* const A[], + const int lda, + int* ipiv, + hipblasComplex* const C[], + const int ldc, + int* info, + const int batchCount) +{ + return hipblasCgetriBatchedFortran(handle, n, A, lda, ipiv, C, ldc, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGetriBatched(hipblasHandle_t handle, + const int n, + hipblasDoubleComplex* const A[], + const int lda, + int* ipiv, + hipblasDoubleComplex* const C[], + const int ldc, + int* info, + const int batchCount) +{ + return hipblasZgetriBatchedFortran(handle, n, A, lda, ipiv, C, ldc, info, batchCount); } // geqrf template <> hipblasStatus_t hipblasGeqrf(hipblasHandle_t handle, - const int m, - const int n, - float* A, - const int lda, - float* ipiv, - int* info) + const int m, + const int n, + float* A, + const int lda, + float* ipiv, + int* info) { return hipblasSgeqrfFortran(handle, m, n, A, lda, ipiv, info); } template <> hipblasStatus_t hipblasGeqrf(hipblasHandle_t handle, - const int m, - const int n, - double* A, - const int lda, - double* ipiv, - int* info) + const int m, + const int n, + double* A, + const int lda, + double* ipiv, + int* info) { return hipblasDgeqrfFortran(handle, m, n, A, lda, ipiv, info); } template <> hipblasStatus_t hipblasGeqrf(hipblasHandle_t handle, - const int m, - const int n, - hipblasComplex* A, - const int lda, - hipblasComplex* ipiv, - int* info) + const int m, + const int n, + hipblasComplex* A, + const int lda, + hipblasComplex* ipiv, + int* info) { return hipblasCgeqrfFortran(handle, m, n, A, lda, ipiv, info); } template <> hipblasStatus_t hipblasGeqrf(hipblasHandle_t handle, - const int m, - const int n, - hipblasDoubleComplex* A, - const int lda, - hipblasDoubleComplex* ipiv, - int* info) + const int m, + const int n, + hipblasDoubleComplex* A, + const int lda, + hipblasDoubleComplex* ipiv, + int* info) { return hipblasZgeqrfFortran(handle, m, n, A, lda, ipiv, info); } @@ -20034,52 +20157,52 @@ hipblasStatus_t hipblasGeqrf(hipblasHandle_t h // geqrf_batched template <> hipblasStatus_t hipblasGeqrfBatched(hipblasHandle_t handle, - const int m, - const int n, - float* const A[], - const int lda, - float* const ipiv[], - int* info, - const int batchCount) + const int m, + const int n, + float* const A[], + const int lda, + float* const ipiv[], + int* info, + const int batchCount) { return hipblasSgeqrfBatchedFortran(handle, m, n, A, lda, ipiv, info, batchCount); } template <> hipblasStatus_t hipblasGeqrfBatched(hipblasHandle_t handle, - const int m, - const int n, - double* const A[], - const int lda, - double* const ipiv[], - int* info, - const int batchCount) + const int m, + const int n, + double* const A[], + const int lda, + double* const ipiv[], + int* info, + const int batchCount) { return hipblasDgeqrfBatchedFortran(handle, m, n, A, lda, ipiv, info, batchCount); } template <> hipblasStatus_t hipblasGeqrfBatched(hipblasHandle_t handle, - const int m, - const int n, - hipblasComplex* const A[], - const int lda, - hipblasComplex* const ipiv[], - int* info, - const int batchCount) + const int m, + const int n, + hipblasComplex* const A[], + const int lda, + hipblasComplex* const ipiv[], + int* info, + const int batchCount) { return hipblasCgeqrfBatchedFortran(handle, m, n, A, lda, ipiv, info, batchCount); } template <> hipblasStatus_t hipblasGeqrfBatched(hipblasHandle_t handle, - const int m, - const int n, - hipblasDoubleComplex* const A[], - const int lda, - hipblasDoubleComplex* const ipiv[], - int* info, - const int batchCount) + const int m, + const int n, + hipblasDoubleComplex* const A[], + const int lda, + hipblasDoubleComplex* const ipiv[], + int* info, + const int batchCount) { return hipblasZgeqrfBatchedFortran(handle, m, n, A, lda, ipiv, info, batchCount); } @@ -20087,15 +20210,15 @@ hipblasStatus_t hipblasGeqrfBatched(hipblasHandle_t // geqrf_strided_batched template <> hipblasStatus_t hipblasGeqrfStridedBatched(hipblasHandle_t handle, - const int m, - const int n, - float* A, - const int lda, - const int strideA, - float* ipiv, - const int strideP, - int* info, - const int batchCount) + const int m, + const int n, + float* A, + const int lda, + const int strideA, + float* ipiv, + const int strideP, + int* info, + const int batchCount) { return hipblasSgeqrfStridedBatchedFortran( handle, m, n, A, lda, strideA, ipiv, strideP, info, batchCount); @@ -20103,15 +20226,15 @@ hipblasStatus_t hipblasGeqrfStridedBatched(hipblasHandle_t handle, template <> hipblasStatus_t hipblasGeqrfStridedBatched(hipblasHandle_t handle, - const int m, - const int n, - double* A, - const int lda, - const int strideA, - double* ipiv, - const int strideP, - int* info, - const int batchCount) + const int m, + const int n, + double* A, + const int lda, + const int strideA, + double* ipiv, + const int strideP, + int* info, + const int batchCount) { return hipblasDgeqrfStridedBatchedFortran( handle, m, n, A, lda, strideA, ipiv, strideP, info, batchCount); @@ -20119,15 +20242,15 @@ hipblasStatus_t hipblasGeqrfStridedBatched(hipblasHandle_t handle, template <> hipblasStatus_t hipblasGeqrfStridedBatched(hipblasHandle_t handle, - const int m, - const int n, - hipblasComplex* A, - const int lda, - const int strideA, - hipblasComplex* ipiv, - const int strideP, - int* info, - const int batchCount) + const int m, + const int n, + hipblasComplex* A, + const int lda, + const int strideA, + hipblasComplex* ipiv, + const int strideP, + int* info, + const int batchCount) { return hipblasCgeqrfStridedBatchedFortran( handle, m, n, A, lda, strideA, ipiv, strideP, info, batchCount); @@ -20135,15 +20258,15 @@ hipblasStatus_t hipblasGeqrfStridedBatched(hipblasHandle_t template <> hipblasStatus_t hipblasGeqrfStridedBatched(hipblasHandle_t handle, - const int m, - const int n, - hipblasDoubleComplex* A, - const int lda, - const int strideA, - hipblasDoubleComplex* ipiv, - const int strideP, - int* info, - const int batchCount) + const int m, + const int n, + hipblasDoubleComplex* A, + const int lda, + const int strideA, + hipblasDoubleComplex* ipiv, + const int strideP, + int* info, + const int batchCount) { return hipblasZgeqrfStridedBatchedFortran( handle, m, n, A, lda, strideA, ipiv, strideP, info, batchCount); diff --git a/clients/gtest/CMakeLists.txt b/clients/gtest/CMakeLists.txt index 3cb23d56f..9f2b88ee9 100644 --- a/clients/gtest/CMakeLists.txt +++ b/clients/gtest/CMakeLists.txt @@ -99,6 +99,7 @@ if( BUILD_WITH_SOLVER ) getrs_gtest.cpp getrs_batched_gtest.cpp getrs_strided_batched_gtest.cpp + getri_batched_gtest.cpp geqrf_gtest.cpp geqrf_batched_gtest.cpp geqrf_strided_batched_gtest.cpp diff --git a/clients/gtest/getri_batched_gtest.cpp b/clients/gtest/getri_batched_gtest.cpp new file mode 100644 index 000000000..257546f08 --- /dev/null +++ b/clients/gtest/getri_batched_gtest.cpp @@ -0,0 +1,157 @@ +/* ************************************************************************ + * Copyright 2016-2020 Advanced Micro Devices, Inc. + * + * ************************************************************************ */ + +#include "testing_getri_batched.hpp" +#include "utility.h" +#include +#include +#include +#include + +using ::testing::Combine; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::ValuesIn; +using namespace std; + +typedef std::tuple, double, int, bool> getri_batched_tuple; + +const vector> matrix_size_range + = {{-1, 1}, {10, 10}, {10, 20}, {500, 600}, {1024, 1024}}; + +const vector stride_scale_range = {2.5}; + +const vector batch_count_range = {-1, 0, 1, 2}; + +const vector is_fortran = {false, true}; + +Arguments setup_getri_batched_arguments(getri_batched_tuple tup) +{ + vector matrix_size = std::get<0>(tup); + double stride_scale = std::get<1>(tup); + int batch_count = std::get<2>(tup); + bool fortran = std::get<3>(tup); + + Arguments arg; + + arg.N = matrix_size[0]; + arg.lda = matrix_size[1]; + + arg.stride_scale = stride_scale; + arg.batch_count = batch_count; + + arg.fortran = fortran; + + return arg; +} + +class getri_batched_gtest : public ::TestWithParam +{ +protected: + getri_batched_gtest() {} + virtual ~getri_batched_gtest() {} + virtual void SetUp() {} + virtual void TearDown() {} +}; + +TEST_P(getri_batched_gtest, getri_batched_gtest_float) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_getri_batched_arguments(GetParam()); + + hipblasStatus_t status = testing_getri_batched(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.N < 0 || arg.lda < arg.N || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_SUCCESS, status); + } + } +} + +TEST_P(getri_batched_gtest, getri_batched_gtest_double) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_getri_batched_arguments(GetParam()); + + hipblasStatus_t status = testing_getri_batched(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.N < 0 || arg.lda < arg.N || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_SUCCESS, status); + } + } +} + +TEST_P(getri_batched_gtest, getri_batched_gtest_float_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_getri_batched_arguments(GetParam()); + + hipblasStatus_t status = testing_getri_batched(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.N < 0 || arg.lda < arg.N || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_SUCCESS, status); + } + } +} + +TEST_P(getri_batched_gtest, getri_batched_gtest_double_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_getri_batched_arguments(GetParam()); + + hipblasStatus_t status = testing_getri_batched(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.N < 0 || arg.lda < arg.N || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_SUCCESS, status); + } + } +} + +// notice we are using vector of vector +// so each elment in xxx_range is a vector, +// ValuesIn takes each element (a vector), combines them, and feeds them to test_p +// The combinations are { {M, N, lda, ldb}, stride_scale, batch_count } + +INSTANTIATE_TEST_CASE_P(hipblasGetriBatched, + getri_batched_gtest, + Combine(ValuesIn(matrix_size_range), + ValuesIn(stride_scale_range), + ValuesIn(batch_count_range), + ValuesIn(is_fortran))); diff --git a/clients/include/cblas_interface.h b/clients/include/cblas_interface.h index cd12209a1..9382491e6 100644 --- a/clients/include/cblas_interface.h +++ b/clients/include/cblas_interface.h @@ -438,6 +438,9 @@ int cblas_getrf(int m, int n, T* A, int lda, int* ipiv); template int cblas_getrs(char trans, int n, int nrhs, T* A, int lda, int* ipiv, T* B, int ldb); +template +int cblas_getri(int n, T* A, int lda, int* ipiv, T* work, int lwork); + template int cblas_geqrf(int m, int n, T* A, int lda, T* tau, T* work, int lwork); /* ============================================================================================ */ diff --git a/clients/include/hipblas.hpp b/clients/include/hipblas.hpp index dfcb0ee41..81b1003da 100644 --- a/clients/include/hipblas.hpp +++ b/clients/include/hipblas.hpp @@ -1900,6 +1900,18 @@ hipblasStatus_t hipblasGetrsStridedBatched(hipblasHandle_t handle, int* info, const int batchCount); +// getri +template +hipblasStatus_t hipblasGetriBatched(hipblasHandle_t handle, + const int n, + T* const A[], + const int lda, + int* ipiv, + T* const C[], + const int ldc, + int* info, + const int batchCount); + // geqrf template hipblasStatus_t hipblasGeqrf( diff --git a/clients/include/hipblas_fortran.hpp b/clients/include/hipblas_fortran.hpp index 50bf6f0c5..e7980cc74 100644 --- a/clients/include/hipblas_fortran.hpp +++ b/clients/include/hipblas_fortran.hpp @@ -6716,6 +6716,47 @@ hipblasStatus_t hipblasZgetrsStridedBatchedFortran(hipblasHandle_t hand int* info, const int batch_count); +// getri_batched +hipblasStatus_t hipblasSgetriBatchedFortran(hipblasHandle_t handle, + const int n, + float* const A[], + const int lda, + int* ipiv, + float* const C[], + const int ldc, + int* info, + const int batch_count); + +hipblasStatus_t hipblasDgetriBatchedFortran(hipblasHandle_t handle, + const int n, + double* const A[], + const int lda, + int* ipiv, + double* const C[], + const int ldc, + int* info, + const int batch_count); + +hipblasStatus_t hipblasCgetriBatchedFortran(hipblasHandle_t handle, + const int n, + hipblasComplex* const A[], + const int lda, + int* ipiv, + hipblasComplex* const C[], + const int ldc, + int* info, + const int batch_count); + +hipblasStatus_t hipblasZgetriBatchedFortran(hipblasHandle_t handle, + const int n, + hipblasDoubleComplex* const A[], + const int lda, + int* ipiv, + hipblasDoubleComplex* const C[], + const int ldc, + int* info, + const int batch_count); + // geqrf hipblasStatus_t hipblasSgeqrfFortran(hipblasHandle_t handle, const int m, diff --git a/clients/include/hipblas_fortran_solver.f90 b/clients/include/hipblas_fortran_solver.f90 index f6d5958f3..099490143 100644 --- a/clients/include/hipblas_fortran_solver.f90 +++ b/clients/include/hipblas_fortran_solver.f90 @@ -514,6 +514,83 @@ function hipblasZgetrsStridedBatchedFortran(handle, trans, n, nrhs, A, lda, stri ipiv, stride_P, B, ldb, stride_B, info, batch_count) end function hipblasZgetrsStridedBatchedFortran + ! getri_batched + function hipblasSgetriBatchedFortran(handle, n, A, lda, ipiv, C, ldc, info, batch_count) & + result(res) & + bind(c, name = 'hipblasSgetriBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: C + integer(c_int), value :: ldc + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasSgetriBatched(handle, n, A, lda, ipiv, C, ldc, info, batch_count) + end function hipblasSgetriBatchedFortran + + function hipblasDgetriBatchedFortran(handle, n, A, lda, ipiv, C, ldc, info, batch_count) & + result(res) & + bind(c, name = 'hipblasDgetriBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: C + integer(c_int), value :: ldc + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasDgetriBatched(handle, n, A, lda, ipiv, C, ldc, info, batch_count) + end function hipblasDgetriBatchedFortran + + function hipblasCgetriBatchedFortran(handle, n, A, lda, ipiv, C, ldc, info, batch_count) & + result(res) & + bind(c, name = 'hipblasCgetriBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: C + integer(c_int), value :: ldc + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasCgetriBatched(handle, n, A, lda, ipiv, C, ldc, info, batch_count) + end function hipblasCgetriBatchedFortran + + function hipblasZgetriBatchedFortran(handle, n, A, lda, ipiv, C, ldc, info, batch_count) & + result(res) & + bind(c, name = 'hipblasZgetriBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: C + integer(c_int), value :: ldc + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasZgetriBatched(handle, n, A, lda, ipiv, C, ldc, info, batch_count) + end function hipblasZgetriBatchedFortran + ! geqrf function hipblasSgeqrfFortran(handle, m, n, A, lda, tau, info) & result(res) & diff --git a/clients/include/testing_getri_batched.hpp b/clients/include/testing_getri_batched.hpp new file mode 100644 index 000000000..be3644faa --- /dev/null +++ b/clients/include/testing_getri_batched.hpp @@ -0,0 +1,158 @@ +/* ************************************************************************ + * Copyright 2016-2020 Advanced Micro Devices, Inc. + * + * ************************************************************************ */ + +#include +#include +#include +#include + +#include "cblas_interface.h" +#include "flops.h" +#include "hipblas.hpp" +#include "norm.h" +#include "unit.h" +#include "utility.h" + +using namespace std; + +template +hipblasStatus_t testing_getri_batched(Arguments argus) +{ + bool FORTRAN = argus.fortran; + auto hipblasGetriBatchedFn + = FORTRAN ? hipblasGetriBatched : hipblasGetriBatched; + + int M = argus.N; + int N = argus.N; + int lda = argus.lda; + int batch_count = argus.batch_count; + + int strideP = min(M, N); + int A_size = lda * N; + int Ipiv_size = strideP * batch_count; + + hipblasStatus_t status = HIPBLAS_STATUS_SUCCESS; + + // Check to prevent memory allocation error + if(M < 0 || N < 0 || lda < M || batch_count < 0) + { + return HIPBLAS_STATUS_INVALID_VALUE; + } + if(batch_count == 0) + { + return HIPBLAS_STATUS_SUCCESS; + } + + // Naming: dK is in GPU (device) memory. hK is in CPU (host) memory + host_vector hA[batch_count]; + host_vector hA1[batch_count]; + host_vector hC[batch_count]; + host_vector hIpiv(Ipiv_size); + host_vector hIpiv1(Ipiv_size); + host_vector hInfo(batch_count); + host_vector hInfo1(batch_count); + + device_batch_vector bA(batch_count, A_size); + device_batch_vector bC(batch_count, A_size); + + device_vector dA(batch_count); + device_vector dC(batch_count); + device_vector dIpiv(Ipiv_size); + device_vector dInfo(batch_count); + + double gpu_time_used, cpu_time_used; + double hipblasGflops, cblas_gflops; + double rocblas_error; + + hipblasHandle_t handle; + hipblasCreate(&handle); + + // Initial hA on CPU + srand(1); + for(int b = 0; b < batch_count; b++) + { + hA[b] = host_vector(A_size); + hA1[b] = host_vector(A_size); + hC[b] = host_vector(A_size); + int* hIpivb = hIpiv.data() + b * strideP; + + hipblas_init(hA[b], M, N, lda); + + // scale A to avoid singularities + for(int i = 0; i < M; i++) + { + for(int j = 0; j < N; j++) + { + if(i == j) + hA[b][i + j * lda] += 400; + else + hA[b][i + j * lda] -= 4; + } + } + + // perform LU factorization on A + hInfo[b] = cblas_getrf(M, N, hA[b].data(), lda, hIpivb); + + // Copy data from CPU to device + CHECK_HIP_ERROR(hipMemcpy(bA[b], hA[b].data(), A_size * sizeof(T), hipMemcpyHostToDevice)); + } + + CHECK_HIP_ERROR(hipMemcpy(dA, bA, batch_count * sizeof(T*), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dC, bC, batch_count * sizeof(T*), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dIpiv, hIpiv, Ipiv_size * sizeof(int), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemset(dInfo, 0, batch_count * sizeof(int))); + + /* ===================================================================== + HIPBLAS + =================================================================== */ + + status = hipblasGetriBatchedFn(handle, N, dA, lda, dIpiv, dC, lda, dInfo, batch_count); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + hipblasDestroy(handle); + return status; + } + + // Copy output from device to CPU + for(int b = 0; b < batch_count; b++) + CHECK_HIP_ERROR(hipMemcpy(hA1[b].data(), bC[b], A_size * sizeof(T), hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR( + hipMemcpy(hIpiv1.data(), dIpiv, Ipiv_size * sizeof(int), hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR( + hipMemcpy(hInfo1.data(), dInfo, batch_count * sizeof(int), hipMemcpyDeviceToHost)); + + if(argus.unit_check) + { + /* ===================================================================== + CPU LAPACK + =================================================================== */ + + for(int b = 0; b < batch_count; b++) + { + // Workspace query + host_vector work(1); + cblas_getri(N, hA[b].data(), lda, hIpiv.data() + b * strideP, work.data(), -1); + int lwork = type2int(work[0]); + + // Perform inversion + work = host_vector(lwork); + hInfo[b] + = cblas_getri(N, hA[b].data(), lda, hIpiv.data() + b * strideP, work.data(), lwork); + + if(argus.unit_check) + { + U eps = std::numeric_limits::epsilon(); + double tolerance = eps * 2000; + + double e = norm_check_general('F', M, N, lda, hA[b].data(), hA1[b].data()); + unit_check_error(e, tolerance); + } + } + } + + hipblasDestroy(handle); + return HIPBLAS_STATUS_SUCCESS; +} diff --git a/library/include/hipblas.h b/library/include/hipblas.h index 91ac6053f..7782d5698 100644 --- a/library/include/hipblas.h +++ b/library/include/hipblas.h @@ -6562,6 +6562,47 @@ HIPBLAS_EXPORT hipblasStatus_t hipblasZgetrsStridedBatched(hipblasHandle_t int* info, const int batch_count); +// getri_batched +HIPBLAS_EXPORT hipblasStatus_t hipblasSgetriBatched(hipblasHandle_t handle, + const int n, + float* const A[], + const int lda, + int* ipiv, + float* const C[], + const int ldc, + int* info, + const int batch_count); + +HIPBLAS_EXPORT hipblasStatus_t hipblasDgetriBatched(hipblasHandle_t handle, + const int n, + double* const A[], + const int lda, + int* ipiv, + double* const C[], + const int ldc, + int* info, + const int batch_count); + +HIPBLAS_EXPORT hipblasStatus_t hipblasCgetriBatched(hipblasHandle_t handle, + const int n, + hipblasComplex* const A[], + const int lda, + int* ipiv, + hipblasComplex* const C[], + const int ldc, + int* info, + const int batch_count); + +HIPBLAS_EXPORT hipblasStatus_t hipblasZgetriBatched(hipblasHandle_t handle, + const int n, + hipblasDoubleComplex* const A[], + const int lda, + int* ipiv, + hipblasDoubleComplex* const C[], + const int ldc, + int* info, + const int batch_count); + // geqrf HIPBLAS_EXPORT hipblasStatus_t hipblasSgeqrf(hipblasHandle_t handle, const int m, diff --git a/library/src/hcc_detail/hipblas.cpp b/library/src/hcc_detail/hipblas.cpp index 3bc28e7da..09e0f1b2a 100644 --- a/library/src/hcc_detail/hipblas.cpp +++ b/library/src/hcc_detail/hipblas.cpp @@ -12320,6 +12320,50 @@ rocblas_status rocsolver_zgeqrf_ptr_batched(rocblas_handle handle rocblas_double_complex* const ipiv[], const rocblas_int batch_count); +rocblas_status rocsolver_sgetri_outofplace_batched(rocblas_handle handle, + const rocblas_int n, + float* const A[], + const rocblas_int lda, + rocblas_int* ipiv, + const rocblas_stride strideP, + float* const C[], + const rocblas_int ldc, + rocblas_int* info, + const rocblas_int batch_count); + +rocblas_status rocsolver_dgetri_outofplace_batched(rocblas_handle handle, + const rocblas_int n, + double* const A[], + const rocblas_int lda, + rocblas_int* ipiv, + const rocblas_stride strideP, + double* const C[], + const rocblas_int ldc, + rocblas_int* info, + const rocblas_int batch_count); + +rocblas_status rocsolver_cgetri_outofplace_batched(rocblas_handle handle, + const rocblas_int n, + rocblas_float_complex* const A[], + const rocblas_int lda, + rocblas_int* ipiv, + const rocblas_stride strideP, + rocblas_float_complex* const C[], + const rocblas_int ldc, + rocblas_int* info, + const rocblas_int batch_count); + +rocblas_status rocsolver_zgetri_outofplace_batched(rocblas_handle handle, + const rocblas_int n, + rocblas_double_complex* const A[], + const rocblas_int lda, + rocblas_int* ipiv, + const rocblas_stride strideP, + rocblas_double_complex* const C[], + const rocblas_int ldc, + rocblas_int* info, + const rocblas_int batch_count); + #ifdef __cplusplus } #endif @@ -13031,6 +13075,79 @@ hipblasStatus_t hipblasZgetrsStridedBatched(hipblasHandle_t handle, batch_count)); } +// getri_batched +hipblasStatus_t hipblasSgetriBatched(hipblasHandle_t handle, + const int n, + float* const A[], + const int lda, + int* ipiv, + float* const C[], + const int ldc, + int* info, + const int batch_count) +{ + return rocBLASStatusToHIPStatus(rocsolver_sgetri_outofplace_batched( + (rocblas_handle)handle, n, A, lda, ipiv, n, C, ldc, info, batch_count)); +} + +hipblasStatus_t hipblasDgetriBatched(hipblasHandle_t handle, + const int n, + double* const A[], + const int lda, + int* ipiv, + double* const C[], + const int ldc, + int* info, + const int batch_count) +{ + return rocBLASStatusToHIPStatus(rocsolver_dgetri_outofplace_batched( + (rocblas_handle)handle, n, A, lda, ipiv, n, C, ldc, info, batch_count)); +} + +hipblasStatus_t hipblasCgetriBatched(hipblasHandle_t handle, + const int n, + hipblasComplex* const A[], + const int lda, + int* ipiv, + hipblasComplex* const C[], + const int ldc, + int* info, + const int batch_count) +{ + return rocBLASStatusToHIPStatus(rocsolver_cgetri_outofplace_batched((rocblas_handle)handle, + n, + (rocblas_float_complex**)A, + lda, + ipiv, + n, + (rocblas_float_complex**)C, + ldc, + info, + batch_count)); +} + +hipblasStatus_t hipblasZgetriBatched(hipblasHandle_t handle, + const int n, + hipblasDoubleComplex* const A[], + const int lda, + int* ipiv, + hipblasDoubleComplex* const C[], + const int ldc, + int* info, + const int batch_count) +{ + return rocBLASStatusToHIPStatus(rocsolver_zgetri_outofplace_batched((rocblas_handle)handle, + n, + (rocblas_double_complex**)A, + lda, + ipiv, + n, + (rocblas_double_complex**)C, + ldc, + info, + batch_count)); +} + // geqrf hipblasStatus_t hipblasSgeqrf(hipblasHandle_t handle, const int m, diff --git a/library/src/hipblas_module.f90 b/library/src/hipblas_module.f90 index fc3bc1531..8487d7799 100644 --- a/library/src/hipblas_module.f90 +++ b/library/src/hipblas_module.f90 @@ -12431,6 +12431,83 @@ function hipblasZgetrsStridedBatched(handle, trans, n, nrhs, A, lda, stride_A, i end function hipblasZgetrsStridedBatched end interface + ! getri_batched + interface + function hipblasSgetriBatched(handle, n, A, lda, ipiv, C, ldc, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasSgetriBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: C + integer(c_int), value :: ldc + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasSgetriBatched + end interface + + interface + function hipblasDgetriBatched(handle, n, A, lda, ipiv, C, ldc, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasDgetriBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: C + integer(c_int), value :: ldc + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasDgetriBatched + end interface + + interface + function hipblasCgetriBatched(handle, n, A, lda, ipiv, C, ldc, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasCgetriBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: C + integer(c_int), value :: ldc + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasCgetriBatched + end interface + + interface + function hipblasZgetriBatched(handle, n, A, lda, ipiv, C, ldc, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasZgetriBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: C + integer(c_int), value :: ldc + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasZgetriBatched + end interface + ! geqrf interface function hipblasSgeqrf(handle, m, n, A, lda, tau, info) & diff --git a/library/src/nvcc_detail/hipblas.cpp b/library/src/nvcc_detail/hipblas.cpp index 3873509a7..e3dfd58f0 100644 --- a/library/src/nvcc_detail/hipblas.cpp +++ b/library/src/nvcc_detail/hipblas.cpp @@ -9405,6 +9405,77 @@ hipblasStatus_t hipblasZgetrsStridedBatched(hipblasHandle_t handle, return HIPBLAS_STATUS_NOT_SUPPORTED; } +// getri_batched +hipblasStatus_t hipblasSgetriBatched(hipblasHandle_t handle, + const int n, + float* const A[], + const int lda, + int* ipiv, + float* const C[], + const int ldc, + int* info, + const int batch_count) +{ + return hipCUBLASStatusToHIPStatus( + cublasSgetriBatched((cublasHandle_t)handle, n, A, lda, ipiv, C, ldc, info, batch_count)); +} + +hipblasStatus_t hipblasDgetriBatched(hipblasHandle_t handle, + const int n, + double* const A[], + const int lda, + int* ipiv, + double* const C[], + const int ldc, + int* info, + const int batch_count) +{ + return hipCUBLASStatusToHIPStatus( + cublasDgetriBatched((cublasHandle_t)handle, n, A, lda, ipiv, C, ldc, info, batch_count)); +} + +hipblasStatus_t hipblasCgetriBatched(hipblasHandle_t handle, + const int n, + hipblasComplex* const A[], + const int lda, + int* ipiv, + hipblasComplex* const C[], + const int ldc, + int* info, + const int batch_count) +{ + return hipCUBLASStatusToHIPStatus(cublasCgetriBatched((cublasHandle_t)handle, + n, + (cuComplex**)A, + lda, + ipiv, + (cuComplex**)C, + ldc, + info, + batch_count)); +} + +hipblasStatus_t hipblasZgetriBatched(hipblasHandle_t handle, + const int n, + hipblasDoubleComplex* const A[], + const int lda, + int* ipiv, + hipblasDoubleComplex* const C[], + const int ldc, + int* info, + const int batch_count) +{ + return hipCUBLASStatusToHIPStatus(cublasZgetriBatched((cublasHandle_t)handle, + n, + (cuDoubleComplex**)A, + lda, + ipiv, + (cuDoubleComplex**)C, + ldc, + info, + batch_count)); +} + // geqrf hipblasStatus_t hipblasSgeqrf(hipblasHandle_t handle, const int m, From 9627d71267877410d06c37456f46131af4499fce Mon Sep 17 00:00:00 2001 From: daineAMD Date: Wed, 8 Jul 2020 09:51:03 -0600 Subject: [PATCH 9/9] Version for master branch release. --- CMakeLists.txt | 2 +- bump_develop_version.sh | 4 ++-- bump_master_version.sh | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8f01e0d76..59e12d620 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -45,7 +45,7 @@ include( ROCMInstallTargets ) include( ROCMPackageConfigHelpers ) include( ROCMInstallSymlinks ) -set ( VERSION_STRING "0.31.0" ) +set ( VERSION_STRING "0.32.0" ) rocm_setup_version( VERSION ${VERSION_STRING} ) # Append our library helper cmake path and the cmake path for hip (for convenience) diff --git a/bump_develop_version.sh b/bump_develop_version.sh index 546ea721b..ee3e1f638 100755 --- a/bump_develop_version.sh +++ b/bump_develop_version.sh @@ -5,7 +5,7 @@ # - run this script in master branch # - after running this script merge master into develop -OLD_HIPBLAS_VERSION="0.30.0" -NEW_HIPBLAS_VERSION="0.31.0" +OLD_HIPBLAS_VERSION="0.32.0" +NEW_HIPBLAS_VERSION="0.33.0" sed -i "s/${OLD_HIPBLAS_VERSION}/${NEW_HIPBLAS_VERSION}/g" CMakeLists.txt diff --git a/bump_master_version.sh b/bump_master_version.sh index bda9710ed..ef66be87a 100755 --- a/bump_master_version.sh +++ b/bump_master_version.sh @@ -6,7 +6,7 @@ # - after running this script and merging develop into master, run bump_develop_version.sh in master and # merge master into develop -OLD_HIPBLAS_VERSION="0.29.0" -NEW_HIPBLAS_VERSION="0.30.0" +OLD_HIPBLAS_VERSION="0.31.0" +NEW_HIPBLAS_VERSION="0.32.0" sed -i "s/${OLD_HIPBLAS_VERSION}/${NEW_HIPBLAS_VERSION}/g" CMakeLists.txt