Skip to content

Commit

Permalink
Add per frame parameters support to brightness_contrast and color_twi…
Browse files Browse the repository at this point in the history
…st families (#3937)

* Add per-frame parameters support to color-manipulating operators (hue, hsv, saturation, color_twist, brightness_contrast, brightness, contrast)
* Make contrast_center in brightness_contrast family per-sample (and per-frame)
* Add unfloded_views_range utility to simplify sequence-like scenarios when SequenceOperator is not applicable

Signed-off-by: Kamil Tokarski <[email protected]>
  • Loading branch information
stiepan authored Jun 14, 2022
1 parent b4a46e1 commit e687379
Show file tree
Hide file tree
Showing 12 changed files with 562 additions and 160 deletions.
56 changes: 24 additions & 32 deletions dali/operators/image/color/brightness_contrast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "dali/operators/image/color/brightness_contrast.h"
#include "dali/kernels/imgproc/pointwise/multiply_add.h"
#include "dali/pipeline/data/sequence_utils.h"

namespace dali {

Expand All @@ -31,12 +32,12 @@ This operator can also change the type of data.)code")
.NumOutput(1)
.AddOptionalArg("brightness",
"Brightness mutliplier.",
kDefaultBrightness, true)
kDefaultBrightness, true, true)
.AddOptionalArg("brightness_shift", R"code(The brightness shift.
For signed types, 1.0 represents the maximum positive value that can be represented by
the type.)code",
kDefaultBrightnessShift, true)
kDefaultBrightnessShift, true, true)
.AddOptionalTypeArg("dtype",
R"code(Output data type.
Expand All @@ -57,12 +58,12 @@ This operator can also change the type of data.)code")
.NumOutput(1)
.AddOptionalArg("contrast", R"code(The contrast multiplier, where 0.0 produces
the uniform grey.)code",
kDefaultContrast, true)
kDefaultContrast, true, true)
.AddOptionalArg("contrast_center", R"code(The intensity level that is unaffected by contrast.
This is the value that all pixels assume when the contrast is zero. When not set,
the half of the input type's positive range (or 0.5 for ``float``) is used.)code",
brightness_contrast::HalfRange<float>(), false)
brightness_contrast::HalfRange<float>(), true, true)
.AddOptionalTypeArg("dtype",
R"code(Output data type.
Expand Down Expand Up @@ -94,45 +95,34 @@ DALI_REGISTER_OPERATOR(Brightness, BrightnessContrastCpu, CPU);
DALI_REGISTER_OPERATOR(Contrast, BrightnessContrastCpu, CPU);


template <typename OutputType, typename InputType>
template <typename OutputType, typename InputType, int ndim>
void BrightnessContrastCpu::RunImplHelper(workspace_t<CPUBackend> &ws) {
const auto &input = ws.template Input<CPUBackend>(0);
auto &output = ws.template Output<CPUBackend>(0);
output.SetLayout(input.GetLayout());
auto out_shape = output.shape();
auto& tp = ws.GetThreadPool();
TensorListShape<> sh = input.shape();
auto num_dims = sh.sample_dim();
int num_samples = input.shape().num_samples();
int num_samples = input.num_samples();
const auto &contrast_center = GetContrastCenter<InputType>(ws, num_samples);

using Kernel = kernels::MultiplyAddCpu<OutputType, InputType, 3>;
kernel_manager_.Initialize<Kernel>();
kernel_manager_.template Resize<Kernel>(1);

auto in_view = view<const InputType, ndim>(input);
auto out_view = view<OutputType, ndim>(output);
for (int sample_id = 0; sample_id < num_samples; sample_id++) {
auto sample_shape = out_shape.tensor_shape_span(sample_id);
auto vol = volume(sample_shape.begin() + num_dims - 3, sample_shape.end());
float add, mul;
OpArgsToKernelArgs<OutputType, InputType>(add, mul,
brightness_[sample_id],
brightness_shift_[sample_id],
contrast_[sample_id]);
if (num_dims == 4) {
int num_frames = sample_shape[0];
for (int frame_id = 0; frame_id < num_frames; frame_id++) {
tp.AddWork([&, sample_id, frame_id, add, mul](int thread_id) {
OpArgsToKernelArgs<OutputType, InputType>(add, mul, brightness_[sample_id],
brightness_shift_[sample_id], contrast_[sample_id],
contrast_center[sample_id]);
auto planes_range =
sequence_utils::unfolded_views_range<ndim - 3>(out_view[sample_id], in_view[sample_id]);
const auto &in_range = planes_range.template get<1>();
for (auto &&views : planes_range) {
tp.AddWork([&, views, add, mul](int thread_id) {
kernels::KernelContext ctx;
auto tvin = subtensor(view<const InputType, 4>(input[sample_id]), frame_id);
auto tvout = subtensor(view<OutputType, 4>(output[sample_id]), frame_id);
auto &[tvout, tvin] = views;
kernel_manager_.Run<Kernel>(0, ctx, tvout, tvin, add, mul);
}, vol);
}
} else {
tp.AddWork([&, sample_id, add, mul](int thread_id) {
kernels::KernelContext ctx;
auto tvin = view<const InputType, 3>(input[sample_id]);
auto tvout = view<OutputType, 3>(output[sample_id]);
kernel_manager_.Run<Kernel>(0, ctx, tvout, tvin, add, mul);
});
}, in_range.SliceSize());
}
}
tp.RunAll();
Expand All @@ -142,9 +132,11 @@ void BrightnessContrastCpu::RunImpl(workspace_t<CPUBackend> &ws) {
const auto &input = ws.template Input<CPUBackend>(0);
TYPE_SWITCH(input.type(), type2id, InputType, BRIGHTNESS_CONTRAST_SUPPORTED_TYPES, (
TYPE_SWITCH(output_type_, type2id, OutputType, BRIGHTNESS_CONTRAST_SUPPORTED_TYPES, (
VALUE_SWITCH(input.sample_dim(), NDim, (3, 4), (
{
RunImplHelper<OutputType, InputType>(ws);
RunImplHelper<OutputType, InputType, NDim>(ws);
}
), DALI_FAIL(make_string("Unsupported sample dimensionality: ", input.sample_dim()))) // NOLINT
), DALI_FAIL(make_string("Unsupported output type: ", output_type_))) // NOLINT
), DALI_FAIL(make_string("Unsupported input type: ", input.type()))) // NOLINT
}
Expand Down
9 changes: 5 additions & 4 deletions dali/operators/image/color/brightness_contrast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,15 @@ void BrightnessContrastGpu::RunImplHelper(workspace_t<GPUBackend> &ws) {
output.SetLayout(input.GetLayout());
auto sh = input.shape();
int num_samples = input.num_samples();
const auto &contrast_center = GetContrastCenter<InputType>(ws, num_samples);
auto num_dims = sh.sample_dim();

addends_.resize(num_samples);
multipliers_.resize(num_samples);
for (int i = 0; i < num_samples; i++) {
OpArgsToKernelArgs<OutputType, InputType>(addends_[i], multipliers_[i],
brightness_[i], brightness_shift_[i],
contrast_[i]);
OpArgsToKernelArgs<OutputType, InputType>(addends_[i], multipliers_[i], brightness_[i],
brightness_shift_[i], contrast_[i],
contrast_center[i]);
}

TensorListView<StorageGPU, const InputType, 3> tvin;
Expand All @@ -53,7 +54,7 @@ void BrightnessContrastGpu::RunImplHelper(workspace_t<GPUBackend> &ws) {
using Kernel = kernels::MultiplyAddGpu<OutputType, InputType, 3>;
kernels::KernelContext ctx;
ctx.gpu.stream = ws.stream();
kernel_manager_.Initialize<Kernel>();
kernel_manager_.template Resize<Kernel>(1);

kernel_manager_.Setup<Kernel>(0, ctx, tvin, brightness_, contrast_);
kernel_manager_.Run<Kernel>(0, ctx, tvout, tvin, addends_, multipliers_);
Expand Down
53 changes: 30 additions & 23 deletions dali/operators/image/color/brightness_contrast.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@
#include <memory>
#include <string>
#include <vector>
#include "dali/core/format.h"
#include "dali/core/static_switch.h"
#include "dali/kernels/kernel_manager.h"
#include "dali/pipeline/data/views.h"
#include "dali/pipeline/operator/common.h"
#include "dali/pipeline/operator/operator.h"
#include "dali/core/format.h"
#include "dali/pipeline/operator/sequence_operator.h"

#define BRIGHTNESS_CONTRAST_SUPPORTED_TYPES (uint8_t, int16_t, int32_t, float)

Expand Down Expand Up @@ -53,38 +54,35 @@ const float kDefaultBrightnessShift = 0;
const float kDefaultContrast = 1.f;

template <typename Backend>
class BrightnessContrastOp : public Operator<Backend> {
class BrightnessContrastOp : public SequenceOperator<Backend> {
public:
~BrightnessContrastOp() override = default;

DISABLE_COPY_MOVE_ASSIGN(BrightnessContrastOp);

protected:
explicit BrightnessContrastOp(const OpSpec &spec)
: Operator<Backend>(spec),
: SequenceOperator<Backend>(spec),
output_type_(DALI_NO_TYPE),
input_type_(DALI_NO_TYPE) {
spec.TryGetArgument(output_type_arg_, "dtype");
if (spec.HasArgument("contrast_center"))
contrast_center_ = spec.GetArgument<float>("contrast_center");

if (std::is_same<Backend, GPUBackend>::value) {
kernel_manager_.Resize(1);
} else {
kernel_manager_.Resize(max_batch_size_);
}
}

bool CanInferOutputs() const override {
return true;
}

// The operator needs 4 dim path for DHWC data, so use it to avoid inflating
// the number of samples and parameters unnecessarily for FHWC when there are no
// per-frame parameters provided.
bool ShouldExpand(const workspace_t<Backend> &ws) override {
return SequenceOperator<Backend>::ShouldExpand(ws) && this->HasPerFrameArgInputs(ws);
}

template <typename OutputType, typename InputType>
void OpArgsToKernelArgs(float &addend, float &multiplier,
float brightness, float brightness_shift, float contrast) {
float contrast_center = std::isnan(contrast_center_)
? brightness_contrast::HalfRange<InputType>()
: contrast_center_;
void OpArgsToKernelArgs(float &addend, float &multiplier, float brightness,
float brightness_shift, float contrast,
float contrast_center) {
float brightness_range = brightness_contrast::FullRange<OutputType>();
// The formula is:
// out = brightness_shift * brightness_range +
Expand Down Expand Up @@ -123,10 +121,20 @@ class BrightnessContrastOp : public Operator<Backend> {
output_type_ = output_type_arg_ != DALI_NO_TYPE ? output_type_arg_ : input_type_;
}

bool SetupImpl(std::vector<OutputDesc> &output_desc,
const workspace_t<Backend> &ws) override {
template <typename InputType>
const vector<float> &GetContrastCenter(const workspace_t<Backend> &ws, int num_samples) {
if (this->spec_.ArgumentDefined("contrast_center")) {
this->GetPerSampleArgument(contrast_center_, "contrast_center", ws, num_samples);
} else {
// argument cannot stop being defined in a built pipeline,
// so just fill in missing samples if needed
contrast_center_.resize(num_samples, brightness_contrast::HalfRange<InputType>());
}
return contrast_center_;
}

bool SetupImpl(std::vector<OutputDesc> &output_desc, const workspace_t<Backend> &ws) override {
const auto &input = ws.template Input<Backend>(0);
const auto &output = ws.template Output<Backend>(0);
AcquireArguments(ws);

auto sh = input.shape();
Expand All @@ -138,11 +146,10 @@ class BrightnessContrastOp : public Operator<Backend> {
}

USE_OPERATOR_MEMBERS();
std::vector<float> brightness_, brightness_shift_, contrast_;
std::vector<float> brightness_, brightness_shift_, contrast_, contrast_center_;
DALIDataType output_type_arg_ = DALI_NO_TYPE;
DALIDataType output_type_ = DALI_NO_TYPE;
DALIDataType input_type_ = DALI_NO_TYPE;
float contrast_center_ = std::nanf("");
kernels::KernelManager kernel_manager_;
};

Expand All @@ -156,7 +163,7 @@ class BrightnessContrastCpu : public BrightnessContrastOp<CPUBackend> {
* "overloaded virtual function `dali::Operator<dali::CPUBackend>::RunImpl` is only partially
* overridden in class `dali::brightness_contrast::BrightnessContrast<dali::CPUBackend>`"
*/
using Operator<CPUBackend>::RunImpl;
using SequenceOperator<CPUBackend>::RunImpl;

~BrightnessContrastCpu() override = default;

Expand All @@ -165,7 +172,7 @@ class BrightnessContrastCpu : public BrightnessContrastOp<CPUBackend> {
protected:
void RunImpl(workspace_t<CPUBackend> &ws) override;

template <typename OutputType, typename InputType>
template <typename OutputType, typename InputType, int ndim>
void RunImplHelper(workspace_t<CPUBackend> &ws);
};

Expand Down
Loading

0 comments on commit e687379

Please sign in to comment.