Skip to content

Commit

Permalink
Make tensorview over a whole list first, not over each particular ten…
Browse files Browse the repository at this point in the history
…sorvector

Signed-off-by: Kamil Tokarski <[email protected]>
  • Loading branch information
stiepan committed May 29, 2022
1 parent 6cd932f commit 86efd70
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 9 deletions.
7 changes: 4 additions & 3 deletions dali/operators/image/color/brightness_contrast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,15 @@ void BrightnessContrastCpu::RunImplHelper(workspace_t<CPUBackend> &ws) {
using Kernel = kernels::MultiplyAddCpu<OutputType, InputType, 3>;
kernel_manager_.template Resize<Kernel>(1);

auto in_view = view<const InputType, ndim>(input);
auto out_view = view<OutputType, ndim>(output);
for (int sample_id = 0; sample_id < num_samples; sample_id++) {
float add, mul;
OpArgsToKernelArgs<OutputType, InputType>(add, mul, brightness_[sample_id],
brightness_shift_[sample_id], contrast_[sample_id],
contrast_center[sample_id]);
auto in_view = view<const InputType, ndim>(input[sample_id]);
auto out_view = view<OutputType, ndim>(output[sample_id]);
auto planes_range = sequence_utils::unfolded_views_range<ndim - 3>(out_view, in_view);
auto planes_range =
sequence_utils::unfolded_views_range<ndim - 3>(out_view[sample_id], in_view[sample_id]);
const auto &in_range = planes_range.template get<1>();
for (auto &&views : planes_range) {
tp.AddWork([&, views, add, mul](int thread_id) {
Expand Down
6 changes: 3 additions & 3 deletions dali/operators/image/color/color_twist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,10 @@ void ColorTwistCpu::RunImplHelper(workspace_t<CPUBackend> &ws) {
using Kernel = kernels::LinearTransformationCpu<OutputType, InputType, 3, 3, 3>;
int num_samples = input.num_samples();
kernel_manager_.template Resize<Kernel>(num_samples);
auto in_view = view<const InputType, ndim>(input);
auto out_view = view<OutputType, ndim>(output);
for (int i = 0; i < num_samples; i++) {
auto in_view = view<const InputType, ndim>(input[i]);
auto out_view = view<OutputType, ndim>(output[i]);
auto planes_range = sequence_utils::unfolded_views_range<ndim - 3>(out_view, in_view);
auto planes_range = sequence_utils::unfolded_views_range<ndim - 3>(out_view[i], in_view[i]);
const auto &in_range = planes_range.template get<1>();
for (auto &&views : planes_range) {
tp.AddWork(
Expand Down
6 changes: 3 additions & 3 deletions dali/pipeline/data/sequence_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef DALI_PIPELINE_OPERATOR_SEQUENCE_UTILS_H_
#define DALI_PIPELINE_OPERATOR_SEQUENCE_UTILS_H_
#ifndef DALI_PIPELINE_DATA_SEQUENCE_UTILS_H_
#define DALI_PIPELINE_DATA_SEQUENCE_UTILS_H_

#include <tuple>
#include <utility>
Expand Down Expand Up @@ -190,4 +190,4 @@ CombinedRange<UnfoldedViewRange<Storages, Ts, ndims, ndims_to_unfold>...> unfold

} // namespace dali

#endif // DALI_PIPELINE_OPERATOR_SEQUENCE_UTILS_H_
#endif // DALI_PIPELINE_DATA_SEQUENCE_UTILS_H_

0 comments on commit 86efd70

Please sign in to comment.