Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ILP64 reference library for LAPACK on Windows #917

Merged
merged 2 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
782 changes: 370 additions & 412 deletions clients/common/cblas_interface.cpp

Large diffs are not rendered by default.

22 changes: 16 additions & 6 deletions clients/include/cblas_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -802,22 +802,32 @@ void ref_trmm(hipblasSideMode_t side,

// potrf
template <typename T>
int ref_potrf(char uplo, int m, T* A, int lda);
int64_t ref_potrf(char uplo, int64_t m, T* A, int64_t lda);

template <typename T>
int ref_getrf(int m, int n, T* A, int lda, int* ipiv);
int64_t ref_getrf(int64_t m, int64_t n, T* A, int64_t lda, int64_t* ipiv);

template <typename T>
int ref_getrs(char trans, int n, int nrhs, T* A, int lda, int* ipiv, T* B, int ldb);
int64_t ref_getrs(
char trans, int64_t n, int64_t nrhs, T* A, int64_t lda, int64_t* ipiv, T* B, int64_t ldb);

template <typename T>
int ref_getri(int n, T* A, int lda, int* ipiv, T* work, int lwork);
int64_t ref_getri(int64_t n, T* A, int64_t lda, int64_t* ipiv, T* work, int64_t lwork);

template <typename T>
int ref_geqrf(int m, int n, T* A, int lda, T* tau, T* work, int lwork);
int64_t ref_geqrf(int64_t m, int64_t n, T* A, int64_t lda, T* tau, T* work, int64_t lwork);

template <typename T>
int ref_gels(char trans, int m, int n, int nrhs, T* A, int lda, T* B, int ldb, T* work, int lwork);
int64_t ref_gels(char trans,
int64_t m,
int64_t n,
int64_t nrhs,
T* A,
int64_t lda,
T* B,
int64_t ldb,
T* work,
int64_t lwork);

#endif

Expand Down
19 changes: 11 additions & 8 deletions clients/include/solver/testing_getrf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@ void testing_getrf(const Arguments& arg)
}

// Naming: dK is in GPU (device) memory. hK is in CPU (host) memory
host_matrix<T> hA(M, N, lda);
host_matrix<T> hA1(M, N, lda);
host_matrix<int> hIpiv(1, Ipiv_size, 1);
host_matrix<int> hIpiv1(1, Ipiv_size, 1);
host_vector<int> hInfo(1);
host_vector<int> hInfo1(1);
host_matrix<T> hA(M, N, lda);
host_matrix<T> hA1(M, N, lda);
host_matrix<int> hIpiv(1, Ipiv_size, 1);
host_matrix<int64_t> hIpiv64(1, Ipiv_size, 1);
host_vector<int> hInfo(1);
host_vector<int> hInfo1(1);

// Allocate device memory
device_matrix<T> dA(M, N, lda);
Expand Down Expand Up @@ -145,13 +145,16 @@ void testing_getrf(const Arguments& arg)

// Copy output from device to CPU
CHECK_HIP_ERROR(hA1.transfer_from(dA));
CHECK_HIP_ERROR(hIpiv1.transfer_from(dIpiv));
CHECK_HIP_ERROR(hIpiv.transfer_from(dIpiv));
CHECK_HIP_ERROR(hipMemcpy(hInfo1, dInfo, sizeof(int), hipMemcpyDeviceToHost));

/* =====================================================================
CPU LAPACK
=================================================================== */
hInfo[0] = ref_getrf(M, N, hA.data(), lda, hIpiv.data());
for(int i = 0; i < Ipiv_size; i++)
hIpiv64[0][i] = hIpiv[0][i];

hInfo[0] = ref_getrf(M, N, hA.data(), lda, hIpiv64.data());

hipblas_error = norm_check_general<T>('F', M, N, lda, hA, hA1);
if(arg.unit_check)
Expand Down
11 changes: 7 additions & 4 deletions clients/include/solver/testing_getrf_batched.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,15 @@ void testing_getrf_batched(const Arguments& arg)
host_batch_matrix<T> hA(M, N, lda, batch_count);
host_batch_matrix<T> hA1(M, N, lda, batch_count);
host_vector<int> hIpiv(Ipiv_size);
host_vector<int> hIpiv1(Ipiv_size);
host_vector<int64_t> hIpiv64(Ipiv_size);
host_vector<int> hInfo(batch_count);
host_vector<int> hInfo1(batch_count);

// Check host memory allocation
CHECK_HIP_ERROR(hA.memcheck());
CHECK_HIP_ERROR(hA1.memcheck());
CHECK_HIP_ERROR(hIpiv.memcheck());
CHECK_HIP_ERROR(hIpiv1.memcheck());
CHECK_HIP_ERROR(hIpiv64.memcheck());

// Allocate device memory
device_batch_matrix<T> dA(M, N, lda, batch_count);
Expand Down Expand Up @@ -171,16 +171,19 @@ void testing_getrf_batched(const Arguments& arg)

