Skip to content

Commit

Permalink
ILP64 reference library for LAPACK on Windows (#917)
Browse files Browse the repository at this point in the history
  • Loading branch information
daineAMD authored Sep 27, 2024
1 parent 611a10d commit db58bab
Show file tree
Hide file tree
Showing 18 changed files with 539 additions and 512 deletions.
782 changes: 370 additions & 412 deletions clients/common/cblas_interface.cpp

Large diffs are not rendered by default.

22 changes: 16 additions & 6 deletions clients/include/cblas_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -802,22 +802,32 @@ void ref_trmm(hipblasSideMode_t side,

// potrf
template <typename T>
int ref_potrf(char uplo, int m, T* A, int lda);
int64_t ref_potrf(char uplo, int64_t m, T* A, int64_t lda);

template <typename T>
int ref_getrf(int m, int n, T* A, int lda, int* ipiv);
int64_t ref_getrf(int64_t m, int64_t n, T* A, int64_t lda, int64_t* ipiv);

template <typename T>
int ref_getrs(char trans, int n, int nrhs, T* A, int lda, int* ipiv, T* B, int ldb);
int64_t ref_getrs(
char trans, int64_t n, int64_t nrhs, T* A, int64_t lda, int64_t* ipiv, T* B, int64_t ldb);

template <typename T>
int ref_getri(int n, T* A, int lda, int* ipiv, T* work, int lwork);
int64_t ref_getri(int64_t n, T* A, int64_t lda, int64_t* ipiv, T* work, int64_t lwork);

template <typename T>
int ref_geqrf(int m, int n, T* A, int lda, T* tau, T* work, int lwork);
int64_t ref_geqrf(int64_t m, int64_t n, T* A, int64_t lda, T* tau, T* work, int64_t lwork);

template <typename T>
int ref_gels(char trans, int m, int n, int nrhs, T* A, int lda, T* B, int ldb, T* work, int lwork);
int64_t ref_gels(char trans,
int64_t m,
int64_t n,
int64_t nrhs,
T* A,
int64_t lda,
T* B,
int64_t ldb,
T* work,
int64_t lwork);

#endif

Expand Down
1 change: 0 additions & 1 deletion clients/include/solver/testing_gels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
*
* ************************************************************************ */

#include "gtest/gtest.h"
#include <fstream>
#include <iostream>
#include <stdlib.h>
Expand Down
1 change: 0 additions & 1 deletion clients/include/solver/testing_gels_batched.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
*
* ************************************************************************ */

#include "gtest/gtest.h"
#include <fstream>
#include <iostream>
#include <stdlib.h>
Expand Down
1 change: 0 additions & 1 deletion clients/include/solver/testing_gels_strided_batched.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
*
* ************************************************************************ */

#include "gtest/gtest.h"
#include <fstream>
#include <iostream>
#include <stdlib.h>
Expand Down
1 change: 0 additions & 1 deletion clients/include/solver/testing_geqrf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
*
* ************************************************************************ */

#include "gtest/gtest.h"
#include <fstream>
#include <iostream>
#include <stdlib.h>
Expand Down
1 change: 0 additions & 1 deletion clients/include/solver/testing_geqrf_batched.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
*
* ************************************************************************ */

#include "gtest/gtest.h"
#include <fstream>
#include <iostream>
#include <stdlib.h>
Expand Down
19 changes: 11 additions & 8 deletions clients/include/solver/testing_getrf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@ void testing_getrf(const Arguments& arg)
}

// Naming: dK is in GPU (device) memory. hK is in CPU (host) memory
host_matrix<T> hA(M, N, lda);
host_matrix<T> hA1(M, N, lda);
host_matrix<int> hIpiv(1, Ipiv_size, 1);
host_matrix<int> hIpiv1(1, Ipiv_size, 1);
host_vector<int> hInfo(1);
host_vector<int> hInfo1(1);
host_matrix<T> hA(M, N, lda);
host_matrix<T> hA1(M, N, lda);
host_matrix<int> hIpiv(1, Ipiv_size, 1);
host_matrix<int64_t> hIpiv64(1, Ipiv_size, 1);
host_vector<int> hInfo(1);
host_vector<int> hInfo1(1);

// Allocate device memory
device_matrix<T> dA(M, N, lda);
Expand Down Expand Up @@ -145,13 +145,16 @@ void testing_getrf(const Arguments& arg)

// Copy output from device to CPU
CHECK_HIP_ERROR(hA1.transfer_from(dA));
CHECK_HIP_ERROR(hIpiv1.transfer_from(dIpiv));
CHECK_HIP_ERROR(hIpiv.transfer_from(dIpiv));
CHECK_HIP_ERROR(hipMemcpy(hInfo1, dInfo, sizeof(int), hipMemcpyDeviceToHost));

/* =====================================================================
CPU LAPACK
=================================================================== */
hInfo[0] = ref_getrf(M, N, hA.data(), lda, hIpiv.data());
for(int i = 0; i < Ipiv_size; i++)
hIpiv64[0][i] = hIpiv[0][i];

hInfo[0] = ref_getrf(M, N, hA.data(), lda, hIpiv64.data());

hipblas_error = norm_check_general<T>('F', M, N, lda, hA, hA1);
if(arg.unit_check)
Expand Down
11 changes: 7 additions & 4 deletions clients/include/solver/testing_getrf_batched.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,15 @@ void testing_getrf_batched(const Arguments& arg)
host_batch_matrix<T> hA(M, N, lda, batch_count);
host_batch_matrix<T> hA1(M, N, lda, batch_count);
host_vector<int> hIpiv(Ipiv_size);
host_vector<int> hIpiv1(Ipiv_size);
host_vector<int64_t> hIpiv64(Ipiv_size);
host_vector<int> hInfo(batch_count);
host_vector<int> hInfo1(batch_count);

// Check host memory allocation
CHECK_HIP_ERROR(hA.memcheck());
CHECK_HIP_ERROR(hA1.memcheck());
CHECK_HIP_ERROR(hIpiv.memcheck());
CHECK_HIP_ERROR(hIpiv1.memcheck());
CHECK_HIP_ERROR(hIpiv64.memcheck());

// Allocate device memory
device_batch_matrix<T> dA(M, N, lda, batch_count);
Expand Down Expand Up @@ -171,16 +171,19 @@ void testing_getrf_batched(const Arguments& arg)

// Copy output from device to CPU
CHECK_HIP_ERROR(hA1.transfer_from(dA));
CHECK_HIP_ERROR(hIpiv1.transfer_from(dIpiv));
CHECK_HIP_ERROR(hIpiv.transfer_from(dIpiv));
CHECK_HIP_ERROR(
hipMemcpy(hInfo1.data(), dInfo, batch_count * sizeof(int), hipMemcpyDeviceToHost));

/* =====================================================================
CPU LAPACK
=================================================================== */
for(int i = 0; i < Ipiv_size; i++)
hIpiv64[i] = hIpiv[i];

for(int b = 0; b < batch_count; b++)
{
hInfo[b] = ref_getrf(M, N, hA[b], lda, hIpiv.data() + b * strideP);
hInfo[b] = ref_getrf(M, N, hA[b], lda, hIpiv64.data() + b * strideP);
}

hipblas_error = norm_check_general<T>('F', M, N, lda, hA, hA1, batch_count);
Expand Down
10 changes: 5 additions & 5 deletions clients/include/solver/testing_getrf_npvt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,11 @@ void testing_getrf_npvt(const Arguments& arg)
}

