Skip to content

Commit

Permalink
[Dy2St][PIR] Copy stop_gradient in while_loop fake_value and arg (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
SigureMo authored Feb 5, 2024
1 parent b4e13a6 commit 9882b25
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 2 deletions.
5 changes: 5 additions & 0 deletions paddle/pir/core/block_argument.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ uint32_t BlockArgument::index() const {
return IMPL_->index_;
}

const AttributeMap &BlockArgument::attributes() const {
CHECK_NULL_IMPL(attributes_);
return IMPL_->attributes_;
}

Attribute BlockArgument::attribute(const std::string &key) const {
return impl_ ? IMPL_->attribute(key) : nullptr;
}
Expand Down
2 changes: 2 additions & 0 deletions paddle/pir/core/block_argument.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#pragma once

#include "paddle/pir/core/operation_utils.h"
#include "paddle/pir/core/value.h"
namespace pir {
class Block;
Expand All @@ -33,6 +34,7 @@ class IR_API BlockArgument : public Value {
Block *owner() const;
uint32_t index() const;

const AttributeMap &attributes() const;
Attribute attribute(const std::string &key) const;
void set_attribute(const std::string &key, Attribute value);

Expand Down
6 changes: 5 additions & 1 deletion paddle/pir/core/region.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ void Region::CloneInto(Region &other, IrMapping &ir_mapping) const {
auto new_block = new Block;
ir_mapping.Add(&block, new_block);
for (const auto &arg : block.args()) {
ir_mapping.Add(arg, new_block->AddArg(arg.type()));
auto new_arg = new_block->AddArg(arg.type());
ir_mapping.Add(arg, new_arg);
for (auto &attr : arg.dyn_cast<BlockArgument>().attributes()) {
new_arg.set_attribute(attr.first, attr.second);
}
}
other.push_back(new_block);
}
Expand Down
9 changes: 8 additions & 1 deletion python/paddle/static/nn/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,7 @@ def create_fake_value_for_undefined_var():
)
cf_yield([next_cond, *unified_next_vars])

# Reset type of UndefinedVar from next_vars
# Reset type and stop_gradient of UndefinedVar from next_vars
for idx, value in undefined_var_mapping.items():
if idx in constant_next_var_indices:
continue
Expand All @@ -821,6 +821,13 @@ def create_fake_value_for_undefined_var():
cur_block.args()[idx].set_type(value_new_type)
while_op.as_operation().results()[idx].set_type(value_new_type)

value_new_stop_gradient = flatten(next_vars)[idx].stop_gradient
value.stop_gradient = value_new_stop_gradient
cur_block.args()[idx].stop_gradient = value_new_stop_gradient
while_op.as_operation().results()[
idx
].stop_gradient = value_new_stop_gradient

# Restore the outputs by variable and constants
optimized_results = while_op.optimize_update()
(optimized_variable_results,) = select_by_indices(
Expand Down

0 comments on commit 9882b25

Please sign in to comment.