// Copy output from device to CPU
CHECK_HIP_ERROR(hA1.transfer_from(dA));
CHECK_HIP_ERROR(hIpiv1.transfer_from(dIpiv));
CHECK_HIP_ERROR(hIpiv.transfer_from(dIpiv));
CHECK_HIP_ERROR(
hipMemcpy(hInfo1.data(), dInfo, batch_count * sizeof(int), hipMemcpyDeviceToHost));

/* =====================================================================
CPU LAPACK
=================================================================== */
for(int i = 0; i < Ipiv_size; i++)
hIpiv64[i] = hIpiv[i];

for(int b = 0; b < batch_count; b++)
{
hInfo[b] = ref_getrf(M, N, hA[b], lda, hIpiv.data() + b * strideP);
hInfo[b] = ref_getrf(M, N, hA[b], lda, hIpiv64.data() + b * strideP);
}

hipblas_error = norm_check_general<T>('F', M, N, lda, hA, hA1, batch_count);
Expand Down
10 changes: 5 additions & 5 deletions clients/include/solver/testing_getrf_npvt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,11 @@ void testing_getrf_npvt(const Arguments& arg)
}

// Naming: dK is in GPU (device) memory. hK is in CPU (host) memory
host_matrix<T> hA(M, N, lda);
host_matrix<T> hA1(M, N, lda);
host_vector<int> hIpiv(Ipiv_size);
host_vector<int> hInfo(1);
host_vector<int> hInfo1(1);
host_matrix<T> hA(M, N, lda);
host_matrix<T> hA1(M, N, lda);
host_vector<int64_t> hIpiv(Ipiv_size);
host_vector<int> hInfo(1);
host_vector<int> hInfo1(1);

// Allocate device memory
device_matrix<T> dA(M, N, lda);
Expand Down
2 changes: 1 addition & 1 deletion clients/include/solver/testing_getrf_npvt_batched.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ void testing_getrf_npvt_batched(const Arguments& arg)
// Naming: dK is in GPU (device) memory. hK is in CPU (host) memory
host_batch_matrix<T> hA(M, N, lda, batch_count);
host_batch_matrix<T> hA1(M, N, lda, batch_count);
host_vector<int> hIpiv(Ipiv_size);
host_vector<int64_t> hIpiv(Ipiv_size);
host_vector<int> hInfo(batch_count);
host_vector<int> hInfo1(batch_count);

Expand Down
10 changes: 5 additions & 5 deletions clients/include/solver/testing_getrf_npvt_strided_batched.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,11 @@ void testing_getrf_npvt_strided_batched(const Arguments& arg)
}

// Naming: dK is in GPU (device) memory. hK is in CPU (host) memory
host_strided_batch_matrix<T> hA(M, N, lda, strideA, batch_count);
host_strided_batch_matrix<T> hA1(M, N, lda, strideA, batch_count);
host_strided_batch_matrix<int> hIpiv(1, Ipiv_size, 1, strideP, batch_count);
host_vector<int> hInfo(batch_count);
host_vector<int> hInfo1(batch_count);
host_strided_batch_matrix<T> hA(M, N, lda, strideA, batch_count);
host_strided_batch_matrix<T> hA1(M, N, lda, strideA, batch_count);
host_strided_batch_matrix<int64_t> hIpiv(1, Ipiv_size, 1, strideP, batch_count);
host_vector<int> hInfo(batch_count);
host_vector<int> hInfo1(batch_count);

// Check host memory allocation
CHECK_HIP_ERROR(hA.memcheck());
Expand Down
11 changes: 7 additions & 4 deletions clients/include/solver/testing_getrf_strided_batched.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,15 @@ void testing_getrf_strided_batched(const Arguments& arg)
host_strided_batch_matrix<T> hA(M, N, lda, strideA, batch_count);
host_strided_batch_matrix<T> hA1(M, N, lda, strideA, batch_count);
host_vector<int> hIpiv(Ipiv_size);
host_vector<int> hIpiv1(Ipiv_size);
host_vector<int64_t> hIpiv64(Ipiv_size);
host_vector<int> hInfo(batch_count);
host_vector<int> hInfo1(batch_count);

// Check host memory allocation
CHECK_HIP_ERROR(hA.memcheck());
CHECK_HIP_ERROR(hA1.memcheck());
CHECK_HIP_ERROR(hIpiv.memcheck());
CHECK_HIP_ERROR(hIpiv1.memcheck());
CHECK_HIP_ERROR(hIpiv64.memcheck());

device_strided_batch_matrix<T> dA(M, N, lda, strideA, batch_count);
device_vector<int> dIpiv(Ipiv_size);
Expand Down Expand Up @@ -175,16 +175,19 @@ void testing_getrf_strided_batched(const Arguments& arg)

// Copy output from device to CPU
CHECK_HIP_ERROR(hA1.transfer_from(dA));
CHECK_HIP_ERROR(hIpiv1.transfer_from(dIpiv));
CHECK_HIP_ERROR(hIpiv.transfer_from(dIpiv));
CHECK_HIP_ERROR(
hipMemcpy(hInfo1.data(), dInfo, batch_count * sizeof(int), hipMemcpyDeviceToHost));

