Skip to content

Commit

Permalink
Add to_hash func and paddle2arg map for cinn (#49402)
Browse files Browse the repository at this point in the history
  • Loading branch information
WhatGhost authored Jan 5, 2023
1 parent 1228bad commit 1168a17
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 12 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/framework/paddle2cinn/cinn_cache_key.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ size_t CinnCacheKey::Hash::operator()(const CinnCacheKey& key) const {

for (const auto& name_shape : key.input_shapes_) {
has_str << name_shape.first;
has_str << name_shape.second.to_str();
has_str << std::hash<phi::DDim>()(name_shape.second);
}

has_str << key.graph_hash_val_;
Expand Down
17 changes: 7 additions & 10 deletions paddle/fluid/operators/cinn/cinn_launch_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ void CinnLaunchContext::InitializeArguments() {
framework::DDim(cinn_buffer->dims, cinn_buffer->dimensions).to_str(),
cinn_tensor->type());
name2argument_.emplace(arg, cinn_buffer.get());
auto pdvar2cinnbuf_ = cinn2paddle_varmap_.at(arg);
paddle2argument_.emplace(pdvar2cinnbuf_, cinn_buffer.get());
hold_buffers_.emplace_back(std::move(cinn_buffer));
}
VLOG(4) << "Total argument size:" << name2argument_.size();
Expand Down Expand Up @@ -491,17 +493,12 @@ framework::InterpreterCore* CinnLaunchContext::InitializeInterpreterCore(

cinn_buffer_t* CinnLaunchContext::GetCinnBufferOfVar(
const std::string& var_name) {
auto it = paddle2cinn_varmap_.find(var_name);
auto res = paddle2argument_.find(var_name);
PADDLE_ENFORCE_NE(
it,
paddle2cinn_varmap_.end(),
platform::errors::InvalidArgument(
"Variable(%s) not found in compilation result", var_name));
auto res = name2argument_.find(it->second);
PADDLE_ENFORCE_NE(res,
name2argument_.end(),
platform::errors::NotFound(
"Argument(%s) not be initialized", it->second));
res,
paddle2argument_.end(),
platform::errors::NotFound("Variable(%s) not found in compilation result",
var_name));
return static_cast<cinn_buffer_t*>(res->second);
}

Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/operators/cinn/cinn_launch_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@ class CinnLaunchContext {
// this map saves all execution arguments with their cinn names as key,
// and it is passed to the Execute interface of a cinn runtime program.
std::map<std::string, cinn_pod_value_t> name2argument_;
// this map saves all execution arguments with paddle variables as key,
// this map conbine name2argument_ and paddle2cinn_varmap_
std::map<std::string, cinn_pod_value_t> paddle2argument_;
};

} // namespace operators::details
Expand Down
13 changes: 13 additions & 0 deletions paddle/phi/core/ddim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,16 @@ DDim DDim::transpose(const std::vector<int>& axis) const {
}

} // namespace phi

namespace std {

std::size_t hash<phi::DDim>::operator()(phi::DDim const& ddim) const {
int ndim = ddim.size();
std::size_t seed = ndim;
for (int i = 0; i < ndim; ++i) {
seed ^= ddim.Get()[i] + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
return seed;
}

} // namespace std
9 changes: 8 additions & 1 deletion paddle/phi/core/ddim.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class DDim {

std::string to_str() const;

DDim reshape(std::vector<int>& shape) const;
DDim reshape(std::vector<int>& shape) const; // NOLINT

DDim transpose(const std::vector<int>& axis) const;

Expand Down Expand Up @@ -262,3 +262,10 @@ using DDim = phi::DDim;

} // namespace framework
} // namespace paddle

namespace std {
template <>
struct hash<phi::DDim> {
std::size_t operator()(phi::DDim const& ddim) const;
};
} // namespace std
8 changes: 8 additions & 0 deletions paddle/phi/tests/core/test_ddim.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -124,5 +124,13 @@ TEST(DDim, Print) {
EXPECT_EQ("", ss2.str());
}

TEST(DDim, Hash) {
// hash a DDim
std::size_t h;
phi::DDim ddim = phi::make_ddim({2, 3, 4});
h = std::hash<phi::DDim>()(ddim);
EXPECT_EQ(h, 0xa16fb2b2967ul);
}

} // namespace tests
} // namespace phi

0 comments on commit 1168a17

Please sign in to comment.