diff --git a/paddle/fluid/compiler/piano/note/CMakeLists.txt b/paddle/fluid/compiler/piano/note/CMakeLists.txt index 2d176fe23895a..5490e68a31cdd 100644 --- a/paddle/fluid/compiler/piano/note/CMakeLists.txt +++ b/paddle/fluid/compiler/piano/note/CMakeLists.txt @@ -7,6 +7,7 @@ target_compile_options(note_proto PUBLIC "-Wno-extra") cc_library(note_template_util SRCS element_type_util.cc populate_attribute_value.cc DEPS note_proto) cc_test(note_element_type_util_test SRCS element_type_util_test.cc DEPS note_template_util) cc_test(note_populate_attribute_value_test SRCS populate_attribute_value_test.cc DEPS note_template_util) +cc_test(note_type_traits_test SRCS type_traits_test.cc DEPS note_proto) cc_library(note_ir SRCS instruction.cc function.cc module.cc DEPS note_opcode note_proto piano_data_description note_template_util) cc_test(note_ir_test SRCS note_ir_test.cc DEPS note_ir) diff --git a/paddle/fluid/compiler/piano/note/type_traits.h b/paddle/fluid/compiler/piano/note/type_traits.h index 46e6f032caeb4..6b84e6e06a7fd 100644 --- a/paddle/fluid/compiler/piano/note/type_traits.h +++ b/paddle/fluid/compiler/piano/note/type_traits.h @@ -224,6 +224,29 @@ class UnboxingIterator : public ForwardIterator { SmartPtrIter iter_; }; +template +struct IsOneOf : public std::false_type {}; + +template +struct IsOneOf + : public std::conditional::value, std::true_type, + IsOneOf>::type {}; + +template +struct IsVariantMember; + +template +struct IsVariantMember> + : public IsOneOf {}; + +template +struct IsOneOfAttrType : public IsVariantMember {}; + +template +struct IsVector : public std::false_type {}; +template +struct IsVector> : public std::true_type {}; + } // namespace note } // namespace piano } // namespace paddle diff --git a/paddle/fluid/compiler/piano/note/type_traits_test.cc b/paddle/fluid/compiler/piano/note/type_traits_test.cc new file mode 100644 index 0000000000000..3d2381050181c --- /dev/null +++ b/paddle/fluid/compiler/piano/note/type_traits_test.cc @@ -0,0 +1,35 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/compiler/piano/note/type_traits.h" +#include +#include "gtest/gtest.h" + +namespace paddle { +namespace piano { +namespace note { + +TEST(IsOneOfAttrTypeTest, Basic) { + ASSERT_FALSE(IsOneOfAttrType::value); + ASSERT_TRUE(IsOneOfAttrType>::value); +} + +TEST(IsVectorTest, Basic) { + ASSERT_FALSE(IsVector::value); + ASSERT_TRUE(IsVector>::value); +} + +} // namespace note +} // namespace piano +} // namespace paddle diff --git a/paddle/fluid/compiler/piano/shape.h b/paddle/fluid/compiler/piano/shape.h index 791ea4b47a84d..684cbe5281d33 100644 --- a/paddle/fluid/compiler/piano/shape.h +++ b/paddle/fluid/compiler/piano/shape.h @@ -15,7 +15,10 @@ limitations under the License. */ #pragma once #include +#include +#include #include +#include #include #include "paddle/fluid/compiler/piano/layout.h" #include "paddle/fluid/compiler/piano/note/note.pb.h" @@ -59,6 +62,15 @@ class Shape { return dimensions_.size(); } + int64_t Numel() const { + if (IsTuple()) { + return tuple_shapes_.size(); + } else { + return std::accumulate(std::begin(dimensions()), std::end(dimensions()), + 1, std::multiplies()); + } + } + // Returns whether the shape is of the specified type bool IsArray() const { return !IsTuple(); } bool IsTuple() const { return element_type() == note::ELEMENT_TYPE_TUPLE; } diff --git a/paddle/fluid/compiler/piano/shape_test.cc b/paddle/fluid/compiler/piano/shape_test.cc index 81f20c2a25589..6c43b9252f4a6 100644 --- a/paddle/fluid/compiler/piano/shape_test.cc +++ b/paddle/fluid/compiler/piano/shape_test.cc @@ -73,6 +73,16 @@ TEST_F(ShapeTest, ShapeToString) { ASSERT_EQ("f32[2, 3]{}", array_string); } +TEST_F(ShapeTest, Basic) { + // scalar + ASSERT_EQ(0, scalar_.Rank()); + ASSERT_EQ(1, scalar_.Numel()); + + // 2-D array + ASSERT_EQ(2, array1_.Rank()); + ASSERT_EQ(18, array1_.Numel()); +} + TEST_F(ShapeTest, EqualToOther) { Shape s8_d23_no_layout(note::S32, {2, 3}); Shape f32_d23_no_layout(note::F32, {2, 3}); diff --git a/paddle/fluid/compiler/piano/symbolization/meta_op.cc b/paddle/fluid/compiler/piano/symbolization/meta_op.cc index 0ec08ee20da2d..399107d57eb65 100644 --- a/paddle/fluid/compiler/piano/symbolization/meta_op.cc +++ b/paddle/fluid/compiler/piano/symbolization/meta_op.cc @@ -17,7 +17,6 @@ limitations under the License. */ #include #include #include "paddle/fluid/compiler/piano/note/note.pb.h" -#include "paddle/fluid/compiler/piano/symbolization/shape_inference.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { diff --git a/paddle/fluid/compiler/piano/symbolization/meta_op.h b/paddle/fluid/compiler/piano/symbolization/meta_op.h index 3d0c68c4c131e..d2e202b871356 100644 --- a/paddle/fluid/compiler/piano/symbolization/meta_op.h +++ b/paddle/fluid/compiler/piano/symbolization/meta_op.h @@ -16,13 +16,16 @@ limitations under the License. */ #include #include +#include #include #include #include "paddle/fluid/compiler/piano/note/attribute_key_defs.h" #include "paddle/fluid/compiler/piano/note/element_type_util.h" #include "paddle/fluid/compiler/piano/note/populate_attribute_value.h" +#include "paddle/fluid/compiler/piano/note/type_traits.h" #include "paddle/fluid/compiler/piano/shape.h" #include "paddle/fluid/compiler/piano/symbolization/note_builder.h" +#include "paddle/fluid/compiler/piano/symbolization/shape_inference.h" namespace paddle { namespace piano { @@ -36,15 +39,20 @@ class Operand; Operand Parameter(NoteBuilder* builder, int64_t parameter_index, const Shape& shape, const std::string& name); -// a constant instruction passing literal 'value' with 0-D array(scalar) -// `builder`: NoteBuilder of current module -// `value`: The scalar value. users should explicitly specifiy the -// data type when the value may be obscure and deduced to -// another compatible type +// a constant instruction literal 'value' with N-D array +// (scalar or multi-dimension array) +// `builder`: NoteBuilder of current module +// `value`: The literal value. users should explicitly specifiy the +// data type when the value may be obscure and deduced to +// another compatible type +// `shape`: Shape of the literal template -Operand ConstantD0(NoteBuilder* builder, NativeT value) { - // construct shape - Shape result_shape(note::NativeToElementTypeProto(), {}); +Operand Constant(NoteBuilder* builder, const NativeT& value, + const Shape& shape) { + static_assert(note::IsOneOfAttrType::value, + "This NativeT is not supported in Constant"); + + auto result_shape = InferConstantShape(value, shape); note::InstructionProto instr; *instr.mutable_shape() = result_shape.ToProto(); // fill attribute of kConstant instruction diff --git a/paddle/fluid/compiler/piano/symbolization/meta_op_test.cc b/paddle/fluid/compiler/piano/symbolization/meta_op_test.cc index c8afbbb511e3f..c3dd18b1fe31c 100644 --- a/paddle/fluid/compiler/piano/symbolization/meta_op_test.cc +++ b/paddle/fluid/compiler/piano/symbolization/meta_op_test.cc @@ -32,24 +32,34 @@ TEST(MetaOpTest, TestParameter) { EXPECT_EQ(Shape(note::F32, {1, 2}), param_op.Shape()); } -TEST(MetaOpTest, TestConstantD0) { - NoteBuilder builder("test_constant_d0"); - auto constant_d0_op = ConstantD0(&builder, 110); +TEST(MetaOpTest, TestConstant) { + NoteBuilder builder("test_constant"); + // add a constant instruction with scalar value + auto constant_d0_op = Constant(&builder, 110, Shape(note::S32, {})); ASSERT_EQ(&builder, constant_d0_op.Builder()); EXPECT_TRUE(constant_d0_op.Valid()); EXPECT_EQ(Shape(note::S32, {}), constant_d0_op.Shape()); + // append a constant instruction with 2-D array value + auto constant_d2_op = Constant(&builder, std::vector({110, 119}), + Shape(note::S32, {1, 2})); + ASSERT_EQ(&builder, constant_d2_op.Builder()); + EXPECT_TRUE(constant_d2_op.Valid()); + EXPECT_EQ(Shape(note::S32, {1, 2}), constant_d2_op.Shape()); + // check the final build module auto&& module_proto = builder.Build(); ASSERT_EQ(1, module_proto.functions_size()); const auto& entry_proto = module_proto.functions(0); - ASSERT_EQ(1, entry_proto.instructions_size()); + ASSERT_EQ(2, entry_proto.instructions_size()); EXPECT_EQ(note::GetOpName(note::OpCode::kConstant), entry_proto.instructions(0).opcode()); - const auto& constant_instr = entry_proto.instructions(0); - ASSERT_EQ(1, constant_instr.attrs().size()); - const auto& attr_value = constant_instr.attrs().at(note::kConstantValue); - ASSERT_TRUE(attr_value.has_i()); - EXPECT_EQ(110, attr_value.i()); + const auto& constant_d2_instr = entry_proto.instructions(1); + ASSERT_EQ(1, constant_d2_instr.attrs().size()); + const auto& attr_value = constant_d2_instr.attrs().at(note::kConstantValue); + ASSERT_TRUE(attr_value.has_ints()); + EXPECT_EQ(2, attr_value.ints().value_size()); + EXPECT_EQ(110, attr_value.ints().value(0)); + EXPECT_EQ(119, attr_value.ints().value(1)); } TEST(MetaOpTest, TestBroadcast) { diff --git a/paddle/fluid/compiler/piano/symbolization/shape_inference.h b/paddle/fluid/compiler/piano/symbolization/shape_inference.h index 61583bdfb28f3..7b89f8c92a9ee 100644 --- a/paddle/fluid/compiler/piano/symbolization/shape_inference.h +++ b/paddle/fluid/compiler/piano/symbolization/shape_inference.h @@ -16,6 +16,7 @@ limitations under the License. */ #include #include "paddle/fluid/compiler/piano/note/opcode.h" +#include "paddle/fluid/compiler/piano/note/type_traits.h" #include "paddle/fluid/compiler/piano/shape.h" namespace paddle { @@ -39,6 +40,32 @@ Shape InferBroadcastShape(const Shape& input_shape, const std::vector& out_dimensions, const std::vector& dimensions_alignment); +// inference for constant operation +template +typename std::enable_if::value>::type ValidateShape( + const NativeT& value, const Shape& shape) { + PADDLE_ENFORCE_EQ(shape.IsArray(), true, + platform::errors::InvalidArgument( + "Shape of vector input should be array tuple")); + PADDLE_ENFORCE_EQ( + shape.Numel(), value.size(), + platform::errors::InvalidArgument("Number of element should be euqal to" + "the shape contains")); +} + +template +typename std::enable_if::value>::type ValidateShape( + const NativeT& value, const Shape& shape) { + PADDLE_ENFORCE_EQ(shape.Rank(), 0, platform::errors::InvalidArgument( + "Rank of Scalar value should be 0")); +} + +template +Shape InferConstantShape(const NativeT& value, const Shape& shape) { + ValidateShape(value, shape); + return shape; +} + } // namespace symbolization } // namespace piano } // namespace paddle diff --git a/paddle/fluid/compiler/piano/symbolization/shape_inference_test.cc b/paddle/fluid/compiler/piano/symbolization/shape_inference_test.cc index e6d1cc5dee3b5..808324506d55e 100644 --- a/paddle/fluid/compiler/piano/symbolization/shape_inference_test.cc +++ b/paddle/fluid/compiler/piano/symbolization/shape_inference_test.cc @@ -63,6 +63,22 @@ TEST(ShapeInferenceTest, TestInferBroadcastShape) { ASSERT_EQ(Shape(note::U64, {2, 3, 6}), res2); } +TEST(ShapeInferenceTest, TestInferConstantShape) { + // check validation on scalar value + ASSERT_THROW(InferConstantShape(110, Shape(note::F32, {1})), + paddle::platform::EnforceNotMet); + + // check validation on multi-dimension array value + ASSERT_THROW(InferConstantShape(std::vector({110, 119}), + Shape(note::F32, {1})), + paddle::platform::EnforceNotMet); + + // normal call + ASSERT_EQ(Shape(note::F32, {1, 2}), + InferConstantShape(std::vector({110, 119}), + Shape(note::F32, {1, 2}))); +} + } // namespace symbolization } // namespace piano } // namespace paddle