Skip to content

Commit

Permalink
Fix domain isolation for the case when multiple domain type is involved
Browse files Browse the repository at this point in the history
Previously when we were stacking domains we inserted the new domain
instructions between the upper most domain and its operand. This caused
issues if that domain had more then one user with different atribute for
the domain inserted at the second pass because we could have ended up
with edges between different domains.

After this change we insert the new domains between the lower most
domain and its user ensuring that the domain separates every instruction
with different attributes.

PiperOrigin-RevId: 209776741
  • Loading branch information
tensorflower-gardener committed Aug 22, 2018
1 parent cffdccd commit 5671899
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 26 deletions.
21 changes: 10 additions & 11 deletions tensorflow/compiler/xla/service/hlo_domain_isolator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,27 +31,27 @@ class HloDomainIsolator::RunContext {
StatusOr<bool> Run();

private:
// Inserts a kDomain instruction between parent and operand, in case
// the attribute (ie, sharding) values change between instruction and operand.
// Inserts a kDomain instruction between operand and instruction in case
// the attribute (ie, sharding) values change between root and instruction.
// Returns the newly inserted kDomain instruction, or nullptr if no kDomain
// instruction was necessary.
StatusOr<HloInstruction*> CreateDomain(HloInstruction* instruction,
HloInstruction* parent,
HloInstruction* root,
HloInstruction* operand);

HloModule* module_;
HloDomainIsolator* isolator_;
};

