From a30b9fb4fc765ef149b46f2d61fb5efbde449fd0 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Sun, 13 Feb 2022 14:46:31 +0000 Subject: [PATCH] add random_flip op --- paddle/fluid/operators/data/CMakeLists.txt | 3 + paddle/fluid/operators/data/data_reader_op.h | 2 + paddle/fluid/operators/data/random_flip_op.cc | 90 ++++++++++++++++++ paddle/fluid/operators/data/random_flip_op.h | 86 +++++++++++++++++ .../operators/data/unity_build_rule.cmake | 3 +- paddle/fluid/operators/flip_op.cc | 95 ++++++++++--------- python/paddle/fluid/dataloader/pipeline.py | 20 ++-- python/paddle/vision/ops.py | 21 +++- 8 files changed, 262 insertions(+), 58 deletions(-) create mode 100644 paddle/fluid/operators/data/random_flip_op.cc create mode 100644 paddle/fluid/operators/data/random_flip_op.h diff --git a/paddle/fluid/operators/data/CMakeLists.txt b/paddle/fluid/operators/data/CMakeLists.txt index bf6470bd02df3..915bda52a4de6 100644 --- a/paddle/fluid/operators/data/CMakeLists.txt +++ b/paddle/fluid/operators/data/CMakeLists.txt @@ -23,8 +23,11 @@ op_library(batch_decode_op SRCS batch_decode_op.cc batch_decode_op.cu DEPS nvjpe op_library(random_crop_and_resize_op SRCS random_crop_and_resize_op.cc random_crop_and_resize_op.cu DEPS ${OP_HEADER_DEPS}) op_library(batch_resize_op SRCS batch_resize_op.cc batch_resize_op.cu DEPS ${OP_HEADER_DEPS}) + op_library(file_label_loader_op SRCS file_label_loader_op.cc DEPS ${OP_HEADER_DEPS}) +op_library(random_flip_op SRCS random_flip_op.cc DEPS ${OP_HEADER_DEPS}) + # register_operators() # TODO: add test here diff --git a/paddle/fluid/operators/data/data_reader_op.h b/paddle/fluid/operators/data/data_reader_op.h index 1532d3e2c6d8e..daa4118272232 100644 --- a/paddle/fluid/operators/data/data_reader_op.h +++ b/paddle/fluid/operators/data/data_reader_op.h @@ -49,6 +49,7 @@ class Sampler { drop_last_(drop_last), rank_(rank), world_size_(world_size) { + LOG(ERROR) << "Sampler num_samples " << num_samples; sample_ids_.reserve(num_samples); for (int64_t i = 0; i < num_samples; i++) { sample_ids_.emplace_back(i); @@ -125,6 +126,7 @@ class DataReader { sampler_.GetNextIndices(&indices); // shutdown reader if indices drained if (indices.size() == 0) { + LOG(ERROR) << "DataReader indices drained"; for(auto& queue: output_queues_) { while (queue->Size()) sleep(0.5); queue->Close(); diff --git a/paddle/fluid/operators/data/random_flip_op.cc b/paddle/fluid/operators/data/random_flip_op.cc new file mode 100644 index 0000000000000..3575c002bc0a9 --- /dev/null +++ b/paddle/fluid/operators/data/random_flip_op.cc @@ -0,0 +1,90 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/fluid/operators/data/random_flip_op.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace operators { +namespace data { + +using framework::OpKernelType; +using framework::Tensor; + +class RandomFlipOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->HasInput("X"), true, + platform::errors::NotFound("Input(X) of RandomFlipOp should not be null.")); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, + platform::errors::NotFound( + "Output(Out) of RandomFlipOp should not be null.")); + + auto x_dims = ctx->GetInputDim("X"); + ctx->SetOutputDim("Out", framework::make_ddim({x_dims[0], 1})); + ctx->ShareLoD("X", "Out"); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(input_data_type, + platform::CPUPlace()); + } +}; + +class RandomFlipOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of flip op."); + AddOutput("Out", "(Tensor), The output tensor in shape of [N, 1], N is " + "the batch size of X, bool data indicates whether to " + "perform flip in this sample."); + AddAttr("probability", "The probability to flip each sample.") + .SetDefault(0.5); + AddAttr("seed", "The seed for uniform random generator") + .SetDefault(0); + AddComment(R"DOC( + Random Flip Operator. + )DOC"); + } +}; + +class RandomFlipOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { + protected: + std::unordered_map& GetInputOutputWithSameType() + const override { + static std::unordered_map m{{"X", /*->*/ "Out"}}; + return m; + } +}; + +} // namespace data +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators::data; +namespace plat = paddle::platform; +REGISTER_OPERATOR(random_flip, ops::RandomFlipOp, ops::RandomFlipOpMaker, ops::RandomFlipOpInferVarType); + +REGISTER_OP_CPU_KERNEL( + random_flip, ops::RandomFlipCPUKernel, + ops::RandomFlipCPUKernel, + ops::RandomFlipCPUKernel); diff --git a/paddle/fluid/operators/data/random_flip_op.h b/paddle/fluid/operators/data/random_flip_op.h new file mode 100644 index 0000000000000..e8f31e1fe69c2 --- /dev/null +++ b/paddle/fluid/operators/data/random_flip_op.h @@ -0,0 +1,86 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace operators { +namespace data { + +using Tensor = framework::Tensor; + +constexpr size_t dim_bitset_size = 64; + +class RandomFlipGenerator { + public: + RandomFlipGenerator(int seed, float prob) + : distribution_(prob), + seed_(seed) { + if (seed != 0) rng_.seed(seed); + else rng_.seed(time(0)); + } + + ~RandomFlipGenerator() = default; + + bool Generate() { return distribution_(rng_); } + + private: + std::bernoulli_distribution distribution_; + int seed_; + std::mt19937 rng_; +}; + +std::map> seed_to_generator_; + +static RandomFlipGenerator* CreateRandomFlipGenerator(int seed, float prob) { + auto iter = seed_to_generator_.find(seed); + if (iter == seed_to_generator_.end()) { + seed_to_generator_[seed] = std::unique_ptr( + new RandomFlipGenerator(seed, prob)); + } + + return seed_to_generator_[seed].get(); +} + +template +class RandomFlipCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + LOG(ERROR) << "RandomFlipCPUKernel enter"; + const Tensor* x = ctx.Input("X"); + Tensor* out = ctx.Output("Out"); + + auto prob = ctx.Attr("probability"); + auto seed = ctx.Attr("seed"); + + auto* data = out->mutable_data(ctx.GetPlace()); + auto* generator = CreateRandomFlipGenerator(seed, prob); + for (int64_t i = 0; i < x->dims()[0]; i++) { + data[i] = generator->Generate() ? 1 : 0; + } + LOG(ERROR) << "RandomFlipCPUKernel finish"; + } +}; + +} // namespace data +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/data/unity_build_rule.cmake b/paddle/fluid/operators/data/unity_build_rule.cmake index c49164464bb03..33fa45153fa4f 100644 --- a/paddle/fluid/operators/data/unity_build_rule.cmake +++ b/paddle/fluid/operators/data/unity_build_rule.cmake @@ -11,7 +11,8 @@ register_unity_group(cc nvjpeg_decoder.cc dataloader_op.cc map_op.cc - batch_decode_random_crop_op.cc) + batch_decode_random_crop_op.cc + random_flip_op.cc) register_unity_group(cu dataloader_op.cu.cc diff --git a/paddle/fluid/operators/flip_op.cc b/paddle/fluid/operators/flip_op.cc index a08a0ca142053..2261dfd19b6a6 100644 --- a/paddle/fluid/operators/flip_op.cc +++ b/paddle/fluid/operators/flip_op.cc @@ -36,52 +36,57 @@ class FlipOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, platform::errors::NotFound( "Output(Out) of FlipOp should not be null.")); + auto x_dims = ctx->GetInputDim("X"); - auto flip_dims = ctx->Attrs().Get>("axis"); - size_t flip_dims_size = flip_dims.size(); - - if (flip_dims_size > 0) { - // check if dims axis within range - auto min_max_d = std::minmax_element(flip_dims.begin(), flip_dims.end()); - PADDLE_ENFORCE_LT( - *min_max_d.first, x_dims.size(), - platform::errors::InvalidArgument( - "min(axes) should be less than the input tensor X's " - "axes of FlipOp. But received min(axes) = %d, " - "X's axes = %d, X's shape = [%s]", - *min_max_d.first, x_dims.size(), x_dims)); - PADDLE_ENFORCE_GE(*min_max_d.first, x_dims.size() * -1, - platform::errors::InvalidArgument( - "min(axes) should be greater than or equal to the " - "input tensor X's " - "axes of FlipOp times -1. But received " - "min(axes) = %d, X's " - "axes = %d, X's shape = [%s]", - *min_max_d.first, x_dims.size() * -1, x_dims)); - PADDLE_ENFORCE_LT( - *min_max_d.second, x_dims.size(), - platform::errors::InvalidArgument( - "max(axes) should be less than the input tensor X's " - "axes of FlipOp. But received max(axes) = %d, " - "X's axes = %d, X's shape = [%s]", - *min_max_d.second, x_dims.size(), x_dims)); - PADDLE_ENFORCE_GE(*min_max_d.second, x_dims.size() * -1, - platform::errors::InvalidArgument( - "max(axes) should be greater than or equal to the " - "input tensor X's " - "axes of FlipOp times -1. But received " - "max(axes) = %d, X's " - "axes = %d, X's shape = [%s]", - *min_max_d.second, x_dims.size() * -1, x_dims)); - - // check duplicates in dims - flip_dims.erase(std::unique(flip_dims.begin(), flip_dims.end()), - flip_dims.end()); - PADDLE_ENFORCE_EQ(flip_dims.size(), flip_dims_size, - platform::errors::InvalidArgument( - "axes has duplicates, original flip axes size=%d, " - "but unique flip axes size=%d.)", - flip_dims_size, flip_dims.size())); + + if (ctx->IsRuntime()) { + auto flip_dims = ctx->Attrs().Get>("axis"); + size_t flip_dims_size = flip_dims.size(); + + if (flip_dims_size > 0) { + // check if dims axis within range + auto min_max_d = std::minmax_element(flip_dims.begin(), flip_dims.end()); + PADDLE_ENFORCE_LT( + *min_max_d.first, x_dims.size(), + platform::errors::InvalidArgument( + "min(axes) should be less than the input tensor X's " + "axes of FlipOp. But received min(axes) = %d, " + "X's axes = %d, X's shape = [%s]", + *min_max_d.first, x_dims.size(), x_dims)); + PADDLE_ENFORCE_GE(*min_max_d.first, x_dims.size() * -1, + platform::errors::InvalidArgument( + "min(axes) should be greater than or equal to the " + "input tensor X's " + "axes of FlipOp times -1. But received " + "min(axes) = %d, X's " + "axes = %d, X's shape = [%s]", + *min_max_d.first, x_dims.size() * -1, x_dims)); + PADDLE_ENFORCE_LT( + *min_max_d.second, x_dims.size(), + platform::errors::InvalidArgument( + "max(axes) should be less than the input tensor X's " + "axes of FlipOp. But received max(axes) = %d, " + "X's axes = %d, X's shape = [%s]", + *min_max_d.second, x_dims.size(), x_dims)); + PADDLE_ENFORCE_GE(*min_max_d.second, x_dims.size() * -1, + platform::errors::InvalidArgument( + "max(axes) should be greater than or equal to the " + "input tensor X's " + "axes of FlipOp times -1. But received " + "max(axes) = %d, X's " + "axes = %d, X's shape = [%s]", + *min_max_d.second, x_dims.size() * -1, x_dims)); + + // check duplicates in dims + flip_dims.erase(std::unique(flip_dims.begin(), flip_dims.end()), + flip_dims.end()); + PADDLE_ENFORCE_EQ(flip_dims.size(), flip_dims_size, + platform::errors::InvalidArgument( + "axes has duplicates, original flip axes size=%d, " + "but unique flip axes size=%d.)", + flip_dims_size, flip_dims.size())); + } + } VLOG(3) << "flip operator x.shape=" << x_dims; diff --git a/python/paddle/fluid/dataloader/pipeline.py b/python/paddle/fluid/dataloader/pipeline.py index f2da5b12102f7..2206f39b8abff 100755 --- a/python/paddle/fluid/dataloader/pipeline.py +++ b/python/paddle/fluid/dataloader/pipeline.py @@ -127,16 +127,16 @@ def __next__(self): "Pipeline not built, please call build() firstly" self._output_vars = self._prepare_output_vars() - try: - import sys - import time - tic = time.time() - _C_ops.dataloader(self._output_vars, *self._attrs) - toc = time.time() - print("_C_ops calling cost {}ms".format((toc - tic) * 1000.)) - sys.stdout.flush() - except: - raise StopIteration + # try: + import sys + import time + tic = time.time() + _C_ops.dataloader(self._output_vars, *self._attrs) + toc = time.time() + print("_C_ops calling cost {}ms".format((toc - tic) * 1000.)) + sys.stdout.flush() + # except: + # raise StopIteration return {k: v for k, v in zip(self._out_names, self._output_vars)} diff --git a/python/paddle/vision/ops.py b/python/paddle/vision/ops.py index 5faa991d6b576..f9084b629c7e5 100644 --- a/python/paddle/vision/ops.py +++ b/python/paddle/vision/ops.py @@ -1012,6 +1012,20 @@ def image_decode_random_crop(x, return out +def flip_vector(x, prob=0.5, name=None): + helper = LayerHelper("flip_vector", **locals()) + out = helper.create_variable( + name=unique_name.generate("flip_vector"), + type=core.VarDesc.VarType.LOD_TENSOR, + dtype=core.VarDesc.VarType.BOOL) + helper.append_op( + type="random_flip", + inputs={"X": x}, + outputs={"Out": out}, + attrs={"probability": prob}) + return out + + def random_flip(x, prob=0.5, name=None): if prob < 0. or prob > 1.: raise ValueError("prob should in (0, 1) in random_flip") @@ -1023,8 +1037,11 @@ def random_flip(x, prob=0.5, name=None): x[i] = paddle.flip(x[i], -1) return x - p = paddle.uniform([layers.shape(x)[0], 1], min=0., max=1.) - ie = layers.IfElse(p < prob) + # p = paddle.uniform([layers.shape(x)[0], 1], min=0., max=1.) + # prob = paddle.ones([layers.shape(x)[0], 1]) * prob + # cond = layers.less_than(p, prob) + cond = flip_vector(x, prob) + ie = layers.IfElse(cond) with ie.true_block(): out = ie.input(x) out = paddle.flip(x, -1)