Skip to content

Commit

Permalink
unifiy constant instruction to support N-D array (PaddlePaddle#44)
Browse files Browse the repository at this point in the history
* unifiy constant instruction to support N-D array

* update template typename
  • Loading branch information
CtfGo authored Sep 2, 2021
1 parent c29ab78 commit e7f279f
Show file tree
Hide file tree
Showing 10 changed files with 159 additions and 18 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/compiler/piano/note/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ target_compile_options(note_proto PUBLIC "-Wno-extra")
cc_library(note_template_util SRCS element_type_util.cc populate_attribute_value.cc DEPS note_proto)
cc_test(note_element_type_util_test SRCS element_type_util_test.cc DEPS note_template_util)
cc_test(note_populate_attribute_value_test SRCS populate_attribute_value_test.cc DEPS note_template_util)
cc_test(note_type_traits_test SRCS type_traits_test.cc DEPS note_proto)

cc_library(note_ir SRCS instruction.cc function.cc module.cc DEPS note_opcode note_proto piano_data_description note_template_util)
cc_test(note_ir_test SRCS note_ir_test.cc DEPS note_ir)
23 changes: 23 additions & 0 deletions paddle/fluid/compiler/piano/note/type_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,29 @@ class UnboxingIterator : public ForwardIterator<SmartPtrIter> {
SmartPtrIter iter_;
};

template <typename T, typename... AllType>
struct IsOneOf : public std::false_type {};

template <typename T, typename FrontType, typename... RestType>
struct IsOneOf<T, FrontType, RestType...>
: public std::conditional<std::is_same<T, FrontType>::value, std::true_type,
IsOneOf<T, RestType...>>::type {};

template <typename T, typename VariantType>
struct IsVariantMember;

template <typename T, typename... AllType>
struct IsVariantMember<T, boost::variant<AllType...>>
: public IsOneOf<T, AllType...> {};

template <typename T>
struct IsOneOfAttrType : public IsVariantMember<T, AttrType> {};

template <typename T>
struct IsVector : public std::false_type {};
template <typename T, typename A>
struct IsVector<std::vector<T, A>> : public std::true_type {};

} // namespace note
} // namespace piano
} // namespace paddle
35 changes: 35 additions & 0 deletions paddle/fluid/compiler/piano/note/type_traits_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/compiler/piano/note/type_traits.h"
#include <type_traits>
#include "gtest/gtest.h"

namespace paddle {
namespace piano {
namespace note {

TEST(IsOneOfAttrTypeTest, Basic) {
ASSERT_FALSE(IsOneOfAttrType<int8_t>::value);
ASSERT_TRUE(IsOneOfAttrType<std::vector<double>>::value);
}

TEST(IsVectorTest, Basic) {
ASSERT_FALSE(IsVector<int>::value);
ASSERT_TRUE(IsVector<std::vector<int>>::value);
}

} // namespace note
} // namespace piano
} // namespace paddle
12 changes: 12 additions & 0 deletions paddle/fluid/compiler/piano/shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ limitations under the License. */
#pragma once

#include <cstdint>
#include <functional>
#include <numeric>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/compiler/piano/layout.h"
#include "paddle/fluid/compiler/piano/note/note.pb.h"
Expand Down Expand Up @@ -59,6 +62,15 @@ class Shape {
return dimensions_.size();
}

int64_t Numel() const {
if (IsTuple()) {
return tuple_shapes_.size();
} else {
return std::accumulate(std::begin(dimensions()), std::end(dimensions()),
1, std::multiplies<int64_t>());
}
}

// Returns whether the shape is of the specified type
bool IsArray() const { return !IsTuple(); }
bool IsTuple() const { return element_type() == note::ELEMENT_TYPE_TUPLE; }
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/compiler/piano/shape_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,16 @@ TEST_F(ShapeTest, ShapeToString) {
ASSERT_EQ("f32[2, 3]{}", array_string);
}