// Naming: dK is in GPU (device) memory. hK is in CPU (host) memory
host_matrix<T> hA(M, N, lda);
host_matrix<T> hA1(M, N, lda);
host_vector<int> hIpiv(Ipiv_size);
host_vector<int> hInfo(1);
host_vector<int> hInfo1(1);
host_matrix<T> hA(M, N, lda);
host_matrix<T> hA1(M, N, lda);
host_vector<int64_t> hIpiv(Ipiv_size);
host_vector<int> hInfo(1);
host_vector<int> hInfo1(1);

// Allocate device memory
device_matrix<T> dA(M, N, lda);
Expand Down
2 changes: 1 addition & 1 deletion clients/include/solver/testing_getrf_npvt_batched.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ void testing_getrf_npvt_batched(const Arguments& arg)
// Naming: dK is in GPU (device) memory. hK is in CPU (host) memory
host_batch_matrix<T> hA(M, N, lda, batch_count);
host_batch_matrix<T> hA1(M, N, lda, batch_count);
host_vector<int> hIpiv(Ipiv_size);
host_vector<int64_t> hIpiv(Ipiv_size);
host_vector<int> hInfo(batch_count);
host_vector<int> hInfo1(batch_count);

Expand Down
10 changes: 5 additions & 5 deletions clients/include/solver/testing_getrf_npvt_strided_batched.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,11 @@ void testing_getrf_npvt_strided_batched(const Arguments& arg)
}