StatusOr<HloInstruction*> HloDomainIsolator::RunContext::CreateDomain(
HloInstruction* instruction, HloInstruction* parent,
HloInstruction* instruction, HloInstruction* root,
HloInstruction* operand) {
HloInstruction* domain = nullptr;
std::unique_ptr<HloInstruction> domain_instruction =
isolator_->creator_(instruction, operand);
isolator_->creator_(instruction, root, operand);
if (domain_instruction != nullptr) {
domain = operand->parent()->AddInstruction(std::move(domain_instruction));
TF_RETURN_IF_ERROR(operand->ReplaceUseWith(parent, domain));
TF_RETURN_IF_ERROR(operand->ReplaceUseWith(instruction, domain));
}
return domain;
}
Expand All @@ -71,14 +71,13 @@ StatusOr<bool> HloDomainIsolator::RunContext::Run() {
// When applying multiple domains, we could end up stacking more than
// one in one edge, so here we want to build the effective
// (kDomain-less) instruction->operand edge.
HloInstruction* parent = instruction;
while (operand->opcode() == HloOpcode::kDomain) {
parent = operand;
operand = operand->mutable_operand(0);
HloInstruction* root = operand;
while (root->opcode() == HloOpcode::kDomain) {
root = root->mutable_operand(0);
}
// Check whether a kDomain is necessary between instruction and operand.
TF_ASSIGN_OR_RETURN(HloInstruction * domain,
CreateDomain(instruction, parent, operand));
CreateDomain(instruction, root, operand));
if (domain != nullptr) {
VLOG(4) << "New domain: " << domain->ToString();
++added_domains;
Expand Down
6 changes: 4 additions & 2 deletions tensorflow/compiler/xla/service/hlo_domain_isolator.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@ class HloDomainIsolator : public HloPassInterface {
public:
// Creates a new kDomain instruction for the edge between the use instruction
// (the first HloInstruction argument), and the operand instruction (the
// second HloInstruction argument).
// third HloInstruction argument) if the interesting attribute of the
// instruction differes from the attribute of the root (the second
// HloInstruction argument).
// Returns nullptr in case no domain separation is necessary.
using DomainCreator = std::function<std::unique_ptr<HloInstruction>(
HloInstruction*, HloInstruction*)>;
HloInstruction*, HloInstruction*, HloInstruction*)>;

explicit HloDomainIsolator(DomainCreator creator);

Expand Down
64 changes: 62 additions & 2 deletions tensorflow/compiler/xla/service/hlo_domain_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,13 @@ class OpNameMetadata : public DomainMetadata {

// Creator function for OpNameMetadata domains.
std::unique_ptr<HloInstruction> OpNameDomainCreator(HloInstruction* instruction,
HloInstruction* root,
HloInstruction* operand) {
if (instruction->metadata().op_name() == operand->metadata().op_name()) {
if (instruction->metadata().op_name() == root->metadata().op_name()) {
return nullptr;
}
std::unique_ptr<DomainMetadata> operand_side_metadata =
absl::make_unique<OpNameMetadata>(operand->metadata().op_name());
absl::make_unique<OpNameMetadata>(root->metadata().op_name());
std::unique_ptr<DomainMetadata> user_side_metadata =
absl::make_unique<OpNameMetadata>(instruction->metadata().op_name());
return HloInstruction::CreateDomain(operand->shape(), operand,
Expand Down Expand Up @@ -524,5 +525,64 @@ ENTRY entry {
tpl->sharding());
}

TEST_F(HloDomainTest, MultiDomainMultiUser) {
const char* const hlo_string = R"(
HloModule Module
ENTRY %entry (p0: (f32[4], f32[4])) -> (f32[4], f32[4], f32[4]) {
%p0 = (f32[4], f32[4]) parameter(0)
%a = f32[4]{0} get-tuple-element(%p0), index=0
%domain = f32[4] domain(%a),
domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}}
%b = f32[4] get-tuple-element(%p0), index=1
%domain.1 = f32[4] domain(%b),
domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}}
%c = f32[4] add(%domain, %domain.1), sharding={maximal device=1}
%domain.2 = f32[4] domain(%c),
domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
%d = f32[4] subtract(%domain, %c),
sharding={maximal device=1}, metadata={op_name="D"}
%domain.3 = f32[4] domain(%d),
domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
%e = f32[4] multiply(%c, %d),
sharding={maximal device=1}, metadata={op_name="D"}
%f = f32[4] add(f32[4]{0} %e, f32[4]{0} %c), sharding={maximal device=1}
%domain.4 = f32[4]{0} domain(%f),
domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
ROOT %g = (f32[4], f32[4], f32[4]) tuple(%domain.2, %domain.3, %domain.4)
})";

TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
LOG(INFO) << "Original module:\n" << module->ToString();

HloDomainIsolator opname_isolator(OpNameDomainCreator);
TF_ASSERT_OK_AND_ASSIGN(bool opname_isolator_changed,
opname_isolator.Run(module));
EXPECT_TRUE(opname_isolator_changed);

EXPECT_TRUE(HasDomainEdge(module, "c", "a"));
EXPECT_TRUE(HasDomainEdge(module, "c", "b"));
EXPECT_TRUE(HasDomainEdge(module, "d", "a"));
EXPECT_TRUE(HasDomainEdge(module, "d", "c"));
EXPECT_FALSE(HasDomainEdge(module, "e", "d"));

HloDomainRemover sharding_remover(ShardingMetadata::KindName(),
ShardingMetadata::NormalizeShardingDomain);
TF_ASSERT_OK_AND_ASSIGN(bool sharding_remover_changed,
sharding_remover.Run(module));
EXPECT_TRUE(sharding_remover_changed);

HloDomainRemover opname_remover(OpNameMetadata::KindName(),
OpNameDomainNormalizer);
TF_ASSERT_OK_AND_ASSIGN(bool opname_remover_changed,
opname_remover.Run(module));
EXPECT_TRUE(opname_remover_changed);

EXPECT_FALSE(HasDomainEdge(module, "c", "a"));
EXPECT_FALSE(HasDomainEdge(module, "c", "b"));
EXPECT_FALSE(HasDomainEdge(module, "d", "a"));
EXPECT_FALSE(HasDomainEdge(module, "d", "c"));
}

} // namespace
} // namespace xla
20 changes: 11 additions & 9 deletions tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
Original file line number Diff line number Diff line change
Expand Up @@ -284,27 +284,28 @@ Status ApplyDomainSharding(const DomainMetadata::Domain& domain,
// The kDomain instruction will be created only if the sharding differ between
// the instruction and the operand.
std::unique_ptr<HloInstruction> CreateDomain(HloInstruction* instruction,
HloInstruction* root,
HloInstruction* operand) {
const HloSharding* instruction_sharding =
instruction->has_sharding() ? &instruction->sharding() : nullptr;
const HloSharding* operand_sharding =
operand->has_sharding() ? &operand->sharding() : nullptr;
const HloSharding* root_sharding =
root->has_sharding() ? &root->sharding() : nullptr;
// No need for domain if they both have no sharding.
if (instruction_sharding == nullptr && operand_sharding == nullptr) {
if (instruction_sharding == nullptr && root_sharding == nullptr) {
return nullptr;
}
// No need for domain if they match.
if (instruction_sharding != nullptr && operand_sharding != nullptr &&
ShardingMatches(*instruction_sharding, *operand_sharding)) {
if (instruction_sharding != nullptr && root_sharding != nullptr &&
ShardingMatches(*instruction_sharding, *root_sharding)) {
return nullptr;
}
std::unique_ptr<HloSharding> real_instruction_sharding;
std::unique_ptr<HloSharding> real_operand_sharding;
if (instruction_sharding != nullptr) {
real_instruction_sharding = CloneShardingForDomain(*instruction_sharding);
}
if (operand_sharding != nullptr) {
real_operand_sharding = CloneShardingForDomain(*operand_sharding);
if (root_sharding != nullptr) {
real_operand_sharding = CloneShardingForDomain(*root_sharding);
}
VLOG(3) << "Creating domain:";
VLOG(3) << " Instruction: " << instruction->name();
Expand Down Expand Up @@ -417,8 +418,9 @@ Status ShardingMetadata::NormalizeShardingDomain(
}

std::unique_ptr<HloInstruction> CreateShardingDomain(
HloInstruction* instruction, HloInstruction* operand) {
return CreateDomain(instruction, operand);
HloInstruction* instruction, HloInstruction* root,
HloInstruction* operand) {
return CreateDomain(instruction, root, operand);
}

} // namespace xla
4 changes: 2 additions & 2 deletions tensorflow/compiler/xla/service/hlo_sharding_metadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ class ShardingMetadata : public DomainMetadata {

// Given an HLO graph edge between instruction and one of its operands, creates
// a ShardingMetadata based kDomain instruction if the sharding between
// instruction and operand changes. Returns nullptr if there is no need for a
// instruction and parent changes. Returns nullptr if there is no need for a
// domain separation.
std::unique_ptr<HloInstruction> CreateShardingDomain(
HloInstruction* instruction, HloInstruction* operand);
HloInstruction* instruction, HloInstruction* root, HloInstruction* operand);

} // namespace xla

Expand Down

0 comments on commit 5671899

Please sign in to comment.