/* =====================================================================
CPU LAPACK
=================================================================== */
for(int i = 0; i < Ipiv_size; i++)
hIpiv64[i] = hIpiv[i];

for(int b = 0; b < batch_count; b++)
{
hInfo[b] = ref_getrf(M, N, hA[b], lda, hIpiv.data() + b * strideP);
hInfo[b] = ref_getrf(M, N, hA[b], lda, hIpiv64.data() + b * strideP);
}

hipblas_error = norm_check_general<T>('F', M, N, lda, strideA, hA, hA1, batch_count);
Expand Down
18 changes: 12 additions & 6 deletions clients/include/solver/testing_getri_batched.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ void testing_getri_batched(const Arguments& arg)
host_batch_matrix<T> hC(M, N, lda, batch_count);
host_batch_matrix<T> hA1(M, N, lda, batch_count);
host_vector<int> hIpiv(Ipiv_size);
host_vector<int> hIpiv1(Ipiv_size);
host_vector<int64_t> hIpiv64(Ipiv_size);
host_vector<int> hInfo(batch_count);
host_vector<int> hInfo1(batch_count);

Expand Down Expand Up @@ -205,10 +205,13 @@ void testing_getri_batched(const Arguments& arg)
}

// perform LU factorization on A
int* hIpivb = hIpiv.data() + b * strideP;
hInfo[b] = ref_getrf(M, N, hA[b], lda, hIpivb);
int64_t* hIpivb = hIpiv64.data() + b * strideP;
hInfo[b] = ref_getrf(M, N, hA[b], lda, hIpivb);
}

for(int i = 0; i < Ipiv_size; i++)
hIpiv[i] = hIpiv64[i];

CHECK_HIP_ERROR(dA.transfer_from(hA));
CHECK_HIP_ERROR(dC.transfer_from(hC));
CHECK_HIP_ERROR(hipMemcpy(dIpiv, hIpiv, Ipiv_size * sizeof(int), hipMemcpyHostToDevice));
Expand All @@ -232,23 +235,26 @@ void testing_getri_batched(const Arguments& arg)
// Copy output from device to CPU
CHECK_HIP_ERROR(hA1.transfer_from(dC));
CHECK_HIP_ERROR(
hipMemcpy(hIpiv1.data(), dIpiv, Ipiv_size * sizeof(int), hipMemcpyDeviceToHost));
hipMemcpy(hIpiv.data(), dIpiv, Ipiv_size * sizeof(int), hipMemcpyDeviceToHost));
CHECK_HIP_ERROR(
hipMemcpy(hInfo1.data(), dInfo, batch_count * sizeof(int), hipMemcpyDeviceToHost));

/* =====================================================================
CPU LAPACK
=================================================================== */
for(int i = 0; i < Ipiv_size; i++)
hIpiv64[i] = hIpiv[i];

for(int b = 0; b < batch_count; b++)
{
// Workspace query
host_vector<T> work(1);
ref_getri(N, hA[b], lda, hIpiv.data() + b * strideP, work.data(), -1);
ref_getri(N, hA[b], lda, hIpiv64.data() + b * strideP, work.data(), -1);
int lwork = type2int(work[0]);

// Perform inversion
work = host_vector<T>(lwork);
hInfo[b] = ref_getri(N, hA[b], lda, hIpiv.data() + b * strideP, work.data(), lwork);
hInfo[b] = ref_getri(N, hA[b], lda, hIpiv64.data() + b * strideP, work.data(), lwork);

hipblas_error = norm_check_general<T>('F', M, N, lda, hA[b], hA1[b]);
if(arg.unit_check)
Expand Down
10 changes: 5 additions & 5 deletions clients/include/solver/testing_getri_npvt_batched.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,9 @@ void testing_getri_npvt_batched(const Arguments& arg)
host_batch_matrix<T> hC(M, N, lda, batch_count);
host_batch_matrix<T> hA1(M, N, lda, batch_count);

host_vector<int> hIpiv(Ipiv_size);
host_vector<int> hInfo(batch_count);
host_vector<int> hInfo1(batch_count);
host_vector<int64_t> hIpiv(Ipiv_size);
host_vector<int> hInfo(batch_count);
host_vector<int> hInfo1(batch_count);

// Check host memory allocation
CHECK_HIP_ERROR(hA.memcheck());
Expand Down Expand Up @@ -197,8 +197,8 @@ void testing_getri_npvt_batched(const Arguments& arg)
}

// perform LU factorization on A
int* hIpivb = hIpiv.data() + b * strideP;
hInfo[b] = ref_getrf(M, N, hA[b], lda, hIpivb);
int64_t* hIpivb = hIpiv.data() + b * strideP;
hInfo[b] = ref_getrf(M, N, hA[b], lda, hIpivb);
}

CHECK_HIP_ERROR(dA.transfer_from(hA));
Expand Down
Loading