Skip to content

Commit

Permalink
Removed extra copy for transpose arrays in dot() (#1477)
Browse files Browse the repository at this point in the history
* Removed extra copy for strided arrays in dot()

* Added support of strided arrays

* Added support of strided out array

* Fix handling of 1d and 2d arrays
  • Loading branch information
antonwolfy authored Jul 17, 2023
1 parent 771653b commit 326c451
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 77 deletions.
1 change: 1 addition & 0 deletions .github/workflows/conda-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ env:
test_special.py
test_umath.py
test_usm_type.py
third_party/cupy/linalg_tests/test_product.py
third_party/cupy/math_tests/test_explog.py
third_party/cupy/math_tests/test_misc.py
third_party/cupy/math_tests/test_trigonometric.py
Expand Down
159 changes: 87 additions & 72 deletions dpnp/backend/kernels/dpnp_krnl_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <dpnp_iface.hpp>

namespace mkl_blas = oneapi::mkl::blas;
namespace mkl_blas_cm = oneapi::mkl::blas::column_major;
namespace mkl_blas_rm = oneapi::mkl::blas::row_major;
namespace mkl_lapack = oneapi::mkl::lapack;

Expand Down Expand Up @@ -227,12 +228,10 @@ DPCTLSyclEventRef dpnp_dot_c(DPCTLSyclQueueRef q_ref,
DPCTLSyclEventRef event_ref = nullptr;
sycl::queue q = *(reinterpret_cast<sycl::queue *>(q_ref));

DPNPC_ptr_adapter<_DataType_input1> input1_ptr(q_ref, input1_in,
input1_size);
DPNPC_ptr_adapter<_DataType_input2> input2_ptr(q_ref, input2_in,
input2_size);
_DataType_input1 *input1 = input1_ptr.get_ptr();
_DataType_input2 *input2 = input2_ptr.get_ptr();
_DataType_input1 *input1 =
static_cast<_DataType_input1 *>(const_cast<void *>(input1_in));
_DataType_input2 *input2 =
static_cast<_DataType_input2 *>(const_cast<void *>(input2_in));
_DataType_output *result = reinterpret_cast<_DataType_output *>(result_out);

if (!input1_size || !input2_size) {
Expand All @@ -257,10 +256,12 @@ DPCTLSyclEventRef dpnp_dot_c(DPCTLSyclQueueRef q_ref,
// if both arrays are vectors
if ((input1_ndim == 1) && (input2_ndim == 1)) {
assert(input1_size == input2_size);

sycl::event event = dot(q, result, input1, input2, input1_strides[0],
input2_strides[0], input1_size);
event.wait();
return event_ref;

event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event);
return DPCTLEvent_Copy(event_ref);
}

// 1D vector
Expand Down Expand Up @@ -297,13 +298,17 @@ DPCTLSyclEventRef dpnp_dot_c(DPCTLSyclQueueRef q_ref,
size_t ext_result_ndim =
((input1_ndim == 1) || (input2_ndim == 1)) ? 2 : result_ndim;
shape_elem_type *ext_result_shape = new shape_elem_type[ext_result_ndim];
shape_elem_type *ext_result_strides = new shape_elem_type[ext_result_ndim];
if ((input1_ndim == 1) || (input2_ndim == 1)) {
ext_result_shape[0] = ext_input1_shape[0];
ext_result_shape[1] = ext_input2_shape[1];
ext_result_strides[0] = 0;
ext_result_strides[1] = result_strides[0];
}
else {
for (size_t i = 0; i < ext_result_ndim; ++i) {
ext_result_shape[i] = result_shape[i];
ext_result_strides[i] = result_strides[i];
}
}

Expand All @@ -316,80 +321,89 @@ DPCTLSyclEventRef dpnp_dot_c(DPCTLSyclQueueRef q_ref,
// check if GEMM can be executed (strides)
// TODO: rewrite the condition in general case for ndims > 2
// (looks like there are such another cases)

if (ext_input1_ndim == 2 && ext_input2_ndim == 2) {
// there is a difference of behavior with trans and sizes params in previous
// version of GEMM only new version is supported, in case of old version
// computation goes in common way
#if INTEL_MKL_VERSION >= 20210004
// is mat1 F-contiguous, C-contiguous
bool mat1_f_contig =
(((ext_input1_shape[0] == 1) || (ext_input1_strides[0] == 1)) &&
((ext_input1_shape[1] == 1) ||
(ext_input1_strides[1] == ext_input1_shape[0])));
bool mat1_c_contig =
(((ext_input1_shape[1] == 1) || (ext_input1_strides[1] == 1)) &&
((ext_input1_shape[0] == 1) ||
(ext_input1_strides[0] == ext_input1_shape[1])));
// is mat2 F-contiguous, C-contiguous
bool mat2_f_contig =
(((ext_input2_shape[0] == 1) || (ext_input2_strides[0] == 1)) &&
((ext_input2_shape[1] == 1) ||
(ext_input2_strides[1] == ext_input2_shape[0])));
bool mat2_c_contig =
(((ext_input2_shape[1] == 1) || (ext_input2_strides[1] == 1)) &&
((ext_input2_shape[0] == 1) ||
(ext_input2_strides[0] == ext_input2_shape[1])));

if ((mat1_f_contig || mat1_c_contig) &&
(mat2_f_contig || mat2_c_contig)) {
oneapi::mkl::transpose trans1 =
(mat1_f_contig && !mat1_c_contig)
? oneapi::mkl::transpose::trans
: oneapi::mkl::transpose::nontrans;
oneapi::mkl::transpose trans2 =
(mat2_f_contig && !mat2_c_contig)
? oneapi::mkl::transpose::trans
: oneapi::mkl::transpose::nontrans;
// OneMKL gemm suports only arrays contiguous on inner dimension,
// so stride for at least one dimension should be equal to 1
if ((ext_input1_strides[0] == 1 || ext_input1_strides[1] == 1) &&
(ext_input2_strides[0] == 1 || ext_input2_strides[1] == 1) &&
(ext_result_strides[0] == 1 || ext_result_strides[1] == 1))
{
const bool isRowmA =
(ext_input1_strides[1] == 1 || ext_input1_strides[0] == 0);
const bool isRowmB =
(ext_input2_strides[1] == 1 || ext_input2_strides[1] == 0);
const bool isRowmC =
(ext_result_strides[1] == 1 || ext_result_strides[0] == 0);

oneapi::mkl::transpose transA =
(isRowmA != isRowmC) ? oneapi::mkl::transpose::trans
: oneapi::mkl::transpose::nontrans;
oneapi::mkl::transpose transB =
(isRowmB != isRowmC) ? oneapi::mkl::transpose::trans
: oneapi::mkl::transpose::nontrans;

const size_t size_m = ext_input1_shape[0];
const size_t size_n = ext_input2_shape[1];
const size_t size_k = ext_input1_shape[1];

const std::int64_t lda =
trans1 == oneapi::mkl::transpose::nontrans
? ext_input1_strides[0]
: ext_input1_strides[1];
const std::int64_t ldb =
trans2 == oneapi::mkl::transpose::nontrans
? ext_input2_strides[0]
: ext_input2_strides[1];

// definition of ldc will be another for result with
// non-standard (c-contiguous) strides const std::int64_t ldc =
// result_strides[0] == 1 ? result_strides[1] :
// result_strides[0];
const std::int64_t ldc = size_n;
auto getLdaLdc = [](const bool isRown, shape_elem_type *strides,
shape_elem_type *shapes) {
if (isRown) {
return (strides[0] != 0) ? strides[0] : shapes[1];
}
return strides[1];
};

const std::int64_t lda = static_cast<std::int64_t>(
getLdaLdc(isRowmA, ext_input1_strides, ext_input1_shape));
const std::int64_t ldb = static_cast<std::int64_t>(
isRowmB ? ext_input2_strides[0] : ext_input2_strides[1]);
const std::int64_t ldc = static_cast<std::int64_t>(
getLdaLdc(isRowmC, ext_result_strides, ext_result_shape));

constexpr _DataType_output alpha = 1;
constexpr _DataType_output beta = 0;

std::stringstream error_msg;
std::int64_t info = 0;

try {
sycl::event event = mkl_blas_rm::gemm(
q, trans1, trans2, size_m, size_n, size_k,
_DataType_output(1), // alpha
input1, lda, input2, ldb,
_DataType_output(0), // beta
result, ldc);
event.wait();
delete[] ext_input1_shape;
delete[] ext_input1_strides;
delete[] ext_input2_shape;
delete[] ext_input2_strides;
delete[] ext_result_shape;

return event_ref;
if (isRowmC) {
mkl_blas_rm::gemm(q, transA, transB, size_m, size_n,
size_k, alpha, input1, lda, input2,
ldb, beta, result, ldc)
.wait();
}
else {
mkl_blas_cm::gemm(q, transA, transB, size_m, size_n,
size_k, alpha, input1, lda, input2,
ldb, beta, result, ldc)
.wait();
}
} catch (mkl_lapack::exception const &e) {
error_msg << "Unexpected MKL exception caught during "
"gemm() call:\nreason: "
<< e.what() << "\ninfo: " << e.info();
info = e.info();
} catch (const std::exception &e) {
// do nothing, proceed to general case
error_msg << "Unexpected SYCL exception caught during "
"gemm() call:\n"
<< e.what();
info = -1;
}
#endif

if (info != 0) // an unexected error occurs
{
throw std::runtime_error(error_msg.str());
}

delete[] ext_input1_shape;
delete[] ext_input1_strides;
delete[] ext_input2_shape;
delete[] ext_input2_strides;
delete[] ext_result_shape;
delete[] ext_result_strides;
return event_ref;
}
}
}
Expand Down Expand Up @@ -437,6 +451,7 @@ DPCTLSyclEventRef dpnp_dot_c(DPCTLSyclQueueRef q_ref,
delete[] ext_input2_shape;
delete[] ext_input2_strides;
delete[] ext_result_shape;
delete[] ext_result_strides;

return event_ref;
}
Expand Down
2 changes: 1 addition & 1 deletion dpnp/dpnp_algo/dpnp_algo_linearalgebra.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_2in_1out_dot_t)(c_dpctl.DPCTLSyclQueueR
const shape_elem_type *, const shape_elem_type * ,
void * , const size_t, const size_t,
const shape_elem_type *, const shape_elem_type * ,
const c_dpctl.DPCTLEventVectorRef)
const c_dpctl.DPCTLEventVectorRef) except +
ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_2in_1out_matmul_t)(c_dpctl.DPCTLSyclQueueRef,
void * , const size_t, const size_t,
const shape_elem_type *, const shape_elem_type * ,
Expand Down
9 changes: 5 additions & 4 deletions dpnp/dpnp_iface_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,17 +108,16 @@ def dot(x1, x2, out=None, **kwargs):
else (None, None)
)

# TODO: copy_when_strides=False (now it's done for faster implementation with transpose arrays)
x1_desc = dpnp.get_dpnp_descriptor(
x1,
copy_when_strides=True,
copy_when_strides=False,
copy_when_nondefault_queue=False,
alloc_usm_type=usm_type,
alloc_queue=queue,
)
x2_desc = dpnp.get_dpnp_descriptor(
x2,
copy_when_strides=True,
copy_when_strides=False,
copy_when_nondefault_queue=False,
alloc_usm_type=usm_type,
alloc_queue=queue,
Expand All @@ -131,7 +130,9 @@ def dot(x1, x2, out=None, **kwargs):
)
out_desc = (
dpnp.get_dpnp_descriptor(
out, copy_when_nondefault_queue=False
out,
copy_when_strides=False,
copy_when_nondefault_queue=False,
)
or None
)
Expand Down

0 comments on commit 326c451

Please sign in to comment.