Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Shixiaowei02 committed Feb 21, 2022
1 parent a3633d5 commit 71aa355
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 9 deletions.
23 changes: 19 additions & 4 deletions paddle/infrt/host_context/kernel_frame_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <gtest/gtest.h>

#include "paddle/infrt/host_context/kernel_frame.h"
#include "paddle/infrt/host_context/kernel_utils.h"

namespace infrt {
namespace host_context {
Expand Down Expand Up @@ -46,21 +47,34 @@ TEST(KernelRegistry, basic) {
}
*/

void TestFunc(const std::string& arg_0,
const std::string& arg_1,
const std::string& arg_2,
Attribute<std::string> attr_0,
Result<std::string> res_0,
Result<std::string> res_1) {
CHECK_EQ(arg_0, "arg_0");
CHECK_EQ(arg_1, "arg_1");
CHECK_EQ(arg_2, "arg_2");
CHECK_EQ(attr_0.get(), "attr_0");

// res_0.Set(Argument<std::string>(ValueRef(new Value())));
// res_1.Set(Argument<std::string>(ValueRef(new Value())));
}

TEST(KernelRegistry, basic) {
KernelFrameBuilder kernel_frame;

Value arg_0(std::string{"arg_0"});
Value arg_1(std::string{"arg_1"});
Value arg_2(std::string{"arg_2"});
Value attr_0(std::string{"attr_0"});
Value res_0(std::string{"res_0"});
Value res_1(std::string{"res_1"});

kernel_frame.AddArgument(&arg_0);
kernel_frame.AddArgument(&arg_1);
kernel_frame.AddArgument(&arg_2);
kernel_frame.AddAttribute(&attr_0);
kernel_frame.SetResults({&res_0, &res_1});
kernel_frame.SetNumResults(2);

CHECK_EQ(kernel_frame.GetNumArgs(), 3);
CHECK_EQ(kernel_frame.GetNumResults(), 2);
Expand All @@ -69,7 +83,8 @@ TEST(KernelRegistry, basic) {

CHECK_EQ(kernel_frame.GetArgAt<std::string>(2), "arg_2");
CHECK_EQ(kernel_frame.GetAttributeAt(0)->get<std::string>(), "attr_0");
CHECK_EQ(kernel_frame.GetResults()[1]->get<std::string>(), "res_1");

KernelImpl<decltype(&TestFunc), TestFunc>::Invoke(&kernel_frame);
}

} // namespace host_context
Expand Down
12 changes: 7 additions & 5 deletions paddle/infrt/host_context/kernel_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,11 @@ struct KernelImpl<Return (*)(Args...), impl_fn> {
static void Invoke(KernelFrame* frame, const PreviousArgs&... pargs) {
static_assert(out_idx != -1,
"Do not place Results after RemainingResults");
static_assert(const_idx == 0,
"Arguments and results should appear before attributes");
Result<Head> arg(&frame->GetResults()[out_idx]);
// static_assert(const_idx == 0,
// "Arguments and results should appear before attributes");

// Result<Head> arg(&frame->GetResults()[out_idx]);
Result<Head> arg(new ValueRef());
KernelCallHelper<
Tail...>::template Invoke<in_idx, out_idx + 1, const_idx>(frame,
pargs...,
Expand All @@ -224,8 +226,8 @@ struct KernelImpl<Return (*)(Args...), impl_fn> {
struct KernelCallHelper<Attribute<Head>, Tail...> {
template <int in_idx, int out_idx, int const_idx, typename... PreviousArgs>
static void Invoke(KernelFrame* frame, const PreviousArgs&... pargs) {
static_assert(const_idx != -1,
"Do not place Attributes after RemainingAttributes");
// static_assert(const_idx != -1,
// "Do not place Attributes after RemainingAttributes");
Attribute<Head> arg(frame->GetAttributeAt(const_idx));
KernelCallHelper<
Tail...>::template Invoke<in_idx, out_idx, const_idx + 1>(frame,
Expand Down

0 comments on commit 71aa355

Please sign in to comment.