Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: avoid additional array allocation in host to device transfer #2966

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions cpp/oneapi/dal/backend/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,7 @@ inline sycl::event memcpy_host2usm(sycl::queue& queue,
const event_vector& deps = {}) {
ONEDAL_ASSERT(is_known_usm(queue, dest_usm));

// TODO: Remove additional copy to host usm memory once
// bug in `copy` with the host memory is fixed
auto tmp_usm_host = make_unique_usm_host(queue, size);
memcpy(tmp_usm_host.get(), src_host, size);
memcpy(queue, dest_usm, tmp_usm_host.get(), size, deps).wait_and_throw();
memcpy(queue, dest_usm, src_host, size, deps).wait_and_throw();
return {};
}

Expand All @@ -236,11 +232,7 @@ inline sycl::event memcpy_usm2host(sycl::queue& queue,
const event_vector& deps = {}) {
ONEDAL_ASSERT(is_known_usm(queue, src_usm));

// TODO: Remove additional copy to host usm memory once
// bug in `copy` with the host memory is fixed
auto tmp_usm_host = make_unique_usm_host(queue, size);
memcpy(queue, tmp_usm_host.get(), src_usm, size, deps).wait_and_throw();
memcpy(dest_host, tmp_usm_host.get(), size);
memcpy(queue, dest_host, src_usm, size, deps).wait_and_throw();
return {};
}

Expand Down
149 changes: 92 additions & 57 deletions cpp/oneapi/dal/table/backend/convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,33 +227,50 @@ sycl::event convert_vector_device2host(sycl::queue& q,
// To perform conversion, we gather data from device to host in temporary
// contigious array and then run host conversion function

const std::int64_t element_size_in_bytes = dal::detail::get_data_type_size(src_type);
const std::int64_t dst_element_size_in_bytes = dal::detail::get_data_type_size(dst_type);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

collect more information about the destination to see if a shortcut can be used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly!

const std::int64_t dst_size_in_bytes =
dal::detail::check_mul_overflow(dst_element_size_in_bytes, element_count);
const std::int64_t dst_stride_in_bytes =
dal::detail::check_mul_overflow(dst_element_size_in_bytes, dst_stride);

const std::int64_t src_element_size_in_bytes = dal::detail::get_data_type_size(src_type);
const std::int64_t src_size_in_bytes =
dal::detail::check_mul_overflow(element_size_in_bytes, element_count);
dal::detail::check_mul_overflow(src_element_size_in_bytes, element_count);
const std::int64_t src_stride_in_bytes =
dal::detail::check_mul_overflow(element_size_in_bytes, src_stride);

const auto tmp_host_unique = make_unique_usm_host(q, src_size_in_bytes);

auto gather_event = gather_device2host(q,
tmp_host_unique.get(),
src_device,
element_count,
src_stride_in_bytes,
element_size_in_bytes,
deps);
gather_event.wait_and_throw();

convert_vector(dal::detail::default_host_policy{},
tmp_host_unique.get(),
dst_host,
src_type,
dst_type,
1L,
dst_stride,
element_count);

return sycl::event{};
dal::detail::check_mul_overflow(src_element_size_in_bytes, src_stride);
if (src_element_size_in_bytes == dst_element_size_in_bytes &&
Alexandr-Solovev marked this conversation as resolved.
Show resolved Hide resolved
src_size_in_bytes == dst_size_in_bytes && src_stride_in_bytes == dst_stride_in_bytes) {
auto copy_event = memcpy_usm2host(q,
dst_host,
src_device,
src_element_size_in_bytes * element_count,
deps);

return copy_event;
}
else {
const auto tmp_host_unique = make_unique_usm_host(q, src_size_in_bytes);

auto gather_event = gather_device2host(q,
tmp_host_unique.get(),
src_device,
element_count,
src_stride_in_bytes,
src_element_size_in_bytes,
deps);
gather_event.wait_and_throw();

convert_vector(dal::detail::default_host_policy{},
tmp_host_unique.get(),
dst_host,
src_type,
dst_type,
1L,
dst_stride,
element_count);

return sycl::event{};
}
}

sycl::event convert_vector_host2device(sycl::queue& q,
Expand All @@ -275,43 +292,61 @@ sycl::event convert_vector_host2device(sycl::queue& q,
// To perform conversion, we perform conversion on the host and gather data
// in temporary contigious array and then scatter it from host to device

const std::int64_t element_size_in_bytes = dal::detail::get_data_type_size(dst_type);
const std::int64_t dst_element_size_in_bytes = dal::detail::get_data_type_size(dst_type);
const std::int64_t dst_size_in_bytes =
dal::detail::check_mul_overflow(element_size_in_bytes, element_count);
dal::detail::check_mul_overflow(dst_element_size_in_bytes, element_count);
const std::int64_t dst_stride_in_bytes =
dal::detail::check_mul_overflow(element_size_in_bytes, dst_stride);

const auto tmp_host_unique = make_unique_usm_host(q, dst_size_in_bytes);

convert_vector(dal::detail::default_host_policy{},
src_host,
tmp_host_unique.get(),
src_type,
dst_type,
src_stride,
1L,
element_count);
const std::int64_t max_loop_range = std::numeric_limits<std::int32_t>::max();
sycl::event scatter_event;
if (element_count > max_loop_range) {
scatter_event = scatter_host2device_blocking(q,
dst_device,
tmp_host_unique.get(),
element_count,
dst_stride_in_bytes,
element_size_in_bytes,
deps);
dal::detail::check_mul_overflow(dst_element_size_in_bytes, dst_stride);

const std::int64_t src_element_size_in_bytes = dal::detail::get_data_type_size(src_type);
const std::int64_t src_size_in_bytes =
dal::detail::check_mul_overflow(src_element_size_in_bytes, element_count);
const std::int64_t src_stride_in_bytes =
dal::detail::check_mul_overflow(src_element_size_in_bytes, src_stride);

if (src_element_size_in_bytes == dst_element_size_in_bytes &&
Alexandr-Solovev marked this conversation as resolved.
Show resolved Hide resolved
src_size_in_bytes == dst_size_in_bytes && src_stride_in_bytes == dst_stride_in_bytes) {
auto copy_event = memcpy_host2usm(q,
dst_device,
src_host,
src_element_size_in_bytes * element_count,
deps);

return copy_event;
}
else {
scatter_event = scatter_host2device(q,
dst_device,
tmp_host_unique.get(),
element_count,
dst_stride_in_bytes,
element_size_in_bytes,
deps);
const auto tmp_host_unique = make_unique_usm_host(q, dst_size_in_bytes);

convert_vector(dal::detail::default_host_policy{},
src_host,
tmp_host_unique.get(),
src_type,
dst_type,
src_stride,
1L,
element_count);
const std::int64_t max_loop_range = std::numeric_limits<std::int32_t>::max();
sycl::event scatter_event;
if (element_count > max_loop_range) {
scatter_event = scatter_host2device_blocking(q,
dst_device,
tmp_host_unique.get(),
element_count,
dst_stride_in_bytes,
dst_element_size_in_bytes,
deps);
}
else {
scatter_event = scatter_host2device(q,
dst_device,
tmp_host_unique.get(),
element_count,
dst_stride_in_bytes,
dst_element_size_in_bytes,
deps);
}
return scatter_event;
}
return scatter_event;
}

void convert_vector(const detail::data_parallel_policy& policy,
Expand Down