Skip to content

Commit

Permalink
Changes the order of kernel frame, test=develop
Browse files Browse the repository at this point in the history
  • Loading branch information
Shixiaowei02 committed Feb 19, 2022
1 parent e410bfa commit a3633d5
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 27 deletions.
1 change: 1 addition & 0 deletions paddle/infrt/host_context/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ gather_srcs(infrt_src SRCS

cc_test_tiny(test_infrt_host_context_value SRCS value_test.cc DEPS infrt ${MLIR_IR_LIBS})
cc_test_tiny(test_infrt_kernel_utils SRCS kernel_utils_test.cc DEPS infrt ${MLIR_IR_LIBS})
cc_test_tiny(test_infrt_kernel_frame SRCS kernel_frame_test.cc DEPS infrt ${MLIR_IR_LIBS})
cc_test_tiny(test_infrt_kernel_registry SRCS kernel_registry_test.cc DEPS infrt ${MLIR_IR_LIBS})
cc_test_tiny(test_infrt_op_executable SRCS op_executable_test.cc DEPS infrt ${MLIR_IR_LIBS})
cc_test_tiny(test_infrt_core_runtime SRCS core_runtime_test.cc DEPS infrt ${MLIR_IR_LIBS})
Expand Down
55 changes: 30 additions & 25 deletions paddle/infrt/host_context/kernel_frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,10 @@ namespace host_context {
class KernelFrame {
public:
int GetNumArgs() const { return num_arguments_; }
int GetNumResults() const { return num_results_ == -1 ? 0 : num_results_; }
int GetNumAttributes() const {
return value_or_attrs_.size() - num_arguments_ -
(num_results_ == -1 ? 0 : num_results_);
int GetNumResults() const {
return value_or_attrs_.size() - num_arguments_ - GetNumAttributes();
}
int GetNumAttributes() const { return num_attrs_ == -1 ? 0 : num_attrs_; }

//! Get something at a specific position \p index. The element might be an
//! argument, an attribute or a result.
Expand Down Expand Up @@ -72,16 +71,19 @@ class KernelFrame {
Value* GetAttributeAt(int idx) {
// CHECK_NE(num_results_, -1)
//<< "Must call SetNumResults before GetAttributeAt";
CHECK_LT(idx,
static_cast<int>(value_or_attrs_.size() - num_arguments_ -
num_results_));
return value_or_attrs_[num_arguments_ + num_results_ + idx];
CHECK_LT(idx, GetNumAttributes());
return value_or_attrs_[num_arguments_ + idx];
}

void AddAttribute(Value* v) {
CHECK_NE(num_results_, -1)
<< "Must call SetNumResults before calling AddAttribute";
CHECK_EQ(num_results_, -1)
<< "Must call SetNumResults after calling AddAttribute";
value_or_attrs_.emplace_back(v);
if (num_attrs_ = -1) {
num_attrs_ = 1;
} else {
num_attrs_++;
}
}

template <typename T, typename... Args>
Expand All @@ -96,16 +98,17 @@ class KernelFrame {

template <typename T>
void SetResultAt(int index, T&& value) {
CHECK_LT(index, num_results_) << "Invalid result index";
CHECK(value_or_attrs_[num_arguments_ + index]);
value_or_attrs_[num_arguments_ + index]->set(std::move(value));
CHECK_LT(index, GetNumResults()) << "Invalid result index";
CHECK(value_or_attrs_[num_arguments_ + GetNumAttributes() + index]);
value_or_attrs_[num_arguments_ + GetNumAttributes() + index]->set(
std::move(value));
}

llvm::ArrayRef<Value*> GetResults() const {
return GetValues(num_arguments_, num_results_);
return GetValues(num_arguments_ + GetNumAttributes(), num_results_);
}
llvm::MutableArrayRef<Value*> GetResults() {
return GetMutableValues(num_arguments_, num_results_);
return GetMutableValues(num_arguments_ + GetNumAttributes(), num_results_);
}

llvm::ArrayRef<Value*> GetValues(size_t from, size_t length) const {
Expand All @@ -129,6 +132,7 @@ class KernelFrame {

protected:
int num_arguments_{};
int num_attrs_{-1};
int num_results_{-1};

llvm::SmallVector<Value*, 8> value_or_attrs_;
Expand All @@ -140,44 +144,45 @@ class KernelFrameBuilder : public KernelFrame {
public:
void AddArgument(Value* value) {
CHECK(value);
CHECK_EQ(num_results_, -1)
<< "Should call AddArgument before calling SetNumResults";
CHECK_EQ(num_attrs_, -1)
<< "Should call AddArgument before calling SetAttributes";
value_or_attrs_.push_back(value);
++num_arguments_;
}

void SetResults(llvm::ArrayRef<Value*> values) {
CHECK_EQ(num_arguments_, static_cast<int>(value_or_attrs_.size()));
CHECK_EQ(num_results_, -1);
CHECK_EQ(num_arguments_ + GetNumAttributes(),
static_cast<int>(value_or_attrs_.size()));
for (Value* x : values) {
value_or_attrs_.push_back(x);
}
num_results_ = values.size();
}

void SetNumResults(size_t n) {
CHECK_EQ(num_arguments_, static_cast<int>(value_or_attrs_.size()));
CHECK_EQ(num_results_, -1);
num_results_ = n;
CHECK_EQ(num_arguments_ + GetNumAttributes(),
static_cast<int>(value_or_attrs_.size()));
for (size_t i = 0; i < n; i++) {
value_or_attrs_.emplace_back(new Value);
}
}

void SetResultAt(int result_id, Value* value) {
CHECK_EQ(static_cast<int>(value_or_attrs_.size()),
num_arguments_ + num_results_)
num_arguments_ + GetNumAttributes() + num_results_)
<< "Call SetNumResults first";
CHECK_LT(result_id + num_arguments_,
CHECK_LT(result_id + num_arguments_ + GetNumAttributes(),
static_cast<int>(value_or_attrs_.size()));
CHECK(value);
value_or_attrs_[num_arguments_ + result_id]->set(value);
value_or_attrs_[num_arguments_ + GetNumAttributes() + result_id]->set(
value);
}

void Reset() {
value_or_attrs_.clear();
num_arguments_ = 0;
num_results_ = -1;
num_attrs_ = -1;
}
};

Expand Down
76 changes: 76 additions & 0 deletions paddle/infrt/host_context/kernel_frame_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <gtest/gtest.h>

#include "paddle/infrt/host_context/kernel_frame.h"

namespace infrt {
namespace host_context {
/*
TEST(KernelRegistry, basic) {
KernelFrameBuilder kernel_frame;
Value arg_0(std::string{"arg_0"});
Value arg_1(std::string{"arg_1"});
Value arg_2(std::string{"arg_2"});
Value res_0(std::string{"res_0"});
Value res_1(std::string{"res_1"});
Value attr_0(std::string{"attr_0"});
kernel_frame.AddArgument(&arg_0);
kernel_frame.AddArgument(&arg_1);
kernel_frame.AddArgument(&arg_2);
kernel_frame.SetResults({&res_0, &res_1});
kernel_frame.AddAttribute(&attr_0);
CHECK_EQ(kernel_frame.GetNumArgs(), 3);
CHECK_EQ(kernel_frame.GetNumResults(), 2);
CHECK_EQ(kernel_frame.GetNumAttributes(), 1);
CHECK_EQ(kernel_frame.GetNumElements(), 6UL);
CHECK_EQ(kernel_frame.GetArgAt<std::string>(2), "arg_2");
CHECK_EQ(kernel_frame.GetAttributeAt(0)->get<std::string>(), "attr_0");
CHECK_EQ(kernel_frame.GetResults()[1]->get<std::string>(), "res_1");
}
*/

TEST(KernelRegistry, basic) {
KernelFrameBuilder kernel_frame;

Value arg_0(std::string{"arg_0"});
Value arg_1(std::string{"arg_1"});
Value arg_2(std::string{"arg_2"});
Value attr_0(std::string{"attr_0"});
Value res_0(std::string{"res_0"});
Value res_1(std::string{"res_1"});

kernel_frame.AddArgument(&arg_0);
kernel_frame.AddArgument(&arg_1);
kernel_frame.AddArgument(&arg_2);
kernel_frame.AddAttribute(&attr_0);
kernel_frame.SetResults({&res_0, &res_1});

CHECK_EQ(kernel_frame.GetNumArgs(), 3);
CHECK_EQ(kernel_frame.GetNumResults(), 2);
CHECK_EQ(kernel_frame.GetNumAttributes(), 1);
CHECK_EQ(kernel_frame.GetNumElements(), 6UL);

CHECK_EQ(kernel_frame.GetArgAt<std::string>(2), "arg_2");
CHECK_EQ(kernel_frame.GetAttributeAt(0)->get<std::string>(), "attr_0");
CHECK_EQ(kernel_frame.GetResults()[1]->get<std::string>(), "res_1");
}

} // namespace host_context
} // namespace infrt
4 changes: 2 additions & 2 deletions paddle/infrt/kernel/pten/infershaped/pten_kernel_launcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ namespace kernel {

static void FakePtenInferShape(const ::pten::MetaTensor& a,
const ::pten::MetaTensor& b,
host_context::Attribute<bool> arg_0,
bool arg_0,
::pten::MetaTensor* c) {}

static void FakePtenKernel(const ::pten::CPUContext& /*Context*/,
const ::pten::DenseTensor& a,
const ::pten::DenseTensor& b,
host_context::Attribute<bool> arg_0,
bool arg_0,
::pten::DenseTensor* c) {}

template <typename KernelFunc,
Expand Down

0 comments on commit a3633d5

Please sign in to comment.