From 2ee5046e46be00583a7ac96324acfd63e826442f Mon Sep 17 00:00:00 2001 From: Piotr Rak Date: Wed, 12 Jan 2022 10:32:10 +0100 Subject: [PATCH] [WIP] Histogram CPU Signed-off-by: Piotr Rak --- dali/operators/generic/CMakeLists.txt | 1 + .../generic/histogram/CMakeLists.txt | 18 + dali/operators/generic/histogram/histogram.cc | 610 ++++++++++++++++++ dali/operators/generic/histogram/histogram.h | 102 +++ dali/operators/generic/reduce/axes_helper.h | 60 ++ dali/operators/generic/reduce/reduce.h | 39 +- dali/pipeline/operator/op_schema.h | 2 +- dali/pipeline/operator/operator.h | 4 +- dali/test/python/test_operator_histogram.py | 106 +++ 9 files changed, 903 insertions(+), 39 deletions(-) create mode 100644 dali/operators/generic/histogram/CMakeLists.txt create mode 100644 dali/operators/generic/histogram/histogram.cc create mode 100644 dali/operators/generic/histogram/histogram.h create mode 100644 dali/operators/generic/reduce/axes_helper.h create mode 100644 dali/test/python/test_operator_histogram.py diff --git a/dali/operators/generic/CMakeLists.txt b/dali/operators/generic/CMakeLists.txt index 8ed66d916ca..9ccaec501aa 100644 --- a/dali/operators/generic/CMakeLists.txt +++ b/dali/operators/generic/CMakeLists.txt @@ -15,6 +15,7 @@ # Get all the source files and dump test files add_subdirectory(erase) +add_subdirectory(histogram) add_subdirectory(reduce) add_subdirectory(slice) add_subdirectory(transpose) diff --git a/dali/operators/generic/histogram/CMakeLists.txt b/dali/operators/generic/histogram/CMakeLists.txt new file mode 100644 index 00000000000..088009d0449 --- /dev/null +++ b/dali/operators/generic/histogram/CMakeLists.txt @@ -0,0 +1,18 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Get all the source files and dump test files +collect_headers(DALI_INST_HDRS PARENT_SCOPE) +collect_sources(DALI_OPERATOR_SRCS PARENT_SCOPE) +collect_test_sources(DALI_OPERATOR_TEST_SRCS PARENT_SCOPE) \ No newline at end of file diff --git a/dali/operators/generic/histogram/histogram.cc b/dali/operators/generic/histogram/histogram.cc new file mode 100644 index 00000000000..ae6cb972d12 --- /dev/null +++ b/dali/operators/generic/histogram/histogram.cc @@ -0,0 +1,610 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dali/operators/generic/histogram/histogram.h" + +#include "dali/kernels/common/copy.h" +#include "dali/kernels/transpose/transpose.h" +#include "dali/pipeline/operator/op_schema.h" + + +#include "dali/core/tensor_shape_print.h" + +#include "opencv2/imgproc.hpp" + +#include + +using namespace dali; +using namespace dali::hist_detail; + +#define id_(x) x + +#define HistogramOpName id_(histogram__Histogram) +#define UniformHistogramOpName id_(histogram__UniformHistogram) + +#define str_next_(x) #x +#define str_(x) str_next_(x) + +namespace { + +constexpr const char histogramOpString[] = str_(HistogramOpName); +constexpr const char unifromHistogramOpString[] = str_(UniformHistogramOpName); + +std::vector GetFlattenedRanges(int sample, const workspace_t &ws) { + std::vector ranges; + + for (int r = 1; r < ws.NumInput(); ++r) { + auto &dim_ranges = ws.template Input(r); + auto range_view = view(dim_ranges); + for (int i = 0; i < range_view[sample].num_elements(); ++i) { + ranges.push_back(range_view.tensor_data(sample)[i]); + } + } + return ranges; +} + +template +struct CVMatType { + static int get(int) { + DALI_ENFORCE(false, "Unreachable - invalid type"); + } +}; + +template <> +struct CVMatType { + static int get(int nchannel) noexcept { + return CV_MAKETYPE(CV_8U, nchannel); + } +}; + +template <> +struct CVMatType { + static int get(int nchannel) noexcept { + return CV_MAKETYPE(CV_16U, nchannel); + } +}; + +template <> +struct CVMatType { + static int get(int nchannel) noexcept { + return CV_MAKETYPE(CV_32F, nchannel); + } +}; + +template <> +struct CVMatType { + static int get(int nchannel) noexcept { + return CV_MAKETYPE(CV_64F, nchannel); + } +}; + +template +TensorListView transpose_view( + dali::ThreadPool &thread_pool, ScratchAlloc &scratch, + const TensorListView &in_view, const Coll &transpose_axes_order) { + const auto &in_shapes = in_view.shape; + + TensorListShape<> transposed_shapes; + permute_dims(transposed_shapes, in_shapes, transpose_axes_order); + std::vector tmp_pointers; + tmp_pointers.reserve(transposed_shapes.num_samples()); + + for (int i = 0; i < transposed_shapes.num_samples(); ++i) { + auto tmp = scratch.template AllocTensor(transposed_shapes[i]); + tmp_pointers.push_back(tmp.data); + } + + TensorListView transpose_out_view(std::move(tmp_pointers), + std::move(transposed_shapes)); + + for (int i = 0; i < transpose_out_view.num_samples(); ++i) { + thread_pool.AddWork([&, i](int thread_id) { + auto perm = make_span(transpose_axes_order); + kernels::Transpose(transpose_out_view[i], in_view[i], perm); + }); + } + thread_pool.RunAll(true); + return reinterpret(transpose_out_view, transpose_out_view.shape); +} + +template +void run_identity(ThreadPool &thread_pool, const TensorListView &in_view, + TensorListView &out_view) { + for (int i = 0; i < in_view.shape.num_samples(); ++i) { + thread_pool.AddWork([&, i](int thread_id) { kernels::copy(out_view[i], in_view[i]); }); + } + thread_pool.RunAll(true); +} + +} // namespace + +HistReductionAxesHelper::HistReductionAxesHelper(const OpSpec &spec) : detail::AxesHelper(spec) { + has_channel_axis_arg_ = spec.TryGetArgument(channel_axis_, "channel_axis"); + has_channel_axis_name_arg_ = spec.TryGetArgument(channel_axis_name_, "channel_axis_name"); + + + DALI_ENFORCE(!has_channel_axis_arg_ || !has_channel_axis_name_arg_, + "Arguments `channel_axis` and `channel_axis_name` are mutually exclusive"); +} + +void HistReductionAxesHelper::PrepareChannelAxisArg(const TensorLayout &layout, + const SmallVector &reduction_axes_mask, + int hist_dim) { + const int sample_dim = reduction_axes_mask.size(); + const bool has_channel_axis = (has_channel_axis_name_arg_ || has_channel_axis_arg_); + + if (hist_dim > 1) { + DALI_ENFORCE(has_channel_axis, + "One of arguments `channel_axis` and `channel_axis_name` should be specified for " + "multidimensional histograms!"); + + if (has_channel_axis_name_arg_) { + auto indices = GetDimIndices(layout, channel_axis_name_); + DALI_ENFORCE(indices.size() == 1, + "Exactly single axis name should be specified as `channel_axis_name`"); + channel_axis_ = indices[0]; + } else { + assert(channel_axis_ < 0 && "Not given?"); + DALI_ENFORCE(channel_axis_ < sample_dim, + make_string("Invalid axis number for argument `channel_axis` (is ", + channel_axis_, " and should be less than ", sample_dim, ")")); + } + DALI_ENFORCE( + reduction_axes_mask[channel_axis_], + make_string("Axis ", channel_axis_, + " can be eigther reduction axis (`axes`) or `channel_axis`, not both")); + } else if (has_channel_axis) { + DALI_ENFORCE(hist_dim == 1, + "None of `channel_axis` and `channel_axis_name` arguments should be specified for " + "multidimensional histograms!"); + channel_axis_ = -1; + } +} + +void HistReductionAxesHelper::PrepareReductionAxes(const TensorLayout &layout, int sample_dim, + int hist_dim) { + assert(hist_dim > 0); + + PrepareAxes(layout, sample_dim); + + SmallVector reduction_axes_mask; + reduction_axes_mask.resize(sample_dim, false); + + for (int axis : axes_) { + reduction_axes_mask[axis] = true; + } + + PrepareChannelAxisArg(layout, reduction_axes_mask, hist_dim); + + // If axes were not specified, we consider all but channel_axis_ reduction axes. + if (!has_empty_axes_arg_ && axes_.empty()) { + for (int i = 0; i < sample_dim; ++i) { + if (i != channel_axis_) { + reduction_axes_mask[i] = true; + } + } + } + + axes_order_.clear(); + axes_order_.reserve(sample_dim); + + // Collect non-reduction axes as outer tensor dimension + for (int i = 0; i < sample_dim; ++i) { + if (!reduction_axes_mask[i]) { + axes_order_.push_back(i); + } + } + + size_t num_non_reduction = axes_order_.size(); + + // Collect reduction axes as inner tensor dimension + for (int i = 0; i < sample_dim; ++i) { + if (reduction_axes_mask[i]) { + axes_order_.push_back(i); + } + } + + size_t num_reduction = axes_order_.size() - num_non_reduction; + + // For multi-dimensional histogram channel axis is most inner dimension + if (hist_dim > 1) { + assert(channel_axis_ != -1 && channel_axis_ < sample_dim); + axes_order_.push_back(channel_axis_); + } + + assert(axes_order_.size() == size_t(sample_dim)); + + non_reduction_axes_ = span(axes_order_.data(), num_non_reduction); + reduction_axes_ = span(axes_order_.data() + num_non_reduction, num_reduction); + + // TODO: verify first part of condition + is_identity_ = non_reduction_axes_.size() == sample_dim && has_empty_axes_arg_; +} + +bool HistReductionAxesHelper::NeedsTranspose() const { + for (size_t i = 0; i < axes_order_.size(); ++i) { + if (axes_order_[i] != i) { + return true; + } + } + return false; +} + +HistogramCPU::HistogramCPU(const OpSpec &spec) + : Operator(spec), + hist_detail::HistReductionAxesHelper(spec), + param_num_bins_("num_bins", spec) { + uniform_ = spec.name() == unifromHistogramOpString; + assert(spec.name() == unifromHistogramOpString || spec.name() == histogramOpString); +} + +TensorListShape<> HistogramCPU::GetBinShapes(int num_samples) const { + TensorListShape<> ret; + ret.resize(num_samples, hist_dim_); + + auto &bins = param_num_bins_.get(); + + for (int i = 0; i < num_samples; ++i) { + TensorShape<> bin_shape{make_span(batch_bins_[i].data(), hist_dim_)}; + ret.set_tensor_shape(i, bin_shape); + } + return ret; +} + +void HistogramCPU::PrepareReductionShapes(const TensorListShape<> &in_sh, OutputDesc &output_desc) { + // Prepare input shapes, if reduction axes followd by channel axis are not inner-most + // permutate axes to create transposed shape + TensorListShape<> transposed_shapes; + const bool needs_transpose = NeedsTranspose(); + + if (needs_transpose) { + permute_dims(transposed_shapes, in_sh, axes_order_); + } + + const TensorListShape<> &input_shapes = needs_transpose ? transposed_shapes : in_sh; + + // Prepare output tensors shapes + TensorListShape<> output_shapes, bin_shapes = GetBinShapes(input_shapes.num_samples()); + + if (!non_reduction_axes_.empty()) { + auto non_reduction_axes_shape = input_shapes.first(non_reduction_axes_.size()); + output_shapes.resize(input_shapes.num_samples(), non_reduction_axes_.size() + hist_dim_); + + for (int i = 0; i < input_shapes.num_samples(); ++i) { + auto out_sh = shape_cat(non_reduction_axes_shape[i], bin_shapes[i]); + output_shapes.set_tensor_shape(i, out_sh); + } + } else { + output_shapes = std::move(bin_shapes); + } + + // Simplify shapes so histogram of per reduction can be calculated easily. + SubdivideTensorsShapes(input_shapes, output_shapes, output_desc); +} + +void HistogramCPU::SubdivideTensorsShapes(const TensorListShape<> &input_shapes, + const TensorListShape<> &output_shapes, + OutputDesc &output_desc) { + SmallVector, 2> in_collapse_groups, out_collapse_groups; + TensorListShape<> norm_inputs, norm_outputs; + + int reduced_start = 1; + if (non_reduction_axes_.empty()) { + // Add unitary outer dimension for unfold + norm_inputs.resize(input_shapes.num_samples(), input_shapes.sample_dim() + 1); + norm_outputs.resize(output_shapes.num_samples(), output_shapes.sample_dim() + 1); + for (int i = 0; i < input_shapes.num_samples(); ++i) { + norm_inputs.set_tensor_shape(i, shape_cat(1, input_shapes[i])); + norm_outputs.set_tensor_shape(i, shape_cat(1, output_shapes[i])); + } + } else { + if (non_reduction_axes_.size() != 1) { + // Add collapse group to collapse outer (non-reduction) dimensions + in_collapse_groups.push_back(std::make_pair(0, non_reduction_axes_.size())); + out_collapse_groups.push_back(std::make_pair(0, non_reduction_axes_.size())); + reduced_start = non_reduction_axes_.size(); + } + norm_inputs = input_shapes; + norm_outputs = output_shapes; + } + + if (reduction_axes_.size() > 1) { + // Add collapse group to collapse inner (reduction) dimensions, possibly ommiting channels + // dimension. + in_collapse_groups.push_back(std::make_pair(reduced_start, reduction_axes_.size())); + } + + norm_inputs = collapse_dims(norm_inputs, in_collapse_groups); + norm_outputs = collapse_dims(norm_outputs, out_collapse_groups); + + auto splited_input_shapes = unfold_outer_dim(norm_inputs); + auto splited_output_shapes = unfold_outer_dim(norm_outputs); + + std::vector split_mapping; + split_mapping.reserve(splited_input_shapes.num_samples()); + auto norm_non_reduced = norm_inputs.first(1); + for (int i = 0; i < norm_non_reduced.num_samples(); ++i) { + for (int j = 0; j < norm_non_reduced[i][0]; ++j) { + split_mapping.push_back(i); + } + } + + split_mapping_ = std::move(split_mapping); + splited_input_shapes_ = std::move(splited_input_shapes); + splited_output_shapes_ = std::move(splited_output_shapes); + + output_desc.shape = std::move(output_shapes); + output_desc.type = DALI_FLOAT; +} + +bool HistogramCPU::SetupImpl(std::vector &output_desc, + const workspace_t &ws) { + output_desc.resize(1); + + auto &input = ws.template Input(0); + const size_t ndims = input.shape().sample_dim(); + + VerifyRangeArguments(ws, input.num_samples()); + assert(hist_dim_ != -1 && "Should be deduced from ranges or bins"); + + PrepareReductionAxes(input.GetLayout(), ndims, hist_dim_); + + // If an empty reduction axes were specified, histogram calculation becomes identity operation + if (is_identity_) { + output_desc[0].type = input.type(); + output_desc[0].shape = input.shape(); + } else { + DALI_ENFORCE(hist_dim_ >= 1, "Number of histogram dimensions should be at least one"); + DALI_ENFORCE(hist_dim_ <= 32, + "Number of histogram dimensions should not be no greater than 32"); + PrepareReductionShapes(input.shape(), output_desc[0]); + } + + return true; +} + +void HistogramCPU::VerifyBinsArgument(const workspace_t &ws, int num_samples, int hist_dim) { + assert(uniform_ && "Specified only for uniform histogram"); + const auto &input = ws.template Input(0); + param_num_bins_.Acquire(this->spec_, ws, input.shape().num_samples(), TensorShape<1>(hist_dim)); + auto bins_view = param_num_bins_.get(); + + DALI_ENFORCE(bins_view.num_elements() % num_samples == 0, + make_string("Histogram bins should be an array of bins per sample", + bins_view.num_elements())); + + batch_bins_.reserve(num_samples); + for (int i=0; i bins; + for (int j=0; j &ws, int num_samples, int hist_dim) { + assert(!uniform_ && "Can infer only for non-uniform histograms"); + + SmallVector bins; + for (int i=0; i(2 + i*2); + bins.push_back(lo.shape().num_samples()); + } + batch_bins_.reserve(num_samples); + + for (int i=0; i &ws, int num_samples) { + assert(ws.NumInput() == 3); + const auto &input = ws.template Input(0); + + const auto &ranges_lo = ws.template Input(1); + const auto &ranges_hi = ws.template Input(2); + + auto lo_view = view(ranges_lo); + auto hi_view = view(ranges_hi); + + int hist_dim = ranges_lo.num_samples() / num_samples; + + VerifyBinsArgument(ws, num_samples, hist_dim); + + DALI_ENFORCE(ranges_lo.num_samples() == ranges_hi.num_samples()); + + DALI_ENFORCE(hist_dim <= CV_CN_MAX, + make_string("Histogram dimensionality should not be greater than ", CV_CN_MAX)); + DALI_ENFORCE(hist_dim >= 1, "Ranges for at least one histogram should be specified"); + + batch_ranges_.reserve(input.num_samples()); + + for (int i=0; i dim_lo_hi(size_t(2*hist_dim)); + for (int j=0; j &ws, int num_samples) { + assert(!uniform_); + + DALI_ENFORCE(ws.NumInput() % 2 == 1, "Should have both ranges"); // FIXME message + + int hist_dim = ws.NumInput() / 2 - 1; + + DALI_ENFORCE(hist_dim <= CV_CN_MAX, + make_string("Histogram dimensionality should not be greater than ", CV_CN_MAX)); + DALI_ENFORCE(hist_dim >= 1, "Ranges for at least one histogram should be specified"); + + InferBinsArgument(ws, num_samples, hist_dim); + + int nsamples = ws.template Input(0).shape().num_samples(); + + for (int r = 1; r < ws.NumInput(); ++r) { + auto &dim_ranges = ws.template Input(r); + DALI_ENFORCE(dim_ranges.type() == DALI_FLOAT, + make_string("Histogram bin ranges should be 32 bit floating-point numbers")); + auto sh_ranges = dim_ranges.shape(); + DALI_ENFORCE(sh_ranges.sample_dim(), + make_string("Histogram bin ranges for ", r, + " dimension should be one-dimensional array")); + + DALI_ENFORCE( + sh_ranges.num_elements() == 2 * nsamples, + make_string("Bin ranges for uniform histogram should consist of lower and upper bound", + dim_ranges.num_samples())); + + for (int i = 0; i < nsamples; ++i) { + batch_ranges_.push_back(GetFlattenedRanges(i, ws)); + } + } + return hist_dim; +} + +int HistogramCPU::VerifyRangeArguments(const workspace_t &ws, int num_samples) { + if (uniform_) { + return VerifyUniformRangeArguments(ws, num_samples); + } else { + return VerifyNonUniformRangeArguments(ws, num_samples); + } +} + +void HistogramCPU::RunImpl(workspace_t &ws) { + assert(hist_dim_ == 1 && "Multidimentional not implemented"); + + const auto &input = ws.Input(0); + auto &output = ws.Output(0); + + int nsamples = input.shape().num_samples(); + auto &thread_pool = ws.GetThreadPool(); + std::vector all_channels; + + for (int j = 0; j < hist_dim_; ++j) { + all_channels.push_back(j); + } + + TYPE_SWITCH(input.type(), type2id, Type, (uint8_t, uint16_t, float), ( + { + auto in_view = view(input); + + if (is_identity_) { + auto out_view_id = view(output); + run_identity(thread_pool, in_view, out_view_id); + return; + } + + auto out_view = view(output); + TensorListView transposed_in_view; + + bool needs_transpose = NeedsTranspose(); + if (needs_transpose) { + transpose_mem_.template Reserve( + in_view.num_elements() * sizeof(Type)); + auto scratch = transpose_mem_.GetScratchpad(); + + transposed_in_view = + transpose_view(thread_pool, scratch, in_view, axes_order_); + } + + auto splited_in_views = reinterpret( + needs_transpose ? transposed_in_view : in_view, splited_input_shapes_); + auto splited_out_views = reinterpret(out_view, splited_output_shapes_); + + assert(splited_in_views.num_samples() == splited_out_views.num_samples()); + assert(split_mapping_.size() == size_t(splited_in_views.num_samples())); + + for (int i = 0; i < splited_in_views.num_samples(); ++i) { + thread_pool.AddWork([&, i](int thread_id) { + SmallVector in_sizes; + + auto in_shape_span = splited_in_views.tensor_shape_span(i); + in_sizes.push_back(in_shape_span[0]); // FIXME, volume? + + int in_type = CVMatType::get(hist_dim_); + + std::vector images = { + cv::Mat(1, in_sizes.data(), in_type, splited_in_views[i].data)}; + + cv::InputArray input_mat(images); + + std::vector bins; + auto out_shape_span = splited_out_views.tensor_shape_span(i); + + for (int j = 0; j < hist_dim_; ++j) { + bins.push_back(out_shape_span[j]); + } + + std::size_t sample_range = split_mapping_[i]; + const std::vector &ranges = batch_ranges_[sample_range]; + + cv::Mat mask; + cv::Mat output_mat; + + cv::calcHist(input_mat, all_channels, mask, output_mat, bins, ranges, uniform_); + + assert(output_mat.isContinuous() && output_mat.type() == CV_32FC1); + assert(output_mat.total() == size_t(splited_out_views[i].num_elements())); + + // OpenCV always allocates output array and we need to copy it. + float *hist_data = output_mat.ptr(); + TensorView hist_view(hist_data, splited_out_views[i].shape); + kernels::copy(splited_out_views[i], hist_view); + }); + } + thread_pool.RunAll(true); + + }), DALI_FAIL(make_string("Unsupported input type: ", input.type()))) // NOLINT +} + +DALI_SCHEMA(HistogramBase) + .AddOptionalArg>("axes", + R"code(Axis or axes along which reduction is performed. + +Not providing any axis results in reduction of all elements.)code", + nullptr) + .AddOptionalArg( + "axis_names", R"code(Name(s) of the axis or axes along which the reduction is performed. + +The input layout is used to translate the axis names to axis indices, for example ``axis_names="HW"`` with input +layout `"FHWC"` is equivalent to specifying ``axes=[1,2]``. This argument cannot be used together with ``axes``.)code", + nullptr) + .AddOptionalArg("channel_axis", "Specifies channel axis for multidimensional histogram", + nullptr) + .AddOptionalArg("channel_axis_name", + "Specifies channel axis name for mulitidimensional histogram", + nullptr); + +DALI_SCHEMA(HistogramOpName) + .DocStr("Calculates histogram.") + .NumInput(3, 1 + 2*CV_CN_MAX) + .NumOutput(1) + .AddParent("HistogramBase"); + +DALI_SCHEMA(UniformHistogramOpName) + .DocStr("Calculates uniform histogram.") + .NumInput(3) + .NumOutput(1) + .AddOptionalArg("num_bins", "An array of histogram bins for each dimension", std::vector(), + true) + .AddParent("HistogramBase"); + +DALI_REGISTER_OPERATOR(HistogramOpName, HistogramCPU, CPU); +DALI_REGISTER_OPERATOR(UniformHistogramOpName, HistogramCPU, CPU); diff --git a/dali/operators/generic/histogram/histogram.h b/dali/operators/generic/histogram/histogram.h new file mode 100644 index 00000000000..d8ec21b9378 --- /dev/null +++ b/dali/operators/generic/histogram/histogram.h @@ -0,0 +1,102 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DALI_OPERATORS_GENERIC_REDUCE_HISTOGRAM_H_ +#define DALI_OPERATORS_GENERIC_REDUCE_HISTOGRAM_H_ + +#include "dali/kernels/kernel_manager.h" +#include "dali/operators/generic/reduce/axes_helper.h" +#include "dali/pipeline/operator/arg_helper.h" +#include "dali/pipeline/operator/operator.h" + +namespace dali { + +namespace hist_detail { + +struct HistReductionAxesHelper : detail::AxesHelper { + public: + explicit HistReductionAxesHelper(const OpSpec &); + + void PrepareReductionAxes(const TensorLayout &layout, int sample_dim, int hist_dim); + bool IsIdentityTransform() const { + return is_identity_; + } + bool IsSimpleReduction1() const; + bool NeedsTranspose() const; + + private: + void PrepareChannelAxisArg(const TensorLayout &layout, + const SmallVector &reduction_axes_mask, int hist_dim); + + public: + // TODO: make private + span reduction_axes_; + span non_reduction_axes_; + int channel_axis_ = -1; + SmallVector axes_order_; + std::string channel_axis_name_; + bool has_channel_axis_arg_; + bool has_channel_axis_name_arg_; + bool is_identity_ = false; +}; + +} // namespace hist_detail + +class HistogramCPU : public Operator, hist_detail::HistReductionAxesHelper { + public: + explicit HistogramCPU(const OpSpec &spec); + + bool CanInferOutputs() const override { + return true; + } + + ~HistogramCPU() override = default; + + private: + int VerifyRangeArguments(const workspace_t &ws, int num_samples); + int VerifyUniformRangeArguments(const workspace_t &ws, int num_samples); + int VerifyNonUniformRangeArguments(const workspace_t &ws, int num_samples); + + void VerifyBinsArgument(const workspace_t &ws, int num_samples, int hist_dim); + void InferBinsArgument(const workspace_t &ws, int num_samples, int hist_dim); + + void PrepareReductionShapes(const TensorListShape<> &input_shapes, OutputDesc &output_desc); + void SubdivideTensorsShapes(const TensorListShape<> &input_shapes, + const TensorListShape<> &output_shapes, OutputDesc &output_desc); + + TensorListShape<> GetBinShapes(int num_samples) const; + + public: + bool SetupImpl(std::vector &output_desc, const workspace_t &ws) override; + void RunImpl(workspace_t &ws) override; + + private: + USE_OPERATOR_MEMBERS(); + TensorListShape<> splited_input_shapes_; + TensorListShape<> splited_output_shapes_; + std::vector split_mapping_; + kernels::KernelManager kmgr_; + std::vector> batch_ranges_; + std::vector> batch_bins_; + ArgValue param_num_bins_; + kernels::ScratchpadAllocator transpose_mem_; + int hist_dim_ = -1; + bool needs_transpose_ = false; + bool uniform_ = true; +}; + +} // namespace dali + + +#endif // DALI_OPERATORS_GENERIC_REDUCE_HISTOGRAM_H_ diff --git a/dali/operators/generic/reduce/axes_helper.h b/dali/operators/generic/reduce/axes_helper.h new file mode 100644 index 00000000000..27dc40f591c --- /dev/null +++ b/dali/operators/generic/reduce/axes_helper.h @@ -0,0 +1,60 @@ +// Copyright (c) 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#ifndef DALI_OPERATORS_GENERIC_REDUCE_AXIS_HELPER_H__ +#define DALI_OPERATORS_GENERIC_REDUCE_AXIS_HELPER_H__ + +#include + +#include "dali/pipeline/operator/operator.h" + +namespace dali { +namespace detail { + +class AxesHelper { + public: + explicit inline AxesHelper(const OpSpec &spec) { + has_axes_arg_ = spec.TryGetRepeatedArgument(axes_, "axes"); + has_axis_names_arg_ = spec.TryGetArgument(axis_names_, "axis_names"); + has_empty_axes_arg_ = + (has_axes_arg_ && axes_.empty()) || (has_axis_names_arg_ && axis_names_.empty()); + + DALI_ENFORCE(!has_axes_arg_ || !has_axis_names_arg_, + "Arguments `axes` and `axis_names` are mutually exclusive"); + } + + void PrepareAxes(const TensorLayout &layout, int sample_dim) { + if (has_axis_names_arg_) { + axes_ = GetDimIndices(layout, axis_names_).to_vector(); + return; + } + + if (!has_axes_arg_) { + axes_.resize(sample_dim); + std::iota(axes_.begin(), axes_.end(), 0); + } + } + + bool has_axes_arg_; + bool has_axis_names_arg_; + bool has_empty_axes_arg_; + std::vector axes_; + TensorLayout axis_names_; +}; + +} // namespace detail +} // namespace dali + +#endif // DALI_OPERATORS_GENERIC_REDUCE_AXIS_HELPER_H__ \ No newline at end of file diff --git a/dali/operators/generic/reduce/reduce.h b/dali/operators/generic/reduce/reduce.h index 056a6ee02c3..d428a9ee775 100644 --- a/dali/operators/generic/reduce/reduce.h +++ b/dali/operators/generic/reduce/reduce.h @@ -18,50 +18,17 @@ #include #include -#include "dali/pipeline/operator/operator.h" #include "dali/kernels/kernel_manager.h" -#include "dali/kernels/reduce/reductions.h" #include "dali/kernels/reduce/reduce_cpu.h" #include "dali/kernels/reduce/reduce_gpu.h" #include "dali/kernels/reduce/reduce_setup_utils.h" +#include "dali/kernels/reduce/reductions.h" +#include "dali/operators/generic/reduce/axes_helper.h" +#include "dali/pipeline/operator/operator.h" #define REDUCE_TYPES (uint8_t, int8_t, uint16_t, int16_t, uint32_t, int32_t, uint64_t, int64_t, float) // NOLINT namespace dali { -namespace detail { - -class AxesHelper { - public: - explicit inline AxesHelper(const OpSpec &spec) { - has_axes_arg_ = spec.TryGetRepeatedArgument(axes_, "axes"); - has_axis_names_arg_ = spec.TryGetArgument(axis_names_, "axis_names"); - has_empty_axes_arg_ = - (has_axes_arg_ && axes_.empty()) || (has_axis_names_arg_ && axis_names_.empty()); - - DALI_ENFORCE(!has_axes_arg_ || !has_axis_names_arg_, - "Arguments `axes` and `axis_names` are mutually exclusive"); - } - - void PrepareAxes(const TensorLayout &layout, int sample_dim) { - if (has_axis_names_arg_) { - axes_ = GetDimIndices(layout, axis_names_).to_vector(); - return; - } - - if (!has_axes_arg_) { - axes_.resize(sample_dim); - std::iota(axes_.begin(), axes_.end(), 0); - } - } - - bool has_axes_arg_; - bool has_axis_names_arg_; - bool has_empty_axes_arg_; - vector axes_; - TensorLayout axis_names_; -}; - -} // namespace detail template < template class ReductionType, diff --git a/dali/pipeline/operator/op_schema.h b/dali/pipeline/operator/op_schema.h index 3462fb3996d..7a9d7c7e710 100644 --- a/dali/pipeline/operator/op_schema.h +++ b/dali/pipeline/operator/op_schema.h @@ -993,7 +993,7 @@ inline T OpSchema::GetDefaultValueForArgument(const std::string &s) const { } #define DALI_SCHEMA_REG(OpName) \ - int DALI_OPERATOR_SCHEMA_REQUIRED_FOR_##OpName() { \ + int CONCAT_2(DALI_OPERATOR_SCHEMA_REQUIRED_FOR_, OpName)() { \ return 42; \ } \ static ::dali::OpSchema* ANONYMIZE_VARIABLE(OpName) = \ diff --git a/dali/pipeline/operator/operator.h b/dali/pipeline/operator/operator.h index 2e82760b4d5..b527586117a 100644 --- a/dali/pipeline/operator/operator.h +++ b/dali/pipeline/operator/operator.h @@ -439,8 +439,8 @@ DALI_DECLARE_OPTYPE_REGISTRY(MixedOperator, OperatorBase); // Must be called from .cc or .cu file #define DALI_REGISTER_OPERATOR(OpName, OpType, device) \ - int DALI_OPERATOR_SCHEMA_REQUIRED_FOR_##OpName(); \ - static int ANONYMIZE_VARIABLE(OpName) = DALI_OPERATOR_SCHEMA_REQUIRED_FOR_##OpName(); \ + int CONCAT_2(DALI_OPERATOR_SCHEMA_REQUIRED_FOR_, OpName)(); \ + static int ANONYMIZE_VARIABLE(OpName) = CONCAT_2(DALI_OPERATOR_SCHEMA_REQUIRED_FOR_, OpName)(); \ DALI_DEFINE_OPTYPE_REGISTERER(OpName, OpType, device##Operator, ::dali::OperatorBase, #device) diff --git a/dali/test/python/test_operator_histogram.py b/dali/test/python/test_operator_histogram.py new file mode 100644 index 00000000000..c46523f8c51 --- /dev/null +++ b/dali/test/python/test_operator_histogram.py @@ -0,0 +1,106 @@ +# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from itertools import cycle, permutations +from random import randint +from nvidia.dali import pipeline, pipeline_def +import nvidia.dali.fn as fn +import nvidia.dali.types as types +import nvidia.dali.ops as ops +from nvidia.dali.pipeline import Pipeline + +import numpy as np + +from test_utils import np_type_to_dali + +min_dim = 2 + +def _all_reduction_shapes(sample_dim, nbins): + axes = range(sample_dim) + sizes = range(min_dim, sample_dim+min_dim) + for axes_perm in permutations(axes): + for nred_axes in range(sample_dim+1): + input_shape = tuple((sizes[ax] for ax in range(sample_dim))) + reduction_axes = axes_perm[:nred_axes] + non_reduction_axes = [] + for ax in axes: + if not ax in reduction_axes: + non_reduction_axes.append(ax) + output_shape = [sizes[axis] for axis in non_reduction_axes] + if len(reduction_axes) != 0: + output_shape.append(nbins) + output_shape = tuple(output_shape) + yield (reduction_axes, input_shape, output_shape) + +def _jagged_batch(batch_size, input_sh, output_sh, np_type, axes): + batch = [] + output_shapes = [] + for i in range(batch_size): + map = {} + for sz in range(min_dim, len(input_sh)+min_dim): + map[sz] = randint(8, 8+len(input_sh)) + + mapped_in = [map[d] for d in [*input_sh,]] + if len([x for x in axes]) > 0: + mapped_out = [map[d] for d in [*output_sh,][:-1]] + mapped_out.append(output_sh[-1]) + else: + mapped_out = [map[d] for d in [*output_sh,]] + batch.append(np.ones(mapped_in, np_type)) + output_shapes.append(np.array(mapped_out)) + return (batch, output_shapes) + +def _testimpl_histogram_shape(batch_size, device, in_sh, out_sh, num_bins, axes, np_type): + range_01 = (np.array([0.0], dtype=np.float32), np.array([1.0], dtype=np.float32)) + ranges = {type(np.uint8) : range_01, type(np.uint16) : range_01, type(np.float32) : range_01} + @pipeline_def(batch_size=batch_size, num_threads=3, device_id=0) + def uniform_histogram1D_uniform_shape_pipe(np_dtype, num_bins=num_bins, device='cpu', axes=[]): + batches = [[np.ones(in_sh, dtype = np_dtype)]*batch_size, [np.zeros(in_sh, dtype = np_dtype)]*batch_size] + in_tensors = fn.external_source(source=batches, device=device, cycle=True) + out_tensors = fn.histogram.uniform_histogram(in_tensors, *ranges[type(np_type)], num_bins=num_bins, axes=axes) + return out_tensors + + @pipeline_def(batch_size=batch_size, num_threads=3, device_id=0) + def uniform_histogram1D_jagged_shape_pipe(np_dtype, num_bins=num_bins, device='cpu', axes=[]): + batch, out_sh_list = _jagged_batch(batch_size, in_sh, out_sh, np_type, axes) + batches = [batch]*2 + out_sizes_batches = [out_sh_list]*2 + in_tensors = fn.external_source(source=batches, device=device, cycle=True) + out_tensors = fn.histogram.uniform_histogram(in_tensors, *ranges[type(np_type)], num_bins=num_bins, axes=axes) + out_sizes = fn.external_source(source = out_sizes_batches, device=device, cycle=True) + return out_tensors, out_sizes + + pipe = uniform_histogram1D_uniform_shape_pipe(np_dtype=np.uint8, device=device, axes=axes, num_bins=num_bins) + pipe.build() + for iter in range(2): + out, = pipe.run() + for ret_sh in out.shape(): + assert(ret_sh == out_sh) + + pipe = uniform_histogram1D_jagged_shape_pipe(np_dtype=np.uint8, device=device, axes=axes, num_bins=num_bins) + pipe.build() + for iter in range(2): + out, sz = pipe.run() + for ret_sz, expected_sz in zip(out.shape(), sz.as_array()): + assert(ret_sz == tuple(expected_sz)) + +def test_reduce_shape_histogram(): + batch_size = 10 + for device in ['cpu']: + for sample_dim in range(1, 4): + for nbins in [1, 16, 1024]: + for type in [np.uint8, np.uint16, np.float32]: + for axes, in_sh, out_sh in _all_reduction_shapes(sample_dim, nbins): + yield _testimpl_histogram_shape, batch_size, device, in_sh, out_sh, [nbins], axes, type +