TEST_F(ShapeTest, Basic) {
// scalar
ASSERT_EQ(0, scalar_.Rank());
ASSERT_EQ(1, scalar_.Numel());

// 2-D array
ASSERT_EQ(2, array1_.Rank());
ASSERT_EQ(18, array1_.Numel());
}

TEST_F(ShapeTest, EqualToOther) {
Shape s8_d23_no_layout(note::S32, {2, 3});
Shape f32_d23_no_layout(note::F32, {2, 3});
Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/compiler/piano/symbolization/meta_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License. */
#include <numeric>
#include <utility>
#include "paddle/fluid/compiler/piano/note/note.pb.h"
#include "paddle/fluid/compiler/piano/symbolization/shape_inference.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
Expand Down
24 changes: 16 additions & 8 deletions paddle/fluid/compiler/piano/symbolization/meta_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@ limitations under the License. */

#include <cstdint>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>
#include "paddle/fluid/compiler/piano/note/attribute_key_defs.h"
#include "paddle/fluid/compiler/piano/note/element_type_util.h"
#include "paddle/fluid/compiler/piano/note/populate_attribute_value.h"
#include "paddle/fluid/compiler/piano/note/type_traits.h"
#include "paddle/fluid/compiler/piano/shape.h"
#include "paddle/fluid/compiler/piano/symbolization/note_builder.h"
#include "paddle/fluid/compiler/piano/symbolization/shape_inference.h"

namespace paddle {
namespace piano {
Expand All @@ -36,15 +39,20 @@ class Operand;
Operand Parameter(NoteBuilder* builder, int64_t parameter_index,
const Shape& shape, const std::string& name);

// a constant instruction passing literal 'value' with 0-D array(scalar)
// `builder`: NoteBuilder of current module
// `value`: The scalar value. users should explicitly specifiy the
// data type when the value may be obscure and deduced to
// another compatible type
// a constant instruction literal 'value' with N-D array
// (scalar or multi-dimension array)
// `builder`: NoteBuilder of current module
// `value`: The literal value. users should explicitly specifiy the
// data type when the value may be obscure and deduced to
// another compatible type
// `shape`: Shape of the literal
template <typename NativeT>
Operand ConstantD0(NoteBuilder* builder, NativeT value) {
// construct shape
Shape result_shape(note::NativeToElementTypeProto<NativeT>(), {});
Operand Constant(NoteBuilder* builder, const NativeT& value,
const Shape& shape) {
static_assert(note::IsOneOfAttrType<NativeT>::value,
"This NativeT is not supported in Constant");

auto result_shape = InferConstantShape<NativeT>(value, shape);
note::InstructionProto instr;
*instr.mutable_shape() = result_shape.ToProto();
// fill attribute of kConstant instruction
Expand Down
28 changes: 19 additions & 9 deletions paddle/fluid/compiler/piano/symbolization/meta_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,34 @@ TEST(MetaOpTest, TestParameter) {
EXPECT_EQ(Shape(note::F32, {1, 2}), param_op.Shape());
}

TEST(MetaOpTest, TestConstantD0) {
NoteBuilder builder("test_constant_d0");
auto constant_d0_op = ConstantD0<int32_t>(&builder, 110);
TEST(MetaOpTest, TestConstant) {
NoteBuilder builder("test_constant");
// add a constant instruction with scalar value
auto constant_d0_op = Constant<int32_t>(&builder, 110, Shape(note::S32, {}));
ASSERT_EQ(&builder, constant_d0_op.Builder());
EXPECT_TRUE(constant_d0_op.Valid());
EXPECT_EQ(Shape(note::S32, {}), constant_d0_op.Shape());
// append a constant instruction with 2-D array value
auto constant_d2_op = Constant(&builder, std::vector<int32_t>({110, 119}),
Shape(note::S32, {1, 2}));
ASSERT_EQ(&builder, constant_d2_op.Builder());
EXPECT_TRUE(constant_d2_op.Valid());
EXPECT_EQ(Shape(note::S32, {1, 2}), constant_d2_op.Shape());

// check the final build module
auto&& module_proto = builder.Build();
ASSERT_EQ(1, module_proto.functions_size());
const auto& entry_proto = module_proto.functions(0);
ASSERT_EQ(1, entry_proto.instructions_size());
ASSERT_EQ(2, entry_proto.instructions_size());
EXPECT_EQ(note::GetOpName(note::OpCode::kConstant),
entry_proto.instructions(0).opcode());
const auto& constant_instr = entry_proto.instructions(0);
ASSERT_EQ(1, constant_instr.attrs().size());
const auto& attr_value = constant_instr.attrs().at(note::kConstantValue);
ASSERT_TRUE(attr_value.has_i());
EXPECT_EQ(110, attr_value.i());
const auto& constant_d2_instr = entry_proto.instructions(1);
ASSERT_EQ(1, constant_d2_instr.attrs().size());
const auto& attr_value = constant_d2_instr.attrs().at(note::kConstantValue);
ASSERT_TRUE(attr_value.has_ints());
EXPECT_EQ(2, attr_value.ints().value_size());
EXPECT_EQ(110, attr_value.ints().value(0));
EXPECT_EQ(119, attr_value.ints().value(1));
}

TEST(MetaOpTest, TestBroadcast) {
Expand Down
27 changes: 27 additions & 0 deletions paddle/fluid/compiler/piano/symbolization/shape_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License. */

#include <vector>
#include "paddle/fluid/compiler/piano/note/opcode.h"
#include "paddle/fluid/compiler/piano/note/type_traits.h"
#include "paddle/fluid/compiler/piano/shape.h"

namespace paddle {
Expand All @@ -39,6 +40,32 @@ Shape InferBroadcastShape(const Shape& input_shape,
const std::vector<int64_t>& out_dimensions,
const std::vector<int64_t>& dimensions_alignment);

// inference for constant operation
template <typename NativeT>
typename std::enable_if<note::IsVector<NativeT>::value>::type ValidateShape(
const NativeT& value, const Shape& shape) {
PADDLE_ENFORCE_EQ(shape.IsArray(), true,
platform::errors::InvalidArgument(
"Shape of vector input should be array tuple"));
PADDLE_ENFORCE_EQ(
shape.Numel(), value.size(),
platform::errors::InvalidArgument("Number of element should be euqal to"
"the shape contains"));
}

template <typename NativeT>
typename std::enable_if<!note::IsVector<NativeT>::value>::type ValidateShape(
const NativeT& value, const Shape& shape) {
PADDLE_ENFORCE_EQ(shape.Rank(), 0, platform::errors::InvalidArgument(
"Rank of Scalar value should be 0"));
}

template <typename NativeT>
Shape InferConstantShape(const NativeT& value, const Shape& shape) {
ValidateShape(value, shape);
return shape;
}

} // namespace symbolization
} // namespace piano
} // namespace paddle
16 changes: 16 additions & 0 deletions paddle/fluid/compiler/piano/symbolization/shape_inference_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,22 @@ TEST(ShapeInferenceTest, TestInferBroadcastShape) {
ASSERT_EQ(Shape(note::U64, {2, 3, 6}), res2);
}

TEST(ShapeInferenceTest, TestInferConstantShape) {
// check validation on scalar value
ASSERT_THROW(InferConstantShape<int32_t>(110, Shape(note::F32, {1})),
paddle::platform::EnforceNotMet);

// check validation on multi-dimension array value
ASSERT_THROW(InferConstantShape(std::vector<int32_t>({110, 119}),
Shape(note::F32, {1})),
paddle::platform::EnforceNotMet);

// normal call
ASSERT_EQ(Shape(note::F32, {1, 2}),
InferConstantShape(std::vector<int32_t>({110, 119}),
Shape(note::F32, {1, 2})));
}

} // namespace symbolization
} // namespace piano
} // namespace paddle

0 comments on commit e7f279f

Please sign in to comment.