Skip to content

Commit

Permalink
Refactor scalar reductions to use common execution policy (#573)
Browse files Browse the repository at this point in the history
Refactor scalar reductions to use common execution policy
  • Loading branch information
jjwilke authored Sep 19, 2022
1 parent e5fc3cf commit e5f90b7
Show file tree
Hide file tree
Showing 8 changed files with 268 additions and 478 deletions.
76 changes: 76 additions & 0 deletions src/cunumeric/execution_policy/reduction/scalar_reduction.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/* Copyright 2021-2022 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

#pragma once

#include "cunumeric/execution_policy/reduction/scalar_reduction.h"
#include "cunumeric/cuda_help.h"

namespace cunumeric {
namespace scalar_reduction_impl {

template <class AccessorRD, class Kernel, class LHS, class Tag>
static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM)
scalar_unary_red_kernel(
size_t volume, size_t iters, AccessorRD out, Kernel kernel, LHS identity, Tag tag)
{
auto value = identity;
for (size_t idx = 0; idx < iters; idx++) {
const size_t offset = (idx * gridDim.x + blockIdx.x) * blockDim.x + threadIdx.x;
if (offset < volume) { kernel(value, offset, tag); }
}
// Every thread in the thread block must participate in the exchange to get correct results
reduce_output(out, value);
}

template <typename Buffer, typename RedAcc>
static __global__ void __launch_bounds__(1, 1) copy_kernel(Buffer result, RedAcc out)
{
out.reduce(0, result.read());
}

} // namespace scalar_reduction_impl

template <class LG_OP, class Tag>
struct ScalarReductionPolicy<VariantKind::GPU, LG_OP, Tag> {
template <class AccessorRD, class LHS, class Kernel>
void __attribute__((visibility("hidden"))) operator()(size_t volume,
AccessorRD& out,
const LHS& identity,
Kernel&& kernel)
{
auto stream = get_cached_stream();

const size_t blocks = (volume + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
DeviceScalarReductionBuffer<LG_OP> result(stream);
size_t shmem_size = THREADS_PER_BLOCK / 32 * sizeof(LHS);

if (blocks >= MAX_REDUCTION_CTAS) {
const size_t iters = (blocks + MAX_REDUCTION_CTAS - 1) / MAX_REDUCTION_CTAS;
scalar_reduction_impl::
scalar_unary_red_kernel<<<MAX_REDUCTION_CTAS, THREADS_PER_BLOCK, shmem_size, stream>>>(
volume, iters, result, std::forward<Kernel>(kernel), identity, Tag{});
} else {
scalar_reduction_impl::
scalar_unary_red_kernel<<<blocks, THREADS_PER_BLOCK, shmem_size, stream>>>(
volume, 1, result, std::forward<Kernel>(kernel), identity, Tag{});
}
scalar_reduction_impl::copy_kernel<<<1, 1, 0, stream>>>(result, out);
CHECK_CUDA_STREAM(stream);
}
};

} // namespace cunumeric
50 changes: 50 additions & 0 deletions src/cunumeric/execution_policy/reduction/scalar_reduction.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/* Copyright 2021-2022 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

#pragma once

#include "cunumeric/cunumeric.h"

namespace cunumeric {

template <VariantKind KIND, class LG_OP, class Tag = void>
struct ScalarReductionPolicy {
// No C++-20 yet. This is just here to illustrate the expected concept
// that all kernels passed to this execution should have.
struct KernelConcept {
// Every operator should take a scalar LHS as the
// target of the reduction and an index represeting the point
// in the iteration space being added into the reduction.
template <class LHS>
void operator()(LHS& lhs, size_t idx)
{
// LHS <- op[idx]
}
};
};

template <class LG_OP, class Tag>
struct ScalarReductionPolicy<VariantKind::CPU, LG_OP, Tag> {
template <class AccessorRD, class LHS, class Kernel>
void operator()(size_t volume, AccessorRD& out, const LHS& identity, Kernel&& kernel)
{
auto result = identity;
for (size_t idx = 0; idx < volume; ++idx) { kernel(result, idx, Tag{}); }
out.reduce(0, result);
}
};

} // namespace cunumeric
44 changes: 44 additions & 0 deletions src/cunumeric/execution_policy/reduction/scalar_reduction_omp.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/* Copyright 2021-2022 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

#pragma once

#include "cunumeric/execution_policy/reduction/scalar_reduction.h"
#include "cunumeric/omp_help.h"

#include <omp.h>

namespace cunumeric {

template <class LG_OP, class Tag>
struct ScalarReductionPolicy<VariantKind::OMP, LG_OP, Tag> {
template <class AccessorRD, class LHS, class Kernel>
void operator()(size_t volume, AccessorRD& out, const LHS& identity, Kernel&& kernel)
{
const auto max_threads = omp_get_max_threads();
ThreadLocalStorage<LHS> locals(max_threads);
for (auto idx = 0; idx < max_threads; ++idx) locals[idx] = identity;
#pragma omp parallel
{
const int tid = omp_get_thread_num();
#pragma omp for schedule(static)
for (size_t idx = 0; idx < volume; ++idx) { kernel(locals[tid], idx, Tag{}); }
}
for (auto idx = 0; idx < max_threads; ++idx) out.reduce(0, locals[idx]);
}
};

} // namespace cunumeric
98 changes: 0 additions & 98 deletions src/cunumeric/unary/scalar_unary_red.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,104 +19,6 @@

namespace cunumeric {

using namespace Legion;
using namespace legate;

template <UnaryRedCode OP_CODE, LegateTypeCode CODE, int DIM>
struct ScalarUnaryRedImplBody<VariantKind::CPU, OP_CODE, CODE, DIM> {
using OP = UnaryRedOp<OP_CODE, CODE>;
using LG_OP = typename OP::OP;
using RHS = legate_type_of<CODE>;

template <UnaryRedCode _OP_CODE = OP_CODE,
std::enable_if_t<!is_arg_reduce<_OP_CODE>::value>* = nullptr>
void operator()(OP func,
AccessorRD<LG_OP, true, 1> out,
AccessorRO<RHS, DIM> in,
const Rect<DIM>& rect,
const Pitches<DIM - 1>& pitches,
bool dense,
const Point<DIM>& shape) const
{
auto result = LG_OP::identity;
const size_t volume = rect.volume();
if (dense) {
auto inptr = in.ptr(rect);
for (size_t idx = 0; idx < volume; ++idx)
OP::template fold<true>(result, OP::convert(inptr[idx]));
} else {
for (size_t idx = 0; idx < volume; ++idx) {
auto p = pitches.unflatten(idx, rect.lo);
OP::template fold<true>(result, OP::convert(in[p]));
}
}
out.reduce(0, result);
}

template <UnaryRedCode _OP_CODE = OP_CODE,
std::enable_if_t<is_arg_reduce<_OP_CODE>::value>* = nullptr>
void operator()(OP func,
AccessorRD<LG_OP, true, 1> out,
AccessorRO<RHS, DIM> in,
const Rect<DIM>& rect,
const Pitches<DIM - 1>& pitches,
bool dense,
const Point<DIM>& shape) const
{
auto result = LG_OP::identity;
const size_t volume = rect.volume();
if (dense) {
auto inptr = in.ptr(rect);
for (size_t idx = 0; idx < volume; ++idx) {
auto p = pitches.unflatten(idx, rect.lo);
OP::template fold<true>(result, OP::convert(p, shape, inptr[idx]));
}
} else {
for (size_t idx = 0; idx < volume; ++idx) {
auto p = pitches.unflatten(idx, rect.lo);
OP::template fold<true>(result, OP::convert(p, shape, in[p]));
}
}
out.reduce(0, result);
}
};

template <LegateTypeCode CODE, int DIM>
struct ScalarUnaryRedImplBody<VariantKind::CPU, UnaryRedCode::CONTAINS, CODE, DIM> {
using OP = UnaryRedOp<UnaryRedCode::SUM, LegateTypeCode::BOOL_LT>;
using LG_OP = typename OP::OP;
using RHS = legate_type_of<CODE>;

void operator()(AccessorRD<LG_OP, true, 1> out,
AccessorRO<RHS, DIM> in,
const Store& to_find_scalar,
const Rect<DIM>& rect,
const Pitches<DIM - 1>& pitches,
bool dense) const
{
auto result = LG_OP::identity;
const auto to_find = to_find_scalar.scalar<RHS>();
const size_t volume = rect.volume();
if (dense) {
auto inptr = in.ptr(rect);
for (size_t idx = 0; idx < volume; ++idx)
if (inptr[idx] == to_find) {
result = true;
break;
}
} else {
for (size_t idx = 0; idx < volume; ++idx) {
auto point = pitches.unflatten(idx, rect.lo);
if (in[point] == to_find) {
result = true;
break;
}
}
}
out.reduce(0, result);
}
};

/*static*/ void ScalarUnaryRedTask::cpu_variant(TaskContext& context)
{
scalar_unary_red_template<VariantKind::CPU>(context);
Expand Down
Loading

0 comments on commit e5f90b7

Please sign in to comment.