diff --git a/.githooks/pre-commit b/.githooks/pre-commit index 13f070903..87fee17b1 100755 --- a/.githooks/pre-commit +++ b/.githooks/pre-commit @@ -4,9 +4,7 @@ # are installed, and if so, uses the installed version to format # the staged changes. -set -x - -format=/opt/rocm/hcc/bin/clang-format +export PATH=/usr/bin:/bin:/opt/rocm/llvm/bin:/opt/rocm/hcc/bin # Redirect stdout to stderr. exec >&2 @@ -48,11 +46,11 @@ for file in $files; do done # if clang-format exists, run it on C/C++ files -if [[ -x $format ]]; then +if command -v clang-format >/dev/null; then for file in $files; do if [[ -e $file ]] && echo $file | grep -Eq '\.c$|\.h$|\.hpp$|\.cpp$|\.cl$|\.h\.in$|\.hpp\.in$|\.cpp\.in$'; then - echo "$format $file" - "$format" -i -style=file "$file" + echo "clang-format $file" + clang-format -i -style=file "$file" git add -u "$file" fi done diff --git a/CMakeLists.txt b/CMakeLists.txt index dca786e4a..59e12d620 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -45,7 +45,7 @@ include( ROCMInstallTargets ) include( ROCMPackageConfigHelpers ) include( ROCMInstallSymlinks ) -set ( VERSION_STRING "0.30.0" ) +set ( VERSION_STRING "0.32.0" ) rocm_setup_version( VERSION ${VERSION_STRING} ) # Append our library helper cmake path and the cmake path for hip (for convenience) diff --git a/bump_develop_version.sh b/bump_develop_version.sh index 546ea721b..ee3e1f638 100755 --- a/bump_develop_version.sh +++ b/bump_develop_version.sh @@ -5,7 +5,7 @@ # - run this script in master branch # - after running this script merge master into develop -OLD_HIPBLAS_VERSION="0.30.0" -NEW_HIPBLAS_VERSION="0.31.0" +OLD_HIPBLAS_VERSION="0.32.0" +NEW_HIPBLAS_VERSION="0.33.0" sed -i "s/${OLD_HIPBLAS_VERSION}/${NEW_HIPBLAS_VERSION}/g" CMakeLists.txt diff --git a/bump_master_version.sh b/bump_master_version.sh index bda9710ed..ef66be87a 100755 --- a/bump_master_version.sh +++ b/bump_master_version.sh @@ -6,7 +6,7 @@ # - after running this script and merging develop into master, run bump_develop_version.sh in master and # merge master into develop -OLD_HIPBLAS_VERSION="0.29.0" -NEW_HIPBLAS_VERSION="0.30.0" +OLD_HIPBLAS_VERSION="0.31.0" +NEW_HIPBLAS_VERSION="0.32.0" sed -i "s/${OLD_HIPBLAS_VERSION}/${NEW_HIPBLAS_VERSION}/g" CMakeLists.txt diff --git a/clients/CMakeLists.txt b/clients/CMakeLists.txt index e639b1831..296ca17e2 100644 --- a/clients/CMakeLists.txt +++ b/clients/CMakeLists.txt @@ -30,8 +30,17 @@ set(hipblas_f90_source_clients include/hipblas_fortran.f90 ) +set(hipblas_f90_source_clients_solver + include/hipblas_fortran_solver.f90 +) + if( BUILD_CLIENTS_TESTS OR BUILD_CLIENTS_SAMPLES ) - add_library(hipblas_fortran_client ${hipblas_f90_source_clients}) + if( BUILD_WITH_SOLVER) + add_library(hipblas_fortran_client ${hipblas_f90_source_clients} ${hipblas_f90_source_clients_solver}) + else() + add_library(hipblas_fortran_client ${hipblas_f90_source_clients}) + endif() + add_dependencies(hipblas_fortran_client hipblas_fortran) include_directories(${CMAKE_BINARY_DIR}/include) endif( ) diff --git a/clients/common/cblas_interface.cpp b/clients/common/cblas_interface.cpp index 9a2fcdd4b..978433869 100644 --- a/clients/common/cblas_interface.cpp +++ b/clients/common/cblas_interface.cpp @@ -53,6 +53,18 @@ void zgetrs_(char* trans, int* ldb, int* info); +void sgetri_(int* n, float* A, int* lda, int* ipiv, float* work, int* lwork, int* info); +void dgetri_(int* n, double* A, int* lda, int* ipiv, double* work, int* lwork, int* info); +void cgetri_( + int* n, hipblasComplex* A, int* lda, int* ipiv, hipblasComplex* work, int* lwork, int* info); +void zgetri_(int* n, + hipblasDoubleComplex* A, + int* lda, + int* ipiv, + hipblasDoubleComplex* work, + int* lwork, + int* info); + void sgeqrf_(int* m, int* n, float* A, int* lda, float* tau, float* work, int* lwork, int* info); void dgeqrf_(int* m, int* n, double* A, int* lda, double* tau, double* work, int* lwork, int* info); void cgeqrf_(int* m, @@ -3163,6 +3175,41 @@ int cblas_getrs(char trans, return info; } +// getri +template <> +int cblas_getri(int n, float* A, int lda, int* ipiv, float* work, int lwork) +{ + int info; + sgetri_(&n, A, &lda, ipiv, work, &lwork, &info); + return info; +} + +template <> +int cblas_getri(int n, double* A, int lda, int* ipiv, double* work, int lwork) +{ + int info; + dgetri_(&n, A, &lda, ipiv, work, &lwork, &info); + return info; +} + +template <> +int cblas_getri( + int n, hipblasComplex* A, int lda, int* ipiv, hipblasComplex* work, int lwork) +{ + int info; + cgetri_(&n, A, &lda, ipiv, work, &lwork, &info); + return info; +} + +template <> +int cblas_getri( + int n, hipblasDoubleComplex* A, int lda, int* ipiv, hipblasDoubleComplex* work, int lwork) +{ + int info; + zgetri_(&n, A, &lda, ipiv, work, &lwork, &info); + return info; +} + // geqrf template <> int cblas_geqrf(int m, int n, float* A, int lda, float* tau, float* work, int lwork) diff --git a/clients/common/hipblas_template_specialization.cpp b/clients/common/hipblas_template_specialization.cpp index 57144d4e6..a0ff958ed 100644 --- a/clients/common/hipblas_template_specialization.cpp +++ b/clients/common/hipblas_template_specialization.cpp @@ -12,26 +12,6 @@ * \brief provide template functions interfaces to ROCBLAS C89 interfaces */ -/* - * =========================================================================== - * level 1 BLAS - * =========================================================================== - */ - -/* ************************************************************************ - * Copyright 2016-2020 Advanced Micro Devices, Inc. - * - * ************************************************************************/ - -#include "hipblas.h" -#include "hipblas.hpp" -#include "hipblas_fortran.hpp" -#include - -/*!\file - * \brief provide template functions interfaces to ROCBLAS C89 interfaces -*/ - /* * =========================================================================== * level 1 BLAS @@ -9881,6 +9861,63 @@ hipblasStatus_t hipblasGetrsStridedBatched(hipblasHandle_t handle, trans, n, nrhs, A, lda, strideA, ipiv, strideP, B, ldb, strideB, info, batchCount); } +// getri_batched +template <> +hipblasStatus_t hipblasGetriBatched(hipblasHandle_t handle, + const int n, + float* const A[], + const int lda, + int* ipiv, + float* const C[], + const int ldc, + int* info, + const int batchCount) +{ + return hipblasSgetriBatched(handle, n, A, lda, ipiv, C, ldc, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGetriBatched(hipblasHandle_t handle, + const int n, + double* const A[], + const int lda, + int* ipiv, + double* const C[], + const int ldc, + int* info, + const int batchCount) +{ + return hipblasDgetriBatched(handle, n, A, lda, ipiv, C, ldc, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGetriBatched(hipblasHandle_t handle, + const int n, + hipblasComplex* const A[], + const int lda, + int* ipiv, + hipblasComplex* const C[], + const int ldc, + int* info, + const int batchCount) +{ + return hipblasCgetriBatched(handle, n, A, lda, ipiv, C, ldc, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGetriBatched(hipblasHandle_t handle, + const int n, + hipblasDoubleComplex* const A[], + const int lda, + int* ipiv, + hipblasDoubleComplex* const C[], + const int ldc, + int* info, + const int batchCount) +{ + return hipblasZgetriBatched(handle, n, A, lda, ipiv, C, ldc, info, batchCount); +} + // geqrf template <> hipblasStatus_t hipblasGeqrf(hipblasHandle_t handle, @@ -19654,519 +19691,585 @@ hipblasStatus_t batchCount); } -// #ifdef __HIP_PLATFORM_SOLVER__ - -// // getrf -// template <> -// hipblasStatus_t hipblasGetrf( -// hipblasHandle_t handle, const int n, float* A, const int lda, int* ipiv, int* info) -// { -// return hipblasSgetrfFortran(handle, n, A, lda, ipiv, info); -// } - -// template <> -// hipblasStatus_t hipblasGetrf( -// hipblasHandle_t handle, const int n, double* A, const int lda, int* ipiv, int* info) -// { -// return hipblasDgetrfFortran(handle, n, A, lda, ipiv, info); -// } - -// template <> -// hipblasStatus_t hipblasGetrf( -// hipblasHandle_t handle, const int n, hipblasComplex* A, const int lda, int* ipiv, int* info) -// { -// return hipblasCgetrfFortran(handle, n, A, lda, ipiv, info); -// } - -// template <> -// hipblasStatus_t hipblasGetrf(hipblasHandle_t handle, -// const int n, -// hipblasDoubleComplex* A, -// const int lda, -// int* ipiv, -// int* info) -// { -// return hipblasZgetrfFortran(handle, n, A, lda, ipiv, info); -// } - -// // getrf_batched -// template <> -// hipblasStatus_t hipblasGetrfBatched(hipblasHandle_t handle, -// const int n, -// float* const A[], -// const int lda, -// int* ipiv, -// int* info, -// const int batchCount) -// { -// return hipblasSgetrfBatchedFortran(handle, n, A, lda, ipiv, info, batchCount); -// } - -// template <> -// hipblasStatus_t hipblasGetrfBatched(hipblasHandle_t handle, -// const int n, -// double* const A[], -// const int lda, -// int* ipiv, -// int* info, -// const int batchCount) -// { -// return hipblasDgetrfBatchedFortran(handle, n, A, lda, ipiv, info, batchCount); -// } - -// template <> -// hipblasStatus_t hipblasGetrfBatched(hipblasHandle_t handle, -// const int n, -// hipblasComplex* const A[], -// const int lda, -// int* ipiv, -// int* info, -// const int batchCount) -// { -// return hipblasCgetrfBatchedFortran(handle, n, A, lda, ipiv, info, batchCount); -// } - -// template <> -// hipblasStatus_t hipblasGetrfBatched(hipblasHandle_t handle, -// const int n, -// hipblasDoubleComplex* const A[], -// const int lda, -// int* ipiv, -// int* info, -// const int batchCount) -// { -// return hipblasZgetrfBatchedFortran(handle, n, A, lda, ipiv, info, batchCount); -// } - -// // getrf_strided_batched -// template <> -// hipblasStatus_t hipblasGetrfStridedBatched(hipblasHandle_t handle, -// const int n, -// float* A, -// const int lda, -// const int strideA, -// int* ipiv, -// const int strideP, -// int* info, -// const int batchCount) -// { -// return hipblasSgetrfStridedBatchedFortran(handle, n, A, lda, strideA, ipiv, strideP, info, batchCount); -// } - -// template <> -// hipblasStatus_t hipblasGetrfStridedBatched(hipblasHandle_t handle, -// const int n, -// double* A, -// const int lda, -// const int strideA, -// int* ipiv, -// const int strideP, -// int* info, -// const int batchCount) -// { -// return hipblasDgetrfStridedBatchedFortran(handle, n, A, lda, strideA, ipiv, strideP, info, batchCount); -// } - -// template <> -// hipblasStatus_t hipblasGetrfStridedBatched(hipblasHandle_t handle, -// const int n, -// hipblasComplex* A, -// const int lda, -// const int strideA, -// int* ipiv, -// const int strideP, -// int* info, -// const int batchCount) -// { -// return hipblasCgetrfStridedBatchedFortran(handle, n, A, lda, strideA, ipiv, strideP, info, batchCount); -// } - -// template <> -// hipblasStatus_t hipblasGetrfStridedBatched(hipblasHandle_t handle, -// const int n, -// hipblasDoubleComplex* A, -// const int lda, -// const int strideA, -// int* ipiv, -// const int strideP, -// int* info, -// const int batchCount) -// { -// return hipblasZgetrfStridedBatchedFortran(handle, n, A, lda, strideA, ipiv, strideP, info, batchCount); -// } - -// // getrs -// template <> -// hipblasStatus_t hipblasGetrs(hipblasHandle_t handle, -// const hipblasOperation_t trans, -// const int n, -// const int nrhs, -// float* A, -// const int lda, -// const int* ipiv, -// float* B, -// const int ldb, -// int* info) -// { -// return hipblasSgetrsFortran(handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info); -// } - -// template <> -// hipblasStatus_t hipblasGetrs(hipblasHandle_t handle, -// const hipblasOperation_t trans, -// const int n, -// const int nrhs, -// double* A, -// const int lda, -// const int* ipiv, -// double* B, -// const int ldb, -// int* info) -// { -// return hipblasDgetrsFortran(handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info); -// } - -// template <> -// hipblasStatus_t hipblasGetrs(hipblasHandle_t handle, -// const hipblasOperation_t trans, -// const int n, -// const int nrhs, -// hipblasComplex* A, -// const int lda, -// const int* ipiv, -// hipblasComplex* B, -// const int ldb, -// int* info) -// { -// return hipblasCgetrsFortran(handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info); -// } - -// template <> -// hipblasStatus_t hipblasGetrs(hipblasHandle_t handle, -// const hipblasOperation_t trans, -// const int n, -// const int nrhs, -// hipblasDoubleComplex* A, -// const int lda, -// const int* ipiv, -// hipblasDoubleComplex* B, -// const int ldb, -// int* info) -// { -// return hipblasZgetrsFortran(handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info); -// } - -// // getrs_batched -// template <> -// hipblasStatus_t hipblasGetrsBatched(hipblasHandle_t handle, -// const hipblasOperation_t trans, -// const int n, -// const int nrhs, -// float* const A[], -// const int lda, -// const int* ipiv, -// float* const B[], -// const int ldb, -// int* info, -// const int batchCount) -// { -// return hipblasSgetrsBatchedFortran(handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info, batchCount); -// } - -// template <> -// hipblasStatus_t hipblasGetrsBatched(hipblasHandle_t handle, -// const hipblasOperation_t trans, -// const int n, -// const int nrhs, -// double* const A[], -// const int lda, -// const int* ipiv, -// double* const B[], -// const int ldb, -// int* info, -// const int batchCount) -// { -// return hipblasDgetrsBatchedFortran(handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info, batchCount); -// } - -// template <> -// hipblasStatus_t hipblasGetrsBatched(hipblasHandle_t handle, -// const hipblasOperation_t trans, -// const int n, -// const int nrhs, -// hipblasComplex* const A[], -// const int lda, -// const int* ipiv, -// hipblasComplex* const B[], -// const int ldb, -// int* info, -// const int batchCount) -// { -// return hipblasCgetrsBatchedFortran(handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info, batchCount); -// } - -// template <> -// hipblasStatus_t hipblasGetrsBatched(hipblasHandle_t handle, -// const hipblasOperation_t trans, -// const int n, -// const int nrhs, -// hipblasDoubleComplex* const A[], -// const int lda, -// const int* ipiv, -// hipblasDoubleComplex* const B[], -// const int ldb, -// int* info, -// const int batchCount) -// { -// return hipblasZgetrsBatchedFortran(handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info, batchCount); -// } - -// // getrs_strided_batched -// template <> -// hipblasStatus_t hipblasGetrsStridedBatched(hipblasHandle_t handle, -// const hipblasOperation_t trans, -// const int n, -// const int nrhs, -// float* A, -// const int lda, -// const int strideA, -// const int* ipiv, -// const int strideP, -// float* B, -// const int ldb, -// const int strideB, -// int* info, -// const int batchCount) -// { -// return hipblasSgetrsStridedBatchedFortran( -// handle, trans, n, nrhs, A, lda, strideA, ipiv, strideP, B, ldb, strideB, info, batchCount); -// } - -// template <> -// hipblasStatus_t hipblasGetrsStridedBatched(hipblasHandle_t handle, -// const hipblasOperation_t trans, -// const int n, -// const int nrhs, -// double* A, -// const int lda, -// const int strideA, -// const int* ipiv, -// const int strideP, -// double* B, -// const int ldb, -// const int strideB, -// int* info, -// const int batchCount) -// { -// return hipblasDgetrsStridedBatchedFortran( -// handle, trans, n, nrhs, A, lda, strideA, ipiv, strideP, B, ldb, strideB, info, batchCount); -// } - -// template <> -// hipblasStatus_t hipblasGetrsStridedBatched(hipblasHandle_t handle, -// const hipblasOperation_t trans, -// const int n, -// const int nrhs, -// hipblasComplex* A, -// const int lda, -// const int strideA, -// const int* ipiv, -// const int strideP, -// hipblasComplex* B, -// const int ldb, -// const int strideB, -// int* info, -// const int batchCount) -// { -// return hipblasCgetrsStridedBatchedFortran( -// handle, trans, n, nrhs, A, lda, strideA, ipiv, strideP, B, ldb, strideB, info, batchCount); -// } - -// template <> -// hipblasStatus_t hipblasGetrsStridedBatched(hipblasHandle_t handle, -// const hipblasOperation_t trans, -// const int n, -// const int nrhs, -// hipblasDoubleComplex* A, -// const int lda, -// const int strideA, -// const int* ipiv, -// const int strideP, -// hipblasDoubleComplex* B, -// const int ldb, -// const int strideB, -// int* info, -// const int batchCount) -// { -// return hipblasZgetrsStridedBatchedFortran( -// handle, trans, n, nrhs, A, lda, strideA, ipiv, strideP, B, ldb, strideB, info, batchCount); -// } - -// // geqrf -// template <> -// hipblasStatus_t hipblasGeqrf(hipblasHandle_t handle, -// const int m, -// const int n, -// float* A, -// const int lda, -// float* ipiv, -// int* info) -// { -// return hipblasSgeqrfFortran(handle, m, n, A, lda, ipiv, info); -// } - -// template <> -// hipblasStatus_t hipblasGeqrf(hipblasHandle_t handle, -// const int m, -// const int n, -// double* A, -// const int lda, -// double* ipiv, -// int* info) -// { -// return hipblasDgeqrfFortran(handle, m, n, A, lda, ipiv, info); -// } - -// template <> -// hipblasStatus_t hipblasGeqrf(hipblasHandle_t handle, -// const int m, -// const int n, -// hipblasComplex* A, -// const int lda, -// hipblasComplex* ipiv, -// int* info) -// { -// return hipblasCgeqrfFortran(handle, m, n, A, lda, ipiv, info); -// } - -// template <> -// hipblasStatus_t hipblasGeqrf(hipblasHandle_t handle, -// const int m, -// const int n, -// hipblasDoubleComplex* A, -// const int lda, -// hipblasDoubleComplex* ipiv, -// int* info) -// { -// return hipblasZgeqrfFortran(handle, m, n, A, lda, ipiv, info); -// } - -// // geqrf_batched -// template <> -// hipblasStatus_t hipblasGeqrfBatched(hipblasHandle_t handle, -// const int m, -// const int n, -// float* const A[], -// const int lda, -// float* const ipiv[], -// int* info, -// const int batchCount) -// { -// return hipblasSgeqrfBatchedFortran(handle, m, n, A, lda, ipiv, info, batchCount); -// } - -// template <> -// hipblasStatus_t hipblasGeqrfBatched(hipblasHandle_t handle, -// const int m, -// const int n, -// double* const A[], -// const int lda, -// double* const ipiv[], -// int* info, -// const int batchCount) -// { -// return hipblasDgeqrfBatchedFortran(handle, m, n, A, lda, ipiv, info, batchCount); -// } - -// template <> -// hipblasStatus_t hipblasGeqrfBatched(hipblasHandle_t handle, -// const int m, -// const int n, -// hipblasComplex* const A[], -// const int lda, -// hipblasComplex* const ipiv[], -// int* info, -// const int batchCount) -// { -// return hipblasCgeqrfBatchedFortran(handle, m, n, A, lda, ipiv, info, batchCount); -// } - -// template <> -// hipblasStatus_t hipblasGeqrfBatched(hipblasHandle_t handle, -// const int m, -// const int n, -// hipblasDoubleComplex* const A[], -// const int lda, -// hipblasDoubleComplex* const ipiv[], -// int* info, -// const int batchCount) -// { -// return hipblasZgeqrfBatchedFortran(handle, m, n, A, lda, ipiv, info, batchCount); -// } - -// // geqrf_strided_batched -// template <> -// hipblasStatus_t hipblasGeqrfStridedBatched(hipblasHandle_t handle, -// const int m, -// const int n, -// float* A, -// const int lda, -// const int strideA, -// float* ipiv, -// const int strideP, -// int* info, -// const int batchCount) -// { -// return hipblasSgeqrfStridedBatchedFortran( -// handle, m, n, A, lda, strideA, ipiv, strideP, info, batchCount); -// } - -// template <> -// hipblasStatus_t hipblasGeqrfStridedBatched(hipblasHandle_t handle, -// const int m, -// const int n, -// double* A, -// const int lda, -// const int strideA, -// double* ipiv, -// const int strideP, -// int* info, -// const int batchCount) -// { -// return hipblasDgeqrfStridedBatchedFortran( -// handle, m, n, A, lda, strideA, ipiv, strideP, info, batchCount); -// } - -// template <> -// hipblasStatus_t hipblasGeqrfStridedBatched(hipblasHandle_t handle, -// const int m, -// const int n, -// hipblasComplex* A, -// const int lda, -// const int strideA, -// hipblasComplex* ipiv, -// const int strideP, -// int* info, -// const int batchCount) -// { -// return hipblasCgeqrfStridedBatchedFortran( -// handle, m, n, A, lda, strideA, ipiv, strideP, info, batchCount); -// } - -// template <> -// hipblasStatus_t hipblasGeqrfStridedBatched(hipblasHandle_t handle, -// const int m, -// const int n, -// hipblasDoubleComplex* A, -// const int lda, -// const int strideA, -// hipblasDoubleComplex* ipiv, -// const int strideP, -// int* info, -// const int batchCount) -// { -// return hipblasZgeqrfStridedBatchedFortran( -// handle, m, n, A, lda, strideA, ipiv, strideP, info, batchCount); -// } - -// #endif +#ifdef __HIP_PLATFORM_SOLVER__ + +// getrf +template <> +hipblasStatus_t hipblasGetrf( + hipblasHandle_t handle, const int n, float* A, const int lda, int* ipiv, int* info) +{ + return hipblasSgetrfFortran(handle, n, A, lda, ipiv, info); +} + +template <> +hipblasStatus_t hipblasGetrf( + hipblasHandle_t handle, const int n, double* A, const int lda, int* ipiv, int* info) +{ + return hipblasDgetrfFortran(handle, n, A, lda, ipiv, info); +} + +template <> +hipblasStatus_t hipblasGetrf( + hipblasHandle_t handle, const int n, hipblasComplex* A, const int lda, int* ipiv, int* info) +{ + return hipblasCgetrfFortran(handle, n, A, lda, ipiv, info); +} + +template <> +hipblasStatus_t hipblasGetrf(hipblasHandle_t handle, + const int n, + hipblasDoubleComplex* A, + const int lda, + int* ipiv, + int* info) +{ + return hipblasZgetrfFortran(handle, n, A, lda, ipiv, info); +} + +// getrf_batched +template <> +hipblasStatus_t hipblasGetrfBatched(hipblasHandle_t handle, + const int n, + float* const A[], + const int lda, + int* ipiv, + int* info, + const int batchCount) +{ + return hipblasSgetrfBatchedFortran(handle, n, A, lda, ipiv, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGetrfBatched(hipblasHandle_t handle, + const int n, + double* const A[], + const int lda, + int* ipiv, + int* info, + const int batchCount) +{ + return hipblasDgetrfBatchedFortran(handle, n, A, lda, ipiv, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGetrfBatched(hipblasHandle_t handle, + const int n, + hipblasComplex* const A[], + const int lda, + int* ipiv, + int* info, + const int batchCount) +{ + return hipblasCgetrfBatchedFortran(handle, n, A, lda, ipiv, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGetrfBatched(hipblasHandle_t handle, + const int n, + hipblasDoubleComplex* const A[], + const int lda, + int* ipiv, + int* info, + const int batchCount) +{ + return hipblasZgetrfBatchedFortran(handle, n, A, lda, ipiv, info, batchCount); +} + +// getrf_strided_batched +template <> +hipblasStatus_t hipblasGetrfStridedBatched(hipblasHandle_t handle, + const int n, + float* A, + const int lda, + const int strideA, + int* ipiv, + const int strideP, + int* info, + const int batchCount) +{ + return hipblasSgetrfStridedBatchedFortran( + handle, n, A, lda, strideA, ipiv, strideP, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGetrfStridedBatched(hipblasHandle_t handle, + const int n, + double* A, + const int lda, + const int strideA, + int* ipiv, + const int strideP, + int* info, + const int batchCount) +{ + return hipblasDgetrfStridedBatchedFortran( + handle, n, A, lda, strideA, ipiv, strideP, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGetrfStridedBatched(hipblasHandle_t handle, + const int n, + hipblasComplex* A, + const int lda, + const int strideA, + int* ipiv, + const int strideP, + int* info, + const int batchCount) +{ + return hipblasCgetrfStridedBatchedFortran( + handle, n, A, lda, strideA, ipiv, strideP, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGetrfStridedBatched(hipblasHandle_t handle, + const int n, + hipblasDoubleComplex* A, + const int lda, + const int strideA, + int* ipiv, + const int strideP, + int* info, + const int batchCount) +{ + return hipblasZgetrfStridedBatchedFortran( + handle, n, A, lda, strideA, ipiv, strideP, info, batchCount); +} + +// getrs +template <> +hipblasStatus_t hipblasGetrs(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + float* A, + const int lda, + const int* ipiv, + float* B, + const int ldb, + int* info) +{ + return hipblasSgetrsFortran(handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info); +} + +template <> +hipblasStatus_t hipblasGetrs(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + double* A, + const int lda, + const int* ipiv, + double* B, + const int ldb, + int* info) +{ + return hipblasDgetrsFortran(handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info); +} + +template <> +hipblasStatus_t hipblasGetrs(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + hipblasComplex* A, + const int lda, + const int* ipiv, + hipblasComplex* B, + const int ldb, + int* info) +{ + return hipblasCgetrsFortran(handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info); +} + +template <> +hipblasStatus_t hipblasGetrs(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + hipblasDoubleComplex* A, + const int lda, + const int* ipiv, + hipblasDoubleComplex* B, + const int ldb, + int* info) +{ + return hipblasZgetrsFortran(handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info); +} + +// getrs_batched +template <> +hipblasStatus_t hipblasGetrsBatched(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + float* const A[], + const int lda, + const int* ipiv, + float* const B[], + const int ldb, + int* info, + const int batchCount) +{ + return hipblasSgetrsBatchedFortran( + handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGetrsBatched(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + double* const A[], + const int lda, + const int* ipiv, + double* const B[], + const int ldb, + int* info, + const int batchCount) +{ + return hipblasDgetrsBatchedFortran( + handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGetrsBatched(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + hipblasComplex* const A[], + const int lda, + const int* ipiv, + hipblasComplex* const B[], + const int ldb, + int* info, + const int batchCount) +{ + return hipblasCgetrsBatchedFortran( + handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGetrsBatched(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + hipblasDoubleComplex* const A[], + const int lda, + const int* ipiv, + hipblasDoubleComplex* const B[], + const int ldb, + int* info, + const int batchCount) +{ + return hipblasZgetrsBatchedFortran( + handle, trans, n, nrhs, A, lda, ipiv, B, ldb, info, batchCount); +} + +// getrs_strided_batched +template <> +hipblasStatus_t hipblasGetrsStridedBatched(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + float* A, + const int lda, + const int strideA, + const int* ipiv, + const int strideP, + float* B, + const int ldb, + const int strideB, + int* info, + const int batchCount) +{ + return hipblasSgetrsStridedBatchedFortran( + handle, trans, n, nrhs, A, lda, strideA, ipiv, strideP, B, ldb, strideB, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGetrsStridedBatched(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + double* A, + const int lda, + const int strideA, + const int* ipiv, + const int strideP, + double* B, + const int ldb, + const int strideB, + int* info, + const int batchCount) +{ + return hipblasDgetrsStridedBatchedFortran( + handle, trans, n, nrhs, A, lda, strideA, ipiv, strideP, B, ldb, strideB, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGetrsStridedBatched(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + hipblasComplex* A, + const int lda, + const int strideA, + const int* ipiv, + const int strideP, + hipblasComplex* B, + const int ldb, + const int strideB, + int* info, + const int batchCount) +{ + return hipblasCgetrsStridedBatchedFortran( + handle, trans, n, nrhs, A, lda, strideA, ipiv, strideP, B, ldb, strideB, info, batchCount); +} + +template <> +hipblasStatus_t + hipblasGetrsStridedBatched(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + hipblasDoubleComplex* A, + const int lda, + const int strideA, + const int* ipiv, + const int strideP, + hipblasDoubleComplex* B, + const int ldb, + const int strideB, + int* info, + const int batchCount) +{ + return hipblasZgetrsStridedBatchedFortran( + handle, trans, n, nrhs, A, lda, strideA, ipiv, strideP, B, ldb, strideB, info, batchCount); +} + +// getri_batched +template <> +hipblasStatus_t hipblasGetriBatched(hipblasHandle_t handle, + const int n, + float* const A[], + const int lda, + int* ipiv, + float* const C[], + const int ldc, + int* info, + const int batchCount) +{ + return hipblasSgetriBatchedFortran(handle, n, A, lda, ipiv, C, ldc, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGetriBatched(hipblasHandle_t handle, + const int n, + double* const A[], + const int lda, + int* ipiv, + double* const C[], + const int ldc, + int* info, + const int batchCount) +{ + return hipblasDgetriBatchedFortran(handle, n, A, lda, ipiv, C, ldc, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGetriBatched(hipblasHandle_t handle, + const int n, + hipblasComplex* const A[], + const int lda, + int* ipiv, + hipblasComplex* const C[], + const int ldc, + int* info, + const int batchCount) +{ + return hipblasCgetriBatchedFortran(handle, n, A, lda, ipiv, C, ldc, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGetriBatched(hipblasHandle_t handle, + const int n, + hipblasDoubleComplex* const A[], + const int lda, + int* ipiv, + hipblasDoubleComplex* const C[], + const int ldc, + int* info, + const int batchCount) +{ + return hipblasZgetriBatchedFortran(handle, n, A, lda, ipiv, C, ldc, info, batchCount); +} + +// geqrf +template <> +hipblasStatus_t hipblasGeqrf(hipblasHandle_t handle, + const int m, + const int n, + float* A, + const int lda, + float* ipiv, + int* info) +{ + return hipblasSgeqrfFortran(handle, m, n, A, lda, ipiv, info); +} + +template <> +hipblasStatus_t hipblasGeqrf(hipblasHandle_t handle, + const int m, + const int n, + double* A, + const int lda, + double* ipiv, + int* info) +{ + return hipblasDgeqrfFortran(handle, m, n, A, lda, ipiv, info); +} + +template <> +hipblasStatus_t hipblasGeqrf(hipblasHandle_t handle, + const int m, + const int n, + hipblasComplex* A, + const int lda, + hipblasComplex* ipiv, + int* info) +{ + return hipblasCgeqrfFortran(handle, m, n, A, lda, ipiv, info); +} + +template <> +hipblasStatus_t hipblasGeqrf(hipblasHandle_t handle, + const int m, + const int n, + hipblasDoubleComplex* A, + const int lda, + hipblasDoubleComplex* ipiv, + int* info) +{ + return hipblasZgeqrfFortran(handle, m, n, A, lda, ipiv, info); +} + +// geqrf_batched +template <> +hipblasStatus_t hipblasGeqrfBatched(hipblasHandle_t handle, + const int m, + const int n, + float* const A[], + const int lda, + float* const ipiv[], + int* info, + const int batchCount) +{ + return hipblasSgeqrfBatchedFortran(handle, m, n, A, lda, ipiv, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGeqrfBatched(hipblasHandle_t handle, + const int m, + const int n, + double* const A[], + const int lda, + double* const ipiv[], + int* info, + const int batchCount) +{ + return hipblasDgeqrfBatchedFortran(handle, m, n, A, lda, ipiv, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGeqrfBatched(hipblasHandle_t handle, + const int m, + const int n, + hipblasComplex* const A[], + const int lda, + hipblasComplex* const ipiv[], + int* info, + const int batchCount) +{ + return hipblasCgeqrfBatchedFortran(handle, m, n, A, lda, ipiv, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGeqrfBatched(hipblasHandle_t handle, + const int m, + const int n, + hipblasDoubleComplex* const A[], + const int lda, + hipblasDoubleComplex* const ipiv[], + int* info, + const int batchCount) +{ + return hipblasZgeqrfBatchedFortran(handle, m, n, A, lda, ipiv, info, batchCount); +} + +// geqrf_strided_batched +template <> +hipblasStatus_t hipblasGeqrfStridedBatched(hipblasHandle_t handle, + const int m, + const int n, + float* A, + const int lda, + const int strideA, + float* ipiv, + const int strideP, + int* info, + const int batchCount) +{ + return hipblasSgeqrfStridedBatchedFortran( + handle, m, n, A, lda, strideA, ipiv, strideP, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGeqrfStridedBatched(hipblasHandle_t handle, + const int m, + const int n, + double* A, + const int lda, + const int strideA, + double* ipiv, + const int strideP, + int* info, + const int batchCount) +{ + return hipblasDgeqrfStridedBatchedFortran( + handle, m, n, A, lda, strideA, ipiv, strideP, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGeqrfStridedBatched(hipblasHandle_t handle, + const int m, + const int n, + hipblasComplex* A, + const int lda, + const int strideA, + hipblasComplex* ipiv, + const int strideP, + int* info, + const int batchCount) +{ + return hipblasCgeqrfStridedBatchedFortran( + handle, m, n, A, lda, strideA, ipiv, strideP, info, batchCount); +} + +template <> +hipblasStatus_t hipblasGeqrfStridedBatched(hipblasHandle_t handle, + const int m, + const int n, + hipblasDoubleComplex* A, + const int lda, + const int strideA, + hipblasDoubleComplex* ipiv, + const int strideP, + int* info, + const int batchCount) +{ + return hipblasZgeqrfStridedBatchedFortran( + handle, m, n, A, lda, strideA, ipiv, strideP, info, batchCount); +} + +#endif diff --git a/clients/common/utility.cpp b/clients/common/utility.cpp index a82d35fff..d6b2a768f 100644 --- a/clients/common/utility.cpp +++ b/clients/common/utility.cpp @@ -32,6 +32,30 @@ char type2char() // return 'z'; // } +template <> +int type2int(float val) +{ + return (int)val; +} + +template <> +int type2int(double val) +{ + return (int)val; +} + +template <> +int type2int(hipblasComplex val) +{ + return (int)val.real(); +} + +template <> +int type2int(hipblasDoubleComplex val) +{ + return (int)val.real(); +} + #ifdef __cplusplus extern "C" { #endif diff --git a/clients/gtest/CMakeLists.txt b/clients/gtest/CMakeLists.txt index 956d6cd5d..9f2b88ee9 100644 --- a/clients/gtest/CMakeLists.txt +++ b/clients/gtest/CMakeLists.txt @@ -86,6 +86,7 @@ set(hipblas_test_source syrk_gtest.cpp syr2k_gtest.cpp trsm_gtest.cpp + trsm_ex_gtest.cpp trmm_gtest.cpp trtri_gtest.cpp ) @@ -98,6 +99,7 @@ if( BUILD_WITH_SOLVER ) getrs_gtest.cpp getrs_batched_gtest.cpp getrs_strided_batched_gtest.cpp + getri_batched_gtest.cpp geqrf_gtest.cpp geqrf_batched_gtest.cpp geqrf_strided_batched_gtest.cpp diff --git a/clients/gtest/geqrf_batched_gtest.cpp b/clients/gtest/geqrf_batched_gtest.cpp index 79885a6fb..cd2bca749 100644 --- a/clients/gtest/geqrf_batched_gtest.cpp +++ b/clients/gtest/geqrf_batched_gtest.cpp @@ -16,7 +16,7 @@ using ::testing::Values; using ::testing::ValuesIn; using namespace std; -typedef std::tuple, double, int> geqrf_batched_tuple; +typedef std::tuple, double, int, bool> geqrf_batched_tuple; const vector> matrix_size_range = {{-1, -1, 1, 1}, {10, 10, 10, 10}, {10, 10, 20, 100}, {600, 500, 600, 600}}; @@ -25,11 +25,14 @@ const vector stride_scale_range = {2.5}; const vector batch_count_range = {-1, 0, 1, 2}; +const vector is_fortran = {false, true}; + Arguments setup_geqrf_batched_arguments(geqrf_batched_tuple tup) { vector matrix_size = std::get<0>(tup); double stride_scale = std::get<1>(tup); int batch_count = std::get<2>(tup); + bool fortran = std::get<3>(tup); Arguments arg; @@ -41,6 +44,8 @@ Arguments setup_geqrf_batched_arguments(geqrf_batched_tuple tup) arg.stride_scale = stride_scale; arg.batch_count = batch_count; + arg.fortran = fortran; + return arg; } @@ -60,7 +65,7 @@ TEST_P(geqrf_batched_gtest, geqrf_batched_gtest_float) Arguments arg = setup_geqrf_batched_arguments(GetParam()); - hipblasStatus_t status = testing_geqrf_batched(arg); + hipblasStatus_t status = testing_geqrf_batched(arg); if(status != HIPBLAS_STATUS_SUCCESS) { @@ -82,7 +87,51 @@ TEST_P(geqrf_batched_gtest, geqrf_batched_gtest_double) Arguments arg = setup_geqrf_batched_arguments(GetParam()); - hipblasStatus_t status = testing_geqrf_batched(arg); + hipblasStatus_t status = testing_geqrf_batched(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.M < 0 || arg.N < 0 || arg.lda < arg.M || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_SUCCESS, status); + } + } +} + +TEST_P(geqrf_batched_gtest, geqrf_batched_gtest_float_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_geqrf_batched_arguments(GetParam()); + + hipblasStatus_t status = testing_geqrf_batched(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.M < 0 || arg.N < 0 || arg.lda < arg.M || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_SUCCESS, status); + } + } +} + +TEST_P(geqrf_batched_gtest, geqrf_batched_gtest_double_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_geqrf_batched_arguments(GetParam()); + + hipblasStatus_t status = testing_geqrf_batched(arg); if(status != HIPBLAS_STATUS_SUCCESS) { @@ -106,4 +155,5 @@ INSTANTIATE_TEST_CASE_P(hipblasGeqrfBatched, geqrf_batched_gtest, Combine(ValuesIn(matrix_size_range), ValuesIn(stride_scale_range), - ValuesIn(batch_count_range))); + ValuesIn(batch_count_range), + ValuesIn(is_fortran))); diff --git a/clients/gtest/geqrf_gtest.cpp b/clients/gtest/geqrf_gtest.cpp index 4b71597f4..bbcc9d741 100644 --- a/clients/gtest/geqrf_gtest.cpp +++ b/clients/gtest/geqrf_gtest.cpp @@ -16,7 +16,7 @@ using ::testing::Values; using ::testing::ValuesIn; using namespace std; -typedef std::tuple, double, int> geqrf_tuple; +typedef std::tuple, double, int, bool> geqrf_tuple; const vector> matrix_size_range = {{-1, -1, 1, 1}, {10, 10, 10, 10}, {10, 10, 20, 100}, {600, 500, 600, 600}}; @@ -25,11 +25,14 @@ const vector stride_scale_range = {2.5}; const vector batch_count_range = {1}; +const vector is_fortran = {false, true}; + Arguments setup_geqrf_arguments(geqrf_tuple tup) { vector matrix_size = std::get<0>(tup); double stride_scale = std::get<1>(tup); int batch_count = std::get<2>(tup); + bool fortran = std::get<3>(tup); Arguments arg; @@ -41,6 +44,8 @@ Arguments setup_geqrf_arguments(geqrf_tuple tup) arg.stride_scale = stride_scale; arg.batch_count = batch_count; + arg.fortran = fortran; + return arg; } @@ -60,7 +65,7 @@ TEST_P(geqrf_gtest, geqrf_gtest_float) Arguments arg = setup_geqrf_arguments(GetParam()); - hipblasStatus_t status = testing_geqrf(arg); + hipblasStatus_t status = testing_geqrf(arg); if(status != HIPBLAS_STATUS_SUCCESS) { @@ -82,7 +87,51 @@ TEST_P(geqrf_gtest, geqrf_gtest_double) Arguments arg = setup_geqrf_arguments(GetParam()); - hipblasStatus_t status = testing_geqrf(arg); + hipblasStatus_t status = testing_geqrf(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.M < 0 || arg.N < 0 || arg.lda < arg.M) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); // for cuda + } + } +} + +TEST_P(geqrf_gtest, geqrf_gtest_float_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_geqrf_arguments(GetParam()); + + hipblasStatus_t status = testing_geqrf(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.M < 0 || arg.N < 0 || arg.lda < arg.M) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); // for cuda + } + } +} + +TEST_P(geqrf_gtest, geqrf_gtest_double_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_geqrf_arguments(GetParam()); + + hipblasStatus_t status = testing_geqrf(arg); if(status != HIPBLAS_STATUS_SUCCESS) { @@ -106,4 +155,5 @@ INSTANTIATE_TEST_CASE_P(hipblasGeqrf, geqrf_gtest, Combine(ValuesIn(matrix_size_range), ValuesIn(stride_scale_range), - ValuesIn(batch_count_range))); + ValuesIn(batch_count_range), + ValuesIn(is_fortran))); diff --git a/clients/gtest/geqrf_strided_batched_gtest.cpp b/clients/gtest/geqrf_strided_batched_gtest.cpp index 2985af85a..6a8fc0e44 100644 --- a/clients/gtest/geqrf_strided_batched_gtest.cpp +++ b/clients/gtest/geqrf_strided_batched_gtest.cpp @@ -16,7 +16,7 @@ using ::testing::Values; using ::testing::ValuesIn; using namespace std; -typedef std::tuple, double, int> geqrf_strided_batched_tuple; +typedef std::tuple, double, int, bool> geqrf_strided_batched_tuple; const vector> matrix_size_range = {{-1, -1, 1, 1}, {10, 10, 10, 10}, {10, 10, 20, 100}, {600, 500, 600, 600}}; @@ -25,11 +25,14 @@ const vector stride_scale_range = {2.5}; const vector batch_count_range = {-1, 0, 1, 2}; +const vector is_fortran = {false, true}; + Arguments setup_geqrf_strided_batched_arguments(geqrf_strided_batched_tuple tup) { vector matrix_size = std::get<0>(tup); double stride_scale = std::get<1>(tup); int batch_count = std::get<2>(tup); + bool fortran = std::get<3>(tup); Arguments arg; @@ -41,6 +44,8 @@ Arguments setup_geqrf_strided_batched_arguments(geqrf_strided_batched_tuple tup) arg.stride_scale = stride_scale; arg.batch_count = batch_count; + arg.fortran = fortran; + return arg; } @@ -60,7 +65,7 @@ TEST_P(geqrf_strided_batched_gtest, geqrf_strided_batched_gtest_float) Arguments arg = setup_geqrf_strided_batched_arguments(GetParam()); - hipblasStatus_t status = testing_geqrf_strided_batched(arg); + hipblasStatus_t status = testing_geqrf_strided_batched(arg); if(status != HIPBLAS_STATUS_SUCCESS) { @@ -82,7 +87,51 @@ TEST_P(geqrf_strided_batched_gtest, geqrf_strided_batched_gtest_double) Arguments arg = setup_geqrf_strided_batched_arguments(GetParam()); - hipblasStatus_t status = testing_geqrf_strided_batched(arg); + hipblasStatus_t status = testing_geqrf_strided_batched(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.M < 0 || arg.N < 0 || arg.lda < arg.M || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); // for cuda + } + } +} + +TEST_P(geqrf_strided_batched_gtest, geqrf_strided_batched_gtest_float_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_geqrf_strided_batched_arguments(GetParam()); + + hipblasStatus_t status = testing_geqrf_strided_batched(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.M < 0 || arg.N < 0 || arg.lda < arg.M || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); // for cuda + } + } +} + +TEST_P(geqrf_strided_batched_gtest, geqrf_strided_batched_gtest_double_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_geqrf_strided_batched_arguments(GetParam()); + + hipblasStatus_t status = testing_geqrf_strided_batched(arg); if(status != HIPBLAS_STATUS_SUCCESS) { @@ -106,4 +155,5 @@ INSTANTIATE_TEST_CASE_P(hipblasGeqrfStridedBatched, geqrf_strided_batched_gtest, Combine(ValuesIn(matrix_size_range), ValuesIn(stride_scale_range), - ValuesIn(batch_count_range))); + ValuesIn(batch_count_range), + ValuesIn(is_fortran))); diff --git a/clients/gtest/getrf_batched_gtest.cpp b/clients/gtest/getrf_batched_gtest.cpp index a386aee4f..284a4ca50 100644 --- a/clients/gtest/getrf_batched_gtest.cpp +++ b/clients/gtest/getrf_batched_gtest.cpp @@ -16,7 +16,7 @@ using ::testing::Values; using ::testing::ValuesIn; using namespace std; -typedef std::tuple, double, int> getrf_batched_tuple; +typedef std::tuple, double, int, bool> getrf_batched_tuple; const vector> matrix_size_range = {{-1, -1, 1, 1}, {10, 10, 10, 10}, @@ -28,11 +28,14 @@ const vector stride_scale_range = {2.5}; const vector batch_count_range = {-1, 0, 1, 2}; +const vector is_fortran = {false, true}; + Arguments setup_getrf_batched_arguments(getrf_batched_tuple tup) { vector matrix_size = std::get<0>(tup); double stride_scale = std::get<1>(tup); int batch_count = std::get<2>(tup); + bool fortran = std::get<3>(tup); Arguments arg; @@ -44,6 +47,8 @@ Arguments setup_getrf_batched_arguments(getrf_batched_tuple tup) arg.stride_scale = stride_scale; arg.batch_count = batch_count; + arg.fortran = fortran; + return arg; } @@ -63,7 +68,7 @@ TEST_P(getrf_batched_gtest, getrf_batched_gtest_float) Arguments arg = setup_getrf_batched_arguments(GetParam()); - hipblasStatus_t status = testing_getrf_batched(arg); + hipblasStatus_t status = testing_getrf_batched(arg); if(status != HIPBLAS_STATUS_SUCCESS) { @@ -85,7 +90,51 @@ TEST_P(getrf_batched_gtest, getrf_batched_gtest_double) Arguments arg = setup_getrf_batched_arguments(GetParam()); - hipblasStatus_t status = testing_getrf_batched(arg); + hipblasStatus_t status = testing_getrf_batched(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.N < 0 || arg.lda < arg.N || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_SUCCESS, status); + } + } +} + +TEST_P(getrf_batched_gtest, getrf_batched_gtest_float_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_getrf_batched_arguments(GetParam()); + + hipblasStatus_t status = testing_getrf_batched(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.N < 0 || arg.lda < arg.N || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_SUCCESS, status); + } + } +} + +TEST_P(getrf_batched_gtest, getrf_batched_gtest_double_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_getrf_batched_arguments(GetParam()); + + hipblasStatus_t status = testing_getrf_batched(arg); if(status != HIPBLAS_STATUS_SUCCESS) { @@ -109,4 +158,5 @@ INSTANTIATE_TEST_CASE_P(hipblasGetrfBatched, getrf_batched_gtest, Combine(ValuesIn(matrix_size_range), ValuesIn(stride_scale_range), - ValuesIn(batch_count_range))); + ValuesIn(batch_count_range), + ValuesIn(is_fortran))); diff --git a/clients/gtest/getrf_gtest.cpp b/clients/gtest/getrf_gtest.cpp index 701a904bb..fe99ca5e6 100644 --- a/clients/gtest/getrf_gtest.cpp +++ b/clients/gtest/getrf_gtest.cpp @@ -16,7 +16,7 @@ using ::testing::Values; using ::testing::ValuesIn; using namespace std; -typedef std::tuple, double, int> getrf_tuple; +typedef std::tuple, double, int, bool> getrf_tuple; const vector> matrix_size_range = {{-1, -1, 1, 1}, {10, 10, 10, 10}, @@ -28,11 +28,14 @@ const vector stride_scale_range = {2.5}; const vector batch_count_range = {1}; +const vector is_fortran = {false, true}; + Arguments setup_getrf_arguments(getrf_tuple tup) { vector matrix_size = std::get<0>(tup); double stride_scale = std::get<1>(tup); int batch_count = std::get<2>(tup); + bool fortran = std::get<3>(tup); Arguments arg; @@ -44,6 +47,8 @@ Arguments setup_getrf_arguments(getrf_tuple tup) arg.stride_scale = stride_scale; arg.batch_count = batch_count; + arg.fortran = fortran; + return arg; } @@ -63,7 +68,7 @@ TEST_P(getrf_gtest, getrf_gtest_float) Arguments arg = setup_getrf_arguments(GetParam()); - hipblasStatus_t status = testing_getrf(arg); + hipblasStatus_t status = testing_getrf(arg); if(status != HIPBLAS_STATUS_SUCCESS) { @@ -85,7 +90,51 @@ TEST_P(getrf_gtest, getrf_gtest_double) Arguments arg = setup_getrf_arguments(GetParam()); - hipblasStatus_t status = testing_getrf(arg); + hipblasStatus_t status = testing_getrf(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.N < 0 || arg.lda < arg.N) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); // for cuda + } + } +} + +TEST_P(getrf_gtest, getrf_gtest_float_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_getrf_arguments(GetParam()); + + hipblasStatus_t status = testing_getrf(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.N < 0 || arg.lda < arg.N) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); // for cuda + } + } +} + +TEST_P(getrf_gtest, getrf_gtest_double_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_getrf_arguments(GetParam()); + + hipblasStatus_t status = testing_getrf(arg); if(status != HIPBLAS_STATUS_SUCCESS) { @@ -109,4 +158,5 @@ INSTANTIATE_TEST_CASE_P(hipblasGetrf, getrf_gtest, Combine(ValuesIn(matrix_size_range), ValuesIn(stride_scale_range), - ValuesIn(batch_count_range))); + ValuesIn(batch_count_range), + ValuesIn(is_fortran))); diff --git a/clients/gtest/getrf_strided_batched_gtest.cpp b/clients/gtest/getrf_strided_batched_gtest.cpp index e5129c43c..30105f277 100644 --- a/clients/gtest/getrf_strided_batched_gtest.cpp +++ b/clients/gtest/getrf_strided_batched_gtest.cpp @@ -16,7 +16,7 @@ using ::testing::Values; using ::testing::ValuesIn; using namespace std; -typedef std::tuple, double, int> getrf_strided_batched_tuple; +typedef std::tuple, double, int, bool> getrf_strided_batched_tuple; const vector> matrix_size_range = {{-1, -1, 1, 1}, {10, 10, 10, 10}, @@ -28,11 +28,14 @@ const vector stride_scale_range = {2.5}; const vector batch_count_range = {-1, 0, 1, 2}; +const vector is_fortran = {false, true}; + Arguments setup_getrf_strided_batched_arguments(getrf_strided_batched_tuple tup) { vector matrix_size = std::get<0>(tup); double stride_scale = std::get<1>(tup); int batch_count = std::get<2>(tup); + bool fortran = std::get<3>(tup); Arguments arg; @@ -44,6 +47,8 @@ Arguments setup_getrf_strided_batched_arguments(getrf_strided_batched_tuple tup) arg.stride_scale = stride_scale; arg.batch_count = batch_count; + arg.fortran = fortran; + return arg; } @@ -63,7 +68,7 @@ TEST_P(getrf_strided_batched_gtest, getrf_strided_batched_gtest_float) Arguments arg = setup_getrf_strided_batched_arguments(GetParam()); - hipblasStatus_t status = testing_getrf_strided_batched(arg); + hipblasStatus_t status = testing_getrf_strided_batched(arg); if(status != HIPBLAS_STATUS_SUCCESS) { @@ -85,7 +90,51 @@ TEST_P(getrf_strided_batched_gtest, getrf_strided_batched_gtest_double) Arguments arg = setup_getrf_strided_batched_arguments(GetParam()); - hipblasStatus_t status = testing_getrf_strided_batched(arg); + hipblasStatus_t status = testing_getrf_strided_batched(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.N < 0 || arg.lda < arg.N || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); // for cuda + } + } +} + +TEST_P(getrf_strided_batched_gtest, getrf_strided_batched_gtest_float_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_getrf_strided_batched_arguments(GetParam()); + + hipblasStatus_t status = testing_getrf_strided_batched(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.N < 0 || arg.lda < arg.N || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); // for cuda + } + } +} + +TEST_P(getrf_strided_batched_gtest, getrf_strided_batched_gtest_double_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_getrf_strided_batched_arguments(GetParam()); + + hipblasStatus_t status = testing_getrf_strided_batched(arg); if(status != HIPBLAS_STATUS_SUCCESS) { @@ -109,4 +158,5 @@ INSTANTIATE_TEST_CASE_P(hipblasGetrfStridedBatched, getrf_strided_batched_gtest, Combine(ValuesIn(matrix_size_range), ValuesIn(stride_scale_range), - ValuesIn(batch_count_range))); + ValuesIn(batch_count_range), + ValuesIn(is_fortran))); diff --git a/clients/gtest/getri_batched_gtest.cpp b/clients/gtest/getri_batched_gtest.cpp new file mode 100644 index 000000000..257546f08 --- /dev/null +++ b/clients/gtest/getri_batched_gtest.cpp @@ -0,0 +1,157 @@ +/* ************************************************************************ + * Copyright 2016-2020 Advanced Micro Devices, Inc. + * + * ************************************************************************ */ + +#include "testing_getri_batched.hpp" +#include "utility.h" +#include +#include +#include +#include + +using ::testing::Combine; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::ValuesIn; +using namespace std; + +typedef std::tuple, double, int, bool> getri_batched_tuple; + +const vector> matrix_size_range + = {{-1, 1}, {10, 10}, {10, 20}, {500, 600}, {1024, 1024}}; + +const vector stride_scale_range = {2.5}; + +const vector batch_count_range = {-1, 0, 1, 2}; + +const vector is_fortran = {false, true}; + +Arguments setup_getri_batched_arguments(getri_batched_tuple tup) +{ + vector matrix_size = std::get<0>(tup); + double stride_scale = std::get<1>(tup); + int batch_count = std::get<2>(tup); + bool fortran = std::get<3>(tup); + + Arguments arg; + + arg.N = matrix_size[0]; + arg.lda = matrix_size[1]; + + arg.stride_scale = stride_scale; + arg.batch_count = batch_count; + + arg.fortran = fortran; + + return arg; +} + +class getri_batched_gtest : public ::TestWithParam +{ +protected: + getri_batched_gtest() {} + virtual ~getri_batched_gtest() {} + virtual void SetUp() {} + virtual void TearDown() {} +}; + +TEST_P(getri_batched_gtest, getri_batched_gtest_float) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_getri_batched_arguments(GetParam()); + + hipblasStatus_t status = testing_getri_batched(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.N < 0 || arg.lda < arg.N || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_SUCCESS, status); + } + } +} + +TEST_P(getri_batched_gtest, getri_batched_gtest_double) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_getri_batched_arguments(GetParam()); + + hipblasStatus_t status = testing_getri_batched(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.N < 0 || arg.lda < arg.N || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_SUCCESS, status); + } + } +} + +TEST_P(getri_batched_gtest, getri_batched_gtest_float_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_getri_batched_arguments(GetParam()); + + hipblasStatus_t status = testing_getri_batched(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.N < 0 || arg.lda < arg.N || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_SUCCESS, status); + } + } +} + +TEST_P(getri_batched_gtest, getri_batched_gtest_double_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_getri_batched_arguments(GetParam()); + + hipblasStatus_t status = testing_getri_batched(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.N < 0 || arg.lda < arg.N || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_SUCCESS, status); + } + } +} + +// notice we are using vector of vector +// so each elment in xxx_range is a vector, +// ValuesIn takes each element (a vector), combines them, and feeds them to test_p +// The combinations are { {M, N, lda, ldb}, stride_scale, batch_count } + +INSTANTIATE_TEST_CASE_P(hipblasGetriBatched, + getri_batched_gtest, + Combine(ValuesIn(matrix_size_range), + ValuesIn(stride_scale_range), + ValuesIn(batch_count_range), + ValuesIn(is_fortran))); diff --git a/clients/gtest/getrs_batched_gtest.cpp b/clients/gtest/getrs_batched_gtest.cpp index 6f79e55d8..e2910ec34 100644 --- a/clients/gtest/getrs_batched_gtest.cpp +++ b/clients/gtest/getrs_batched_gtest.cpp @@ -16,7 +16,7 @@ using ::testing::Values; using ::testing::ValuesIn; using namespace std; -typedef std::tuple, double, int> getrs_batched_tuple; +typedef std::tuple, double, int, bool> getrs_batched_tuple; const vector> matrix_size_range = {{-1, 1, 1}, {10, 20, 100}, {500, 600, 600}, {1024, 1024, 1024}}; @@ -25,11 +25,14 @@ const vector stride_scale_range = {2.5}; const vector batch_count_range = {-1, 0, 1, 2}; +const vector is_fortran = {false, true}; + Arguments setup_getrs_batched_arguments(getrs_batched_tuple tup) { vector matrix_size = std::get<0>(tup); double stride_scale = std::get<1>(tup); int batch_count = std::get<2>(tup); + bool fortran = std::get<3>(tup); Arguments arg; @@ -40,6 +43,8 @@ Arguments setup_getrs_batched_arguments(getrs_batched_tuple tup) arg.stride_scale = stride_scale; arg.batch_count = batch_count; + arg.fortran = fortran; + return arg; } @@ -59,7 +64,7 @@ TEST_P(getrs_batched_gtest, getrs_batched_gtest_float) Arguments arg = setup_getrs_batched_arguments(GetParam()); - hipblasStatus_t status = testing_getrs_batched(arg); + hipblasStatus_t status = testing_getrs_batched(arg); if(status != HIPBLAS_STATUS_SUCCESS) { @@ -81,7 +86,51 @@ TEST_P(getrs_batched_gtest, getrs_batched_gtest_double) Arguments arg = setup_getrs_batched_arguments(GetParam()); - hipblasStatus_t status = testing_getrs_batched(arg); + hipblasStatus_t status = testing_getrs_batched(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.N < 0 || arg.lda < arg.N || arg.ldb < arg.N || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_SUCCESS, status); + } + } +} + +TEST_P(getrs_batched_gtest, getrs_batched_gtest_float_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_getrs_batched_arguments(GetParam()); + + hipblasStatus_t status = testing_getrs_batched(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.N < 0 || arg.lda < arg.N || arg.ldb < arg.N || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_SUCCESS, status); + } + } +} + +TEST_P(getrs_batched_gtest, getrs_batched_gtest_double_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_getrs_batched_arguments(GetParam()); + + hipblasStatus_t status = testing_getrs_batched(arg); if(status != HIPBLAS_STATUS_SUCCESS) { @@ -105,4 +154,5 @@ INSTANTIATE_TEST_CASE_P(hipblasGetrsBatched, getrs_batched_gtest, Combine(ValuesIn(matrix_size_range), ValuesIn(stride_scale_range), - ValuesIn(batch_count_range))); + ValuesIn(batch_count_range), + ValuesIn(is_fortran))); diff --git a/clients/gtest/getrs_gtest.cpp b/clients/gtest/getrs_gtest.cpp index 214c78b07..5bc30725a 100644 --- a/clients/gtest/getrs_gtest.cpp +++ b/clients/gtest/getrs_gtest.cpp @@ -16,7 +16,7 @@ using ::testing::Values; using ::testing::ValuesIn; using namespace std; -typedef std::tuple, double, int> getrs_tuple; +typedef std::tuple, double, int, bool> getrs_tuple; const vector> matrix_size_range = {{-1, 1, 1}, {10, 20, 100}, {500, 600, 600}, {1024, 1024, 1024}}; @@ -25,11 +25,14 @@ const vector stride_scale_range = {2.5}; const vector batch_count_range = {1}; +const vector is_fortran = {false, true}; + Arguments setup_getrs_arguments(getrs_tuple tup) { vector matrix_size = std::get<0>(tup); double stride_scale = std::get<1>(tup); int batch_count = std::get<2>(tup); + bool fortran = std::get<3>(tup); Arguments arg; @@ -40,6 +43,8 @@ Arguments setup_getrs_arguments(getrs_tuple tup) arg.stride_scale = stride_scale; arg.batch_count = batch_count; + arg.fortran = fortran; + return arg; } @@ -59,7 +64,7 @@ TEST_P(getrs_gtest, getrs_gtest_float) Arguments arg = setup_getrs_arguments(GetParam()); - hipblasStatus_t status = testing_getrs(arg); + hipblasStatus_t status = testing_getrs(arg); if(status != HIPBLAS_STATUS_SUCCESS) { @@ -81,7 +86,51 @@ TEST_P(getrs_gtest, getrs_gtest_double) Arguments arg = setup_getrs_arguments(GetParam()); - hipblasStatus_t status = testing_getrs(arg); + hipblasStatus_t status = testing_getrs(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.N < 0 || arg.lda < arg.N || arg.ldb < arg.N) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); // for cuda + } + } +} + +TEST_P(getrs_gtest, getrs_gtest_float_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_getrs_arguments(GetParam()); + + hipblasStatus_t status = testing_getrs(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.N < 0 || arg.lda < arg.N || arg.ldb < arg.N) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); // for cuda + } + } +} + +TEST_P(getrs_gtest, getrs_gtest_double_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_getrs_arguments(GetParam()); + + hipblasStatus_t status = testing_getrs(arg); if(status != HIPBLAS_STATUS_SUCCESS) { @@ -105,4 +154,5 @@ INSTANTIATE_TEST_CASE_P(hipblasGetrs, getrs_gtest, Combine(ValuesIn(matrix_size_range), ValuesIn(stride_scale_range), - ValuesIn(batch_count_range))); + ValuesIn(batch_count_range), + ValuesIn(is_fortran))); diff --git a/clients/gtest/getrs_strided_batched_gtest.cpp b/clients/gtest/getrs_strided_batched_gtest.cpp index f4b1b2761..d1b13c535 100644 --- a/clients/gtest/getrs_strided_batched_gtest.cpp +++ b/clients/gtest/getrs_strided_batched_gtest.cpp @@ -16,7 +16,7 @@ using ::testing::Values; using ::testing::ValuesIn; using namespace std; -typedef std::tuple, double, int> getrs_strided_batched_tuple; +typedef std::tuple, double, int, bool> getrs_strided_batched_tuple; const vector> matrix_size_range = {{-1, 1, 1}, {10, 20, 100}, {500, 600, 600}, {1024, 1024, 1024}}; @@ -25,11 +25,14 @@ const vector stride_scale_range = {2.5}; const vector batch_count_range = {-1, 0, 1, 2}; +const vector is_fortran = {false, true}; + Arguments setup_getrs_strided_batched_arguments(getrs_strided_batched_tuple tup) { vector matrix_size = std::get<0>(tup); double stride_scale = std::get<1>(tup); int batch_count = std::get<2>(tup); + bool fortran = std::get<3>(tup); Arguments arg; @@ -40,6 +43,8 @@ Arguments setup_getrs_strided_batched_arguments(getrs_strided_batched_tuple tup) arg.stride_scale = stride_scale; arg.batch_count = batch_count; + arg.fortran = fortran; + return arg; } @@ -59,7 +64,7 @@ TEST_P(getrs_strided_batched_gtest, getrs_strided_batched_gtest_float) Arguments arg = setup_getrs_strided_batched_arguments(GetParam()); - hipblasStatus_t status = testing_getrs_strided_batched(arg); + hipblasStatus_t status = testing_getrs_strided_batched(arg); if(status != HIPBLAS_STATUS_SUCCESS) { @@ -81,7 +86,51 @@ TEST_P(getrs_strided_batched_gtest, getrs_strided_batched_gtest_double) Arguments arg = setup_getrs_strided_batched_arguments(GetParam()); - hipblasStatus_t status = testing_getrs_strided_batched(arg); + hipblasStatus_t status = testing_getrs_strided_batched(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.N < 0 || arg.lda < arg.N || arg.ldb < arg.N || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); // for cuda + } + } +} + +TEST_P(getrs_strided_batched_gtest, getrs_strided_batched_gtest_float_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_getrs_strided_batched_arguments(GetParam()); + + hipblasStatus_t status = testing_getrs_strided_batched(arg); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.N < 0 || arg.lda < arg.N || arg.ldb < arg.N || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); // for cuda + } + } +} + +TEST_P(getrs_strided_batched_gtest, getrs_strided_batched_gtest_double_complex) +{ + // GetParam returns a tuple. The setup routine unpacks the tuple + // and initializes arg(Arguments), which will be passed to testing routine. + + Arguments arg = setup_getrs_strided_batched_arguments(GetParam()); + + hipblasStatus_t status = testing_getrs_strided_batched(arg); if(status != HIPBLAS_STATUS_SUCCESS) { @@ -105,4 +154,5 @@ INSTANTIATE_TEST_CASE_P(hipblasGetrsStridedBatched, getrs_strided_batched_gtest, Combine(ValuesIn(matrix_size_range), ValuesIn(stride_scale_range), - ValuesIn(batch_count_range))); + ValuesIn(batch_count_range), + ValuesIn(is_fortran))); diff --git a/clients/gtest/set_get_matrix_gtest.cpp b/clients/gtest/set_get_matrix_gtest.cpp index a95b4ec3c..f120fe78b 100644 --- a/clients/gtest/set_get_matrix_gtest.cpp +++ b/clients/gtest/set_get_matrix_gtest.cpp @@ -4,6 +4,7 @@ * ************************************************************************ */ #include "testing_set_get_matrix.hpp" +#include "testing_set_get_matrix_async.hpp" #include "utility.h" #include #include @@ -20,7 +21,7 @@ using namespace std; // only GCC/VS 2010 comes with std::tr1::tuple, but it is unnecessary, std::tuple is good enough; -typedef std::tuple, vector> set_get_matrix_tuple; +typedef std::tuple, vector, bool> set_get_matrix_tuple; /* ===================================================================== README: This file contains testers to verify the correctness of @@ -66,6 +67,9 @@ const vector> lda_ldb_ldc_range = {{3, 3, 3}, {5, 5, 3}, {5, 5, 4}, {5, 5, 5}}; + +const bool is_fortran[] = {false, true}; + /* ===============Google Unit Test==================================================== */ /* ===================================================================== @@ -123,23 +127,32 @@ TEST_P(set_matrix_get_matrix_gtest, float) // if not success, then the input argument is problematic, so detect the error message if(status != HIPBLAS_STATUS_SUCCESS) { - if(arg.rows < 0) + if(arg.rows < 0 || arg.cols <= 0 || arg.lda <= 0 || arg.ldb <= 0 || arg.ldc <= 0) { EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); } - else if(arg.cols <= 0) - { - EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); - } - else if(arg.lda <= 0) - { - EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); - } - else if(arg.ldb <= 0) + else { - EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + EXPECT_EQ(HIPBLAS_STATUS_SUCCESS, status); // fail } - else if(arg.ldc <= 0) + } +} + +TEST_P(set_matrix_get_matrix_gtest, async_float) +{ + // GetParam return a tuple. Tee setup routine unpack the tuple + // and initializes arg(Arguments) which will be passed to testing routine + // The Arguments data struture have physical meaning associated. + // while the tuple is non-intuitive. + + Arguments arg = setup_set_get_matrix_arguments(GetParam()); + + hipblasStatus_t status = testing_set_get_matrix_async(arg); + + // if not success, then the input argument is problematic, so detect the error message + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.rows < 0 || arg.cols <= 0 || arg.lda <= 0 || arg.ldb <= 0 || arg.ldc <= 0) { EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); } @@ -157,4 +170,6 @@ TEST_P(set_matrix_get_matrix_gtest, float) INSTANTIATE_TEST_CASE_P(hipblasAuxiliary_small, set_matrix_get_matrix_gtest, - Combine(ValuesIn(rows_cols_range), ValuesIn(lda_ldb_ldc_range))); + Combine(ValuesIn(rows_cols_range), + ValuesIn(lda_ldb_ldc_range), + ValuesIn(is_fortran))); diff --git a/clients/gtest/set_get_vector_gtest.cpp b/clients/gtest/set_get_vector_gtest.cpp index 9580534f0..2fdda1ccf 100644 --- a/clients/gtest/set_get_vector_gtest.cpp +++ b/clients/gtest/set_get_vector_gtest.cpp @@ -4,6 +4,7 @@ * ************************************************************************ */ #include "testing_set_get_vector.hpp" +#include "testing_set_get_vector_async.hpp" #include "utility.h" #include #include @@ -18,7 +19,7 @@ using namespace std; // only GCC/VS 2010 comes with std::tr1::tuple, but it is unnecessary, std::tuple is good enough; -typedef std::tuple> set_get_vector_tuple; +typedef std::tuple, bool> set_get_vector_tuple; /* ===================================================================== README: This file contains testers to verify the correctness of @@ -55,6 +56,8 @@ const vector> incx_incy_incd_range = {{1, 1, 1}, {3, 3, 1}, {3, 3, 3}}; +const bool is_fortran[] = {false, true}; + /* ===============Google Unit Test==================================================== */ /* ===================================================================== @@ -114,19 +117,32 @@ TEST_P(set_vector_get_vector_gtest, float) // if not success, then the input argument is problematic, so detect the error message if(status != HIPBLAS_STATUS_SUCCESS) { - if(arg.M < 0) - { - EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); - } - else if(arg.incx <= 0) + if(arg.M < 0 || arg.incx <= 0 || arg.incy <= 0 || arg.incx <= 0) { EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); } - else if(arg.incy <= 0) + else { - EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + EXPECT_EQ(HIPBLAS_STATUS_SUCCESS, status); // fail } - else if(arg.incd <= 0) + } +} + +TEST_P(set_vector_get_vector_gtest, async_float) +{ + // GetParam return a tuple. Tee setup routine unpack the tuple + // and initializes arg(Arguments) which will be passed to testing routine + // The Arguments data struture have physical meaning associated. + // while the tuple is non-intuitive. + + Arguments arg = setup_set_get_vector_arguments(GetParam()); + + hipblasStatus_t status = testing_set_get_vector_async(arg); + + // if not success, then the input argument is problematic, so detect the error message + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.M < 0 || arg.incx <= 0 || arg.incy <= 0 || arg.incx <= 0) { EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); } @@ -144,4 +160,6 @@ TEST_P(set_vector_get_vector_gtest, float) INSTANTIATE_TEST_CASE_P(rocblas_auxiliary_small, set_vector_get_vector_gtest, - Combine(ValuesIn(M_range), ValuesIn(incx_incy_incd_range))); + Combine(ValuesIn(M_range), + ValuesIn(incx_incy_incd_range), + ValuesIn(is_fortran))); diff --git a/clients/gtest/trsm_ex_gtest.cpp b/clients/gtest/trsm_ex_gtest.cpp new file mode 100644 index 000000000..2b27c2615 --- /dev/null +++ b/clients/gtest/trsm_ex_gtest.cpp @@ -0,0 +1,325 @@ +/* ************************************************************************ + * Copyright 2016-2020 Advanced Micro Devices, Inc. + * + * ************************************************************************ */ + +#include "testing_trsm_batched_ex.hpp" +#include "testing_trsm_ex.hpp" +#include "testing_trsm_strided_batched_ex.hpp" +#include "utility.h" +#include +#include +#include +#include + +using ::testing::Combine; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::ValuesIn; +using namespace std; + +// only GCC/VS 2010 comes with std::tr1::tuple, but it is unnecessary, std::tuple is good enough; + +typedef std::tuple, double, vector, double, int, bool> trsm_ex_tuple; + +/* ===================================================================== +README: This file contains testers to verify the correctness of + BLAS routines with google test + + It is supposed to be played/used by advance / expert users + Normal users only need to get the library routines without testers + =================================================================== */ + +/* ===================================================================== +Advance users only: BrainStorm the parameters but do not make artificial one which invalidates the +matrix. +like lda pairs with M, and "lda must >= M". case "lda < M" will be guarded by argument-checkers +inside API of course. +Yet, the goal of this file is to verify result correctness not argument-checkers. + +Representative sampling is sufficient, endless brute-force sampling is not necessary +=================================================================== */ + +// vector of vector, each vector is a {M, N, lda, ldb}; +// add/delete as a group +const vector> full_matrix_size_range = { + {192, 192, 192, 192}, {640, 640, 960, 960}, + // {1000, 1000, 1000, 1000}, + // {2000, 2000, 2000, 2000}, +}; + +const vector alpha_range = {1.0, -5.0}; + +// vector of vector, each pair is a {side, uplo, transA, diag}; +// side has two option "Lefe (L), Right (R)" +// uplo has two "Lower (L), Upper (U)" +// transA has three ("Nontranspose (N), conjTranspose(C), transpose (T)") +// for single/double precision, 'C'(conjTranspose) will downgraded to 'T' (transpose) automatically +// in strsm/dtrsm, +// so we use 'C' +// Diag has two options ("Non-unit (N), Unit (U)") + +// Each letter is capitalizied, e.g. do not use 'l', but use 'L' instead. + +const vector> side_uplo_transA_diag_range = { + {'L', 'L', 'N', 'N'}, + {'R', 'L', 'N', 'N'}, + {'L', 'U', 'C', 'N'}, +}; + +const vector stride_scale_range = {2.5}; + +const vector batch_count_range = {-1, 0, 1, 2}; + +const bool is_fortran[] = {false, true}; + +/* ===============Google Unit Test==================================================== */ + +/* ===================================================================== + BLAS-3 trsm: +=================================================================== */ + +/* ============================Setup Arguments======================================= */ + +// Please use "class Arguments" (see utility.hpp) to pass parameters to templated testers; +// Some routines may not touch/use certain "members" of objects "argus". +// like BLAS-1 Scal does not have lda, BLAS-2 GEMV does not have ldb, ldc; +// That is fine. These testers & routines will leave untouched members alone. +// Do not use std::tuple to directly pass parameters to testers +// by std:tuple, you have unpack it with extreme care for each one by like "std::get<0>" which is +// not intuitive and error-prone + +Arguments setup_trsm_ex_arguments(trsm_ex_tuple tup) +{ + + vector matrix_size = std::get<0>(tup); + double alpha = std::get<1>(tup); + vector side_uplo_transA_diag = std::get<2>(tup); + double stride_scale = std::get<3>(tup); + int batch_count = std::get<4>(tup); + bool fortran = std::get<5>(tup); + + Arguments arg; + + // see the comments about matrix_size_range above + arg.M = matrix_size[0]; + arg.N = matrix_size[1]; + arg.lda = matrix_size[2]; + arg.ldb = matrix_size[3]; + + arg.alpha = alpha; + + arg.side_option = side_uplo_transA_diag[0]; + arg.uplo_option = side_uplo_transA_diag[1]; + arg.transA_option = side_uplo_transA_diag[2]; + arg.diag_option = side_uplo_transA_diag[3]; + + arg.timing = 1; + + arg.stride_scale = stride_scale; + arg.batch_count = batch_count; + + arg.fortran = fortran; + + return arg; +} + +class trsm_ex_gtest : public ::TestWithParam +{ +protected: + trsm_ex_gtest() {} + virtual ~trsm_ex_gtest() {} + virtual void SetUp() {} + virtual void TearDown() {} +}; + +TEST_P(trsm_ex_gtest, trsm_ex_gtest_float) +{ + // GetParam return a tuple. Tee setup routine unpack the tuple + // and initializes arg(Arguments) which will be passed to testing routine + // The Arguments data struture have physical meaning associated. + // while the tuple is non-intuitive. + + Arguments arg = setup_trsm_ex_arguments(GetParam()); + arg.compute_type = HIPBLAS_R_32F; + + hipblasStatus_t status = testing_trsm_ex(arg); + + // if not success, then the input argument is problematic, so detect the error message + if(status != HIPBLAS_STATUS_SUCCESS) + { + + if(arg.M < 0 || arg.N < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else if(arg.side_option == 'L' ? arg.lda < arg.M : arg.lda < arg.N) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else if(arg.ldb < arg.M) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); // for cuda + } + } +} + +TEST_P(trsm_ex_gtest, trsm_gtest_ex_double_complex) +{ + // GetParam return a tuple. Tee setup routine unpack the tuple + // and initializes arg(Arguments) which will be passed to testing routine + // The Arguments data struture have physical meaning associated. + // while the tuple is non-intuitive. + + Arguments arg = setup_trsm_ex_arguments(GetParam()); + arg.compute_type = HIPBLAS_C_64F; + + hipblasStatus_t status = testing_trsm_ex(arg); + + // if not success, then the input argument is problematic, so detect the error message + if(status != HIPBLAS_STATUS_SUCCESS) + { + + if(arg.M < 0 || arg.N < 0 || arg.ldb < arg.M + || (arg.side_option == 'L' ? arg.lda < arg.M : arg.lda < arg.N)) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); // for cuda + } + } +} + +TEST_P(trsm_ex_gtest, trsm_batched_ex_gtest_float) +{ + // GetParam return a tuple. Tee setup routine unpack the tuple + // and initializes arg(Arguments) which will be passed to testing routine + // The Arguments data struture have physical meaning associated. + // while the tuple is non-intuitive. + + Arguments arg = setup_trsm_ex_arguments(GetParam()); + arg.compute_type = HIPBLAS_R_32F; + + hipblasStatus_t status = testing_trsm_batched_ex(arg); + + // if not success, then the input argument is problematic, so detect the error message + if(status != HIPBLAS_STATUS_SUCCESS) + { + + if(arg.M < 0 || arg.N < 0 || arg.lda < arg.K || arg.ldb < arg.M + || (arg.side_option == 'L' ? arg.lda < arg.M : arg.lda < arg.N) || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); // for cuda + } + } +} + +TEST_P(trsm_ex_gtest, trsm_batched_ex_gtest_double_complex) +{ + // GetParam return a tuple. Tee setup routine unpack the tuple + // and initializes arg(Arguments) which will be passed to testing routine + // The Arguments data struture have physical meaning associated. + // while the tuple is non-intuitive. + + Arguments arg = setup_trsm_ex_arguments(GetParam()); + arg.compute_type = HIPBLAS_C_64F; + + hipblasStatus_t status = testing_trsm_batched_ex(arg); + + // if not success, then the input argument is problematic, so detect the error message + if(status != HIPBLAS_STATUS_SUCCESS) + { + + if(arg.M < 0 || arg.N < 0 || arg.lda < arg.K || arg.ldb < arg.M + || (arg.side_option == 'L' ? arg.lda < arg.M : arg.lda < arg.N) || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); // for cuda + } + } +} + +TEST_P(trsm_ex_gtest, trsm_strided_batched_ex_gtest_float) +{ + // GetParam return a tuple. Tee setup routine unpack the tuple + // and initializes arg(Arguments) which will be passed to testing routine + // The Arguments data struture have physical meaning associated. + // while the tuple is non-intuitive. + + Arguments arg = setup_trsm_ex_arguments(GetParam()); + arg.compute_type = HIPBLAS_R_32F; + + hipblasStatus_t status = testing_trsm_strided_batched_ex(arg); + + // if not success, then the input argument is problematic, so detect the error message + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.M < 0 || arg.N < 0 || arg.ldb < arg.M + || (arg.side_option == 'L' ? arg.lda < arg.M : arg.lda < arg.N) || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); // for cuda + } + } +} + +TEST_P(trsm_ex_gtest, trsm_strided_batched_ex_gtest_double_complex) +{ + // GetParam return a tuple. Tee setup routine unpack the tuple + // and initializes arg(Arguments) which will be passed to testing routine + // The Arguments data struture have physical meaning associated. + // while the tuple is non-intuitive. + + Arguments arg = setup_trsm_ex_arguments(GetParam()); + arg.compute_type = HIPBLAS_C_64F; + + hipblasStatus_t status = testing_trsm_strided_batched_ex(arg); + + // if not success, then the input argument is problematic, so detect the error message + if(status != HIPBLAS_STATUS_SUCCESS) + { + if(arg.M < 0 || arg.N < 0 || arg.ldb < arg.M + || (arg.side_option == 'L' ? arg.lda < arg.M : arg.lda < arg.N) || arg.batch_count < 0) + { + EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status); + } + else + { + EXPECT_EQ(HIPBLAS_STATUS_NOT_SUPPORTED, status); // for cuda + } + } +} + +// notice we are using vector of vector +// so each elment in xxx_range is a avector, +// ValuesIn take each element (a vector) and combine them and feed them to test_p +// The combinations are { {M, N, lda, ldb}, alpha, {side, uplo, transA, diag} } + +// THis function mainly test the scope of matrix_size. the scope of side_uplo_transA_diag_range is +// small +// Testing order: side_uplo_transA_xx first, alpha_range second, full_matrix_size last +// i.e fix the matrix size and alpha, test all the side_uplo_transA_xx first. +INSTANTIATE_TEST_CASE_P(hipblasTrsm_matrix_size, + trsm_ex_gtest, + Combine(ValuesIn(full_matrix_size_range), + ValuesIn(alpha_range), + ValuesIn(side_uplo_transA_diag_range), + ValuesIn(stride_scale_range), + ValuesIn(batch_count_range), + ValuesIn(is_fortran))); diff --git a/clients/include/cblas_interface.h b/clients/include/cblas_interface.h index cd12209a1..9382491e6 100644 --- a/clients/include/cblas_interface.h +++ b/clients/include/cblas_interface.h @@ -438,6 +438,9 @@ int cblas_getrf(int m, int n, T* A, int lda, int* ipiv); template int cblas_getrs(char trans, int n, int nrhs, T* A, int lda, int* ipiv, T* B, int ldb); +template +int cblas_getri(int n, T* A, int lda, int* ipiv, T* work, int lwork); + template int cblas_geqrf(int m, int n, T* A, int lda, T* tau, T* work, int lwork); /* ============================================================================================ */ diff --git a/clients/include/hipblas.hpp b/clients/include/hipblas.hpp index dfcb0ee41..81b1003da 100644 --- a/clients/include/hipblas.hpp +++ b/clients/include/hipblas.hpp @@ -1900,6 +1900,18 @@ hipblasStatus_t hipblasGetrsStridedBatched(hipblasHandle_t handle, int* info, const int batchCount); +// getri +template +hipblasStatus_t hipblasGetriBatched(hipblasHandle_t handle, + const int n, + T* const A[], + const int lda, + int* ipiv, + T* const C[], + const int ldc, + int* info, + const int batchCount); + // geqrf template hipblasStatus_t hipblasGeqrf( diff --git a/clients/include/hipblas_fortran.f90 b/clients/include/hipblas_fortran.f90 index 3fc5c4f87..722167d16 100644 --- a/clients/include/hipblas_fortran.f90 +++ b/clients/include/hipblas_fortran.f90 @@ -4,6 +4,137 @@ module hipblas_interface contains + !--------! + ! Aux ! + !--------! + function hipblasSetVectorFortran(n, elemSize, x, incx, y, incy) & + result(res) & + bind(c, name = 'hipblasSetVectorFortran') + use iso_c_binding + implicit none + integer(c_int), value :: n + integer(c_int), value :: elemSize + type(c_ptr), value :: x + integer(c_int), value :: incx + type(c_ptr), value :: y + integer(c_int), value :: incy + integer(c_int) :: res + res = hipblasSetVector(n, elemSize, x, incx, y, incy) + end function hipblasSetVectorFortran + + function hipblasGetVectorFortran(n, elemSize, x, incx, y, incy) & + result(res) & + bind(c, name = 'hipblasGetVectorFortran') + use iso_c_binding + implicit none + integer(c_int), value :: n + integer(c_int), value :: elemSize + type(c_ptr), value :: x + integer(c_int), value :: incx + type(c_ptr), value :: y + integer(c_int), value :: incy + integer(c_int) :: res + res = hipblasGetVector(n, elemSize, x, incx, y, incy) + end function hipblasGetVectorFortran + + function hipblasSetMatrixFortran(rows, cols, elemSize, A, lda, B, ldb) & + result(res) & + bind(c, name = 'hipblasSetMatrixFortran') + use iso_c_binding + implicit none + integer(c_int), value :: rows + integer(c_int), value :: cols + integer(c_int), value :: elemSize + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: B + integer(c_int), value :: ldb + integer(c_int) :: res + res = hipblasSetMatrix(rows, cols, elemSize, A, lda, B, ldb) + end function hipblasSetMatrixFortran + + function hipblasGetMatrixFortran(rows, cols, elemSize, A, lda, B, ldb) & + result(res) & + bind(c, name = 'hipblasGetMatrixFortran') + use iso_c_binding + implicit none + integer(c_int), value :: rows + integer(c_int), value :: cols + integer(c_int), value :: elemSize + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: B + integer(c_int), value :: ldb + integer(c_int) :: res + res = hipblasSetMatrix(rows, cols, elemSize, A, lda, B, ldb) + end function hipblasGetMatrixFortran + + function hipblasSetVectorAsyncFortran(n, elemSize, x, incx, y, incy, stream) & + result(res) & + bind(c, name = 'hipblasSetVectorAsyncFortran') + use iso_c_binding + implicit none + integer(c_int), value :: n + integer(c_int), value :: elemSize + type(c_ptr), value :: x + integer(c_int), value :: incx + type(c_ptr), value :: y + integer(c_int), value :: incy + type(c_ptr), value :: stream + integer(c_int) :: res + res = hipblasSetVectorAsync(n, elemSize, x, incx, y, incy, stream) + end function hipblasSetVectorAsyncFortran + + function hipblasGetVectorAsyncFortran(n, elemSize, x, incx, y, incy, stream) & + result(res) & + bind(c, name = 'hipblasGetVectorAsyncFortran') + use iso_c_binding + implicit none + integer(c_int), value :: n + integer(c_int), value :: elemSize + type(c_ptr), value :: x + integer(c_int), value :: incx + type(c_ptr), value :: y + integer(c_int), value :: incy + type(c_ptr), value :: stream + integer(c_int) :: res + res = hipblasGetVectorAsync(n, elemSize, x, incx, y, incy, stream) + end function hipblasGetVectorAsyncFortran + + function hipblasSetMatrixAsyncFortran(rows, cols, elemSize, A, lda, B, ldb, stream) & + result(res) & + bind(c, name = 'hipblasSetMatrixAsyncFortran') + use iso_c_binding + implicit none + integer(c_int), value :: rows + integer(c_int), value :: cols + integer(c_int), value :: elemSize + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: stream + integer(c_int) :: res + res = hipblasSetMatrixAsync(rows, cols, elemSize, A, lda, B, ldb, stream) + end function hipblasSetMatrixAsyncFortran + + function hipblasGetMatrixAsyncFortran(rows, cols, elemSize, A, lda, B, ldb, stream) & + result(res) & + bind(c, name = 'hipblasGetMatrixAsyncFortran') + use iso_c_binding + implicit none + integer(c_int), value :: rows + integer(c_int), value :: cols + integer(c_int), value :: elemSize + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: stream + integer(c_int) :: res + res = hipblasGetMatrixAsync(rows, cols, elemSize, A, lda, B, ldb, stream) + end function hipblasGetMatrixAsyncFortran + !--------! ! blas 1 ! !--------! @@ -12113,91 +12244,91 @@ function hipblasGemmStridedBatchedExFortran(handle, transA, transB, m, n, k, alp batch_count, compute_type, algo, solution_index, flags) end function hipblasGemmStridedBatchedExFortran - ! ! trsmEx - ! function hipblasTrsmExFortran(handle, side, uplo, transA, diag, m, n, alpha, A, lda, & - ! B, ldb, invA, invA_size, compute_type) & - ! result(res) & - ! bind(c, name = 'hipblasTrsmExFortran') - ! use iso_c_binding - ! use hipblas_enums - ! implicit none - ! type(c_ptr), value :: handle - ! integer(kind(HIPBLAS_SIDE_LEFT)), value :: side - ! integer(kind(HIPBLAS_FILL_MODE_UPPER)), value :: uplo - ! integer(kind(HIPBLAS_OP_N)), value :: transA - ! integer(kind(HIPBLAS_DIAG_UNIT)), value :: diag - ! integer(c_int), value :: m - ! integer(c_int), value :: n - ! type(c_ptr), value :: alpha - ! type(c_ptr), value :: A - ! integer(c_int), value :: lda - ! type(c_ptr), value :: B - ! integer(c_int), value :: ldb - ! type(c_ptr), value :: invA - ! integer(c_int), value :: invA_size - ! integer(kind(HIPBLAS_R_16F)), value :: compute_type - ! integer(c_int) :: res - ! res = hipblasTrsmEx(handle, side, uplo, transA, diag, m, n, alpha,& - ! A, lda, B, ldb, invA, invA_size, compute_type) - ! end function hipblasTrsmExFortran - - ! function hipblasTrsmBatchedExFortran(handle, side, uplo, transA, diag, m, n, alpha, A, lda, & - ! B, ldb, batch_count, invA, invA_size, compute_type) & - ! result(res) & - ! bind(c, name = 'hipblasTrsmBatchedExFortran') - ! use iso_c_binding - ! use hipblas_enums - ! implicit none - ! type(c_ptr), value :: handle - ! integer(kind(HIPBLAS_SIDE_LEFT)), value :: side - ! integer(kind(HIPBLAS_FILL_MODE_UPPER)), value :: uplo - ! integer(kind(HIPBLAS_OP_N)), value :: transA - ! integer(kind(HIPBLAS_DIAG_UNIT)), value :: diag - ! integer(c_int), value :: m - ! integer(c_int), value :: n - ! type(c_ptr), value :: alpha - ! type(c_ptr), value :: A - ! integer(c_int), value :: lda - ! type(c_ptr), value :: B - ! integer(c_int), value :: ldb - ! integer(c_int), value :: batch_count - ! type(c_ptr), value :: invA - ! integer(c_int), value :: invA_size - ! integer(kind(HIPBLAS_R_16F)), value :: compute_type - ! integer(c_int) :: res - ! res = hipblasTrsmBatchedEx(handle, side, uplo, transA, diag, m, n, alpha,& - ! A, lda, B, ldb, batch_count, invA, invA_size, compute_type) - ! end function hipblasTrsmBatchedExFortran - - ! function hipblasTrsmStridedBatchedExFortran(handle, side, uplo, transA, diag, m, n, alpha, A, lda, stride_A, & - ! B, ldb, stride_B, batch_count, invA, invA_size, stride_invA, compute_type) & - ! result(res) & - ! bind(c, name = 'hipblasTrsmStridedBatchedExFortran') - ! use iso_c_binding - ! use hipblas_enums - ! implicit none - ! type(c_ptr), value :: handle - ! integer(kind(HIPBLAS_SIDE_LEFT)), value :: side - ! integer(kind(HIPBLAS_FILL_MODE_UPPER)), value :: uplo - ! integer(kind(HIPBLAS_OP_N)), value :: transA - ! integer(kind(HIPBLAS_DIAG_UNIT)), value :: diag - ! integer(c_int), value :: m - ! integer(c_int), value :: n - ! type(c_ptr), value :: alpha - ! type(c_ptr), value :: A - ! integer(c_int), value :: lda - ! integer(c_int64_t), value :: stride_A - ! type(c_ptr), value :: B - ! integer(c_int), value :: ldb - ! integer(c_int64_t), value :: stride_B - ! integer(c_int), value :: batch_count - ! type(c_ptr), value :: invA - ! integer(c_int), value :: invA_size - ! integer(c_int64_t), value :: stride_invA - ! integer(kind(HIPBLAS_R_16F)), value :: compute_type - ! integer(c_int) :: res - ! res = hipblasTrsmStridedBatchedEx(handle, side, uplo, transA, diag, m, n, alpha,& - ! A, lda, stride_A, B, ldb, stride_B, batch_count, invA, invA_size, stride_invA, compute_type) - ! end function hipblasTrsmStridedBatchedExFortran + ! trsmEx + function hipblasTrsmExFortran(handle, side, uplo, transA, diag, m, n, alpha, A, lda, & + B, ldb, invA, invA_size, compute_type) & + result(res) & + bind(c, name = 'hipblasTrsmExFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_SIDE_LEFT)), value :: side + integer(kind(HIPBLAS_FILL_MODE_UPPER)), value :: uplo + integer(kind(HIPBLAS_OP_N)), value :: transA + integer(kind(HIPBLAS_DIAG_UNIT)), value :: diag + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: alpha + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: invA + integer(c_int), value :: invA_size + integer(kind(HIPBLAS_R_16F)), value :: compute_type + integer(c_int) :: res + res = hipblasTrsmEx(handle, side, uplo, transA, diag, m, n, alpha,& + A, lda, B, ldb, invA, invA_size, compute_type) + end function hipblasTrsmExFortran + + function hipblasTrsmBatchedExFortran(handle, side, uplo, transA, diag, m, n, alpha, A, lda, & + B, ldb, batch_count, invA, invA_size, compute_type) & + result(res) & + bind(c, name = 'hipblasTrsmBatchedExFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_SIDE_LEFT)), value :: side + integer(kind(HIPBLAS_FILL_MODE_UPPER)), value :: uplo + integer(kind(HIPBLAS_OP_N)), value :: transA + integer(kind(HIPBLAS_DIAG_UNIT)), value :: diag + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: alpha + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: B + integer(c_int), value :: ldb + integer(c_int), value :: batch_count + type(c_ptr), value :: invA + integer(c_int), value :: invA_size + integer(kind(HIPBLAS_R_16F)), value :: compute_type + integer(c_int) :: res + res = hipblasTrsmBatchedEx(handle, side, uplo, transA, diag, m, n, alpha,& + A, lda, B, ldb, batch_count, invA, invA_size, compute_type) + end function hipblasTrsmBatchedExFortran + + function hipblasTrsmStridedBatchedExFortran(handle, side, uplo, transA, diag, m, n, alpha, A, lda, stride_A, & + B, ldb, stride_B, batch_count, invA, invA_size, stride_invA, compute_type) & + result(res) & + bind(c, name = 'hipblasTrsmStridedBatchedExFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_SIDE_LEFT)), value :: side + integer(kind(HIPBLAS_FILL_MODE_UPPER)), value :: uplo + integer(kind(HIPBLAS_OP_N)), value :: transA + integer(kind(HIPBLAS_DIAG_UNIT)), value :: diag + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: alpha + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int64_t), value :: stride_A + type(c_ptr), value :: B + integer(c_int), value :: ldb + integer(c_int64_t), value :: stride_B + integer(c_int), value :: batch_count + type(c_ptr), value :: invA + integer(c_int), value :: invA_size + integer(c_int64_t), value :: stride_invA + integer(kind(HIPBLAS_R_16F)), value :: compute_type + integer(c_int) :: res + res = hipblasTrsmStridedBatchedEx(handle, side, uplo, transA, diag, m, n, alpha,& + A, lda, stride_A, B, ldb, stride_B, batch_count, invA, invA_size, stride_invA, compute_type) + end function hipblasTrsmStridedBatchedExFortran end module hipblas_interface diff --git a/clients/include/hipblas_fortran.hpp b/clients/include/hipblas_fortran.hpp index 70d66319a..e7980cc74 100644 --- a/clients/include/hipblas_fortran.hpp +++ b/clients/include/hipblas_fortran.hpp @@ -12,6 +12,33 @@ */ extern "C" { +/* ========== + * Aux + * ========== */ +hipblasStatus_t + hipblasSetVectorFortran(int n, int elemSize, const void* x, int incx, void* y, int incy); + +hipblasStatus_t + hipblasGetVectorFortran(int n, int elemSize, const void* x, int incx, void* y, int incy); + +hipblasStatus_t hipblasSetMatrixFortran( + int rows, int cols, int elemSize, const void* A, int lda, void* B, int ldb); + +hipblasStatus_t hipblasGetMatrixFortran( + int rows, int cols, int elemSize, const void* A, int lda, void* B, int ldb); + +hipblasStatus_t hipblasSetVectorAsyncFortran( + int n, int elemSize, const void* x, int incx, void* y, int incy, hipStream_t stream); + +hipblasStatus_t hipblasGetVectorAsyncFortran( + int n, int elemSize, const void* x, int incx, void* y, int incy, hipStream_t stream); + +hipblasStatus_t hipblasSetMatrixAsyncFortran( + int rows, int cols, int elemSize, const void* A, int lda, void* B, int ldb, hipStream_t stream); + +hipblasStatus_t hipblasGetMatrixAsyncFortran( + int rows, int cols, int elemSize, const void* A, int lda, void* B, int ldb, hipStream_t stream); + /* ========== * L1 * ========== */ @@ -6066,106 +6093,483 @@ hipblasStatus_t hipblasZtrsmStridedBatchedFortran(hipblasHandle_t ha int strideB, int batch_count); -// getrf -hipblasStatus_t hipblasSgetrf( - hipblasHandle_t handle, const int n, float* A, const int lda, int* ipiv, int* info); - -hipblasStatus_t hipblasDgetrf( - hipblasHandle_t handle, const int n, double* A, const int lda, int* ipiv, int* info); +// gemm +hipblasStatus_t hipblasHgemmFortran(hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, + int n, + int k, + const hipblasHalf* alpha, + const hipblasHalf* A, + int lda, + const hipblasHalf* B, + int ldb, + const hipblasHalf* beta, + hipblasHalf* C, + int ldc); -hipblasStatus_t hipblasCgetrf( - hipblasHandle_t handle, const int n, hipblasComplex* A, const int lda, int* ipiv, int* info); +hipblasStatus_t hipblasSgemmFortran(hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, + int n, + int k, + const float* alpha, + const float* A, + int lda, + const float* B, + int ldb, + const float* beta, + float* C, + int ldc); -hipblasStatus_t hipblasZgetrfFortran(hipblasHandle_t handle, - const int n, - hipblasDoubleComplex* A, - const int lda, - int* ipiv, - int* info); +hipblasStatus_t hipblasDgemmFortran(hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, + int n, + int k, + const double* alpha, + const double* A, + int lda, + const double* B, + int ldb, + const double* beta, + double* C, + int ldc); -// getrf_batched -hipblasStatus_t hipblasSgetrfBatchedFortran(hipblasHandle_t handle, - const int n, - float* const A[], - const int lda, - int* ipiv, - int* info, - const int batch_count); +hipblasStatus_t hipblasCgemmFortran(hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, + int n, + int k, + const hipblasComplex* alpha, + const hipblasComplex* A, + int lda, + const hipblasComplex* B, + int ldb, + const hipblasComplex* beta, + hipblasComplex* C, + int ldc); -hipblasStatus_t hipblasDgetrfBatchedFortran(hipblasHandle_t handle, - const int n, - double* const A[], - const int lda, - int* ipiv, - int* info, - const int batch_count); +hipblasStatus_t hipblasZgemmFortran(hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, + int n, + int k, + const hipblasDoubleComplex* alpha, + const hipblasDoubleComplex* A, + int lda, + const hipblasDoubleComplex* B, + int ldb, + const hipblasDoubleComplex* beta, + hipblasDoubleComplex* C, + int ldc); -hipblasStatus_t hipblasCgetrfBatchedFortran(hipblasHandle_t handle, - const int n, - hipblasComplex* const A[], - const int lda, - int* ipiv, - int* info, - const int batch_count); +// gemm batched +hipblasStatus_t hipblasHgemmBatchedFortran(hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, + int n, + int k, + const hipblasHalf* alpha, + const hipblasHalf* const A[], + int lda, + const hipblasHalf* const B[], + int ldb, + const hipblasHalf* beta, + hipblasHalf* const C[], + int ldc, + int batchCount); -hipblasStatus_t hipblasZgetrfBatchedFortran(hipblasHandle_t handle, - const int n, - hipblasDoubleComplex* const A[], - const int lda, - int* ipiv, - int* info, - const int batch_count); +hipblasStatus_t hipblasSgemmBatchedFortran(hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, + int n, + int k, + const float* alpha, + const float* const A[], + int lda, + const float* const B[], + int ldb, + const float* beta, + float* const C[], + int ldc, + int batchCount); -// getrf_strided_batched -hipblasStatus_t hipblasSgetrfStridedBatchedFortran(hipblasHandle_t handle, - const int n, - float* A, - const int lda, - const int strideA, - int* ipiv, - const int strideP, - int* info, - const int batch_count); +hipblasStatus_t hipblasDgemmBatchedFortran(hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, + int n, + int k, + const double* alpha, + const double* const A[], + int lda, + const double* const B[], + int ldb, + const double* beta, + double* const C[], + int ldc, + int batchCount); -hipblasStatus_t hipblasDgetrfStridedBatchedFortran(hipblasHandle_t handle, - const int n, - double* A, - const int lda, - const int strideA, - int* ipiv, - const int strideP, - int* info, - const int batch_count); +hipblasStatus_t hipblasCgemmBatchedFortran(hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, + int n, + int k, + const hipblasComplex* alpha, + const hipblasComplex* const A[], + int lda, + const hipblasComplex* const B[], + int ldb, + const hipblasComplex* beta, + hipblasComplex* const C[], + int ldc, + int batchCount); -hipblasStatus_t hipblasCgetrfStridedBatchedFortran(hipblasHandle_t handle, - const int n, - hipblasComplex* A, - const int lda, - const int strideA, - int* ipiv, - const int strideP, - int* info, - const int batch_count); +hipblasStatus_t hipblasZgemmBatchedFortran(hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, + int n, + int k, + const hipblasDoubleComplex* alpha, + const hipblasDoubleComplex* const A[], + int lda, + const hipblasDoubleComplex* const B[], + int ldb, + const hipblasDoubleComplex* beta, + hipblasDoubleComplex* const C[], + int ldc, + int batchCount); -hipblasStatus_t hipblasZgetrfStridedBatchedFortran(hipblasHandle_t handle, - const int n, - hipblasDoubleComplex* A, - const int lda, - const int strideA, - int* ipiv, - const int strideP, - int* info, - const int batch_count); +// gemm_strided_batched +hipblasStatus_t hipblasHgemmStridedBatchedFortran(hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, + int n, + int k, + const hipblasHalf* alpha, + const hipblasHalf* A, + int lda, + long long bsa, + const hipblasHalf* B, + int ldb, + long long bsb, + const hipblasHalf* beta, + hipblasHalf* C, + int ldc, + long long bsc, + int batchCount); -// getrs -hipblasStatus_t hipblasSgetrsFortran(hipblasHandle_t handle, - const hipblasOperation_t trans, - const int n, - const int nrhs, - float* A, - const int lda, - const int* ipiv, - float* B, +hipblasStatus_t hipblasSgemmStridedBatchedFortran(hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, + int n, + int k, + const float* alpha, + const float* A, + int lda, + long long bsa, + const float* B, + int ldb, + long long bsb, + const float* beta, + float* C, + int ldc, + long long bsc, + int batchCount); + +hipblasStatus_t hipblasDgemmStridedBatchedFortran(hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, + int n, + int k, + const double* alpha, + const double* A, + int lda, + long long bsa, + const double* B, + int ldb, + long long bsb, + const double* beta, + double* C, + int ldc, + long long bsc, + int batchCount); + +hipblasStatus_t hipblasCgemmStridedBatchedFortran(hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, + int n, + int k, + const hipblasComplex* alpha, + const hipblasComplex* A, + int lda, + long long bsa, + const hipblasComplex* B, + int ldb, + long long bsb, + const hipblasComplex* beta, + hipblasComplex* C, + int ldc, + long long bsc, + int batchCount); + +hipblasStatus_t hipblasZgemmStridedBatchedFortran(hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, + int n, + int k, + const hipblasDoubleComplex* alpha, + const hipblasDoubleComplex* A, + int lda, + long long bsa, + const hipblasDoubleComplex* B, + int ldb, + long long bsb, + const hipblasDoubleComplex* beta, + hipblasDoubleComplex* C, + int ldc, + long long bsc, + int batchCount); + +// gemmex +hipblasStatus_t hipblasGemmExFortran(hipblasHandle_t handle, + hipblasOperation_t trans_a, + hipblasOperation_t trans_b, + int m, + int n, + int k, + const void* alpha, + const void* a, + hipblasDatatype_t a_type, + int lda, + const void* b, + hipblasDatatype_t b_type, + int ldb, + const void* beta, + void* c, + hipblasDatatype_t c_type, + int ldc, + hipblasDatatype_t compute_type, + hipblasGemmAlgo_t algo); + +hipblasStatus_t hipblasGemmBatchedExFortran(hipblasHandle_t handle, + hipblasOperation_t trans_a, + hipblasOperation_t trans_b, + int m, + int n, + int k, + const void* alpha, + const void* a[], + hipblasDatatype_t a_type, + int lda, + const void* b[], + hipblasDatatype_t b_type, + int ldb, + const void* beta, + void* c[], + hipblasDatatype_t c_type, + int ldc, + int batch_count, + hipblasDatatype_t compute_type, + hipblasGemmAlgo_t algo); + +hipblasStatus_t hipblasGemmStridedBatchedExFortran(hipblasHandle_t handle, + hipblasOperation_t trans_a, + hipblasOperation_t trans_b, + int m, + int n, + int k, + const void* alpha, + const void* a, + hipblasDatatype_t a_type, + int lda, + int stride_A, + const void* b, + hipblasDatatype_t b_type, + int ldb, + int stride_B, + const void* beta, + void* c, + hipblasDatatype_t c_type, + int ldc, + int stride_C, + int batch_count, + hipblasDatatype_t compute_type, + hipblasGemmAlgo_t algo); + +// trsm_ex +hipblasStatus_t hipblasTrsmExFortran(hipblasHandle_t handle, + hipblasSideMode_t side, + hipblasFillMode_t uplo, + hipblasOperation_t transA, + hipblasDiagType_t diag, + int m, + int n, + const void* alpha, + void* A, + int lda, + void* B, + int ldb, + const void* invA, + int invA_size, + hipblasDatatype_t compute_type); + +hipblasStatus_t hipblasTrsmBatchedExFortran(hipblasHandle_t handle, + hipblasSideMode_t side, + hipblasFillMode_t uplo, + hipblasOperation_t transA, + hipblasDiagType_t diag, + int m, + int n, + const void* alpha, + void* A, + int lda, + void* B, + int ldb, + int batch_count, + const void* invA, + int invA_size, + hipblasDatatype_t compute_type); + +hipblasStatus_t hipblasTrsmStridedBatchedExFortran(hipblasHandle_t handle, + hipblasSideMode_t side, + hipblasFillMode_t uplo, + hipblasOperation_t transA, + hipblasDiagType_t diag, + int m, + int n, + const void* alpha, + void* A, + int lda, + int stride_A, + void* B, + int ldb, + int stride_B, + int batch_count, + const void* invA, + int invA_size, + int stride_invA, + hipblasDatatype_t compute_type); + +/* ========== + * Solver + * ========== */ + +// getrf +hipblasStatus_t hipblasSgetrfFortran( + hipblasHandle_t handle, const int n, float* A, const int lda, int* ipiv, int* info); + +hipblasStatus_t hipblasDgetrfFortran( + hipblasHandle_t handle, const int n, double* A, const int lda, int* ipiv, int* info); + +hipblasStatus_t hipblasCgetrfFortran( + hipblasHandle_t handle, const int n, hipblasComplex* A, const int lda, int* ipiv, int* info); + +hipblasStatus_t hipblasZgetrfFortran(hipblasHandle_t handle, + const int n, + hipblasDoubleComplex* A, + const int lda, + int* ipiv, + int* info); + +// getrf_batched +hipblasStatus_t hipblasSgetrfBatchedFortran(hipblasHandle_t handle, + const int n, + float* const A[], + const int lda, + int* ipiv, + int* info, + const int batch_count); + +hipblasStatus_t hipblasDgetrfBatchedFortran(hipblasHandle_t handle, + const int n, + double* const A[], + const int lda, + int* ipiv, + int* info, + const int batch_count); + +hipblasStatus_t hipblasCgetrfBatchedFortran(hipblasHandle_t handle, + const int n, + hipblasComplex* const A[], + const int lda, + int* ipiv, + int* info, + const int batch_count); + +hipblasStatus_t hipblasZgetrfBatchedFortran(hipblasHandle_t handle, + const int n, + hipblasDoubleComplex* const A[], + const int lda, + int* ipiv, + int* info, + const int batch_count); + +// getrf_strided_batched +hipblasStatus_t hipblasSgetrfStridedBatchedFortran(hipblasHandle_t handle, + const int n, + float* A, + const int lda, + const int stride_A, + int* ipiv, + const int stride_P, + int* info, + const int batch_count); + +hipblasStatus_t hipblasDgetrfStridedBatchedFortran(hipblasHandle_t handle, + const int n, + double* A, + const int lda, + const int stride_A, + int* ipiv, + const int stride_P, + int* info, + const int batch_count); + +hipblasStatus_t hipblasCgetrfStridedBatchedFortran(hipblasHandle_t handle, + const int n, + hipblasComplex* A, + const int lda, + const int stride_A, + int* ipiv, + const int stride_P, + int* info, + const int batch_count); + +hipblasStatus_t hipblasZgetrfStridedBatchedFortran(hipblasHandle_t handle, + const int n, + hipblasDoubleComplex* A, + const int lda, + const int stride_A, + int* ipiv, + const int stride_P, + int* info, + const int batch_count); + +// getrs +hipblasStatus_t hipblasSgetrsFortran(hipblasHandle_t handle, + const hipblasOperation_t trans, + const int n, + const int nrhs, + float* A, + const int lda, + const int* ipiv, + float* B, const int ldb, int* info); @@ -6258,12 +6662,12 @@ hipblasStatus_t hipblasSgetrsStridedBatchedFortran(hipblasHandle_t hand const int nrhs, float* A, const int lda, - const int strideA, + const int stride_A, const int* ipiv, - const int strideP, + const int stride_P, float* B, const int ldb, - const int strideB, + const int stride_B, int* info, const int batch_count); @@ -6273,12 +6677,12 @@ hipblasStatus_t hipblasDgetrsStridedBatchedFortran(hipblasHandle_t hand const int nrhs, double* A, const int lda, - const int strideA, + const int stride_A, const int* ipiv, - const int strideP, + const int stride_P, double* B, const int ldb, - const int strideB, + const int stride_B, int* info, const int batch_count); @@ -6288,12 +6692,12 @@ hipblasStatus_t hipblasCgetrsStridedBatchedFortran(hipblasHandle_t hand const int nrhs, hipblasComplex* A, const int lda, - const int strideA, + const int stride_A, const int* ipiv, - const int strideP, + const int stride_P, hipblasComplex* B, const int ldb, - const int strideB, + const int stride_B, int* info, const int batch_count); @@ -6303,448 +6707,170 @@ hipblasStatus_t hipblasZgetrsStridedBatchedFortran(hipblasHandle_t hand const int nrhs, hipblasDoubleComplex* A, const int lda, - const int strideA, + const int stride_A, const int* ipiv, - const int strideP, + const int stride_P, hipblasDoubleComplex* B, const int ldb, - const int strideB, + const int stride_B, int* info, const int batch_count); -// geqrf -hipblasStatus_t hipblasSgeqrfFortran(hipblasHandle_t handle, - const int m, - const int n, - float* A, - const int lda, - float* ipiv, - int* info); - -hipblasStatus_t hipblasDgeqrfFortran(hipblasHandle_t handle, - const int m, - const int n, - double* A, - const int lda, - double* ipiv, - int* info); - -hipblasStatus_t hipblasCgeqrfFortran(hipblasHandle_t handle, - const int m, - const int n, - hipblasComplex* A, - const int lda, - hipblasComplex* ipiv, - int* info); - -hipblasStatus_t hipblasZgeqrfFortran(hipblasHandle_t handle, - const int m, - const int n, - hipblasDoubleComplex* A, - const int lda, - hipblasDoubleComplex* ipiv, - int* info); - -// geqrf_batched -hipblasStatus_t hipblasSgeqrfBatchedFortran(hipblasHandle_t handle, - const int m, +// getri_batched +hipblasStatus_t hipblasSgetriBatchedFortran(hipblasHandle_t handle, const int n, float* const A[], const int lda, - float* const ipiv[], + int* ipiv, + float* const C[], + const int ldc, int* info, const int batch_count); -hipblasStatus_t hipblasDgeqrfBatchedFortran(hipblasHandle_t handle, - const int m, +hipblasStatus_t hipblasDgetriBatchedFortran(hipblasHandle_t handle, const int n, double* const A[], const int lda, - double* const ipiv[], + int* ipiv, + double* const C[], + const int ldc, int* info, const int batch_count); -hipblasStatus_t hipblasCgeqrfBatchedFortran(hipblasHandle_t handle, - const int m, +hipblasStatus_t hipblasCgetriBatchedFortran(hipblasHandle_t handle, const int n, hipblasComplex* const A[], const int lda, - hipblasComplex* const ipiv[], + int* ipiv, + hipblasComplex* const C[], + const int ldc, int* info, const int batch_count); -hipblasStatus_t hipblasZgeqrfBatchedFortran(hipblasHandle_t handle, - const int m, +hipblasStatus_t hipblasZgetriBatchedFortran(hipblasHandle_t handle, const int n, hipblasDoubleComplex* const A[], const int lda, - hipblasDoubleComplex* const ipiv[], + int* ipiv, + hipblasDoubleComplex* const C[], + const int ldc, int* info, const int batch_count); -// geqrf_strided_batched -hipblasStatus_t hipblasSgeqrfStridedBatchedFortran(hipblasHandle_t handle, - const int m, - const int n, - float* A, - const int lda, - const int strideA, - float* ipiv, - const int strideP, - int* info, - const int batch_count); - -hipblasStatus_t hipblasDgeqrfStridedBatchedFortran(hipblasHandle_t handle, - const int m, - const int n, - double* A, - const int lda, - const int strideA, - double* ipiv, - const int strideP, - int* info, - const int batch_count); - -hipblasStatus_t hipblasCgeqrfStridedBatchedFortran(hipblasHandle_t handle, - const int m, - const int n, - hipblasComplex* A, - const int lda, - const int strideA, - hipblasComplex* ipiv, - const int strideP, - int* info, - const int batch_count); - -hipblasStatus_t hipblasZgeqrfStridedBatchedFortran(hipblasHandle_t handle, - const int m, - const int n, - hipblasDoubleComplex* A, - const int lda, - const int strideA, - hipblasDoubleComplex* ipiv, - const int strideP, - int* info, - const int batch_count); - -// gemm -hipblasStatus_t hipblasHgemmFortran(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, - int n, - int k, - const hipblasHalf* alpha, - const hipblasHalf* A, - int lda, - const hipblasHalf* B, - int ldb, - const hipblasHalf* beta, - hipblasHalf* C, - int ldc); - -hipblasStatus_t hipblasSgemmFortran(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, - int n, - int k, - const float* alpha, - const float* A, - int lda, - const float* B, - int ldb, - const float* beta, - float* C, - int ldc); - -hipblasStatus_t hipblasDgemmFortran(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, - int n, - int k, - const double* alpha, - const double* A, - int lda, - const double* B, - int ldb, - const double* beta, - double* C, - int ldc); - -hipblasStatus_t hipblasCgemmFortran(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, - int n, - int k, - const hipblasComplex* alpha, - const hipblasComplex* A, - int lda, - const hipblasComplex* B, - int ldb, - const hipblasComplex* beta, - hipblasComplex* C, - int ldc); - -hipblasStatus_t hipblasZgemmFortran(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, - int n, - int k, - const hipblasDoubleComplex* alpha, - const hipblasDoubleComplex* A, - int lda, - const hipblasDoubleComplex* B, - int ldb, - const hipblasDoubleComplex* beta, - hipblasDoubleComplex* C, - int ldc); - -// gemm batched -hipblasStatus_t hipblasHgemmBatchedFortran(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, - int n, - int k, - const hipblasHalf* alpha, - const hipblasHalf* const A[], - int lda, - const hipblasHalf* const B[], - int ldb, - const hipblasHalf* beta, - hipblasHalf* const C[], - int ldc, - int batchCount); - -hipblasStatus_t hipblasSgemmBatchedFortran(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, - int n, - int k, - const float* alpha, - const float* const A[], - int lda, - const float* const B[], - int ldb, - const float* beta, - float* const C[], - int ldc, - int batchCount); +// geqrf +hipblasStatus_t hipblasSgeqrfFortran(hipblasHandle_t handle, + const int m, + const int n, + float* A, + const int lda, + float* tau, + int* info); -hipblasStatus_t hipblasDgemmBatchedFortran(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, - int n, - int k, - const double* alpha, - const double* const A[], - int lda, - const double* const B[], - int ldb, - const double* beta, - double* const C[], - int ldc, - int batchCount); +hipblasStatus_t hipblasDgeqrfFortran(hipblasHandle_t handle, + const int m, + const int n, + double* A, + const int lda, + double* tau, + int* info); -hipblasStatus_t hipblasCgemmBatchedFortran(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, - int n, - int k, - const hipblasComplex* alpha, - const hipblasComplex* const A[], - int lda, - const hipblasComplex* const B[], - int ldb, - const hipblasComplex* beta, - hipblasComplex* const C[], - int ldc, - int batchCount); +hipblasStatus_t hipblasCgeqrfFortran(hipblasHandle_t handle, + const int m, + const int n, + hipblasComplex* A, + const int lda, + hipblasComplex* tau, + int* info); -hipblasStatus_t hipblasZgemmBatchedFortran(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, - int n, - int k, - const hipblasDoubleComplex* alpha, - const hipblasDoubleComplex* const A[], - int lda, - const hipblasDoubleComplex* const B[], - int ldb, - const hipblasDoubleComplex* beta, - hipblasDoubleComplex* const C[], - int ldc, - int batchCount); +hipblasStatus_t hipblasZgeqrfFortran(hipblasHandle_t handle, + const int m, + const int n, + hipblasDoubleComplex* A, + const int lda, + hipblasDoubleComplex* tau, + int* info); -// gemm_strided_batched -hipblasStatus_t hipblasHgemmStridedBatchedFortran(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, - int n, - int k, - const hipblasHalf* alpha, - const hipblasHalf* A, - int lda, - long long bsa, - const hipblasHalf* B, - int ldb, - long long bsb, - const hipblasHalf* beta, - hipblasHalf* C, - int ldc, - long long bsc, - int batchCount); +// geqrf_batched +hipblasStatus_t hipblasSgeqrfBatchedFortran(hipblasHandle_t handle, + const int m, + const int n, + float* const A[], + const int lda, + float* const tau[], + int* info, + const int batch_count); -hipblasStatus_t hipblasSgemmStridedBatchedFortran(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, - int n, - int k, - const float* alpha, - const float* A, - int lda, - long long bsa, - const float* B, - int ldb, - long long bsb, - const float* beta, - float* C, - int ldc, - long long bsc, - int batchCount); +hipblasStatus_t hipblasDgeqrfBatchedFortran(hipblasHandle_t handle, + const int m, + const int n, + double* const A[], + const int lda, + double* const tau[], + int* info, + const int batch_count); -hipblasStatus_t hipblasDgemmStridedBatchedFortran(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, - int n, - int k, - const double* alpha, - const double* A, - int lda, - long long bsa, - const double* B, - int ldb, - long long bsb, - const double* beta, - double* C, - int ldc, - long long bsc, - int batchCount); +hipblasStatus_t hipblasCgeqrfBatchedFortran(hipblasHandle_t handle, + const int m, + const int n, + hipblasComplex* const A[], + const int lda, + hipblasComplex* const tau[], + int* info, + const int batch_count); -hipblasStatus_t hipblasCgemmStridedBatchedFortran(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, - int n, - int k, - const hipblasComplex* alpha, - const hipblasComplex* A, - int lda, - long long bsa, - const hipblasComplex* B, - int ldb, - long long bsb, - const hipblasComplex* beta, - hipblasComplex* C, - int ldc, - long long bsc, - int batchCount); +hipblasStatus_t hipblasZgeqrfBatchedFortran(hipblasHandle_t handle, + const int m, + const int n, + hipblasDoubleComplex* const A[], + const int lda, + hipblasDoubleComplex* const tau[], + int* info, + const int batch_count); -hipblasStatus_t hipblasZgemmStridedBatchedFortran(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, - int n, - int k, - const hipblasDoubleComplex* alpha, - const hipblasDoubleComplex* A, - int lda, - long long bsa, - const hipblasDoubleComplex* B, - int ldb, - long long bsb, - const hipblasDoubleComplex* beta, - hipblasDoubleComplex* C, - int ldc, - long long bsc, - int batchCount); +// geqrf_strided_batched +hipblasStatus_t hipblasSgeqrfStridedBatchedFortran(hipblasHandle_t handle, + const int m, + const int n, + float* A, + const int lda, + const int stride_A, + float* tau, + const int stride_T, + int* info, + const int batch_count); -// gemmex -hipblasStatus_t hipblasGemmExFortran(hipblasHandle_t handle, - hipblasOperation_t trans_a, - hipblasOperation_t trans_b, - int m, - int n, - int k, - const void* alpha, - const void* a, - hipblasDatatype_t a_type, - int lda, - const void* b, - hipblasDatatype_t b_type, - int ldb, - const void* beta, - void* c, - hipblasDatatype_t c_type, - int ldc, - hipblasDatatype_t compute_type, - hipblasGemmAlgo_t algo); +hipblasStatus_t hipblasDgeqrfStridedBatchedFortran(hipblasHandle_t handle, + const int m, + const int n, + double* A, + const int lda, + const int stride_A, + double* tau, + const int stride_T, + int* info, + const int batch_count); -hipblasStatus_t hipblasGemmBatchedExFortran(hipblasHandle_t handle, - hipblasOperation_t trans_a, - hipblasOperation_t trans_b, - int m, - int n, - int k, - const void* alpha, - const void* a[], - hipblasDatatype_t a_type, - int lda, - const void* b[], - hipblasDatatype_t b_type, - int ldb, - const void* beta, - void* c[], - hipblasDatatype_t c_type, - int ldc, - int batch_count, - hipblasDatatype_t compute_type, - hipblasGemmAlgo_t algo); +hipblasStatus_t hipblasCgeqrfStridedBatchedFortran(hipblasHandle_t handle, + const int m, + const int n, + hipblasComplex* A, + const int lda, + const int stride_A, + hipblasComplex* tau, + const int stride_T, + int* info, + const int batch_count); -hipblasStatus_t hipblasGemmStridedBatchedExFortran(hipblasHandle_t handle, - hipblasOperation_t trans_a, - hipblasOperation_t trans_b, - int m, - int n, - int k, - const void* alpha, - const void* a, - hipblasDatatype_t a_type, - int lda, - int stride_A, - const void* b, - hipblasDatatype_t b_type, - int ldb, - int stride_B, - const void* beta, - void* c, - hipblasDatatype_t c_type, - int ldc, - int stride_C, - int batch_count, - hipblasDatatype_t compute_type, - hipblasGemmAlgo_t algo); +hipblasStatus_t hipblasZgeqrfStridedBatchedFortran(hipblasHandle_t handle, + const int m, + const int n, + hipblasDoubleComplex* A, + const int lda, + const int stride_A, + hipblasDoubleComplex* tau, + const int stride_T, + int* info, + const int batch_count); } #endif diff --git a/clients/include/hipblas_fortran_solver.f90 b/clients/include/hipblas_fortran_solver.f90 new file mode 100644 index 000000000..099490143 --- /dev/null +++ b/clients/include/hipblas_fortran_solver.f90 @@ -0,0 +1,825 @@ +module hipblas_interface + use iso_c_binding + use hipblas + + contains + + !--------! + ! Solver ! + !--------! + + ! getrf + function hipblasSgetrfFortran(handle, n, A, lda, ipiv, info) & + result(res) & + bind(c, name = 'hipblasSgetrfFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: info + integer(c_int) :: res + res = hipblasSgetrf(handle, n, A, lda, ipiv, info) + end function hipblasSgetrfFortran + + function hipblasDgetrfFortran(handle, n, A, lda, ipiv, info) & + result(res) & + bind(c, name = 'hipblasDgetrfFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: info + integer(c_int) :: res + res = hipblasDgetrf(handle, n, A, lda, ipiv, info) + end function hipblasDgetrfFortran + + function hipblasCgetrfFortran(handle, n, A, lda, ipiv, info) & + result(res) & + bind(c, name = 'hipblasCgetrfFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: info + integer(c_int) :: res + res = hipblasCgetrf(handle, n, A, lda, ipiv, info) + end function hipblasCgetrfFortran + + function hipblasZgetrfFortran(handle, n, A, lda, ipiv, info) & + result(res) & + bind(c, name = 'hipblasZgetrfFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: info + integer(c_int) :: res + res = hipblasZgetrf(handle, n, A, lda, ipiv, info) + end function hipblasZgetrfFortran + + ! getrf_batched + function hipblasSgetrfBatchedFortran(handle, n, A, lda, ipiv, info, batch_count) & + result(res) & + bind(c, name = 'hipblasSgetrfBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasSgetrfBatched(handle, n, A, lda, ipiv, info, batch_count) + end function hipblasSgetrfBatchedFortran + + function hipblasDgetrfBatchedFortran(handle, n, A, lda, ipiv, info, batch_count) & + result(res) & + bind(c, name = 'hipblasDgetrfBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasDgetrfBatched(handle, n, A, lda, ipiv, info, batch_count) + end function hipblasDgetrfBatchedFortran + + function hipblasCgetrfBatchedFortran(handle, n, A, lda, ipiv, info, batch_count) & + result(res) & + bind(c, name = 'hipblasCgetrfBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasCgetrfBatched(handle, n, A, lda, ipiv, info, batch_count) + end function hipblasCgetrfBatchedFortran + + function hipblasZgetrfBatchedFortran(handle, n, A, lda, ipiv, info, batch_count) & + result(res) & + bind(c, name = 'hipblasZgetrfBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasZgetrfBatched(handle, n, A, lda, ipiv, info, batch_count) + end function hipblasZgetrfBatchedFortran + + ! getrf_strided_batched + function hipblasSgetrfStridedBatchedFortran(handle, n, A, lda, stride_A,& + ipiv, stride_P, info, batch_count) & + result(res) & + bind(c, name = 'hipblasSgetrfStridedBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: ipiv + integer(c_int), value :: stride_P + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasSgetrfStridedBatched(handle, n, A, lda, stride_A,& + ipiv, stride_P, info, batch_count) + end function hipblasSgetrfStridedBatchedFortran + + function hipblasDgetrfStridedBatchedFortran(handle, n, A, lda, stride_A,& + ipiv, stride_P, info, batch_count) & + result(res) & + bind(c, name = 'hipblasDgetrfStridedBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: ipiv + integer(c_int), value :: stride_P + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasDgetrfStridedBatched(handle, n, A, lda, stride_A,& + ipiv, stride_P, info, batch_count) + end function hipblasDgetrfStridedBatchedFortran + + function hipblasCgetrfStridedBatchedFortran(handle, n, A, lda, stride_A,& + ipiv, stride_P, info, batch_count) & + result(res) & + bind(c, name = 'hipblasCgetrfStridedBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: ipiv + integer(c_int), value :: stride_P + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasCgetrfStridedBatched(handle, n, A, lda, stride_A,& + ipiv, stride_P, info, batch_count) + end function hipblasCgetrfStridedBatchedFortran + + function hipblasZgetrfStridedBatchedFortran(handle, n, A, lda, stride_A,& + ipiv, stride_P, info, batch_count) & + result(res) & + bind(c, name = 'hipblasZgetrfStridedBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: ipiv + integer(c_int), value :: stride_P + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasZgetrfStridedBatched(handle, n, A, lda, stride_A,& + ipiv, stride_P, info, batch_count) + end function hipblasZgetrfStridedBatchedFortran + + ! getrs + function hipblasSgetrsFortran(handle, trans, n, nrhs, A, lda, ipiv,& + B, ldb, info) & + result(res) & + bind(c, name = 'hipblasSgetrsFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: info + integer(c_int) :: res + res = hipblasSgetrs(handle, trans, n, nrhs, A, lda,& + ipiv, B, ldb, info) + end function hipblasSgetrsFortran + + function hipblasDgetrsFortran(handle, trans, n, nrhs, A, lda, ipiv,& + B, ldb, info) & + result(res) & + bind(c, name = 'hipblasDgetrsFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: info + integer(c_int) :: res + res = hipblasDgetrs(handle, trans, n, nrhs, A, lda,& + ipiv, B, ldb, info) + end function hipblasDgetrsFortran + + function hipblasCgetrsFortran(handle, trans, n, nrhs, A, lda, ipiv,& + B, ldb, info) & + result(res) & + bind(c, name = 'hipblasCgetrsFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: info + integer(c_int) :: res + res = hipblasCgetrs(handle, trans, n, nrhs, A, lda,& + ipiv, B, ldb, info) + end function hipblasCgetrsFortran + + function hipblasZgetrsFortran(handle, trans, n, nrhs, A, lda, ipiv,& + B, ldb, info) & + result(res) & + bind(c, name = 'hipblasZgetrsFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: info + integer(c_int) :: res + res = hipblasZgetrs(handle, trans, n, nrhs, A, lda,& + ipiv, B, ldb, info) + end function hipblasZgetrsFortran + + ! getrs_batched + function hipblasSgetrsBatchedFortran(handle, trans, n, nrhs, A, lda, ipiv,& + B, ldb, info, batch_count) & + result(res) & + bind(c, name = 'hipblasSgetrsBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasSgetrsBatched(handle, trans, n, nrhs, A, lda,& + ipiv, B, ldb, info, batch_count) + end function hipblasSgetrsBatchedFortran + + function hipblasDgetrsBatchedFortran(handle, trans, n, nrhs, A, lda, ipiv,& + B, ldb, info, batch_count) & + result(res) & + bind(c, name = 'hipblasDgetrsBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasDgetrsBatched(handle, trans, n, nrhs, A, lda,& + ipiv, B, ldb, info, batch_count) + end function hipblasDgetrsBatchedFortran + + function hipblasCgetrsBatchedFortran(handle, trans, n, nrhs, A, lda, ipiv,& + B, ldb, info, batch_count) & + result(res) & + bind(c, name = 'hipblasCgetrsBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasCgetrsBatched(handle, trans, n, nrhs, A, lda,& + ipiv, B, ldb, info, batch_count) + end function hipblasCgetrsBatchedFortran + + function hipblasZgetrsBatchedFortran(handle, trans, n, nrhs, A, lda, ipiv,& + B, ldb, info, batch_count) & + result(res) & + bind(c, name = 'hipblasZgetrsBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasZgetrsBatched(handle, trans, n, nrhs, A, lda,& + ipiv, B, ldb, info, batch_count) + end function hipblasZgetrsBatchedFortran + + ! getrs_strided_batched + function hipblasSgetrsStridedBatchedFortran(handle, trans, n, nrhs, A, lda, stride_A, ipiv,& + stride_P, B, ldb, stride_B, info, batch_count) & + result(res) & + bind(c, name = 'hipblasSgetrsStridedBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: ipiv + integer(c_int), value :: stride_P + type(c_ptr), value :: B + integer(c_int), value :: ldb + integer(c_int), value :: stride_B + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasSgetrsStridedBatched(handle, trans, n, nrhs, A, lda, stride_A,& + ipiv, stride_P, B, ldb, stride_B, info, batch_count) + end function hipblasSgetrsStridedBatchedFortran + + function hipblasDgetrsStridedBatchedFortran(handle, trans, n, nrhs, A, lda, stride_A, ipiv,& + stride_P, B, ldb, stride_B, info, batch_count) & + result(res) & + bind(c, name = 'hipblasDgetrsStridedBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: ipiv + integer(c_int), value :: stride_P + type(c_ptr), value :: B + integer(c_int), value :: ldb + integer(c_int), value :: stride_B + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasDgetrsStridedBatched(handle, trans, n, nrhs, A, lda, stride_A,& + ipiv, stride_P, B, ldb, stride_B, info, batch_count) + end function hipblasDgetrsStridedBatchedFortran + + function hipblasCgetrsStridedBatchedFortran(handle, trans, n, nrhs, A, lda, stride_A, ipiv,& + stride_P, B, ldb, stride_B, info, batch_count) & + result(res) & + bind(c, name = 'hipblasCgetrsStridedBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: ipiv + integer(c_int), value :: stride_P + type(c_ptr), value :: B + integer(c_int), value :: ldb + integer(c_int), value :: stride_B + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasCgetrsStridedBatched(handle, trans, n, nrhs, A, lda, stride_A,& + ipiv, stride_P, B, ldb, stride_B, info, batch_count) + end function hipblasCgetrsStridedBatchedFortran + + function hipblasZgetrsStridedBatchedFortran(handle, trans, n, nrhs, A, lda, stride_A, ipiv,& + stride_P, B, ldb, stride_B, info, batch_count) & + result(res) & + bind(c, name = 'hipblasZgetrsStridedBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: ipiv + integer(c_int), value :: stride_P + type(c_ptr), value :: B + integer(c_int), value :: ldb + integer(c_int), value :: stride_B + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasZgetrsStridedBatched(handle, trans, n, nrhs, A, lda, stride_A,& + ipiv, stride_P, B, ldb, stride_B, info, batch_count) + end function hipblasZgetrsStridedBatchedFortran + + ! getri_batched + function hipblasSgetriBatchedFortran(handle, n, A, lda, ipiv, C, ldc, info, batch_count) & + result(res) & + bind(c, name = 'hipblasSgetriBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: C + integer(c_int), value :: ldc + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasSgetriBatched(handle, n, A, lda, ipiv, C, ldc, info, batch_count) + end function hipblasSgetriBatchedFortran + + function hipblasDgetriBatchedFortran(handle, n, A, lda, ipiv, C, ldc, info, batch_count) & + result(res) & + bind(c, name = 'hipblasDgetriBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: C + integer(c_int), value :: ldc + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasDgetriBatched(handle, n, A, lda, ipiv, C, ldc, info, batch_count) + end function hipblasDgetriBatchedFortran + + function hipblasCgetriBatchedFortran(handle, n, A, lda, ipiv, C, ldc, info, batch_count) & + result(res) & + bind(c, name = 'hipblasCgetriBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: C + integer(c_int), value :: ldc + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasCgetriBatched(handle, n, A, lda, ipiv, C, ldc, info, batch_count) + end function hipblasCgetriBatchedFortran + + function hipblasZgetriBatchedFortran(handle, n, A, lda, ipiv, C, ldc, info, batch_count) & + result(res) & + bind(c, name = 'hipblasZgetriBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: C + integer(c_int), value :: ldc + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasZgetriBatched(handle, n, A, lda, ipiv, C, ldc, info, batch_count) + end function hipblasZgetriBatchedFortran + + ! geqrf + function hipblasSgeqrfFortran(handle, m, n, A, lda, tau, info) & + result(res) & + bind(c, name = 'hipblasSgeqrfFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: tau + type(c_ptr), value :: info + integer(c_int) :: res + res = hipblasSgeqrf(handle, m, n, A, lda, tau, info) + end function hipblasSgeqrfFortran + + function hipblasDgeqrfFortran(handle, m, n, A, lda, tau, info) & + result(res) & + bind(c, name = 'hipblasDgeqrfFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: tau + type(c_ptr), value :: info + integer(c_int) :: res + res = hipblasDgeqrf(handle, m, n, A, lda, tau, info) + end function hipblasDgeqrfFortran + + function hipblasCgeqrfFortran(handle, m, n, A, lda, tau, info) & + result(res) & + bind(c, name = 'hipblasCgeqrfFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: tau + type(c_ptr), value :: info + integer(c_int) :: res + res = hipblasCgeqrf(handle, m, n, A, lda, tau, info) + end function hipblasCgeqrfFortran + + function hipblasZgeqrfFortran(handle, m, n, A, lda, tau, info) & + result(res) & + bind(c, name = 'hipblasZgeqrfFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: tau + type(c_ptr), value :: info + integer(c_int) :: res + res = hipblasZgeqrf(handle, m, n, A, lda, tau, info) + end function hipblasZgeqrfFortran + + ! geqrf_batched + function hipblasSgeqrfBatchedFortran(handle, m, n, A, lda, tau, info, batch_count) & + result(res) & + bind(c, name = 'hipblasSgeqrfBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: tau + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasSgeqrfBatched(handle, m, n, A, lda, tau, info, batch_count) + end function hipblasSgeqrfBatchedFortran + + function hipblasDgeqrfBatchedFortran(handle, m, n, A, lda, tau, info, batch_count) & + result(res) & + bind(c, name = 'hipblasDgeqrfBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: tau + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasDgeqrfBatched(handle, m, n, A, lda, tau, info, batch_count) + end function hipblasDgeqrfBatchedFortran + + function hipblasCgeqrfBatchedFortran(handle, m, n, A, lda, tau, info, batch_count) & + result(res) & + bind(c, name = 'hipblasCgeqrfBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: tau + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasCgeqrfBatched(handle, m, n, A, lda, tau, info, batch_count) + end function hipblasCgeqrfBatchedFortran + + function hipblasZgeqrfBatchedFortran(handle, m, n, A, lda, tau, info, batch_count) & + result(res) & + bind(c, name = 'hipblasZgeqrfBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: tau + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasZgeqrfBatched(handle, m, n, A, lda, tau, info, batch_count) + end function hipblasZgeqrfBatchedFortran + + ! geqrf_strided_batched + function hipblasSgeqrfStridedBatchedFortran(handle, m, n, A, lda, stride_A,& + tau, stride_T, info, batch_count) & + result(res) & + bind(c, name = 'hipblasSgeqrfStridedBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: tau + integer(c_int), value :: stride_T + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasSgeqrfStridedBatched(handle, m, n, A, lda, stride_A,& + tau, stride_T, info, batch_count) + end function hipblasSgeqrfStridedBatchedFortran + +function hipblasDgeqrfStridedBatchedFortran(handle, m, n, A, lda, stride_A,& + tau, stride_T, info, batch_count) & + result(res) & + bind(c, name = 'hipblasDgeqrfStridedBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: tau + integer(c_int), value :: stride_T + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasDgeqrfStridedBatched(handle, m, n, A, lda, stride_A,& + tau, stride_T, info, batch_count) +end function hipblasDgeqrfStridedBatchedFortran + + function hipblasCgeqrfStridedBatchedFortran(handle, m, n, A, lda, stride_A,& + tau, stride_T, info, batch_count) & + result(res) & + bind(c, name = 'hipblasCgeqrfStridedBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: tau + integer(c_int), value :: stride_T + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasCgeqrfStridedBatched(handle, m, n, A, lda, stride_A,& + tau, stride_T, info, batch_count) + end function hipblasCgeqrfStridedBatchedFortran + + function hipblasZgeqrfStridedBatchedFortran(handle, m, n, A, lda, stride_A,& + tau, stride_T, info, batch_count) & + result(res) & + bind(c, name = 'hipblasZgeqrfStridedBatchedFortran') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: tau + integer(c_int), value :: stride_T + type(c_ptr), value :: info + integer(c_int), value :: batch_count + integer(c_int) :: res + res = hipblasZgeqrfStridedBatched(handle, m, n, A, lda, stride_A,& + tau, stride_T, info, batch_count) + end function hipblasZgeqrfStridedBatchedFortran + +end module hipblas_interface \ No newline at end of file diff --git a/clients/include/testing_geqrf.hpp b/clients/include/testing_geqrf.hpp index bd67c625f..e5b49e505 100644 --- a/clients/include/testing_geqrf.hpp +++ b/clients/include/testing_geqrf.hpp @@ -17,9 +17,12 @@ using namespace std; -template +template hipblasStatus_t testing_geqrf(Arguments argus) { + bool FORTRAN = argus.fortran; + auto hipblasGeqrfFn = FORTRAN ? hipblasGeqrf : hipblasGeqrf; + int M = argus.M; int N = argus.N; int lda = argus.lda; @@ -56,6 +59,18 @@ hipblasStatus_t testing_geqrf(Arguments argus) srand(1); hipblas_init(hA, M, N, lda); + // scale A to avoid singularities + for(int i = 0; i < M; i++) + { + for(int j = 0; j < N; j++) + { + if(i == j) + hA[i + j * lda] += 400; + else + hA[i + j * lda] -= 4; + } + } + // Copy data from CPU to device CHECK_HIP_ERROR(hipMemcpy(dA, hA.data(), A_size * sizeof(T), hipMemcpyHostToDevice)); CHECK_HIP_ERROR(hipMemset(dIpiv, 0, Ipiv_size * sizeof(T))); @@ -64,7 +79,7 @@ hipblasStatus_t testing_geqrf(Arguments argus) HIPBLAS =================================================================== */ - status = hipblasGeqrf(handle, M, N, dA, lda, dIpiv, &info); + status = hipblasGeqrfFn(handle, M, N, dA, lda, dIpiv, &info); if(status != HIPBLAS_STATUS_SUCCESS) { @@ -85,7 +100,7 @@ hipblasStatus_t testing_geqrf(Arguments argus) // Workspace query host_vector work(1); cblas_geqrf(M, N, hA.data(), lda, hIpiv.data(), work.data(), -1); - int lwork = (int)work[0]; + int lwork = type2int(work[0]); // Perform factorization work = host_vector(lwork); @@ -93,14 +108,14 @@ hipblasStatus_t testing_geqrf(Arguments argus) if(argus.unit_check) { - T eps = std::numeric_limits::epsilon(); + U eps = std::numeric_limits::epsilon(); double tolerance = eps * 2000; - double e1 = norm_check_general('M', M, N, lda, hA.data(), hA1.data()); + double e1 = norm_check_general('F', M, N, lda, hA.data(), hA1.data()); unit_check_error(e1, tolerance); double e2 - = norm_check_general('M', min(M, N), 1, min(M, N), hIpiv.data(), hIpiv1.data()); + = norm_check_general('F', min(M, N), 1, min(M, N), hIpiv.data(), hIpiv1.data()); unit_check_error(e2, tolerance); } } diff --git a/clients/include/testing_geqrf_batched.hpp b/clients/include/testing_geqrf_batched.hpp index f24944665..5d9b8b5c8 100644 --- a/clients/include/testing_geqrf_batched.hpp +++ b/clients/include/testing_geqrf_batched.hpp @@ -17,9 +17,12 @@ using namespace std; -template +template hipblasStatus_t testing_geqrf_batched(Arguments argus) { + bool FORTRAN = argus.fortran; + auto hipblasGeqrfBatchedFn = FORTRAN ? hipblasGeqrfBatched : hipblasGeqrfBatched; + int M = argus.M; int N = argus.N; int lda = argus.lda; @@ -71,6 +74,18 @@ hipblasStatus_t testing_geqrf_batched(Arguments argus) hipblas_init(hA[b], M, N, lda); + // scale A to avoid singularities + for(int i = 0; i < M; i++) + { + for(int j = 0; j < N; j++) + { + if(i == j) + hA[b][i + j * lda] += 400; + else + hA[b][i + j * lda] -= 4; + } + } + // Copy data from CPU to device CHECK_HIP_ERROR(hipMemcpy(bA[b], hA[b].data(), A_size * sizeof(T), hipMemcpyHostToDevice)); CHECK_HIP_ERROR( @@ -84,7 +99,7 @@ hipblasStatus_t testing_geqrf_batched(Arguments argus) HIPBLAS =================================================================== */ - status = hipblasGeqrfBatched(handle, M, N, dA, lda, dIpiv, &info, batch_count); + status = hipblasGeqrfBatchedFn(handle, M, N, dA, lda, dIpiv, &info, batch_count); if(status != HIPBLAS_STATUS_SUCCESS) { @@ -109,7 +124,7 @@ hipblasStatus_t testing_geqrf_batched(Arguments argus) // Workspace query host_vector work(1); cblas_geqrf(M, N, hA[0].data(), lda, hIpiv[0].data(), work.data(), -1); - int lwork = (int)work[0]; + int lwork = type2int(work[0]); // Perform factorization work = host_vector(lwork); @@ -119,14 +134,14 @@ hipblasStatus_t testing_geqrf_batched(Arguments argus) if(argus.unit_check) { - T eps = std::numeric_limits::epsilon(); + U eps = std::numeric_limits::epsilon(); double tolerance = eps * 2000; - double e1 = norm_check_general('M', M, N, lda, hA[b].data(), hA1[b].data()); + double e1 = norm_check_general('F', M, N, lda, hA[b].data(), hA1[b].data()); unit_check_error(e1, tolerance); double e2 = norm_check_general( - 'M', min(M, N), 1, min(M, N), hIpiv[b].data(), hIpiv1[b].data()); + 'F', min(M, N), 1, min(M, N), hIpiv[b].data(), hIpiv1[b].data()); unit_check_error(e2, tolerance); } } diff --git a/clients/include/testing_geqrf_strided_batched.hpp b/clients/include/testing_geqrf_strided_batched.hpp index 447476f38..654db1c49 100644 --- a/clients/include/testing_geqrf_strided_batched.hpp +++ b/clients/include/testing_geqrf_strided_batched.hpp @@ -17,9 +17,12 @@ using namespace std; -template +template hipblasStatus_t testing_geqrf_strided_batched(Arguments argus) { + bool FORTRAN = argus.fortran; + auto hipblasGeqrfStridedBatchedFn = FORTRAN ? hipblasGeqrfStridedBatched : hipblasGeqrfStridedBatched; + int M = argus.M; int N = argus.N; int lda = argus.lda; @@ -67,6 +70,18 @@ hipblasStatus_t testing_geqrf_strided_batched(Arguments argus) T* hAb = hA.data() + b * strideA; hipblas_init(hAb, M, N, lda); + + // scale A to avoid singularities + for(int i = 0; i < M; i++) + { + for(int j = 0; j < N; j++) + { + if(i == j) + hAb[i + j * lda] += 400; + else + hAb[i + j * lda] -= 4; + } + } } // Copy data from CPU to device @@ -77,7 +92,7 @@ hipblasStatus_t testing_geqrf_strided_batched(Arguments argus) HIPBLAS =================================================================== */ - status = hipblasGeqrfStridedBatched( + status = hipblasGeqrfStridedBatchedFn( handle, M, N, dA, lda, strideA, dIpiv, strideP, &info, batch_count); if(status != HIPBLAS_STATUS_SUCCESS) @@ -99,7 +114,7 @@ hipblasStatus_t testing_geqrf_strided_batched(Arguments argus) // Workspace query host_vector work(1); cblas_geqrf(M, N, hA.data(), lda, hIpiv.data(), work.data(), -1); - int lwork = (int)work[0]; + int lwork = type2int(work[0]); // Perform factorization work = host_vector(lwork); @@ -110,14 +125,14 @@ hipblasStatus_t testing_geqrf_strided_batched(Arguments argus) if(argus.unit_check) { - T eps = std::numeric_limits::epsilon(); + U eps = std::numeric_limits::epsilon(); double tolerance = eps * 2000; double e1 = norm_check_general( - 'M', M, N, lda, hA.data() + b * strideA, hA1.data() + b * strideA); + 'F', M, N, lda, hA.data() + b * strideA, hA1.data() + b * strideA); unit_check_error(e1, tolerance); - double e2 = norm_check_general('M', + double e2 = norm_check_general('F', min(M, N), 1, strideP, diff --git a/clients/include/testing_getrf.hpp b/clients/include/testing_getrf.hpp index b075b84b6..af61325af 100644 --- a/clients/include/testing_getrf.hpp +++ b/clients/include/testing_getrf.hpp @@ -17,9 +17,12 @@ using namespace std; -template +template hipblasStatus_t testing_getrf(Arguments argus) { + bool FORTRAN = argus.fortran; + auto hipblasGetrfFn = FORTRAN ? hipblasGetrf : hipblasGetrf; + int M = argus.N; int N = argus.N; int lda = argus.lda; @@ -58,6 +61,18 @@ hipblasStatus_t testing_getrf(Arguments argus) srand(1); hipblas_init(hA, M, N, lda); + // scale A to avoid singularities + for(int i = 0; i < M; i++) + { + for(int j = 0; j < N; j++) + { + if(i == j) + hA[i + j * lda] += 400; + else + hA[i + j * lda] -= 4; + } + } + // Copy data from CPU to device CHECK_HIP_ERROR(hipMemcpy(dA, hA.data(), A_size * sizeof(T), hipMemcpyHostToDevice)); CHECK_HIP_ERROR(hipMemset(dIpiv, 0, Ipiv_size * sizeof(int))); @@ -67,7 +82,7 @@ hipblasStatus_t testing_getrf(Arguments argus) HIPBLAS =================================================================== */ - status = hipblasGetrf(handle, N, dA, lda, dIpiv, dInfo); + status = hipblasGetrfFn(handle, N, dA, lda, dIpiv, dInfo); if(status != HIPBLAS_STATUS_SUCCESS) { @@ -91,10 +106,10 @@ hipblasStatus_t testing_getrf(Arguments argus) if(argus.unit_check) { - T eps = std::numeric_limits::epsilon(); + U eps = std::numeric_limits::epsilon(); double tolerance = eps * 2000; - double e = norm_check_general('M', M, N, lda, hA.data(), hA1.data()); + double e = norm_check_general('F', M, N, lda, hA.data(), hA1.data()); unit_check_error(e, tolerance); } } diff --git a/clients/include/testing_getrf_batched.hpp b/clients/include/testing_getrf_batched.hpp index 060ccd76a..9518e29bd 100644 --- a/clients/include/testing_getrf_batched.hpp +++ b/clients/include/testing_getrf_batched.hpp @@ -17,9 +17,12 @@ using namespace std; -template +template hipblasStatus_t testing_getrf_batched(Arguments argus) { + bool FORTRAN = argus.fortran; + auto hipblasGetrfBatchedFn = FORTRAN ? hipblasGetrfBatched : hipblasGetrfBatched; + int M = argus.N; int N = argus.N; int lda = argus.lda; @@ -71,6 +74,18 @@ hipblasStatus_t testing_getrf_batched(Arguments argus) hipblas_init(hA[b], M, N, lda); + // scale A to avoid singularities + for(int i = 0; i < M; i++) + { + for(int j = 0; j < N; j++) + { + if(i == j) + hA[b][i + j * lda] += 400; + else + hA[b][i + j * lda] -= 4; + } + } + // Copy data from CPU to device CHECK_HIP_ERROR(hipMemcpy(bA[b], hA[b].data(), A_size * sizeof(T), hipMemcpyHostToDevice)); } @@ -83,7 +98,7 @@ hipblasStatus_t testing_getrf_batched(Arguments argus) HIPBLAS =================================================================== */ - status = hipblasGetrfBatched(handle, N, dA, lda, dIpiv, dInfo, batch_count); + status = hipblasGetrfBatchedFn(handle, N, dA, lda, dIpiv, dInfo, batch_count); if(status != HIPBLAS_STATUS_SUCCESS) { @@ -111,10 +126,10 @@ hipblasStatus_t testing_getrf_batched(Arguments argus) if(argus.unit_check) { - T eps = std::numeric_limits::epsilon(); + U eps = std::numeric_limits::epsilon(); double tolerance = eps * 2000; - double e = norm_check_general('M', M, N, lda, hA[b].data(), hA1[b].data()); + double e = norm_check_general('F', M, N, lda, hA[b].data(), hA1[b].data()); unit_check_error(e, tolerance); } } diff --git a/clients/include/testing_getrf_strided_batched.hpp b/clients/include/testing_getrf_strided_batched.hpp index 013da1a0e..30156266b 100644 --- a/clients/include/testing_getrf_strided_batched.hpp +++ b/clients/include/testing_getrf_strided_batched.hpp @@ -17,9 +17,12 @@ using namespace std; -template +template hipblasStatus_t testing_getrf_strided_batched(Arguments argus) { + bool FORTRAN = argus.fortran; + auto hipblasGetrfStridedBatchedFn = FORTRAN ? hipblasGetrfStridedBatched : hipblasGetrfStridedBatched; + int M = argus.N; int N = argus.N; int lda = argus.lda; @@ -69,6 +72,18 @@ hipblasStatus_t testing_getrf_strided_batched(Arguments argus) T* hAb = hA.data() + b * strideA; hipblas_init(hAb, M, N, lda); + + // scale A to avoid singularities + for(int i = 0; i < M; i++) + { + for(int j = 0; j < N; j++) + { + if(i == j) + hAb[i + j * lda] += 400; + else + hAb[i + j * lda] -= 4; + } + } } // Copy data from CPU to device @@ -80,7 +95,7 @@ hipblasStatus_t testing_getrf_strided_batched(Arguments argus) HIPBLAS =================================================================== */ - status = hipblasGetrfStridedBatched( + status = hipblasGetrfStridedBatchedFn( handle, N, dA, lda, strideA, dIpiv, strideP, dInfo, batch_count); if(status != HIPBLAS_STATUS_SUCCESS) @@ -108,11 +123,11 @@ hipblasStatus_t testing_getrf_strided_batched(Arguments argus) if(argus.unit_check) { - T eps = std::numeric_limits::epsilon(); + U eps = std::numeric_limits::epsilon(); double tolerance = eps * 2000; double e = norm_check_general( - 'M', M, N, lda, hA.data() + b * strideA, hA1.data() + b * strideA); + 'F', M, N, lda, hA.data() + b * strideA, hA1.data() + b * strideA); unit_check_error(e, tolerance); } } diff --git a/clients/include/testing_getri_batched.hpp b/clients/include/testing_getri_batched.hpp new file mode 100644 index 000000000..be3644faa --- /dev/null +++ b/clients/include/testing_getri_batched.hpp @@ -0,0 +1,158 @@ +/* ************************************************************************ + * Copyright 2016-2020 Advanced Micro Devices, Inc. + * + * ************************************************************************ */ + +#include +#include +#include +#include + +#include "cblas_interface.h" +#include "flops.h" +#include "hipblas.hpp" +#include "norm.h" +#include "unit.h" +#include "utility.h" + +using namespace std; + +template +hipblasStatus_t testing_getri_batched(Arguments argus) +{ + bool FORTRAN = argus.fortran; + auto hipblasGetriBatchedFn + = FORTRAN ? hipblasGetriBatched : hipblasGetriBatched; + + int M = argus.N; + int N = argus.N; + int lda = argus.lda; + int batch_count = argus.batch_count; + + int strideP = min(M, N); + int A_size = lda * N; + int Ipiv_size = strideP * batch_count; + + hipblasStatus_t status = HIPBLAS_STATUS_SUCCESS; + + // Check to prevent memory allocation error + if(M < 0 || N < 0 || lda < M || batch_count < 0) + { + return HIPBLAS_STATUS_INVALID_VALUE; + } + if(batch_count == 0) + { + return HIPBLAS_STATUS_SUCCESS; + } + + // Naming: dK is in GPU (device) memory. hK is in CPU (host) memory + host_vector hA[batch_count]; + host_vector hA1[batch_count]; + host_vector hC[batch_count]; + host_vector hIpiv(Ipiv_size); + host_vector hIpiv1(Ipiv_size); + host_vector hInfo(batch_count); + host_vector hInfo1(batch_count); + + device_batch_vector bA(batch_count, A_size); + device_batch_vector bC(batch_count, A_size); + + device_vector dA(batch_count); + device_vector dC(batch_count); + device_vector dIpiv(Ipiv_size); + device_vector dInfo(batch_count); + + double gpu_time_used, cpu_time_used; + double hipblasGflops, cblas_gflops; + double rocblas_error; + + hipblasHandle_t handle; + hipblasCreate(&handle); + + // Initial hA on CPU + srand(1); + for(int b = 0; b < batch_count; b++) + { + hA[b] = host_vector(A_size); + hA1[b] = host_vector(A_size); + hC[b] = host_vector(A_size); + int* hIpivb = hIpiv.data() + b * strideP; + + hipblas_init(hA[b], M, N, lda); + + // scale A to avoid singularities + for(int i = 0; i < M; i++) + { + for(int j = 0; j < N; j++) + { + if(i == j) + hA[b][i + j * lda] += 400; + else + hA[b][i + j * lda] -= 4; + } + } + + // perform LU factorization on A + hInfo[b] = cblas_getrf(M, N, hA[b].data(), lda, hIpivb); + + // Copy data from CPU to device + CHECK_HIP_ERROR(hipMemcpy(bA[b], hA[b].data(), A_size * sizeof(T), hipMemcpyHostToDevice)); + } + + CHECK_HIP_ERROR(hipMemcpy(dA, bA, batch_count * sizeof(T*), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dC, bC, batch_count * sizeof(T*), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dIpiv, hIpiv, Ipiv_size * sizeof(int), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemset(dInfo, 0, batch_count * sizeof(int))); + + /* ===================================================================== + HIPBLAS + =================================================================== */ + + status = hipblasGetriBatchedFn(handle, N, dA, lda, dIpiv, dC, lda, dInfo, batch_count); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + hipblasDestroy(handle); + return status; + } + + // Copy output from device to CPU + for(int b = 0; b < batch_count; b++) + CHECK_HIP_ERROR(hipMemcpy(hA1[b].data(), bC[b], A_size * sizeof(T), hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR( + hipMemcpy(hIpiv1.data(), dIpiv, Ipiv_size * sizeof(int), hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR( + hipMemcpy(hInfo1.data(), dInfo, batch_count * sizeof(int), hipMemcpyDeviceToHost)); + + if(argus.unit_check) + { + /* ===================================================================== + CPU LAPACK + =================================================================== */ + + for(int b = 0; b < batch_count; b++) + { + // Workspace query + host_vector work(1); + cblas_getri(N, hA[b].data(), lda, hIpiv.data() + b * strideP, work.data(), -1); + int lwork = type2int(work[0]); + + // Perform inversion + work = host_vector(lwork); + hInfo[b] + = cblas_getri(N, hA[b].data(), lda, hIpiv.data() + b * strideP, work.data(), lwork); + + if(argus.unit_check) + { + U eps = std::numeric_limits::epsilon(); + double tolerance = eps * 2000; + + double e = norm_check_general('F', M, N, lda, hA[b].data(), hA1[b].data()); + unit_check_error(e, tolerance); + } + } + } + + hipblasDestroy(handle); + return HIPBLAS_STATUS_SUCCESS; +} diff --git a/clients/include/testing_getrs.hpp b/clients/include/testing_getrs.hpp index 5aa13e48e..78757657b 100644 --- a/clients/include/testing_getrs.hpp +++ b/clients/include/testing_getrs.hpp @@ -17,9 +17,12 @@ using namespace std; -template +template hipblasStatus_t testing_getrs(Arguments argus) { + bool FORTRAN = argus.fortran; + auto hipblasGetrsFn = FORTRAN ? hipblasGetrs : hipblasGetrs; + int N = argus.N; int lda = argus.lda; int ldb = argus.ldb; @@ -61,15 +64,15 @@ hipblasStatus_t testing_getrs(Arguments argus) hipblas_init(hA, N, N, lda); hipblas_init(hX, N, 1, ldb); - // Put hA entries into range [0, 1], make diagonally dominant + // scale A to avoid singularities for(int i = 0; i < N; i++) { for(int j = 0; j < N; j++) { - hA[i + j * lda] = (hA[i + j * lda] - 1.0) / 10.0; - if(i == j) - hA[i + j * lda] *= 100; + hA[i + j * lda] += 400; + else + hA[i + j * lda] -= 4; } } @@ -94,7 +97,7 @@ hipblasStatus_t testing_getrs(Arguments argus) HIPBLAS =================================================================== */ - status = hipblasGetrs(handle, op, N, 1, dA, lda, dIpiv, dB, ldb, &info); + status = hipblasGetrsFn(handle, op, N, 1, dA, lda, dIpiv, dB, ldb, &info); if(status != HIPBLAS_STATUS_SUCCESS) { @@ -122,10 +125,10 @@ hipblasStatus_t testing_getrs(Arguments argus) if(argus.unit_check) { - T eps = std::numeric_limits::epsilon(); + U eps = std::numeric_limits::epsilon(); double tolerance = N * eps * 100; - double e = norm_check_general('M', N, 1, ldb, hB.data(), hB1.data()); + double e = norm_check_general('F', N, 1, ldb, hB.data(), hB1.data()); unit_check_error(e, tolerance); } } diff --git a/clients/include/testing_getrs_batched.hpp b/clients/include/testing_getrs_batched.hpp index 5cabffc32..0e50ca900 100644 --- a/clients/include/testing_getrs_batched.hpp +++ b/clients/include/testing_getrs_batched.hpp @@ -17,9 +17,12 @@ using namespace std; -template +template hipblasStatus_t testing_getrs_batched(Arguments argus) { + bool FORTRAN = argus.fortran; + auto hipblasGetrsBatchedFn = FORTRAN ? hipblasGetrsBatched : hipblasGetrsBatched; + int N = argus.N; int lda = argus.lda; int ldb = argus.ldb; @@ -78,15 +81,15 @@ hipblasStatus_t testing_getrs_batched(Arguments argus) hipblas_init(hA[b], N, N, lda); hipblas_init(hX[b], N, 1, ldb); - // Put hA entries into range [0, 1], make diagonally dominant + // scale A to avoid singularities for(int i = 0; i < N; i++) { for(int j = 0; j < N; j++) { - hA[b][i + j * lda] = (hA[b][i + j * lda] - 1.0) / 10.0; - if(i == j) - hA[b][i + j * lda] *= 100; + hA[b][i + j * lda] += 400; + else + hA[b][i + j * lda] -= 4; } } @@ -116,7 +119,7 @@ hipblasStatus_t testing_getrs_batched(Arguments argus) HIPBLAS =================================================================== */ - status = hipblasGetrsBatched(handle, op, N, 1, dA, lda, dIpiv, dB, ldb, &info, batch_count); + status = hipblasGetrsBatchedFn(handle, op, N, 1, dA, lda, dIpiv, dB, ldb, &info, batch_count); if(status != HIPBLAS_STATUS_SUCCESS) { @@ -143,10 +146,10 @@ hipblasStatus_t testing_getrs_batched(Arguments argus) if(argus.unit_check) { - T eps = std::numeric_limits::epsilon(); + U eps = std::numeric_limits::epsilon(); double tolerance = N * eps * 100; - double e = norm_check_general('M', N, 1, ldb, hB[b].data(), hB1[b].data()); + double e = norm_check_general('F', N, 1, ldb, hB[b].data(), hB1[b].data()); unit_check_error(e, tolerance); } } diff --git a/clients/include/testing_getrs_strided_batched.hpp b/clients/include/testing_getrs_strided_batched.hpp index 50416eb81..df760f0ac 100644 --- a/clients/include/testing_getrs_strided_batched.hpp +++ b/clients/include/testing_getrs_strided_batched.hpp @@ -17,9 +17,12 @@ using namespace std; -template +template hipblasStatus_t testing_getrs_strided_batched(Arguments argus) { + bool FORTRAN = argus.fortran; + auto hipblasGetrsStridedBatchedFn = FORTRAN ? hipblasGetrsStridedBatched : hipblasGetrsStridedBatched; + int N = argus.N; int lda = argus.lda; int ldb = argus.ldb; @@ -78,15 +81,15 @@ hipblasStatus_t testing_getrs_strided_batched(Arguments argus) hipblas_init(hAb, N, N, lda); hipblas_init(hXb, N, 1, ldb); - // Put hA entries into range [0, 1], make diagonally dominant + // scale A to avoid singularities for(int i = 0; i < N; i++) { for(int j = 0; j < N; j++) { - hAb[i + j * lda] = (hAb[i + j * lda] - 1.0) / 10.0; - if(i == j) - hAb[i + j * lda] *= 100; + hAb[i + j * lda] += 400; + else + hAb[i + j * lda] -= 4; } } @@ -111,7 +114,7 @@ hipblasStatus_t testing_getrs_strided_batched(Arguments argus) HIPBLAS =================================================================== */ - status = hipblasGetrsStridedBatched( + status = hipblasGetrsStridedBatchedFn( handle, op, N, 1, dA, lda, strideA, dIpiv, strideP, dB, ldb, strideB, &info, batch_count); if(status != HIPBLAS_STATUS_SUCCESS) @@ -144,11 +147,11 @@ hipblasStatus_t testing_getrs_strided_batched(Arguments argus) if(argus.unit_check) { - T eps = std::numeric_limits::epsilon(); + U eps = std::numeric_limits::epsilon(); double tolerance = N * eps * 100; double e = norm_check_general( - 'M', N, 1, ldb, hB.data() + b * strideB, hB1.data() + b * strideB); + 'F', N, 1, ldb, hB.data() + b * strideB, hB1.data() + b * strideB); unit_check_error(e, tolerance); } } diff --git a/clients/include/testing_set_get_matrix.hpp b/clients/include/testing_set_get_matrix.hpp index f7c07fff4..1e5c86aaf 100644 --- a/clients/include/testing_set_get_matrix.hpp +++ b/clients/include/testing_set_get_matrix.hpp @@ -11,6 +11,7 @@ #include "cblas_interface.h" #include "flops.h" #include "hipblas.hpp" +#include "hipblas_fortran.hpp" #include "norm.h" #include "unit.h" #include "utility.h" @@ -22,6 +23,10 @@ using namespace std; template hipblasStatus_t testing_set_get_matrix(Arguments argus) { + bool FORTRAN = argus.fortran; + auto hipblasSetMatrixFn = FORTRAN ? hipblasSetMatrixFortran : hipblasSetMatrix; + auto hipblasGetMatrixFn = FORTRAN ? hipblasGetMatrixFortran : hipblasGetMatrix; + int rows = argus.rows; int cols = argus.cols; int lda = argus.lda; @@ -61,12 +66,12 @@ hipblasStatus_t testing_set_get_matrix(Arguments argus) } // Naming: dK is in GPU (device) memory. hK is in CPU (host) memory - vector ha(cols * lda); - vector hb(cols * ldb); - vector hb_ref(cols * ldb); - vector hc(cols * ldc); + host_vector ha(cols * lda); + host_vector hb(cols * ldb); + host_vector hb_ref(cols * ldb); + host_vector hc(cols * ldc); - T* dc; + device_vector dc(cols * ldc); double gpu_time_used, cpu_time_used; double hipblasBandwidth, cpu_bandwidth; @@ -76,9 +81,6 @@ hipblasStatus_t testing_set_get_matrix(Arguments argus) hipblasCreate(&handle); - // allocate memory on device - CHECK_HIP_ERROR(hipMalloc(&dc, cols * ldc * sizeof(T))); - // Initial Data on CPU srand(1); hipblas_init(ha, rows, cols, lda); @@ -98,13 +100,18 @@ hipblasStatus_t testing_set_get_matrix(Arguments argus) ROCBLAS =================================================================== */ - status_set = hipblasSetMatrix(rows, cols, sizeof(T), (void*)ha.data(), lda, (void*)dc, ldc); - status_get = hipblasGetMatrix(rows, cols, sizeof(T), (void*)dc, ldc, (void*)hb.data(), ldb); - if(status_set != HIPBLAS_STATUS_SUCCESS || status_get != HIPBLAS_STATUS_SUCCESS) + status_set = hipblasSetMatrixFn(rows, cols, sizeof(T), (void*)ha.data(), lda, (void*)dc, ldc); + status_get = hipblasGetMatrixFn(rows, cols, sizeof(T), (void*)dc, ldc, (void*)hb.data(), ldb); + if(status_set != HIPBLAS_STATUS_SUCCESS) { - CHECK_HIP_ERROR(hipFree(dc)); hipblasDestroy(handle); - return status; + return status_set; + } + + if(status_get != HIPBLAS_STATUS_SUCCESS) + { + hipblasDestroy(handle); + return status_get; } if(argus.unit_check) @@ -130,7 +137,6 @@ hipblasStatus_t testing_set_get_matrix(Arguments argus) } } - CHECK_HIP_ERROR(hipFree(dc)); hipblasDestroy(handle); return HIPBLAS_STATUS_SUCCESS; } diff --git a/clients/include/testing_set_get_matrix_async.hpp b/clients/include/testing_set_get_matrix_async.hpp new file mode 100644 index 000000000..747f35c18 --- /dev/null +++ b/clients/include/testing_set_get_matrix_async.hpp @@ -0,0 +1,129 @@ +/* ************************************************************************ + * Copyright 2016-2020 Advanced Micro Devices, Inc. + * + * ************************************************************************ */ + +#include +#include +#include +#include + +#include "cblas_interface.h" +#include "flops.h" +#include "hipblas.hpp" +#include "hipblas_fortran.hpp" +#include "norm.h" +#include "unit.h" +#include "utility.h" + +using namespace std; + +/* ============================================================================================ */ + +template +hipblasStatus_t testing_set_get_matrix_async(Arguments argus) +{ + bool FORTRAN = argus.fortran; + auto hipblasSetMatrixAsyncFn = FORTRAN ? hipblasSetMatrixAsyncFortran : hipblasSetMatrixAsync; + auto hipblasGetMatrixAsyncFn = FORTRAN ? hipblasGetMatrixAsyncFortran : hipblasGetMatrixAsync; + + int rows = argus.rows; + int cols = argus.cols; + int lda = argus.lda; + int ldb = argus.ldb; + int ldc = argus.ldc; + + hipblasStatus_t status = HIPBLAS_STATUS_SUCCESS; + hipblasStatus_t status_set = HIPBLAS_STATUS_SUCCESS; + hipblasStatus_t status_get = HIPBLAS_STATUS_SUCCESS; + + // argument sanity check, quick return if input parameters are invalid before allocating invalid + // memory + if(rows < 0 || cols < 0 || lda <= 0 || ldb <= 0 || ldc <= 0) + { + status = HIPBLAS_STATUS_INVALID_VALUE; + return status; + } + + // Naming: dK is in GPU (device) memory. hK is in CPU (host) memory + host_vector ha(cols * lda); + host_vector hb(cols * ldb); + host_vector hb_ref(cols * ldb); + host_vector hc(cols * ldc); + + device_vector dc(cols * ldc); + + double gpu_time_used, cpu_time_used; + double hipblasBandwidth, cpu_bandwidth; + double rocblas_error = 0.0; + + hipblasHandle_t handle; + hipblasCreate(&handle); + + hipStream_t stream; + hipblasGetStream(handle, &stream); + + // Initial Data on CPU + srand(1); + hipblas_init(ha, rows, cols, lda); + hipblas_init(hb, rows, cols, ldb); + hb_ref = hb; + for(int i = 0; i < cols * ldc; i++) + { + hc[i] = 100 + i; + }; + CHECK_HIP_ERROR(hipMemcpy(dc, hc.data(), sizeof(T) * ldc * cols, hipMemcpyHostToDevice)); + for(int i = 0; i < cols * ldc; i++) + { + hc[i] = 99.0; + }; + + /* ===================================================================== + ROCBLAS + =================================================================== */ + + status_set = hipblasSetMatrixAsyncFn( + rows, cols, sizeof(T), (void*)ha.data(), lda, (void*)dc, ldc, stream); + status_get = hipblasGetMatrixAsyncFn( + rows, cols, sizeof(T), (void*)dc, ldc, (void*)hb.data(), ldb, stream); + + hipStreamSynchronize(stream); + + if(status_set != HIPBLAS_STATUS_SUCCESS) + { + hipblasDestroy(handle); + return status_set; + } + + if(status_get != HIPBLAS_STATUS_SUCCESS) + { + hipblasDestroy(handle); + return status_get; + } + + if(argus.unit_check) + { + /* ===================================================================== + CPU BLAS + =================================================================== */ + + // reference calculation + for(int i1 = 0; i1 < rows; i1++) + { + for(int i2 = 0; i2 < cols; i2++) + { + hb_ref[i1 + i2 * ldb] = ha[i1 + i2 * lda]; + } + } + + // enable unit check, notice unit check is not invasive, but norm check is, + // unit check and norm check can not be interchanged their order + if(argus.unit_check) + { + unit_check_general(rows, cols, ldb, hb.data(), hb_ref.data()); + } + } + + hipblasDestroy(handle); + return HIPBLAS_STATUS_SUCCESS; +} diff --git a/clients/include/testing_set_get_vector.hpp b/clients/include/testing_set_get_vector.hpp index ba81bad3a..35fab5388 100644 --- a/clients/include/testing_set_get_vector.hpp +++ b/clients/include/testing_set_get_vector.hpp @@ -10,6 +10,7 @@ #include "cblas_interface.h" #include "hipblas.hpp" +#include "hipblas_fortran.hpp" #include "norm.h" #include "unit.h" #include "utility.h" @@ -21,6 +22,10 @@ using namespace std; template hipblasStatus_t testing_set_get_vector(Arguments argus) { + bool FORTRAN = argus.fortran; + auto hipblasSetVectorFn = FORTRAN ? hipblasSetVectorFortran : hipblasSetVector; + auto hipblasGetVectorFn = FORTRAN ? hipblasGetVectorFortran : hipblasGetVector; + int M = argus.M; int incx = argus.incx; int incy = argus.incy; @@ -54,19 +59,16 @@ hipblasStatus_t testing_set_get_vector(Arguments argus) } // Naming: dK is in GPU (device) memory. hK is in CPU (host) memory - vector hx(M * incx); - vector hy(M * incy); - vector hy_ref(M * incy); + host_vector hx(M * incx); + host_vector hy(M * incy); + host_vector hy_ref(M * incy); - T* db; + device_vector db(M * incd); hipblasHandle_t handle; hipblasCreate(&handle); - // allocate memory on device - CHECK_HIP_ERROR(hipMalloc(&db, M * incd * sizeof(T))); - // Initial Data on CPU srand(1); hipblas_init(hx, 1, M, incx); @@ -76,22 +78,20 @@ hipblasStatus_t testing_set_get_vector(Arguments argus) /* ===================================================================== ROCBLAS =================================================================== */ - status_set = hipblasSetVector(M, sizeof(T), (void*)hx.data(), incx, (void*)db, incd); + status_set = hipblasSetVectorFn(M, sizeof(T), (void*)hx.data(), incx, (void*)db, incd); - status_get = hipblasGetVector(M, sizeof(T), (void*)db, incd, (void*)hy.data(), incy); + status_get = hipblasGetVectorFn(M, sizeof(T), (void*)db, incd, (void*)hy.data(), incy); if(status_set != HIPBLAS_STATUS_SUCCESS) { - CHECK_HIP_ERROR(hipFree(db)); hipblasDestroy(handle); - return status; + return status_set; } if(status_get != HIPBLAS_STATUS_SUCCESS) { - CHECK_HIP_ERROR(hipFree(db)); hipblasDestroy(handle); - return status; + return status_get; } if(argus.unit_check) @@ -114,7 +114,6 @@ hipblasStatus_t testing_set_get_vector(Arguments argus) } } - CHECK_HIP_ERROR(hipFree(db)); hipblasDestroy(handle); return HIPBLAS_STATUS_SUCCESS; } diff --git a/clients/include/testing_set_get_vector_async.hpp b/clients/include/testing_set_get_vector_async.hpp new file mode 100644 index 000000000..79b4bc7cb --- /dev/null +++ b/clients/include/testing_set_get_vector_async.hpp @@ -0,0 +1,109 @@ +/* ************************************************************************ + * Copyright 2016-2020 Advanced Micro Devices, Inc. + * + * ************************************************************************ */ + +#include +#include +#include +#include + +#include "cblas_interface.h" +#include "hipblas.hpp" +#include "hipblas_fortran.hpp" +#include "norm.h" +#include "unit.h" +#include "utility.h" + +using namespace std; + +/* ============================================================================================ */ + +template +hipblasStatus_t testing_set_get_vector_async(Arguments argus) +{ + bool FORTRAN = argus.fortran; + auto hipblasSetVectorAsyncFn = FORTRAN ? hipblasSetVectorAsyncFortran : hipblasSetVectorAsync; + auto hipblasGetVectorAsyncFn = FORTRAN ? hipblasGetVectorAsyncFortran : hipblasGetVectorAsync; + + int M = argus.M; + int incx = argus.incx; + int incy = argus.incy; + int incd = argus.incd; + + hipblasStatus_t status = HIPBLAS_STATUS_SUCCESS; + hipblasStatus_t status_set = HIPBLAS_STATUS_SUCCESS; + hipblasStatus_t status_get = HIPBLAS_STATUS_SUCCESS; + + // argument sanity check, quick return if input parameters are invalid before allocating invalid + // memory + if(M < 0 || incx <= 0 || incy <= 0 || incd <= 0) + { + status = HIPBLAS_STATUS_INVALID_VALUE; + return status; + } + + // Naming: dK is in GPU (device) memory. hK is in CPU (host) memory + host_vector hx(M * incx); + host_vector hy(M * incy); + host_vector hy_ref(M * incy); + + device_vector db(M * incd); + + hipblasHandle_t handle; + hipblasCreate(&handle); + + hipStream_t stream; + hipblasGetStream(handle, &stream); + + // Initial Data on CPU + srand(1); + hipblas_init(hx, 1, M, incx); + hipblas_init(hy, 1, M, incy); + hy_ref = hy; + + /* ===================================================================== + ROCBLAS + =================================================================== */ + status_set + = hipblasSetVectorAsyncFn(M, sizeof(T), (void*)hx.data(), incx, (void*)db, incd, stream); + status_get + = hipblasGetVectorAsyncFn(M, sizeof(T), (void*)db, incd, (void*)hy.data(), incy, stream); + + hipStreamSynchronize(stream); + + if(status_set != HIPBLAS_STATUS_SUCCESS) + { + hipblasDestroy(handle); + return status_set; + } + + if(status_get != HIPBLAS_STATUS_SUCCESS) + { + hipblasDestroy(handle); + return status_get; + } + + if(argus.unit_check) + { + /* ===================================================================== + CPU BLAS + =================================================================== */ + + // reference calculation + for(int i = 0; i < M; i++) + { + hy_ref[i * incy] = hx[i * incx]; + } + + // enable unit check, notice unit check is not invasive, but norm check is, + // unit check and norm check can not be interchanged their order + if(argus.unit_check) + { + unit_check_general(1, M, incy, hy.data(), hy_ref.data()); + } + } + + hipblasDestroy(handle); + return HIPBLAS_STATUS_SUCCESS; +} diff --git a/clients/include/testing_trsm_batched_ex.hpp b/clients/include/testing_trsm_batched_ex.hpp new file mode 100644 index 000000000..58c485596 --- /dev/null +++ b/clients/include/testing_trsm_batched_ex.hpp @@ -0,0 +1,271 @@ +/* ************************************************************************ + * Copyright 2016-2020 Advanced Micro Devices, Inc. + * + * ************************************************************************ */ + +#include +#include +#include +#include + +#include "cblas_interface.h" +#include "flops.h" +#include "hipblas.hpp" +#include "hipblas_fortran.hpp" +#include "norm.h" +#include "unit.h" +#include "utility.h" + +using namespace std; + +#define TRSM_BLOCK 128 + +/* ============================================================================================ */ + +template +hipblasStatus_t testing_trsm_batched_ex(Arguments argus) +{ + bool FORTRAN = argus.fortran; + auto hipblasTrsmBatchedExFn = FORTRAN ? hipblasTrsmBatchedExFortran : hipblasTrsmBatchedEx; + + int M = argus.M; + int N = argus.N; + int lda = argus.lda; + int ldb = argus.ldb; + + char char_side = argus.side_option; + char char_uplo = argus.uplo_option; + char char_transA = argus.transA_option; + char char_diag = argus.diag_option; + T alpha = argus.alpha; + int batch_count = argus.batch_count; + + hipblasSideMode_t side = char2hipblas_side(char_side); + hipblasFillMode_t uplo = char2hipblas_fill(char_uplo); + hipblasOperation_t transA = char2hipblas_operation(char_transA); + hipblasDiagType_t diag = char2hipblas_diagonal(char_diag); + + int K = (side == HIPBLAS_SIDE_LEFT ? M : N); + int A_size = lda * K; + int B_size = ldb * N; + + hipblasStatus_t status = HIPBLAS_STATUS_SUCCESS; + + // check here to prevent undefined memory allocation error + // TODO: Workaround for cuda tests, not actually testing return values + if(M < 0 || N < 0 || lda < K || ldb < M || batch_count < 0) + { + return HIPBLAS_STATUS_INVALID_VALUE; + } + if(!M || !N || !lda || !ldb || !batch_count) + { + return HIPBLAS_STATUS_SUCCESS; + } + // Naming: dK is in GPU (device) memory. hK is in CPU (host) memory + host_vector hA[batch_count]; + host_vector hB[batch_count]; + host_vector hB_copy[batch_count]; + host_vector hX[batch_count]; + + device_batch_vector bA(batch_count, A_size); + device_batch_vector bB(batch_count, B_size); + device_batch_vector binvA(batch_count, TRSM_BLOCK * K); + + device_vector dA(batch_count); + device_vector dB(batch_count); + device_vector dinvA(batch_count); + + int last = batch_count - 1; + if(!dA || !dB || !dinvA || (!bA[last] && A_size) || (!bB[last] && B_size) || !binvA[last]) + { + return HIPBLAS_STATUS_ALLOC_FAILED; + } + + double gpu_time_used, cpu_time_used; + double hipblasGflops, cblas_gflops; + + hipblasHandle_t handle; + hipblasCreate(&handle); + + // Initial hA on CPU + srand(1); + for(int b = 0; b < batch_count; b++) + { + hA[b] = host_vector(A_size); + hB[b] = host_vector(B_size); + hB_copy[b] = host_vector(B_size); + hX[b] = host_vector(B_size); + + hipblas_init_symmetric(hA[b], K, lda); + // pad untouched area into zero + for(int i = K; i < lda; i++) + { + for(int j = 0; j < K; j++) + { + hA[b][i + j * lda] = 0.0; + } + } + + // proprocess the matrix to avoid ill-conditioned matrix + vector ipiv(K); + cblas_getrf(K, K, hA[b].data(), lda, ipiv.data()); + for(int i = 0; i < K; i++) + { + for(int j = i; j < K; j++) + { + hA[b][i + j * lda] = hA[b][j + i * lda]; + if(diag == HIPBLAS_DIAG_UNIT) + { + if(i == j) + hA[b][i + j * lda] = 1.0; + } + } + } + + // Initial hB, hX on CPU + hipblas_init(hB[b], M, N, ldb); + // pad untouched area into zero + for(int i = M; i < ldb; i++) + { + for(int j = 0; j < N; j++) + { + hB[b][i + j * ldb] = 0.0; + } + } + hX[b] = hB[b]; // original solution hX + + // Calculate hB = hA*hX; + cblas_trmm(side, + uplo, + transA, + diag, + M, + N, + T(1.0) / alpha, + (const T*)hA[b].data(), + lda, + hB[b].data(), + ldb); + + hB_copy[b] = hB[b]; + + // copy data from CPU to device + CHECK_HIP_ERROR(hipMemcpy(bA[b], hA[b], sizeof(T) * A_size, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(bB[b], hB[b], sizeof(T) * B_size, hipMemcpyHostToDevice)); + } + CHECK_HIP_ERROR(hipMemcpy(dA, bA, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dB, bB, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dinvA, binvA, sizeof(T*) * batch_count, hipMemcpyHostToDevice)); + + /* ===================================================================== + HIPBLAS + =================================================================== */ + + int stride_A = TRSM_BLOCK * lda + TRSM_BLOCK; + int stride_invA = TRSM_BLOCK * TRSM_BLOCK; + int blocks = K / TRSM_BLOCK; + + for(int b = 0; b < batch_count; b++) + { + if(blocks > 0) + { + status = hipblasTrtriStridedBatched(handle, + uplo, + diag, + TRSM_BLOCK, + bA[b], + lda, + stride_A, + binvA[b], + TRSM_BLOCK, + stride_invA, + blocks); + if(status != HIPBLAS_STATUS_SUCCESS) + { + hipblasDestroy(handle); + return status; + } + } + + if(K % TRSM_BLOCK != 0 || blocks == 0) + { + status = hipblasTrtriStridedBatched(handle, + uplo, + diag, + K - TRSM_BLOCK * blocks, + bA[b] + stride_A * blocks, + lda, + stride_A, + binvA[b] + stride_invA * blocks, + TRSM_BLOCK, + stride_invA, + 1); + if(status != HIPBLAS_STATUS_SUCCESS) + { + hipblasDestroy(handle); + return status; + } + } + } + + status = hipblasTrsmBatchedExFn(handle, + side, + uplo, + transA, + diag, + M, + N, + &alpha, + dA, + lda, + dB, + ldb, + batch_count, + dinvA, + TRSM_BLOCK * K, + argus.compute_type); + if(status != HIPBLAS_STATUS_SUCCESS) + { + hipblasDestroy(handle); + return status; + } + + // copy output from device to CPU + for(int b = 0; b < batch_count; b++) + CHECK_HIP_ERROR(hipMemcpy(hB[b], bB[b], sizeof(T) * B_size, hipMemcpyDeviceToHost)); + + if(argus.unit_check) + { + /* ===================================================================== + CPU BLAS + =================================================================== */ + + for(int b = 0; b < batch_count; b++) + { + cblas_trsm(side, + uplo, + transA, + diag, + M, + N, + alpha, + (const T*)hA[b].data(), + lda, + hB_copy[b].data(), + ldb); + } + + // if enable norm check, norm check is invasive + real_t eps = std::numeric_limits>::epsilon(); + double tolerance = eps * 40 * M; + + for(int b = 0; b < batch_count; b++) + { + double error = norm_check_general('F', M, N, ldb, hB_copy[b].data(), hB[b].data()); + unit_check_error(error, tolerance); + } + } + + hipblasDestroy(handle); + return HIPBLAS_STATUS_SUCCESS; +} diff --git a/clients/include/testing_trsm_ex.hpp b/clients/include/testing_trsm_ex.hpp new file mode 100644 index 000000000..5c6823d74 --- /dev/null +++ b/clients/include/testing_trsm_ex.hpp @@ -0,0 +1,221 @@ +/* ************************************************************************ + * Copyright 2016-2020 Advanced Micro Devices, Inc. + * + * ************************************************************************ */ + +#include +#include +#include +#include + +#include "cblas_interface.h" +#include "flops.h" +#include "hipblas.hpp" +#include "hipblas_fortran.hpp" +#include "norm.h" +#include "unit.h" +#include "utility.h" + +using namespace std; + +#define TRSM_BLOCK 128 + +/* ============================================================================================ */ + +template +hipblasStatus_t testing_trsm_ex(Arguments argus) +{ + bool FORTRAN = argus.fortran; + auto hipblasTrsmExFn = FORTRAN ? hipblasTrsmExFortran : hipblasTrsmEx; + + int M = argus.M; + int N = argus.N; + int lda = argus.lda; + int ldb = argus.ldb; + + char char_side = argus.side_option; + char char_uplo = argus.uplo_option; + char char_transA = argus.transA_option; + char char_diag = argus.diag_option; + T alpha = argus.alpha; + + hipblasSideMode_t side = char2hipblas_side(char_side); + hipblasFillMode_t uplo = char2hipblas_fill(char_uplo); + hipblasOperation_t transA = char2hipblas_operation(char_transA); + hipblasDiagType_t diag = char2hipblas_diagonal(char_diag); + + int K = (side == HIPBLAS_SIDE_LEFT ? M : N); + int A_size = lda * K; + int B_size = ldb * N; + + hipblasStatus_t status = HIPBLAS_STATUS_SUCCESS; + + // check here to prevent undefined memory allocation error + if(M < 0 || N < 0 || lda < K || ldb < M) + { + return HIPBLAS_STATUS_INVALID_VALUE; + } + // Naming: dK is in GPU (device) memory. hK is in CPU (host) memory + host_vector hA(A_size); + host_vector hB(B_size); + host_vector hB_copy(B_size); + host_vector hX(B_size); + + device_vector dA(A_size); + device_vector dB(B_size); + device_vector dinvA(TRSM_BLOCK * K); + + double gpu_time_used, cpu_time_used; + double hipblasGflops, cblas_gflops; + double rocblas_error; + + hipblasHandle_t handle; + hipblasCreate(&handle); + + // Initial hA on CPU + srand(1); + hipblas_init_symmetric(hA, K, lda); + // pad untouched area into zero + for(int i = K; i < lda; i++) + { + for(int j = 0; j < K; j++) + { + hA[i + j * lda] = 0.0; + } + } + // proprocess the matrix to avoid ill-conditioned matrix + vector ipiv(K); + cblas_getrf(K, K, hA.data(), lda, ipiv.data()); + for(int i = 0; i < K; i++) + { + for(int j = i; j < K; j++) + { + hA[i + j * lda] = hA[j + i * lda]; + if(diag == HIPBLAS_DIAG_UNIT) + { + if(i == j) + hA[i + j * lda] = 1.0; + } + } + } + + // Initial hB, hX on CPU + hipblas_init(hB, M, N, ldb); + // pad untouched area into zero + for(int i = M; i < ldb; i++) + { + for(int j = 0; j < N; j++) + { + hB[i + j * ldb] = 0.0; + } + } + hX = hB; // original solution hX + + // Calculate hB = hA*hX; + cblas_trmm( + side, uplo, transA, diag, M, N, T(1.0) / alpha, (const T*)hA.data(), lda, hB.data(), ldb); + + hB_copy = hB; + + // copy data from CPU to device + CHECK_HIP_ERROR(hipMemcpy(dA, hA.data(), sizeof(T) * A_size, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dB, hB.data(), sizeof(T) * B_size, hipMemcpyHostToDevice)); + + int stride_A = TRSM_BLOCK * lda + TRSM_BLOCK; + int stride_invA = TRSM_BLOCK * TRSM_BLOCK; + int blocks = K / TRSM_BLOCK; + + /* ===================================================================== + HIPBLAS + =================================================================== */ + // Calculate invA + if(blocks > 0) + { + status = hipblasTrtriStridedBatched(handle, + uplo, + diag, + TRSM_BLOCK, + dA, + lda, + stride_A, + dinvA, + TRSM_BLOCK, + stride_invA, + blocks); + if(status != HIPBLAS_STATUS_SUCCESS) + { + hipblasDestroy(handle); + return status; + } + } + + if(K % TRSM_BLOCK != 0 || blocks == 0) + { + status = hipblasTrtriStridedBatched(handle, + uplo, + diag, + K - TRSM_BLOCK * blocks, + dA + stride_A * blocks, + lda, + stride_A, + dinvA + stride_invA * blocks, + TRSM_BLOCK, + stride_invA, + 1); + + if(blocks > 0) + { + if(status != HIPBLAS_STATUS_SUCCESS) + { + hipblasDestroy(handle); + return status; + } + } + } + + status = hipblasTrsmExFn(handle, + side, + uplo, + transA, + diag, + M, + N, + &alpha, + dA, + lda, + dB, + ldb, + dinvA, + TRSM_BLOCK * K, + argus.compute_type); + if(status != HIPBLAS_STATUS_SUCCESS) + { + hipblasDestroy(handle); + return status; + } + + // copy output from device to CPU + CHECK_HIP_ERROR(hipMemcpy(hB.data(), dB, sizeof(T) * B_size, hipMemcpyDeviceToHost)); + + if(argus.unit_check) + { + /* ===================================================================== + CPU BLAS + =================================================================== */ + + cblas_trsm( + side, uplo, transA, diag, M, N, alpha, (const T*)hA.data(), lda, hB_copy.data(), ldb); + + // print_matrix(hB_copy, hB, min(M, 3), min(N,3), ldb); + + // if enable norm check, norm check is invasive + real_t eps = std::numeric_limits>::epsilon(); + double tolerance = eps * 40 * M; + + double error = norm_check_general('F', M, N, ldb, hB_copy.data(), hB.data()); + unit_check_error(error, tolerance); + } + + hipblasDestroy(handle); + return HIPBLAS_STATUS_SUCCESS; +} diff --git a/clients/include/testing_trsm_strided_batched_ex.hpp b/clients/include/testing_trsm_strided_batched_ex.hpp new file mode 100644 index 000000000..2e37a4dc7 --- /dev/null +++ b/clients/include/testing_trsm_strided_batched_ex.hpp @@ -0,0 +1,258 @@ +/* ************************************************************************ + * Copyright 2016-2020 Advanced Micro Devices, Inc. + * + * ************************************************************************ */ + +#include +#include +#include +#include + +#include "cblas_interface.h" +#include "flops.h" +#include "hipblas.hpp" +#include "hipblas_fortran.hpp" +#include "norm.h" +#include "unit.h" +#include "utility.h" + +using namespace std; + +#define TRSM_BLOCK 128 + +/* ============================================================================================ */ + +template +hipblasStatus_t testing_trsm_strided_batched_ex(Arguments argus) +{ + bool FORTRAN = argus.fortran; + auto hipblasTrsmStridedBatchedExFn + = FORTRAN ? hipblasTrsmStridedBatchedEx : hipblasTrsmStridedBatchedEx; + + int M = argus.M; + int N = argus.N; + int lda = argus.lda; + int ldb = argus.ldb; + + char char_side = argus.side_option; + char char_uplo = argus.uplo_option; + char char_transA = argus.transA_option; + char char_diag = argus.diag_option; + T alpha = argus.alpha; + double stride_scale = argus.stride_scale; + int batch_count = argus.batch_count; + + hipblasSideMode_t side = char2hipblas_side(char_side); + hipblasFillMode_t uplo = char2hipblas_fill(char_uplo); + hipblasOperation_t transA = char2hipblas_operation(char_transA); + hipblasDiagType_t diag = char2hipblas_diagonal(char_diag); + + int K = (side == HIPBLAS_SIDE_LEFT ? M : N); + + int strideA = lda * K * stride_scale; + int strideB = ldb * N * stride_scale; + int stride_invA = TRSM_BLOCK * K; + int A_size = strideA * batch_count; + int B_size = strideB * batch_count; + int invA_size = stride_invA * batch_count; + + hipblasStatus_t status = HIPBLAS_STATUS_SUCCESS; + + // check here to prevent undefined memory allocation error + // TODO: Workaround for cuda tests, not actually testing return values + if(M < 0 || N < 0 || lda < K || ldb < M || batch_count < 0) + { + return HIPBLAS_STATUS_INVALID_VALUE; + } + if(!batch_count) + { + return HIPBLAS_STATUS_SUCCESS; + } + // Naming: dK is in GPU (device) memory. hK is in CPU (host) memory + host_vector hA(A_size); + host_vector hB(B_size); + host_vector hB_copy(B_size); + host_vector hX(B_size); + + device_vector dA(A_size); + device_vector dB(B_size); + device_vector dinvA(invA_size); + + double gpu_time_used, cpu_time_used; + double hipblasGflops, cblas_gflops; + + hipblasHandle_t handle; + hipblasCreate(&handle); + + // Initial hA on CPU + srand(1); + hipblas_init_symmetric(hA, K, lda, strideA, batch_count); + for(int b = 0; b < batch_count; b++) + { + T* hAb = hA.data() + b * strideA; + T* hBb = hB.data() + b * strideB; + + // pad ountouched area into zero + for(int i = K; i < lda; i++) + { + for(int j = 0; j < K; j++) + { + hAb[i + j * lda] = 0.0; + } + } + + // proprocess the matrix to avoid ill-conditioned matrix + vector ipiv(K); + cblas_getrf(K, K, hAb, lda, ipiv.data()); + for(int i = 0; i < K; i++) + { + for(int j = i; j < K; j++) + { + hAb[i + j * lda] = hAb[j + i * lda]; + if(diag == HIPBLAS_DIAG_UNIT) + { + if(i == j) + hAb[i + j * lda] = 1.0; + } + } + } + + // Initial hB, hX on CPU + hipblas_init(hBb, M, N, ldb); + // pad untouched area into zero + for(int i = M; i < ldb; i++) + { + for(int j = 0; j < N; j++) + { + hBb[i + j * ldb] = 0.0; + } + } + + // Calculate hB = hA*hX; + cblas_trmm(side, uplo, transA, diag, M, N, T(1.0) / alpha, (const T*)hAb, lda, hBb, ldb); + } + hX = hB; // original solutions hX + hB_copy = hB; + + // copy data from CPU to device + CHECK_HIP_ERROR(hipMemcpy(dA, hA.data(), sizeof(T) * A_size, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dB, hB.data(), sizeof(T) * B_size, hipMemcpyHostToDevice)); + + /* ===================================================================== + HIPBLAS + =================================================================== */ + int sub_stride_A = TRSM_BLOCK * lda + TRSM_BLOCK; + int sub_stride_invA = TRSM_BLOCK * TRSM_BLOCK; + int blocks = K / TRSM_BLOCK; + + for(int b = 0; b < batch_count; b++) + { + if(blocks > 0) + { + status = hipblasTrtriStridedBatched(handle, + uplo, + diag, + TRSM_BLOCK, + dA + b * strideA, + lda, + sub_stride_A, + dinvA + b * stride_invA, + TRSM_BLOCK, + sub_stride_invA, + blocks); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + hipblasDestroy(handle); + return status; + } + } + + if(K % TRSM_BLOCK != 0 || blocks == 0) + { + status + = hipblasTrtriStridedBatched(handle, + uplo, + diag, + K - TRSM_BLOCK * blocks, + dA + sub_stride_A * blocks + b * strideA, + lda, + sub_stride_A, + dinvA + sub_stride_invA * blocks + b * stride_invA, + TRSM_BLOCK, + sub_stride_invA, + 1); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + hipblasDestroy(handle); + return status; + } + } + } + + status = hipblasTrsmStridedBatchedExFn(handle, + side, + uplo, + transA, + diag, + M, + N, + &alpha, + dA, + lda, + strideA, + dB, + ldb, + strideB, + batch_count, + dinvA, + invA_size, + stride_invA, + argus.compute_type); + + if(status != HIPBLAS_STATUS_SUCCESS) + { + hipblasDestroy(handle); + return status; + } + + // copy output from device to CPU + CHECK_HIP_ERROR(hipMemcpy(hB.data(), dB, sizeof(T) * B_size, hipMemcpyDeviceToHost)); + + if(argus.unit_check) + { + /* ===================================================================== + CPU BLAS + =================================================================== */ + + for(int b = 0; b < batch_count; b++) + { + cblas_trsm(side, + uplo, + transA, + diag, + M, + N, + alpha, + (const T*)hA.data() + b * strideA, + lda, + hB_copy.data() + b * strideB, + ldb); + } + + // if enable norm check, norm check is invasive + real_t eps = std::numeric_limits>::epsilon(); + double tolerance = eps * 40 * M; + + for(int b = 0; b < batch_count; b++) + { + double error = norm_check_general( + 'F', M, N, ldb, hB_copy.data() + b * strideB, hB.data() + b * strideB); + unit_check_error(error, tolerance); + } + } + + hipblasDestroy(handle); + return HIPBLAS_STATUS_SUCCESS; +} diff --git a/clients/include/utility.h b/clients/include/utility.h index 954ee03c0..a605ecf35 100644 --- a/clients/include/utility.h +++ b/clients/include/utility.h @@ -581,6 +581,12 @@ void prepare_triangular_solve(T* hA, int lda, T* AAT, int N, char char_uplo) template char type2char(); +/* ============================================================================================ */ +/*! \brief turn float -> int, double -> int, hipblas_float_complex.real() -> int, + * hipblas_double_complex.real() -> int */ +template +int type2int(T val); + /* ============================================================================================ */ /*! \brief Debugging purpose, print out CPU and GPU result matrix, not valid in complex number */ template , int> = 0> diff --git a/install.sh b/install.sh index 0d59634ae..dac6b27c4 100755 --- a/install.sh +++ b/install.sh @@ -24,6 +24,7 @@ function display_help() echo " [--custom-target] link against custom target (e.g. host, device)" echo " [-v|--rocm-dev] Set specific rocm-dev version" echo " [-b|--rocblas] Set specific rocblas version" + echo " [--rocblas-path] Set specific path to custom built rocblas" } # This function is helpful for dockerfiles that do not have sudo installed, but the default user is root @@ -147,18 +148,22 @@ install_packages( ) fi # Custom rocblas installation - if [[ -z ${custom_rocblas+foo} ]]; then - # Install base rocblas package unless -b/--rocblas flag is passed - library_dependencies_ubuntu+=( "rocblas" ) - library_dependencies_centos+=( "rocblas" ) - library_dependencies_fedora+=( "rocblas" ) - library_dependencies_sles+=( "rocblas" ) - else - # Install rocm-specific rocblas package - library_dependencies_ubuntu+=( "${custom_rocblas}" ) - library_dependencies_centos+=( "${custom_rocblas}" ) - library_dependencies_fedora+=( "${custom_rocblas}" ) - library_dependencies_sles+=( "${custom_rocblas}" ) + # Do not install rocblas if --rocblas_path flag is set, + # as we will be building against our own rocblas intead. + if [[ -z ${rocblas_path+foo} ]]; then + if [[ -z ${custom_rocblas+foo} ]]; then + # Install base rocblas package unless -b/--rocblas flag is passed + library_dependencies_ubuntu+=( "rocblas" ) + library_dependencies_centos+=( "rocblas" ) + library_dependencies_fedora+=( "rocblas" ) + library_dependencies_sles+=( "rocblas" ) + else + # Install rocm-specific rocblas package + library_dependencies_ubuntu+=( "${custom_rocblas}" ) + library_dependencies_centos+=( "${custom_rocblas}" ) + library_dependencies_fedora+=( "${custom_rocblas}" ) + library_dependencies_sles+=( "${custom_rocblas}" ) + fi fi if [[ "${build_solver}" == true ]]; then @@ -286,7 +291,7 @@ compiler=g++ # check if we have a modern version of getopt that can handle whitespace and long parameters getopt -T if [[ $? -eq 4 ]]; then - GETOPT_PARSE=$(getopt --name "${0}" --longoptions help,install,clients,no-solver,dependencies,debug,hip-clang,compiler:,cuda,cmakepp,relocatable:,rocm-dev:,rocblas:,custom-target: --options rhicndgp:v:b: -- "$@") + GETOPT_PARSE=$(getopt --name "${0}" --longoptions help,install,clients,no-solver,dependencies,debug,hip-clang,compiler:,cuda,cmakepp,relocatable:,rocm-dev:,rocblas:,rocblas-path:,custom-target: --options rhicndgp:v:b: -- "$@") else echo "Need a new version of getopt" exit 1 @@ -344,6 +349,9 @@ while true; do -b|--rocblas) custom_rocblas=${2} shift 2;; + --rocblas-path) + rocblas_path=${2} + shift 2 ;; --prefix) install_prefix=${2} shift 2 ;; @@ -438,6 +446,11 @@ pushd . cmake_common_options="${cmake_common_options} -DCUSTOM_TARGET=${custom_target}" fi + # custom rocblas + if [[ ${rocblas_path+foo} ]]; then + cmake_common_options="${cmake_common_options} -DCUSTOM_ROCBLAS=${rocblas_path}" + fi + # Build library if [[ "${build_relocatable}" == true ]]; then CXX=${compiler} ${cmake_executable} ${cmake_common_options} ${cmake_client_options} -DCPACK_SET_DESTDIR=OFF -DCMAKE_INSTALL_PREFIX="${rocm_path}" \ diff --git a/library/include/hipblas.h b/library/include/hipblas.h index 4631e6650..7782d5698 100644 --- a/library/include/hipblas.h +++ b/library/include/hipblas.h @@ -256,6 +256,30 @@ HIPBLAS_EXPORT hipblasStatus_t HIPBLAS_EXPORT hipblasStatus_t hipblasGetMatrix(int rows, int cols, int elemSize, const void* A, int lda, void* B, int ldb); +HIPBLAS_EXPORT hipblasStatus_t hipblasSetVectorAsync( + int n, int elem_size, const void* x, int incx, void* y, int incy, hipStream_t stream); + +HIPBLAS_EXPORT hipblasStatus_t hipblasGetVectorAsync( + int n, int elem_size, const void* x, int incx, void* y, int incy, hipStream_t stream); + +HIPBLAS_EXPORT hipblasStatus_t hipblasSetMatrixAsync(int rows, + int cols, + int elem_size, + const void* A, + int lda, + void* B, + int ldb, + hipStream_t stream); + +HIPBLAS_EXPORT hipblasStatus_t hipblasGetMatrixAsync(int rows, + int cols, + int elem_size, + const void* A, + int lda, + void* B, + int ldb, + hipStream_t stream); + // amax HIPBLAS_EXPORT hipblasStatus_t hipblasIsamax(hipblasHandle_t handle, int n, const float* x, int incx, int* result); @@ -6538,6 +6562,47 @@ HIPBLAS_EXPORT hipblasStatus_t hipblasZgetrsStridedBatched(hipblasHandle_t int* info, const int batch_count); +// getri_batched +HIPBLAS_EXPORT hipblasStatus_t hipblasSgetriBatched(hipblasHandle_t handle, + const int n, + float* const A[], + const int lda, + int* ipiv, + float* const C[], + const int ldc, + int* info, + const int batch_count); + +HIPBLAS_EXPORT hipblasStatus_t hipblasDgetriBatched(hipblasHandle_t handle, + const int n, + double* const A[], + const int lda, + int* ipiv, + double* const C[], + const int ldc, + int* info, + const int batch_count); + +HIPBLAS_EXPORT hipblasStatus_t hipblasCgetriBatched(hipblasHandle_t handle, + const int n, + hipblasComplex* const A[], + const int lda, + int* ipiv, + hipblasComplex* const C[], + const int ldc, + int* info, + const int batch_count); + +HIPBLAS_EXPORT hipblasStatus_t hipblasZgetriBatched(hipblasHandle_t handle, + const int n, + hipblasDoubleComplex* const A[], + const int lda, + int* ipiv, + hipblasDoubleComplex* const C[], + const int ldc, + int* info, + const int batch_count); + // geqrf HIPBLAS_EXPORT hipblasStatus_t hipblasSgeqrf(hipblasHandle_t handle, const int m, @@ -6972,6 +7037,60 @@ HIPBLAS_EXPORT hipblasStatus_t hipblasGemmStridedBatchedEx(hipblasHandle_t ha hipblasDatatype_t compute_type, hipblasGemmAlgo_t algo); +// trsm_ex +HIPBLAS_EXPORT hipblasStatus_t hipblasTrsmEx(hipblasHandle_t handle, + hipblasSideMode_t side, + hipblasFillMode_t uplo, + hipblasOperation_t transA, + hipblasDiagType_t diag, + int m, + int n, + const void* alpha, + void* A, + int lda, + void* B, + int ldb, + const void* invA, + int invA_size, + hipblasDatatype_t compute_type); + +HIPBLAS_EXPORT hipblasStatus_t hipblasTrsmBatchedEx(hipblasHandle_t handle, + hipblasSideMode_t side, + hipblasFillMode_t uplo, + hipblasOperation_t transA, + hipblasDiagType_t diag, + int m, + int n, + const void* alpha, + void* A, + int lda, + void* B, + int ldb, + int batch_count, + const void* invA, + int invA_size, + hipblasDatatype_t compute_type); + +HIPBLAS_EXPORT hipblasStatus_t hipblasTrsmStridedBatchedEx(hipblasHandle_t handle, + hipblasSideMode_t side, + hipblasFillMode_t uplo, + hipblasOperation_t transA, + hipblasDiagType_t diag, + int m, + int n, + const void* alpha, + void* A, + int lda, + int stride_A, + void* B, + int ldb, + int stride_B, + int batch_count, + const void* invA, + int invA_size, + int stride_invA, + hipblasDatatype_t compute_type); + #ifdef __cplusplus } #endif diff --git a/library/src/CMakeLists.txt b/library/src/CMakeLists.txt index f21a7046c..258d742d3 100755 --- a/library/src/CMakeLists.txt +++ b/library/src/CMakeLists.txt @@ -49,7 +49,12 @@ target_include_directories( hipblas # Build hipblas from source on AMD platform if( NOT CUDA_FOUND ) if( NOT TARGET rocblas ) - find_package( rocblas REQUIRED CONFIG PATHS /opt/rocm /opt/rocm/rocblas ) + if( CUSTOM_ROCBLAS ) + set ( ENV{rocblas_DIR} ${CUSTOM_ROCBLAS}) + find_package( rocblas REQUIRED CONFIG NO_CMAKE_PATH ) + else( ) + find_package( rocblas REQUIRED CONFIG PATHS /opt/rocm /opt/rocm/rocblas ) + endif( ) endif( ) target_compile_definitions( hipblas PRIVATE __HIP_PLATFORM_HCC__ ) diff --git a/library/src/hcc_detail/hipblas.cpp b/library/src/hcc_detail/hipblas.cpp index 928f98dd3..09e0f1b2a 100644 --- a/library/src/hcc_detail/hipblas.cpp +++ b/library/src/hcc_detail/hipblas.cpp @@ -10,18 +10,6 @@ #include #include -#define USE_DEVICE_POINTER_MODE(handle, cmd) \ - do \ - { \ - hipblasPointerMode_t mode; \ - hipblasGetPointerMode(handle, &mode); \ - hipblasSetPointerMode(handle, HIPBLAS_POINTER_MODE_DEVICE); \ - \ - cmd; \ - \ - hipblasSetPointerMode(handle, mode); \ - } while(0); - #ifdef __cplusplus extern "C" { #endif @@ -351,6 +339,34 @@ hipblasStatus_t return rocBLASStatusToHIPStatus(rocblas_get_matrix(rows, cols, elemSize, A, lda, B, ldb)); } +hipblasStatus_t hipblasSetVectorAsync( + int n, int elemSize, const void* x, int incx, void* y, int incy, hipStream_t stream) +{ + return rocBLASStatusToHIPStatus( + rocblas_set_vector_async(n, elemSize, x, incx, y, incy, stream)); +} + +hipblasStatus_t hipblasGetVectorAsync( + int n, int elemSize, const void* x, int incx, void* y, int incy, hipStream_t stream) +{ + return rocBLASStatusToHIPStatus( + rocblas_get_vector_async(n, elemSize, x, incx, y, incy, stream)); +} + +hipblasStatus_t hipblasSetMatrixAsync( + int rows, int cols, int elemSize, const void* A, int lda, void* B, int ldb, hipStream_t stream) +{ + return rocBLASStatusToHIPStatus( + rocblas_set_matrix_async(rows, cols, elemSize, A, lda, B, ldb, stream)); +} + +hipblasStatus_t hipblasGetMatrixAsync( + int rows, int cols, int elemSize, const void* A, int lda, void* B, int ldb, hipStream_t stream) +{ + return rocBLASStatusToHIPStatus( + rocblas_get_matrix_async(rows, cols, elemSize, A, lda, B, ldb, stream)); +} + // amax hipblasStatus_t hipblasIsamax(hipblasHandle_t handle, int n, const float* x, int incx, int* result) { @@ -12266,34 +12282,112 @@ hipblasStatus_t hipblasZdgmmStridedBatched(hipblasHandle_t handle, //rocSOLVER functions //-------------------------------------------------------------------------------------- +// The following functions are not included in the public API and must be declared + +#ifdef __cplusplus +extern "C" { +#endif + +rocblas_status rocsolver_sgeqrf_ptr_batched(rocblas_handle handle, + const rocblas_int m, + const rocblas_int n, + float* const A[], + const rocblas_int lda, + float* const ipiv[], + const rocblas_int batch_count); + +rocblas_status rocsolver_dgeqrf_ptr_batched(rocblas_handle handle, + const rocblas_int m, + const rocblas_int n, + double* const A[], + const rocblas_int lda, + double* const ipiv[], + const rocblas_int batch_count); + +rocblas_status rocsolver_cgeqrf_ptr_batched(rocblas_handle handle, + const rocblas_int m, + const rocblas_int n, + rocblas_float_complex* const A[], + const rocblas_int lda, + rocblas_float_complex* const ipiv[], + const rocblas_int batch_count); + +rocblas_status rocsolver_zgeqrf_ptr_batched(rocblas_handle handle, + const rocblas_int m, + const rocblas_int n, + rocblas_double_complex* const A[], + const rocblas_int lda, + rocblas_double_complex* const ipiv[], + const rocblas_int batch_count); + +rocblas_status rocsolver_sgetri_outofplace_batched(rocblas_handle handle, + const rocblas_int n, + float* const A[], + const rocblas_int lda, + rocblas_int* ipiv, + const rocblas_stride strideP, + float* const C[], + const rocblas_int ldc, + rocblas_int* info, + const rocblas_int batch_count); + +rocblas_status rocsolver_dgetri_outofplace_batched(rocblas_handle handle, + const rocblas_int n, + double* const A[], + const rocblas_int lda, + rocblas_int* ipiv, + const rocblas_stride strideP, + double* const C[], + const rocblas_int ldc, + rocblas_int* info, + const rocblas_int batch_count); + +rocblas_status rocsolver_cgetri_outofplace_batched(rocblas_handle handle, + const rocblas_int n, + rocblas_float_complex* const A[], + const rocblas_int lda, + rocblas_int* ipiv, + const rocblas_stride strideP, + rocblas_float_complex* const C[], + const rocblas_int ldc, + rocblas_int* info, + const rocblas_int batch_count); + +rocblas_status rocsolver_zgetri_outofplace_batched(rocblas_handle handle, + const rocblas_int n, + rocblas_double_complex* const A[], + const rocblas_int lda, + rocblas_int* ipiv, + const rocblas_stride strideP, + rocblas_double_complex* const C[], + const rocblas_int ldc, + rocblas_int* info, + const rocblas_int batch_count); + +#ifdef __cplusplus +} +#endif + // getrf hipblasStatus_t hipblasSgetrf( hipblasHandle_t handle, const int n, float* A, const int lda, int* ipiv, int* info) { - rocsolver_status status; - USE_DEVICE_POINTER_MODE( - handle, status = rocsolver_sgetrf((rocblas_handle)handle, n, n, A, lda, ipiv, info)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus( + rocsolver_sgetrf((rocblas_handle)handle, n, n, A, lda, ipiv, info)); } hipblasStatus_t hipblasDgetrf( hipblasHandle_t handle, const int n, double* A, const int lda, int* ipiv, int* info) { - rocsolver_status status; - USE_DEVICE_POINTER_MODE( - handle, status = rocsolver_dgetrf((rocblas_handle)handle, n, n, A, lda, ipiv, info)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus( + rocsolver_dgetrf((rocblas_handle)handle, n, n, A, lda, ipiv, info)); } hipblasStatus_t hipblasCgetrf( hipblasHandle_t handle, const int n, hipblasComplex* A, const int lda, int* ipiv, int* info) { - rocsolver_status status; - USE_DEVICE_POINTER_MODE( - handle, - status = rocsolver_cgetrf( - (rocblas_handle)handle, n, n, (rocblas_float_complex*)A, lda, ipiv, info)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus( + rocsolver_cgetrf((rocblas_handle)handle, n, n, (rocblas_float_complex*)A, lda, ipiv, info)); } hipblasStatus_t hipblasZgetrf(hipblasHandle_t handle, @@ -12303,12 +12397,8 @@ hipblasStatus_t hipblasZgetrf(hipblasHandle_t handle, int* ipiv, int* info) { - rocsolver_status status; - USE_DEVICE_POINTER_MODE( - handle, - status = rocsolver_zgetrf( - (rocblas_handle)handle, n, n, (rocblas_double_complex*)A, lda, ipiv, info)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_zgetrf( + (rocblas_handle)handle, n, n, (rocblas_double_complex*)A, lda, ipiv, info)); } // getrf_batched @@ -12320,11 +12410,8 @@ hipblasStatus_t hipblasSgetrfBatched(hipblasHandle_t handle, int* info, const int batch_count) { - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status = rocsolver_sgetrf_batched( - (rocblas_handle)handle, n, n, A, lda, ipiv, n, info, batch_count)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus( + rocsolver_sgetrf_batched((rocblas_handle)handle, n, n, A, lda, ipiv, n, info, batch_count)); } hipblasStatus_t hipblasDgetrfBatched(hipblasHandle_t handle, @@ -12335,11 +12422,8 @@ hipblasStatus_t hipblasDgetrfBatched(hipblasHandle_t handle, int* info, const int batch_count) { - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status = rocsolver_dgetrf_batched( - (rocblas_handle)handle, n, n, A, lda, ipiv, n, info, batch_count)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus( + rocsolver_dgetrf_batched((rocblas_handle)handle, n, n, A, lda, ipiv, n, info, batch_count)); } hipblasStatus_t hipblasCgetrfBatched(hipblasHandle_t handle, @@ -12350,18 +12434,8 @@ hipblasStatus_t hipblasCgetrfBatched(hipblasHandle_t handle, int* info, const int batch_count) { - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status = rocsolver_cgetrf_batched((rocblas_handle)handle, - n, - n, - (rocblas_float_complex**)A, - lda, - ipiv, - n, - info, - batch_count)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_cgetrf_batched( + (rocblas_handle)handle, n, n, (rocblas_float_complex**)A, lda, ipiv, n, info, batch_count)); } hipblasStatus_t hipblasZgetrfBatched(hipblasHandle_t handle, @@ -12372,18 +12446,15 @@ hipblasStatus_t hipblasZgetrfBatched(hipblasHandle_t handle, int* info, const int batch_count) { - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status = rocsolver_zgetrf_batched((rocblas_handle)handle, - n, - n, - (rocblas_double_complex**)A, - lda, - ipiv, - n, - info, - batch_count)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_zgetrf_batched((rocblas_handle)handle, + n, + n, + (rocblas_double_complex**)A, + lda, + ipiv, + n, + info, + batch_count)); } // getrf_strided_batched @@ -12397,12 +12468,8 @@ hipblasStatus_t hipblasSgetrfStridedBatched(hipblasHandle_t handle, int* info, const int batch_count) { - rocsolver_status status; - USE_DEVICE_POINTER_MODE( - handle, - status = rocsolver_sgetrf_strided_batched( - (rocblas_handle)handle, n, n, A, lda, strideA, ipiv, strideP, info, batch_count)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_sgetrf_strided_batched( + (rocblas_handle)handle, n, n, A, lda, strideA, ipiv, strideP, info, batch_count)); } hipblasStatus_t hipblasDgetrfStridedBatched(hipblasHandle_t handle, @@ -12415,12 +12482,8 @@ hipblasStatus_t hipblasDgetrfStridedBatched(hipblasHandle_t handle, int* info, const int batch_count) { - rocsolver_status status; - USE_DEVICE_POINTER_MODE( - handle, - status = rocsolver_dgetrf_strided_batched( - (rocblas_handle)handle, n, n, A, lda, strideA, ipiv, strideP, info, batch_count)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_dgetrf_strided_batched( + (rocblas_handle)handle, n, n, A, lda, strideA, ipiv, strideP, info, batch_count)); } hipblasStatus_t hipblasCgetrfStridedBatched(hipblasHandle_t handle, @@ -12433,19 +12496,16 @@ hipblasStatus_t hipblasCgetrfStridedBatched(hipblasHandle_t handle, int* info, const int batch_count) { - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status = rocsolver_cgetrf_strided_batched((rocblas_handle)handle, - n, - n, - (rocblas_float_complex*)A, - lda, - strideA, - ipiv, - strideP, - info, - batch_count)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_cgetrf_strided_batched((rocblas_handle)handle, + n, + n, + (rocblas_float_complex*)A, + lda, + strideA, + ipiv, + strideP, + info, + batch_count)); } hipblasStatus_t hipblasZgetrfStridedBatched(hipblasHandle_t handle, @@ -12458,19 +12518,16 @@ hipblasStatus_t hipblasZgetrfStridedBatched(hipblasHandle_t handle, int* info, const int batch_count) { - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status = rocsolver_zgetrf_strided_batched((rocblas_handle)handle, - n, - n, - (rocblas_double_complex*)A, - lda, - strideA, - ipiv, - strideP, - info, - batch_count)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_zgetrf_strided_batched((rocblas_handle)handle, + n, + n, + (rocblas_double_complex*)A, + lda, + strideA, + ipiv, + strideP, + info, + batch_count)); } // getrs @@ -12504,18 +12561,8 @@ hipblasStatus_t hipblasSgetrs(hipblasHandle_t handle, else *info = 0; - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status = rocsolver_sgetrs((rocblas_handle)handle, - hipOperationToHCCOperation(trans), - n, - nrhs, - A, - lda, - ipiv, - B, - ldb)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_sgetrs( + (rocblas_handle)handle, hipOperationToHCCOperation(trans), n, nrhs, A, lda, ipiv, B, ldb)); } hipblasStatus_t hipblasDgetrs(hipblasHandle_t handle, @@ -12548,18 +12595,8 @@ hipblasStatus_t hipblasDgetrs(hipblasHandle_t handle, else *info = 0; - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status = rocsolver_dgetrs((rocblas_handle)handle, - hipOperationToHCCOperation(trans), - n, - nrhs, - A, - lda, - ipiv, - B, - ldb)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_dgetrs( + (rocblas_handle)handle, hipOperationToHCCOperation(trans), n, nrhs, A, lda, ipiv, B, ldb)); } hipblasStatus_t hipblasCgetrs(hipblasHandle_t handle, @@ -12592,18 +12629,15 @@ hipblasStatus_t hipblasCgetrs(hipblasHandle_t handle, else *info = 0; - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status = rocsolver_cgetrs((rocblas_handle)handle, - hipOperationToHCCOperation(trans), - n, - nrhs, - (rocblas_float_complex*)A, - lda, - ipiv, - (rocblas_float_complex*)B, - ldb)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_cgetrs((rocblas_handle)handle, + hipOperationToHCCOperation(trans), + n, + nrhs, + (rocblas_float_complex*)A, + lda, + ipiv, + (rocblas_float_complex*)B, + ldb)); } hipblasStatus_t hipblasZgetrs(hipblasHandle_t handle, @@ -12636,18 +12670,15 @@ hipblasStatus_t hipblasZgetrs(hipblasHandle_t handle, else *info = 0; - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status = rocsolver_zgetrs((rocblas_handle)handle, - hipOperationToHCCOperation(trans), - n, - nrhs, - (rocblas_double_complex*)A, - lda, - ipiv, - (rocblas_double_complex*)B, - ldb)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_zgetrs((rocblas_handle)handle, + hipOperationToHCCOperation(trans), + n, + nrhs, + (rocblas_double_complex*)A, + lda, + ipiv, + (rocblas_double_complex*)B, + ldb)); } // getrs_batched @@ -12684,20 +12715,17 @@ hipblasStatus_t hipblasSgetrsBatched(hipblasHandle_t handle, else *info = 0; - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status = rocsolver_sgetrs_batched((rocblas_handle)handle, - hipOperationToHCCOperation(trans), - n, - nrhs, - A, - lda, - ipiv, - n, - B, - ldb, - batch_count)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_sgetrs_batched((rocblas_handle)handle, + hipOperationToHCCOperation(trans), + n, + nrhs, + A, + lda, + ipiv, + n, + B, + ldb, + batch_count)); } hipblasStatus_t hipblasDgetrsBatched(hipblasHandle_t handle, @@ -12733,20 +12761,17 @@ hipblasStatus_t hipblasDgetrsBatched(hipblasHandle_t handle, else *info = 0; - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status = rocsolver_dgetrs_batched((rocblas_handle)handle, - hipOperationToHCCOperation(trans), - n, - nrhs, - A, - lda, - ipiv, - n, - B, - ldb, - batch_count)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_dgetrs_batched((rocblas_handle)handle, + hipOperationToHCCOperation(trans), + n, + nrhs, + A, + lda, + ipiv, + n, + B, + ldb, + batch_count)); } hipblasStatus_t hipblasCgetrsBatched(hipblasHandle_t handle, @@ -12782,20 +12807,17 @@ hipblasStatus_t hipblasCgetrsBatched(hipblasHandle_t handle, else *info = 0; - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status = rocsolver_cgetrs_batched((rocblas_handle)handle, - hipOperationToHCCOperation(trans), - n, - nrhs, - (rocblas_float_complex**)A, - lda, - ipiv, - n, - (rocblas_float_complex**)B, - ldb, - batch_count)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_cgetrs_batched((rocblas_handle)handle, + hipOperationToHCCOperation(trans), + n, + nrhs, + (rocblas_float_complex**)A, + lda, + ipiv, + n, + (rocblas_float_complex**)B, + ldb, + batch_count)); } hipblasStatus_t hipblasZgetrsBatched(hipblasHandle_t handle, @@ -12831,20 +12853,17 @@ hipblasStatus_t hipblasZgetrsBatched(hipblasHandle_t handle, else *info = 0; - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status = rocsolver_zgetrs_batched((rocblas_handle)handle, - hipOperationToHCCOperation(trans), - n, - nrhs, - (rocblas_double_complex**)A, - lda, - ipiv, - n, - (rocblas_double_complex**)B, - ldb, - batch_count)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_zgetrs_batched((rocblas_handle)handle, + hipOperationToHCCOperation(trans), + n, + nrhs, + (rocblas_double_complex**)A, + lda, + ipiv, + n, + (rocblas_double_complex**)B, + ldb, + batch_count)); } // getrs_strided_batched @@ -12884,23 +12903,20 @@ hipblasStatus_t hipblasSgetrsStridedBatched(hipblasHandle_t handle, else *info = 0; - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status - = rocsolver_sgetrs_strided_batched((rocblas_handle)handle, - hipOperationToHCCOperation(trans), - n, - nrhs, - A, - lda, - strideA, - ipiv, - strideP, - B, - ldb, - strideB, - batch_count)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus( + rocsolver_sgetrs_strided_batched((rocblas_handle)handle, + hipOperationToHCCOperation(trans), + n, + nrhs, + A, + lda, + strideA, + ipiv, + strideP, + B, + ldb, + strideB, + batch_count)); } hipblasStatus_t hipblasDgetrsStridedBatched(hipblasHandle_t handle, @@ -12939,23 +12955,20 @@ hipblasStatus_t hipblasDgetrsStridedBatched(hipblasHandle_t handle, else *info = 0; - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status - = rocsolver_dgetrs_strided_batched((rocblas_handle)handle, - hipOperationToHCCOperation(trans), - n, - nrhs, - A, - lda, - strideA, - ipiv, - strideP, - B, - ldb, - strideB, - batch_count)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus( + rocsolver_dgetrs_strided_batched((rocblas_handle)handle, + hipOperationToHCCOperation(trans), + n, + nrhs, + A, + lda, + strideA, + ipiv, + strideP, + B, + ldb, + strideB, + batch_count)); } hipblasStatus_t hipblasCgetrsStridedBatched(hipblasHandle_t handle, @@ -12994,23 +13007,20 @@ hipblasStatus_t hipblasCgetrsStridedBatched(hipblasHandle_t handle, else *info = 0; - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status - = rocsolver_cgetrs_strided_batched((rocblas_handle)handle, - hipOperationToHCCOperation(trans), - n, - nrhs, - (rocblas_float_complex*)A, - lda, - strideA, - ipiv, - strideP, - (rocblas_float_complex*)B, - ldb, - strideB, - batch_count)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus( + rocsolver_cgetrs_strided_batched((rocblas_handle)handle, + hipOperationToHCCOperation(trans), + n, + nrhs, + (rocblas_float_complex*)A, + lda, + strideA, + ipiv, + strideP, + (rocblas_float_complex*)B, + ldb, + strideB, + batch_count)); } hipblasStatus_t hipblasZgetrsStridedBatched(hipblasHandle_t handle, @@ -13049,23 +13059,93 @@ hipblasStatus_t hipblasZgetrsStridedBatched(hipblasHandle_t handle, else *info = 0; - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status - = rocsolver_zgetrs_strided_batched((rocblas_handle)handle, - hipOperationToHCCOperation(trans), - n, - nrhs, - (rocblas_double_complex*)A, - lda, - strideA, - ipiv, - strideP, - (rocblas_double_complex*)B, - ldb, - strideB, - batch_count)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus( + rocsolver_zgetrs_strided_batched((rocblas_handle)handle, + hipOperationToHCCOperation(trans), + n, + nrhs, + (rocblas_double_complex*)A, + lda, + strideA, + ipiv, + strideP, + (rocblas_double_complex*)B, + ldb, + strideB, + batch_count)); +} + +// getri_batched +hipblasStatus_t hipblasSgetriBatched(hipblasHandle_t handle, + const int n, + float* const A[], + const int lda, + int* ipiv, + float* const C[], + const int ldc, + int* info, + const int batch_count) +{ + return rocBLASStatusToHIPStatus(rocsolver_sgetri_outofplace_batched( + (rocblas_handle)handle, n, A, lda, ipiv, n, C, ldc, info, batch_count)); +} + +hipblasStatus_t hipblasDgetriBatched(hipblasHandle_t handle, + const int n, + double* const A[], + const int lda, + int* ipiv, + double* const C[], + const int ldc, + int* info, + const int batch_count) +{ + return rocBLASStatusToHIPStatus(rocsolver_dgetri_outofplace_batched( + (rocblas_handle)handle, n, A, lda, ipiv, n, C, ldc, info, batch_count)); +} + +hipblasStatus_t hipblasCgetriBatched(hipblasHandle_t handle, + const int n, + hipblasComplex* const A[], + const int lda, + int* ipiv, + hipblasComplex* const C[], + const int ldc, + int* info, + const int batch_count) +{ + return rocBLASStatusToHIPStatus(rocsolver_cgetri_outofplace_batched((rocblas_handle)handle, + n, + (rocblas_float_complex**)A, + lda, + ipiv, + n, + (rocblas_float_complex**)C, + ldc, + info, + batch_count)); +} + +hipblasStatus_t hipblasZgetriBatched(hipblasHandle_t handle, + const int n, + hipblasDoubleComplex* const A[], + const int lda, + int* ipiv, + hipblasDoubleComplex* const C[], + const int ldc, + int* info, + const int batch_count) +{ + return rocBLASStatusToHIPStatus(rocsolver_zgetri_outofplace_batched((rocblas_handle)handle, + n, + (rocblas_double_complex**)A, + lda, + ipiv, + n, + (rocblas_double_complex**)C, + ldc, + info, + batch_count)); } // geqrf @@ -13092,10 +13172,7 @@ hipblasStatus_t hipblasSgeqrf(hipblasHandle_t handle, else *info = 0; - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status = rocsolver_sgeqrf((rocblas_handle)handle, m, n, A, lda, tau)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_sgeqrf((rocblas_handle)handle, m, n, A, lda, tau)); } hipblasStatus_t hipblasDgeqrf(hipblasHandle_t handle, @@ -13121,10 +13198,7 @@ hipblasStatus_t hipblasDgeqrf(hipblasHandle_t handle, else *info = 0; - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status = rocsolver_dgeqrf((rocblas_handle)handle, m, n, A, lda, tau)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_dgeqrf((rocblas_handle)handle, m, n, A, lda, tau)); } hipblasStatus_t hipblasCgeqrf(hipblasHandle_t handle, @@ -13150,17 +13224,8 @@ hipblasStatus_t hipblasCgeqrf(hipblasHandle_t handle, else *info = 0; - // rocsolver_status status; - // USE_DEVICE_POINTER_MODE(handle, status = rocsolver_cgeqrf( - // (rocblas_handle)handle, - // m, - // n, - // (rocblas_float_complex*)A, - // lda, - // (rocblas_float_complex*)tau)); - // return rocBLASStatusToHIPStatus(status); - - return HIPBLAS_STATUS_NOT_SUPPORTED; + return rocBLASStatusToHIPStatus(rocsolver_cgeqrf( + (rocblas_handle)handle, m, n, (rocblas_float_complex*)A, lda, (rocblas_float_complex*)tau)); } hipblasStatus_t hipblasZgeqrf(hipblasHandle_t handle, @@ -13186,17 +13251,12 @@ hipblasStatus_t hipblasZgeqrf(hipblasHandle_t handle, else *info = 0; - // rocsolver_status status; - // USE_DEVICE_POINTER_MODE(handle, status = rocsolver_zgeqrf( - // (rocblas_handle)handle, - // m, - // n, - // (rocblas_double_complex*)A, - // lda, - // (rocblas_double_complex*)tau)); - // return rocBLASStatusToHIPStatus(status); - - return HIPBLAS_STATUS_NOT_SUPPORTED; + return rocBLASStatusToHIPStatus(rocsolver_zgeqrf((rocblas_handle)handle, + m, + n, + (rocblas_double_complex*)A, + lda, + (rocblas_double_complex*)tau)); } // geqrf_batched @@ -13226,31 +13286,8 @@ hipblasStatus_t hipblasSgeqrfBatched(hipblasHandle_t handle, else *info = 0; - int perArray = std::min(m, n); - float* ipiv = NULL; - if(perArray > 0) - hipMalloc(&ipiv, batch_count * perArray * sizeof(float)); - - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status = rocsolver_sgeqrf_batched( - (rocblas_handle)handle, m, n, A, lda, ipiv, perArray, batch_count)); - - if(status == rocblas_status_success) - { - // TO DO: Copy ipiv into tau using a fast kernel call - float* hTau[batch_count]; - hipMemcpy(hTau, tau, sizeof(float*) * batch_count, hipMemcpyDeviceToHost); - - for(int b = 0; b < batch_count; b++) - hipMemcpy( - hTau[b], ipiv + b * perArray, sizeof(float) * perArray, hipMemcpyDeviceToDevice); - } - - if(perArray > 0) - hipFree(ipiv); - - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus( + rocsolver_sgeqrf_ptr_batched((rocblas_handle)handle, m, n, A, lda, tau, batch_count)); } hipblasStatus_t hipblasDgeqrfBatched(hipblasHandle_t handle, @@ -13279,31 +13316,8 @@ hipblasStatus_t hipblasDgeqrfBatched(hipblasHandle_t handle, else *info = 0; - int perArray = std::min(m, n); - double* ipiv = NULL; - if(perArray > 0) - hipMalloc(&ipiv, batch_count * perArray * sizeof(double)); - - rocsolver_status status; - USE_DEVICE_POINTER_MODE(handle, - status = rocsolver_dgeqrf_batched( - (rocblas_handle)handle, m, n, A, lda, ipiv, perArray, batch_count)); - - if(status == rocblas_status_success) - { - // TO DO: Copy ipiv into tau using a fast kernel call - double* hTau[batch_count]; - hipMemcpy(hTau, tau, sizeof(double*) * batch_count, hipMemcpyDeviceToHost); - - for(int b = 0; b < batch_count; b++) - hipMemcpy( - hTau[b], ipiv + b * perArray, sizeof(double) * perArray, hipMemcpyDeviceToDevice); - } - - if(perArray > 0) - hipFree(ipiv); - - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus( + rocsolver_dgeqrf_ptr_batched((rocblas_handle)handle, m, n, A, lda, tau, batch_count)); } hipblasStatus_t hipblasCgeqrfBatched(hipblasHandle_t handle, @@ -13332,38 +13346,13 @@ hipblasStatus_t hipblasCgeqrfBatched(hipblasHandle_t handle, else *info = 0; - // int perArray = std::min(m, n); - // rocblas_float_complex* ipiv = NULL; - // if (perArray > 0) - // hipMalloc(&ipiv, batch_count * perArray * sizeof(rocblas_float_complex)); - - // rocsolver_status status; - // USE_DEVICE_POINTER_MODE(handle, status = rocsolver_cgeqrf_batched( - // (rocblas_handle)handle, - // m, - // n, - // (rocblas_float_complex**)A, - // lda, - // (rocblas_float_complex*)ipiv, - // perArray, - // batch_count)); - - // if (status == rocblas_status_success) - // { - // // TO DO: Copy ipiv into tau using a fast kernel call - // rocblas_float_complex* hTau[batch_count]; - // hipMemcpy(hTau, tau, sizeof(rocblas_float_complex*) * batch_count, hipMemcpyDeviceToHost); - - // for(int b = 0; b < batch_count; b++) - // hipMemcpy(hTau[b], ipiv + b * perArray, sizeof(rocblas_float_complex) * perArray, hipMemcpyDeviceToDevice); - // } - - // if (perArray > 0) - // hipFree(ipiv); - - // return rocBLASStatusToHIPStatus(status); - - return HIPBLAS_STATUS_NOT_SUPPORTED; + return rocBLASStatusToHIPStatus(rocsolver_cgeqrf_ptr_batched((rocblas_handle)handle, + m, + n, + (rocblas_float_complex**)A, + lda, + (rocblas_float_complex**)tau, + batch_count)); } hipblasStatus_t hipblasZgeqrfBatched(hipblasHandle_t handle, @@ -13392,38 +13381,13 @@ hipblasStatus_t hipblasZgeqrfBatched(hipblasHandle_t handle, else *info = 0; - // int perArray = std::min(m, n); - // rocblas_double_complex* ipiv = NULL; - // if (perArray > 0) - // hipMalloc(&ipiv, batch_count * perArray * sizeof(rocblas_double_complex)); - - // rocsolver_status status; - // USE_DEVICE_POINTER_MODE(handle, status = rocsolver_zgeqrf_batched( - // (rocblas_handle)handle, - // m, - // n, - // (rocblas_double_complex**)A, - // lda, - // (rocblas_double_complex*)ipiv, - // perArray, - // batch_count)); - - // if (status == rocblas_status_success) - // { - // // TO DO: Copy ipiv into tau using a fast kernel call - // rocblas_double_complex* hTau[batch_count]; - // hipMemcpy(hTau, tau, sizeof(rocblas_double_complex*) * batch_count, hipMemcpyDeviceToHost); - - // for(int b = 0; b < batch_count; b++) - // hipMemcpy(hTau[b], ipiv + b * perArray, sizeof(rocblas_double_complex) * perArray, hipMemcpyDeviceToDevice); - // } - - // if (perArray > 0) - // hipFree(ipiv); - - // return rocBLASStatusToHIPStatus(status); - - return HIPBLAS_STATUS_NOT_SUPPORTED; + return rocBLASStatusToHIPStatus(rocsolver_zgeqrf_ptr_batched((rocblas_handle)handle, + m, + n, + (rocblas_double_complex**)A, + lda, + (rocblas_double_complex**)tau, + batch_count)); } // geqrf_strided_batched @@ -13455,12 +13419,8 @@ hipblasStatus_t hipblasSgeqrfStridedBatched(hipblasHandle_t handle, else *info = 0; - rocsolver_status status; - USE_DEVICE_POINTER_MODE( - handle, - status = rocsolver_sgeqrf_strided_batched( - (rocblas_handle)handle, m, n, A, lda, strideA, tau, strideT, batch_count)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_sgeqrf_strided_batched( + (rocblas_handle)handle, m, n, A, lda, strideA, tau, strideT, batch_count)); } hipblasStatus_t hipblasDgeqrfStridedBatched(hipblasHandle_t handle, @@ -13491,12 +13451,8 @@ hipblasStatus_t hipblasDgeqrfStridedBatched(hipblasHandle_t handle, else *info = 0; - rocsolver_status status; - USE_DEVICE_POINTER_MODE( - handle, - status = rocsolver_dgeqrf_strided_batched( - (rocblas_handle)handle, m, n, A, lda, strideA, tau, strideT, batch_count)); - return rocBLASStatusToHIPStatus(status); + return rocBLASStatusToHIPStatus(rocsolver_dgeqrf_strided_batched( + (rocblas_handle)handle, m, n, A, lda, strideA, tau, strideT, batch_count)); } hipblasStatus_t hipblasCgeqrfStridedBatched(hipblasHandle_t handle, @@ -13527,20 +13483,15 @@ hipblasStatus_t hipblasCgeqrfStridedBatched(hipblasHandle_t handle, else *info = 0; - // rocsolver_status status; - // USE_DEVICE_POINTER_MODE(handle, status = rocsolver_cgeqrf_strided_batched( - // (rocblas_handle)handle, - // m, - // n, - // (rocblas_float_complex*)A, - // lda, - // strideA, - // (rocblas_float_complex*)tau, - // strideT, - // batch_count)); - // return rocBLASStatusToHIPStatus(status); - - return HIPBLAS_STATUS_NOT_SUPPORTED; + return rocBLASStatusToHIPStatus(rocsolver_cgeqrf_strided_batched((rocblas_handle)handle, + m, + n, + (rocblas_float_complex*)A, + lda, + strideA, + (rocblas_float_complex*)tau, + strideT, + batch_count)); } hipblasStatus_t hipblasZgeqrfStridedBatched(hipblasHandle_t handle, @@ -13571,20 +13522,15 @@ hipblasStatus_t hipblasZgeqrfStridedBatched(hipblasHandle_t handle, else *info = 0; - // rocsolver_status status; - // USE_DEVICE_POINTER_MODE(handle, status = rocsolver_zgeqrf_strided_batched( - // (rocblas_handle)handle, - // m, - // n, - // (rocblas_double_complex*)A, - // lda, - // strideA, - // (rocblas_double_complex*)tau, - // strideT, - // batch_count)); - // return rocBLASStatusToHIPStatus(status); - - return HIPBLAS_STATUS_NOT_SUPPORTED; + return rocBLASStatusToHIPStatus(rocsolver_zgeqrf_strided_batched((rocblas_handle)handle, + m, + n, + (rocblas_double_complex*)A, + lda, + strideA, + (rocblas_double_complex*)tau, + strideT, + batch_count)); } #endif @@ -14176,6 +14122,7 @@ hipblasStatus_t hipblasZgemmStridedBatched(hipblasHandle_t handle, } #endif +// gemm_ex extern "C" hipblasStatus_t hipblasGemmEx(hipblasHandle_t handle, hipblasOperation_t transa, hipblasOperation_t transb, @@ -14335,3 +14282,115 @@ extern "C" hipblasStatus_t hipblasGemmStridedBatchedEx(hipblasHandle_t handle solution_index, flags)); } + +// trsm_ex +extern "C" hipblasStatus_t hipblasTrsmEx(hipblasHandle_t handle, + hipblasSideMode_t side, + hipblasFillMode_t uplo, + hipblasOperation_t transA, + hipblasDiagType_t diag, + int m, + int n, + const void* alpha, + void* A, + int lda, + void* B, + int ldb, + const void* invA, + int invA_size, + hipblasDatatype_t compute_type) +{ + return rocBLASStatusToHIPStatus(rocblas_trsm_ex((rocblas_handle)handle, + hipSideToHCCSide(side), + hipFillToHCCFill(uplo), + hipOperationToHCCOperation(transA), + hipDiagonalToHCCDiagonal(diag), + m, + n, + alpha, + A, + lda, + B, + ldb, + invA, + invA_size, + HIPDatatypeToRocblasDatatype(compute_type))); +} + +extern "C" hipblasStatus_t hipblasTrsmBatchedEx(hipblasHandle_t handle, + hipblasSideMode_t side, + hipblasFillMode_t uplo, + hipblasOperation_t transA, + hipblasDiagType_t diag, + int m, + int n, + const void* alpha, + void* A, + int lda, + void* B, + int ldb, + int batch_count, + const void* invA, + int invA_size, + hipblasDatatype_t compute_type) +{ + return rocBLASStatusToHIPStatus( + rocblas_trsm_batched_ex((rocblas_handle)handle, + hipSideToHCCSide(side), + hipFillToHCCFill(uplo), + hipOperationToHCCOperation(transA), + hipDiagonalToHCCDiagonal(diag), + m, + n, + alpha, + A, + lda, + B, + ldb, + batch_count, + invA, + invA_size, + HIPDatatypeToRocblasDatatype(compute_type))); +} + +extern "C" hipblasStatus_t hipblasTrsmStridedBatchedEx(hipblasHandle_t handle, + hipblasSideMode_t side, + hipblasFillMode_t uplo, + hipblasOperation_t transA, + hipblasDiagType_t diag, + int m, + int n, + const void* alpha, + void* A, + int lda, + int stride_A, + void* B, + int ldb, + int stride_B, + int batch_count, + const void* invA, + int invA_size, + int stride_invA, + hipblasDatatype_t compute_type) +{ + return rocBLASStatusToHIPStatus( + rocblas_trsm_strided_batched_ex((rocblas_handle)handle, + hipSideToHCCSide(side), + hipFillToHCCFill(uplo), + hipOperationToHCCOperation(transA), + hipDiagonalToHCCDiagonal(diag), + m, + n, + alpha, + A, + lda, + stride_A, + B, + ldb, + stride_B, + batch_count, + invA, + invA_size, + stride_invA, + HIPDatatypeToRocblasDatatype(compute_type))); +} diff --git a/library/src/hipblas_module.f90 b/library/src/hipblas_module.f90 index 25ebf7ee7..8487d7799 100644 --- a/library/src/hipblas_module.f90 +++ b/library/src/hipblas_module.f90 @@ -205,7 +205,73 @@ function hipblasGetMatrix(rows, cols, elemSize, A, lda, B, ldb) & integer(c_int), value :: ldb end function hipblasGetMatrix end interface - + + interface + function hipblasSetVectorAsync(n, elemSize, x, incx, y, incy, stream) & + result(c_int) & + bind(c, name = 'hipblasSetVectorAsync') + use iso_c_binding + implicit none + integer(c_int), value :: n + integer(c_int), value :: elemSize + type(c_ptr), value :: x + integer(c_int), value :: incx + type(c_ptr), value :: y + integer(c_int), value :: incy + type(c_ptr), value :: stream + end function hipblasSetVectorAsync + end interface + + interface + function hipblasGetVectorAsync(n, elemSize, x, incx, y, incy, stream) & + result(c_int) & + bind(c, name = 'hipblasGetVectorAsync') + use iso_c_binding + implicit none + integer(c_int), value :: n + integer(c_int), value :: elemSize + type(c_ptr), value :: x + integer(c_int), value :: incx + type(c_ptr), value :: y + integer(c_int), value :: incy + type(c_ptr), value :: stream + end function hipblasGetVectorAsync + end interface + + interface + function hipblasSetMatrixAsync(rows, cols, elemSize, A, lda, B, ldb, stream) & + result(c_int) & + bind(c, name = 'hipblasSetMatrixAsync') + use iso_c_binding + implicit none + integer(c_int), value :: rows + integer(c_int), value :: cols + integer(c_int), value :: elemSize + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: stream + end function hipblasSetMatrixAsync + end interface + + interface + function hipblasGetMatrixAsync(rows, cols, elemSize, A, lda, B, ldb, stream) & + result(c_int) & + bind(c, name = 'hipblasGetMatrixAsync') + use iso_c_binding + implicit none + integer(c_int), value :: rows + integer(c_int), value :: cols + integer(c_int), value :: elemSize + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: stream + end function hipblasGetMatrixAsync + end interface + !--------! ! blas 1 ! !--------! @@ -11871,4 +11937,802 @@ function hipblasTrsmStridedBatchedEx(handle, side, uplo, transA, diag, m, n, alp end function hipblasTrsmStridedBatchedEx end interface + !--------! + ! Solver ! + !--------! + + ! getrf + interface + function hipblasSgetrf(handle, n, A, lda, ipiv, info) & + result(c_int) & + bind(c, name = 'hipblasSgetrf') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: info + end function hipblasSgetrf + end interface + + interface + function hipblasDgetrf(handle, n, A, lda, ipiv, info) & + result(c_int) & + bind(c, name = 'hipblasDgetrf') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: info + end function hipblasDgetrf + end interface + + interface + function hipblasCgetrf(handle, n, A, lda, ipiv, info) & + result(c_int) & + bind(c, name = 'hipblasCgetrf') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: info + end function hipblasCgetrf + end interface + + interface + function hipblasZgetrf(handle, n, A, lda, ipiv, info) & + result(c_int) & + bind(c, name = 'hipblasZgetrf') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: info + end function hipblasZgetrf + end interface + + ! getrf_batched + interface + function hipblasSgetrfBatched(handle, n, A, lda, ipiv, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasSgetrfBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasSgetrfBatched + end interface + + interface + function hipblasDgetrfBatched(handle, n, A, lda, ipiv, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasDgetrfBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasDgetrfBatched + end interface + + interface + function hipblasCgetrfBatched(handle, n, A, lda, ipiv, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasCgetrfBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasCgetrfBatched + end interface + + interface + function hipblasZgetrfBatched(handle, n, A, lda, ipiv, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasZgetrfBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasZgetrfBatched + end interface + + ! getrf_strided_batched + interface + function hipblasSgetrfStridedBatched(handle, n, A, lda, stride_A,& + ipiv, stride_P, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasSgetrfStridedBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: ipiv + integer(c_int), value :: stride_P + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasSgetrfStridedBatched + end interface + + interface + function hipblasDgetrfStridedBatched(handle, n, A, lda, stride_A,& + ipiv, stride_P, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasDgetrfStridedBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: ipiv + integer(c_int), value :: stride_P + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasDgetrfStridedBatched + end interface + + interface + function hipblasCgetrfStridedBatched(handle, n, A, lda, stride_A,& + ipiv, stride_P, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasCgetrfStridedBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: ipiv + integer(c_int), value :: stride_P + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasCgetrfStridedBatched + end interface + + interface + function hipblasZgetrfStridedBatched(handle, n, A, lda, stride_A,& + ipiv, stride_P, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasZgetrfStridedBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: ipiv + integer(c_int), value :: stride_P + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasZgetrfStridedBatched + end interface + + ! getrs + interface + function hipblasSgetrs(handle, trans, n, nrhs, A, lda, ipiv,& + B, ldb, info) & + result(c_int) & + bind(c, name = 'hipblasSgetrs') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: info + end function hipblasSgetrs + end interface + + interface + function hipblasDgetrs(handle, trans, n, nrhs, A, lda, ipiv,& + B, ldb, info) & + result(c_int) & + bind(c, name = 'hipblasDgetrs') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: info + end function hipblasDgetrs + end interface + + interface + function hipblasCgetrs(handle, trans, n, nrhs, A, lda, ipiv,& + B, ldb, info) & + result(c_int) & + bind(c, name = 'hipblasCgetrs') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: info + end function hipblasCgetrs + end interface + + interface + function hipblasZgetrs(handle, trans, n, nrhs, A, lda, ipiv,& + B, ldb, info) & + result(c_int) & + bind(c, name = 'hipblasZgetrs') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: info + end function hipblasZgetrs + end interface + + ! getrs_batched + interface + function hipblasSgetrsBatched(handle, trans, n, nrhs, A, lda, ipiv,& + B, ldb, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasSgetrsBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasSgetrsBatched + end interface + + interface + function hipblasDgetrsBatched(handle, trans, n, nrhs, A, lda, ipiv,& + B, ldb, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasDgetrsBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasDgetrsBatched + end interface + + interface + function hipblasCgetrsBatched(handle, trans, n, nrhs, A, lda, ipiv,& + B, ldb, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasCgetrsBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasCgetrsBatched + end interface + + interface + function hipblasZgetrsBatched(handle, trans, n, nrhs, A, lda, ipiv,& + B, ldb, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasZgetrsBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: B + integer(c_int), value :: ldb + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasZgetrsBatched + end interface + + ! getrs_strided_batched + interface + function hipblasSgetrsStridedBatched(handle, trans, n, nrhs, A, lda, stride_A, ipiv,& + stride_P, B, ldb, stride_B, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasSgetrsStridedBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: ipiv + integer(c_int), value :: stride_P + type(c_ptr), value :: B + integer(c_int), value :: ldb + integer(c_int), value :: stride_B + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasSgetrsStridedBatched + end interface + + interface + function hipblasDgetrsStridedBatched(handle, trans, n, nrhs, A, lda, stride_A, ipiv,& + stride_P, B, ldb, stride_B, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasDgetrsStridedBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: ipiv + integer(c_int), value :: stride_P + type(c_ptr), value :: B + integer(c_int), value :: ldb + integer(c_int), value :: stride_B + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasDgetrsStridedBatched + end interface + + interface + function hipblasCgetrsStridedBatched(handle, trans, n, nrhs, A, lda, stride_A, ipiv,& + stride_P, B, ldb, stride_B, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasCgetrsStridedBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: ipiv + integer(c_int), value :: stride_P + type(c_ptr), value :: B + integer(c_int), value :: ldb + integer(c_int), value :: stride_B + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasCgetrsStridedBatched + end interface + + interface + function hipblasZgetrsStridedBatched(handle, trans, n, nrhs, A, lda, stride_A, ipiv,& + stride_P, B, ldb, stride_B, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasZgetrsStridedBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(kind(HIPBLAS_OP_N)), value :: trans + integer(c_int), value :: n + integer(c_int), value :: nrhs + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: ipiv + integer(c_int), value :: stride_P + type(c_ptr), value :: B + integer(c_int), value :: ldb + integer(c_int), value :: stride_B + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasZgetrsStridedBatched + end interface + + ! getri_batched + interface + function hipblasSgetriBatched(handle, n, A, lda, ipiv, C, ldc, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasSgetriBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: C + integer(c_int), value :: ldc + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasSgetriBatched + end interface + + interface + function hipblasDgetriBatched(handle, n, A, lda, ipiv, C, ldc, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasDgetriBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: C + integer(c_int), value :: ldc + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasDgetriBatched + end interface + + interface + function hipblasCgetriBatched(handle, n, A, lda, ipiv, C, ldc, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasCgetriBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: C + integer(c_int), value :: ldc + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasCgetriBatched + end interface + + interface + function hipblasZgetriBatched(handle, n, A, lda, ipiv, C, ldc, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasZgetriBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: ipiv + type(c_ptr), value :: C + integer(c_int), value :: ldc + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasZgetriBatched + end interface + + ! geqrf + interface + function hipblasSgeqrf(handle, m, n, A, lda, tau, info) & + result(c_int) & + bind(c, name = 'hipblasSgeqrf') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: tau + type(c_ptr), value :: info + end function hipblasSgeqrf + end interface + + interface + function hipblasDgeqrf(handle, m, n, A, lda, tau, info) & + result(c_int) & + bind(c, name = 'hipblasDgeqrf') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: tau + type(c_ptr), value :: info + end function hipblasDgeqrf + end interface + + interface + function hipblasCgeqrf(handle, m, n, A, lda, tau, info) & + result(c_int) & + bind(c, name = 'hipblasCgeqrf') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: tau + type(c_ptr), value :: info + end function hipblasCgeqrf + end interface + + interface + function hipblasZgeqrf(handle, m, n, A, lda, tau, info) & + result(c_int) & + bind(c, name = 'hipblasZgeqrf') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: tau + type(c_ptr), value :: info + end function hipblasZgeqrf + end interface + + ! geqrf_batched + interface + function hipblasSgeqrfBatched(handle, m, n, A, lda, tau, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasSgeqrfBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: tau + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasSgeqrfBatched + end interface + + interface + function hipblasDgeqrfBatched(handle, m, n, A, lda, tau, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasDgeqrfBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: tau + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasDgeqrfBatched + end interface + + interface + function hipblasCgeqrfBatched(handle, m, n, A, lda, tau, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasCgeqrfBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: tau + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasCgeqrfBatched + end interface + + interface + function hipblasZgeqrfBatched(handle, m, n, A, lda, tau, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasZgeqrfBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + type(c_ptr), value :: tau + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasZgeqrfBatched + end interface + + ! geqrf_strided_batched + interface + function hipblasSgeqrfStridedBatched(handle, m, n, A, lda, stride_A,& + tau, stride_T, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasSgeqrfStridedBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: tau + integer(c_int), value :: stride_T + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasSgeqrfStridedBatched + end interface + + interface + function hipblasDgeqrfStridedBatched(handle, m, n, A, lda, stride_A,& + tau, stride_T, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasDgeqrfStridedBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: tau + integer(c_int), value :: stride_T + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasDgeqrfStridedBatched + end interface + + interface + function hipblasCgeqrfStridedBatched(handle, m, n, A, lda, stride_A,& + tau, stride_T, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasCgeqrfStridedBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: tau + integer(c_int), value :: stride_T + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasCgeqrfStridedBatched + end interface + + interface + function hipblasZgeqrfStridedBatched(handle, m, n, A, lda, stride_A,& + tau, stride_T, info, batch_count) & + result(c_int) & + bind(c, name = 'hipblasZgeqrfStridedBatched') + use iso_c_binding + use hipblas_enums + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: m + integer(c_int), value :: n + type(c_ptr), value :: A + integer(c_int), value :: lda + integer(c_int), value :: stride_A + type(c_ptr), value :: tau + integer(c_int), value :: stride_T + type(c_ptr), value :: info + integer(c_int), value :: batch_count + end function hipblasZgeqrfStridedBatched + end interface + end module hipblas \ No newline at end of file diff --git a/library/src/nvcc_detail/hipblas.cpp b/library/src/nvcc_detail/hipblas.cpp index d7f26e84f..e3dfd58f0 100644 --- a/library/src/nvcc_detail/hipblas.cpp +++ b/library/src/nvcc_detail/hipblas.cpp @@ -286,6 +286,32 @@ hipblasStatus_t return hipCUBLASStatusToHIPStatus(cublasGetMatrix(rows, cols, elemSize, A, lda, B, ldb)); } +hipblasStatus_t hipblasSetVectorAsync( + int n, int elemSize, const void* x, int incx, void* y, int incy, hipStream_t stream) +{ + return hipCUBLASStatusToHIPStatus(cublasSetVectorAsync(n, elemSize, x, incx, y, incy, stream)); +} + +hipblasStatus_t hipblasGetVectorAsync( + int n, int elemSize, const void* x, int incx, void* y, int incy, hipStream_t stream) +{ + return hipCUBLASStatusToHIPStatus(cublasGetVectorAsync(n, elemSize, x, incx, y, incy, stream)); +} + +hipblasStatus_t hipblasSetMatrixAsync( + int rows, int cols, int elemSize, const void* A, int lda, void* B, int ldb, hipStream_t stream) +{ + return hipCUBLASStatusToHIPStatus( + cublasSetMatrixAsync(rows, cols, elemSize, A, lda, B, ldb, stream)); +} + +hipblasStatus_t hipblasGetMatrixAsync( + int rows, int cols, int elemSize, const void* A, int lda, void* B, int ldb, hipStream_t stream) +{ + return hipCUBLASStatusToHIPStatus( + cublasGetMatrixAsync(rows, cols, elemSize, A, lda, B, ldb, stream)); +} + // amax hipblasStatus_t hipblasIsamax(hipblasHandle_t handle, int n, const float* x, int incx, int* result) { @@ -9379,6 +9405,77 @@ hipblasStatus_t hipblasZgetrsStridedBatched(hipblasHandle_t handle, return HIPBLAS_STATUS_NOT_SUPPORTED; } +// getri_batched +hipblasStatus_t hipblasSgetriBatched(hipblasHandle_t handle, + const int n, + float* const A[], + const int lda, + int* ipiv, + float* const C[], + const int ldc, + int* info, + const int batch_count) +{ + return hipCUBLASStatusToHIPStatus( + cublasSgetriBatched((cublasHandle_t)handle, n, A, lda, ipiv, C, ldc, info, batch_count)); +} + +hipblasStatus_t hipblasDgetriBatched(hipblasHandle_t handle, + const int n, + double* const A[], + const int lda, + int* ipiv, + double* const C[], + const int ldc, + int* info, + const int batch_count) +{ + return hipCUBLASStatusToHIPStatus( + cublasDgetriBatched((cublasHandle_t)handle, n, A, lda, ipiv, C, ldc, info, batch_count)); +} + +hipblasStatus_t hipblasCgetriBatched(hipblasHandle_t handle, + const int n, + hipblasComplex* const A[], + const int lda, + int* ipiv, + hipblasComplex* const C[], + const int ldc, + int* info, + const int batch_count) +{ + return hipCUBLASStatusToHIPStatus(cublasCgetriBatched((cublasHandle_t)handle, + n, + (cuComplex**)A, + lda, + ipiv, + (cuComplex**)C, + ldc, + info, + batch_count)); +} + +hipblasStatus_t hipblasZgetriBatched(hipblasHandle_t handle, + const int n, + hipblasDoubleComplex* const A[], + const int lda, + int* ipiv, + hipblasDoubleComplex* const C[], + const int ldc, + int* info, + const int batch_count) +{ + return hipCUBLASStatusToHIPStatus(cublasZgetriBatched((cublasHandle_t)handle, + n, + (cuDoubleComplex**)A, + lda, + ipiv, + (cuDoubleComplex**)C, + ldc, + info, + batch_count)); +} + // geqrf hipblasStatus_t hipblasSgeqrf(hipblasHandle_t handle, const int m, @@ -10064,6 +10161,7 @@ hipblasStatus_t hipblasZgemmStridedBatched(hipblasHandle_t handle, } #endif +// gemm_ex extern "C" hipblasStatus_t hipblasGemmEx(hipblasHandle_t handle, hipblasOperation_t transa, hipblasOperation_t transb, @@ -10197,3 +10295,66 @@ extern "C" hipblasStatus_t hipblasGemmStridedBatchedEx(hipblasHandle_t handle HIPDatatypeToCudaDatatype(compute_type), HIPGemmAlgoToCudaGemmAlgo(algo))); } + +// trsm_ex +extern "C" hipblasStatus_t hipblasTrsmEx(hipblasHandle_t handle, + hipblasSideMode_t side, + hipblasFillMode_t uplo, + hipblasOperation_t transA, + hipblasDiagType_t diag, + int m, + int n, + const void* alpha, + void* A, + int lda, + void* B, + int ldb, + const void* invA, + int invA_size, + hipblasDatatype_t compute_type) +{ + return HIPBLAS_STATUS_NOT_SUPPORTED; +} + +extern "C" hipblasStatus_t hipblasTrsmBatchedEx(hipblasHandle_t handle, + hipblasSideMode_t side, + hipblasFillMode_t uplo, + hipblasOperation_t transA, + hipblasDiagType_t diag, + int m, + int n, + const void* alpha, + void* A, + int lda, + void* B, + int ldb, + int batch_count, + const void* invA, + int invA_size, + hipblasDatatype_t compute_type) +{ + return HIPBLAS_STATUS_NOT_SUPPORTED; +} + +extern "C" hipblasStatus_t hipblasTrsmStridedBatchedEx(hipblasHandle_t handle, + hipblasSideMode_t side, + hipblasFillMode_t uplo, + hipblasOperation_t transA, + hipblasDiagType_t diag, + int m, + int n, + const void* alpha, + void* A, + int lda, + int stride_A, + void* B, + int ldb, + int stride_B, + int batch_count, + const void* invA, + int invA_size, + int stride_invA, + hipblasDatatype_t compute_type) +{ + return HIPBLAS_STATUS_NOT_SUPPORTED; +}