Skip to content

Commit

Permalink
Merge pull request #51 from Shixiaowei02/Shixiaowei02-dev/infrt_kerne…
Browse files Browse the repository at this point in the history
…ls_9

fix fake kernel execute bug
  • Loading branch information
Shixiaowei02 authored Feb 21, 2022
2 parents 71aa355 + ab9152b commit fe4a7ef
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 32 deletions.
2 changes: 2 additions & 0 deletions paddle/infrt/host_context/kernel_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ std::string KernelFrame::DumpArgTypes() const {
ss << "pten::CPUContext,";
} else if (value->is_type<host_context::None>()) {
ss << "none,";
} else if (value->is_type<backends::CpuPtenContext>()) {
ss << "CpuPtenContext,";
} else {
ss << "unk,";
}
Expand Down
15 changes: 10 additions & 5 deletions paddle/infrt/host_context/kernel_frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ class KernelFrame {
return value_or_attrs_[index]->template get_or_default<T>();
}

Value* GetElementAt(int index) {
CHECK_LT(static_cast<size_t>(index), GetNumElements());
return value_or_attrs_[index];
}

// Get number of elements, either input, attributes or results.
size_t GetNumElements() const { return value_or_attrs_.size(); }

Expand Down Expand Up @@ -79,11 +84,11 @@ class KernelFrame {
CHECK_EQ(num_results_, -1)
<< "Must call SetNumResults after calling AddAttribute";
value_or_attrs_.emplace_back(v);
if (num_attrs_ = -1) {
num_attrs_ = 1;
} else {
num_attrs_++;
}
if (num_attrs_ == -1) num_attrs_ = 0;
num_attrs_++;

CHECK_EQ(value_or_attrs_.size(),
static_cast<size_t>(num_arguments_ + num_attrs_));
}

template <typename T, typename... Args>
Expand Down
2 changes: 1 addition & 1 deletion paddle/infrt/host_context/kernel_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ struct KernelImpl<Return (*)(Args...), impl_fn> {
static_assert(const_idx == 0,
"Arguments and results should appear before attributes.");

auto* value = frame->GetArgAt(in_idx);
auto* value = frame->GetElementAt(in_idx);
auto&& arg = value->get<ArgT>();

KernelCallHelper<
Expand Down
38 changes: 19 additions & 19 deletions paddle/infrt/host_context/mlir_to_runtime_translate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -274,25 +274,6 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) {
<< GetValue(operand) << " vs " << arg_value;
}

// process results
llvm::SmallVector<Value*, 4> res_values;
for (int i = 0, e = op->getNumResults(); i < e; i++) {
auto res = op->getResult(i);
res_values.push_back(AddValue(res));

VLOG(3) << "* op mlir res: " << DumpToString(res) << " " << GetValue(res);
}
impl_->cur_op->SetResults(res_values);

#ifdef INFRT_DEBUG
{
VLOG(3) << "check result";
for (int i = 0; i < impl_->cur_op->frame().GetNumResults(); i++) {
VLOG(3) << "+ res value: " << impl_->cur_op->frame().GetResults()[i];
}
}
#endif

// process attributes
auto attrs = op->getAttrs();

Expand Down Expand Up @@ -325,6 +306,25 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) {
}
}

// process results
llvm::SmallVector<Value*, 4> res_values;
for (int i = 0, e = op->getNumResults(); i < e; i++) {
auto res = op->getResult(i);
res_values.push_back(AddValue(res));

VLOG(3) << "* op mlir res: " << DumpToString(res) << " " << GetValue(res);
}
impl_->cur_op->SetResults(res_values);

#ifdef INFRT_DEBUG
{
VLOG(3) << "check result";
for (int i = 0; i < impl_->cur_op->frame().GetNumResults(); i++) {
VLOG(3) << "+ res value: " << impl_->cur_op->frame().GetResults()[i];
}
}
#endif

// process regions, we treat regions as attribute.
auto num_regions = op->getNumRegions();
if (num_regions > 0) {
Expand Down
4 changes: 2 additions & 2 deletions paddle/infrt/host_context/value.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ class Value : public common::Object {

template <typename T>
T& get() {
LOG(INFO) << data.index();
CHECK(data.template is<T>());
CHECK(data.template is<T>()) << "typeid: " << data.index()
<< " != " << ValueVariantType::IndexOf<T>;
return data.get<T>();
}

Expand Down
9 changes: 7 additions & 2 deletions paddle/infrt/kernel/pten/infershaped/pten_kernel_launcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@

#include <llvm/ADT/SmallVector.h>

#include <iostream>

#include "paddle/infrt/backends/host/pten_context.h"
#include "paddle/infrt/host_context/kernel_utils.h"
#include "paddle/infrt/kernel/pten/infershaped/infershaped_kernel_launcher.h"
#include "paddle/infrt/kernel/pten/infershaped/infershaped_utils.h"
Expand All @@ -27,11 +30,13 @@ static void FakePtenInferShape(const ::pten::MetaTensor& a,
bool arg_0,
::pten::MetaTensor* c) {}

static void FakePtenKernel(const ::pten::CPUContext& /*Context*/,
static void FakePtenKernel(const backends::CpuPtenContext& /*Context*/,
const ::pten::DenseTensor& a,
const ::pten::DenseTensor& b,
bool arg_0,
::pten::DenseTensor* c) {}
::pten::DenseTensor* c) {
std::cout << "@FakePtenKernel@" << std::endl;
}

template <typename KernelFunc,
KernelFunc kernel,
Expand Down
2 changes: 1 addition & 1 deletion paddle/infrt/support/variant.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,10 @@ class Variant {

IndexT index() { return index_; }

private:
template <typename T>
static constexpr size_t IndexOf = TupleIndexOf<T, Types>::value;

private:
static constexpr size_t kStorageSize = std::max({sizeof(Ts)...});
static constexpr size_t kAlignment = std::max({alignof(Ts)...});

Expand Down
5 changes: 3 additions & 2 deletions paddle/infrt/tests/dialect/pten/dense_tensor.mlir
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
// RUN: infrtexec %s | FileCheck %s

// CHECK-LABEL: @basic_tensor
func @basic_tensor() {
// CHECK-LABEL: @fake_pten_kernel_execute
func @fake_pten_kernel_execute() {
%allocator = "pten_dt.create_allocator.cpu" (): () -> !pten.CPU_allocator
%ctx = "pten_dt.create_context.cpu" (): () -> !pten.CPU_context
%t = "pten_dt.create_dense_tensor.cpu.f32.nchw" (%allocator) {dims=[1:i64], lod=[1:i64]}: (!pten.CPU_allocator) -> (!infrt.tensor<X86, NCHW, F32>)

// CHECK: @FakePtenKernel@
%d = "pten_dt.fake_pten_kernel" (%ctx, %t, %t) {transpose_x=false} : (!pten.CPU_context, !infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>) -> (!infrt.tensor<X86, NCHW, F32>)
infrt.return
}

0 comments on commit fe4a7ef

Please sign in to comment.