Skip to content

Commit

Permalink
Fix range analysis bug.
Browse files Browse the repository at this point in the history
The way we multiply operand ranges with constant is wrong because step was not multiplied when the operand is constant.

PiperOrigin-RevId: 705709863
  • Loading branch information
fhoushmand authored and Google-ML-Automation committed Dec 13, 2024
1 parent 072c10c commit 8102719
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 27 deletions.
15 changes: 10 additions & 5 deletions xla/service/value_range.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,16 @@ Range RecursivelyIdentifyRange(
return Range{};
}
ConstantValue single_value = lhs.IsSingleValue() ? lhs.min() : rhs.min();
ConstantValue min = lhs.IsSingleValue() ? rhs.min().mul(single_value)
: lhs.min().mul(single_value);
ConstantValue max = lhs.IsSingleValue() ? rhs.max().mul(single_value)
: lhs.max().mul(single_value);
return Range{min, max, single_value, lhs.IsLinear() && rhs.IsLinear()};
Range operand_range = lhs.IsSingleValue() ? rhs : lhs;
// When multiplying with a constant, min, max, and step are all
// multiplied by the single value.
ConstantValue min = operand_range.min().mul(single_value);
ConstantValue max = operand_range.max().mul(single_value);
if (!operand_range.IsStepKnown()) {
return Range{min, max, operand_range.IsLinear()};
}
ConstantValue step = operand_range.step().mul(single_value);
return Range{min, max, step, operand_range.IsLinear()};
}
case HloOpcode::kSelect: {
VLOG(5) << "Handling Select: " << instr->ToString();
Expand Down
47 changes: 25 additions & 22 deletions xla/service/value_range_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ TEST_F(ValueRangeTest, AddedValue) {
EXPECT_FALSE(range.IsSingleValue());
EXPECT_TRUE(range.IsLinear());
EXPECT_EQ(range.min().GetSignedValue(), 124);
EXPECT_EQ(range.max().GetSignedValue(), 129);
EXPECT_EQ(range.max().GetSignedValue(), 124 + 5);
EXPECT_EQ(range.step().GetSignedValue(), 1);
}

Expand All @@ -78,18 +78,19 @@ TEST_F(ValueRangeTest, MultiplyValue) {
const HloInstruction* root = module->entry_computation()->root_instruction();
const HloInstruction* p0 = root->operand(0);
absl::flat_hash_map<const HloInstruction*, Range> fs;
fs.insert(
std::make_pair(p0, Range{ConstantValue::GetZero(32, /*is_signed=*/true),
ConstantValue::GetSigned(5, 32),
ConstantValue::GetOne(32, /*is_signed=*/false),
/*is_linear=*/true}));
// p0 has range min = 0, max = 32, step = 2.
fs.insert(std::make_pair(
p0, Range{/*min=*/ConstantValue::GetSigned(0, /*bitwidth=*/32),
/*max=*/ConstantValue::GetSigned(32, /*bitwidth=*/32),
/*step=*/ConstantValue::GetUnsigned(2, /*bitwidth=*/32),
/*is_linear=*/true}));
auto range = RecursivelyIdentifyRange(root, fs);
EXPECT_FALSE(range.IsEmpty());
EXPECT_FALSE(range.IsSingleValue());
EXPECT_TRUE(range.IsLinear());
EXPECT_EQ(range.min().GetSignedValue(), 0);
EXPECT_EQ(range.max().GetSignedValue(), 5120);
EXPECT_EQ(range.step().GetSignedValue(), 1024);
EXPECT_EQ(range.max().GetSignedValue(), 32 * 1024);
EXPECT_EQ(range.step().GetSignedValue(), 2 * 1024);
}

TEST_F(ValueRangeTest, ConstantValuePred) {
Expand Down Expand Up @@ -151,27 +152,28 @@ TEST_F(ValueRangeTest, ConstantValueWithConditional) {
const HloInstruction* p0 =
module->entry_computation()->parameter_instruction(0);
absl::flat_hash_map<const HloInstruction*, Range> fs;
fs.insert(
std::make_pair(p0, Range{ConstantValue::GetZero(32, /*is_signed=*/true),
ConstantValue::GetSigned(5, 32),
ConstantValue::GetOne(32, /*is_signed=*/false),
/*is_linear=*/true}));
// p0 has range min = 0, max = 32, step = 2.
fs.insert(std::make_pair(
p0, Range{/*min=*/ConstantValue::GetSigned(0, /*bitwidth=*/32),
/*max=*/ConstantValue::GetSigned(32, /*bitwidth=*/32),
/*step=*/ConstantValue::GetUnsigned(2, /*bitwidth=*/32),
/*is_linear=*/true}));

auto add_range = RecursivelyIdentifyRange(add, fs, alias_analysis.get());
EXPECT_FALSE(add_range.IsEmpty());
EXPECT_FALSE(add_range.IsSingleValue());
EXPECT_TRUE(add_range.IsLinear());
EXPECT_EQ(add_range.min().GetSignedValue(), 1024);
EXPECT_EQ(add_range.max().GetSignedValue(), 1029);
EXPECT_EQ(add_range.step().GetSignedValue(), 1);
EXPECT_EQ(add_range.max().GetSignedValue(), 1024 + 32);
EXPECT_EQ(add_range.step().GetSignedValue(), 2);

auto mult_range = RecursivelyIdentifyRange(mult, fs, alias_analysis.get());
EXPECT_FALSE(mult_range.IsEmpty());
EXPECT_FALSE(mult_range.IsSingleValue());
EXPECT_TRUE(mult_range.IsLinear());
EXPECT_EQ(mult_range.min().GetSignedValue(), 0);
EXPECT_EQ(mult_range.max().GetSignedValue(), 5120);
EXPECT_EQ(mult_range.step().GetSignedValue(), 1024);
EXPECT_EQ(mult_range.max().GetSignedValue(), 32 * 1024);
EXPECT_EQ(mult_range.step().GetSignedValue(), 2 * 1024);
}

TEST_F(ValueRangeTest, SelectValueWithCompareInConditional) {
Expand Down Expand Up @@ -216,11 +218,12 @@ TEST_F(ValueRangeTest, SelectValueWithCompareInConditional) {
const HloInstruction* p0 =
module->entry_computation()->parameter_instruction(0);
absl::flat_hash_map<const HloInstruction*, Range> fs;
fs.insert(
std::make_pair(p0, Range{ConstantValue::GetZero(32, /*is_signed=*/true),
ConstantValue::GetSigned(5, 32),
ConstantValue::GetOne(32, /*is_signed=*/false),
/*is_linear=*/true}));
// p0 has range min = 0, max = 32, step = 2.
fs.insert(std::make_pair(
p0, Range{/*min=*/ConstantValue::GetSigned(0, /*bitwidth=*/32),
/*max=*/ConstantValue::GetSigned(32, /*bitwidth=*/32),
/*step=*/ConstantValue::GetUnsigned(2, /*bitwidth=*/32),
/*is_linear=*/true}));

auto select1_range =
RecursivelyIdentifyRange(select1, fs, alias_analysis.get());
Expand Down

0 comments on commit 8102719

Please sign in to comment.