// Naming: dK is in GPU (device) memory. hK is in CPU (host) memory
host_strided_batch_matrix<T> hA(M, N, lda, strideA, batch_count);
host_strided_batch_matrix<T> hA1(M, N, lda, strideA, batch_count);
host_strided_batch_matrix<int> hIpiv(1, Ipiv_size, 1, strideP, batch_count);
host_vector<int> hInfo(batch_count);
host_vector<int> hInfo1(batch_count);
host_strided_batch_matrix<T> hA(M, N, lda, strideA, batch_count);
host_strided_batch_matrix<T> hA1(M, N, lda, strideA, batch_count);
host_strided_batch_matrix<int64_t> hIpiv(1, Ipiv_size, 1, strideP, batch_count);
host_vector<int> hInfo(batch_count);
host_vector<int> hInfo1(batch_count);

// Check host memory allocation
CHECK_HIP_ERROR(hA.memcheck());
Expand Down
11 changes: 7 additions & 4 deletions clients/include/solver/testing_getrf_strided_batched.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,15 @@ void testing_getrf_strided_batched(const Arguments& arg)
host_strided_batch_matrix<T> hA(M, N, lda, strideA, batch_count);
host_strided_batch_matrix<T> hA1(M, N, lda, strideA, batch_count);
host_vector<int> hIpiv(Ipiv_size);
host_vector<int> hIpiv1(Ipiv_size);
host_vector<int64_t> hIpiv64(Ipiv_size);
host_vector<int> hInfo(batch_count);
host_vector<int> hInfo1(batch_count);

// Check host memory allocation
CHECK_HIP_ERROR(hA.memcheck());
CHECK_HIP_ERROR(hA1.memcheck());
CHECK_HIP_ERROR(hIpiv.memcheck());
CHECK_HIP_ERROR(hIpiv1.memcheck());
CHECK_HIP_ERROR(hIpiv64.memcheck());

device_strided_batch_matrix<T> dA(M, N, lda, strideA, batch_count);
device_vector<int> dIpiv(Ipiv_size);
Expand Down Expand Up @@ -175,16 +175,19 @@ void testing_getrf_strided_batched(const Arguments& arg)

// Copy output from device to CPU
CHECK_HIP_ERROR(hA1.transfer_from(dA));
CHECK_HIP_ERROR(hIpiv1.transfer_from(dIpiv));
CHECK_HIP_ERROR(hIpiv.transfer_from(dIpiv));
CHECK_HIP_ERROR(
hipMemcpy(hInfo1.data(), dInfo, batch_count * sizeof(int), hipMemcpyDeviceToHost));

/* =====================================================================
CPU LAPACK
=================================================================== */
for(int i = 0; i < Ipiv_size; i++)
hIpiv64[i] = hIpiv[i];

for(int b = 0; b < batch_count; b++)
{
hInfo[b] = ref_getrf(M, N, hA[b], lda, hIpiv.data() + b * strideP);
hInfo[b] = ref_getrf(M, N, hA[b], lda, hIpiv64.data() + b * strideP);
}

hipblas_error = norm_check_general<T>('F', M, N, lda, strideA, hA, hA1, batch_count);
Expand Down
18 changes: 12 additions & 6 deletions clients/include/solver/testing_getri_batched.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ void testing_getri_batched(const Arguments& arg)
host_batch_matrix<T> hC(M, N, lda, batch_count);
host_batch_matrix<T> hA1(M, N, lda, batch_count);
host_vector<int> hIpiv(Ipiv_size);
host_vector<int> hIpiv1(Ipiv_size);
host_vector<int64_t> hIpiv64(Ipiv_size);
host_vector<int> hInfo(batch_count);
host_vector<int> hInfo1(batch_count);

Expand Down Expand Up @@ -205,10 +205,13 @@ void testing_getri_batched(const Arguments& arg)
}

// perform LU factorization on A
int* hIpivb = hIpiv.data() + b * strideP;
hInfo[b] = ref_getrf(M, N, hA[b], lda, hIpivb);
int64_t* hIpivb = hIpiv64.data() + b * strideP;
hInfo[b] = ref_getrf(M, N, hA[b], lda, hIpivb);
}

for(int i = 0; i < Ipiv_size; i++)
hIpiv[i] = hIpiv64[i];

CHECK_HIP_ERROR(dA.transfer_from(hA));
CHECK_HIP_ERROR(dC.transfer_from(hC));
CHECK_HIP_ERROR(hipMemcpy(dIpiv, hIpiv, Ipiv_size * sizeof(int), hipMemcpyHostToDevice));
Expand All @@ -232,23 +235,26 @@ void testing_getri_batched(const Arguments& arg)
// Copy output from device to CPU
CHECK_HIP_ERROR(hA1.transfer_from(dC));
CHECK_HIP_ERROR(
hipMemcpy(hIpiv1.data(), dIpiv, Ipiv_size * sizeof(int), hipMemcpyDeviceToHost));
hipMemcpy(hIpiv.data(), dIpiv, Ipiv_size * sizeof(int), hipMemcpyDeviceToHost));
CHECK_HIP_ERROR(
hipMemcpy(hInfo1.data(), dInfo, batch_count * sizeof(int), hipMemcpyDeviceToHost));

/* =====================================================================
CPU LAPACK
=================================================================== */
for(int i = 0; i < Ipiv_size; i++)
hIpiv64[i] = hIpiv[i];

for(int b = 0; b < batch_count; b++)
{
// Workspace query
host_vector<T> work(1);
ref_getri(N, hA[b], lda, hIpiv.data() + b * strideP, work.data(), -1);
ref_getri(N, hA[b], lda, hIpiv64.data() + b * strideP, work.data(), -1);
int lwork = type2int(work[0]);

// Perform inversion
work = host_vector<T>(lwork);
hInfo[b] = ref_getri(N, hA[b], lda, hIpiv.data() + b * strideP, work.data(), lwork);
hInfo[b] = ref_getri(N, hA[b], lda, hIpiv64.data() + b * strideP, work.data(), lwork);

hipblas_error = norm_check_general<T>('F', M, N, lda, hA[b], hA1[b]);
if(arg.unit_check)
Expand Down
10 changes: 5 additions & 5 deletions clients/include/solver/testing_getri_npvt_batched.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,9 @@ void testing_getri_npvt_batched(const Arguments& arg)
host_batch_matrix<T> hC(M, N, lda, batch_count);
host_batch_matrix<T> hA1(M, N, lda, batch_count);

host_vector<int> hIpiv(Ipiv_size);
host_vector<int> hInfo(batch_count);
host_vector<int> hInfo1(batch_count);
host_vector<int64_t> hIpiv(Ipiv_size);
host_vector<int> hInfo(batch_count);
host_vector<int> hInfo1(batch_count);

// Check host memory allocation
CHECK_HIP_ERROR(hA.memcheck());
Expand Down Expand Up @@ -197,8 +197,8 @@ void testing_getri_npvt_batched(const Arguments& arg)
}

// perform LU factorization on A
int* hIpivb = hIpiv.data() + b * strideP;
hInfo[b] = ref_getrf(M, N, hA[b], lda, hIpivb);
int64_t* hIpivb = hIpiv.data() + b * strideP;
hInfo[b] = ref_getrf(M, N, hA[b], lda, hIpivb);
}

CHECK_HIP_ERROR(dA.transfer_from(hA));
Expand Down
Loading

0 comments on commit db58bab

Please sign in to comment.