diff --git a/xla/service/value_range.cc b/xla/service/value_range.cc index 850db808d7392..0bdf42ae090b6 100644 --- a/xla/service/value_range.cc +++ b/xla/service/value_range.cc @@ -215,11 +215,16 @@ Range RecursivelyIdentifyRange( return Range{}; } ConstantValue single_value = lhs.IsSingleValue() ? lhs.min() : rhs.min(); - ConstantValue min = lhs.IsSingleValue() ? rhs.min().mul(single_value) - : lhs.min().mul(single_value); - ConstantValue max = lhs.IsSingleValue() ? rhs.max().mul(single_value) - : lhs.max().mul(single_value); - return Range{min, max, single_value, lhs.IsLinear() && rhs.IsLinear()}; + Range operand_range = lhs.IsSingleValue() ? rhs : lhs; + // When multiplying with a constant, min, max, and step are all + // multiplied by the single value. + ConstantValue min = operand_range.min().mul(single_value); + ConstantValue max = operand_range.max().mul(single_value); + if (!operand_range.IsStepKnown()) { + return Range{min, max, operand_range.IsLinear()}; + } + ConstantValue step = operand_range.step().mul(single_value); + return Range{min, max, step, operand_range.IsLinear()}; } case HloOpcode::kSelect: { VLOG(5) << "Handling Select: " << instr->ToString(); diff --git a/xla/service/value_range_test.cc b/xla/service/value_range_test.cc index 05a64ae3a6d9b..0b83a374e5da0 100644 --- a/xla/service/value_range_test.cc +++ b/xla/service/value_range_test.cc @@ -59,7 +59,7 @@ TEST_F(ValueRangeTest, AddedValue) { EXPECT_FALSE(range.IsSingleValue()); EXPECT_TRUE(range.IsLinear()); EXPECT_EQ(range.min().GetSignedValue(), 124); - EXPECT_EQ(range.max().GetSignedValue(), 129); + EXPECT_EQ(range.max().GetSignedValue(), 124 + 5); EXPECT_EQ(range.step().GetSignedValue(), 1); } @@ -78,18 +78,19 @@ TEST_F(ValueRangeTest, MultiplyValue) { const HloInstruction* root = module->entry_computation()->root_instruction(); const HloInstruction* p0 = root->operand(0); absl::flat_hash_map fs; - fs.insert( - std::make_pair(p0, Range{ConstantValue::GetZero(32, /*is_signed=*/true), - ConstantValue::GetSigned(5, 32), - ConstantValue::GetOne(32, /*is_signed=*/false), - /*is_linear=*/true})); + // p0 has range min = 0, max = 32, step = 2. + fs.insert(std::make_pair( + p0, Range{/*min=*/ConstantValue::GetSigned(0, /*bitwidth=*/32), + /*max=*/ConstantValue::GetSigned(32, /*bitwidth=*/32), + /*step=*/ConstantValue::GetUnsigned(2, /*bitwidth=*/32), + /*is_linear=*/true})); auto range = RecursivelyIdentifyRange(root, fs); EXPECT_FALSE(range.IsEmpty()); EXPECT_FALSE(range.IsSingleValue()); EXPECT_TRUE(range.IsLinear()); EXPECT_EQ(range.min().GetSignedValue(), 0); - EXPECT_EQ(range.max().GetSignedValue(), 5120); - EXPECT_EQ(range.step().GetSignedValue(), 1024); + EXPECT_EQ(range.max().GetSignedValue(), 32 * 1024); + EXPECT_EQ(range.step().GetSignedValue(), 2 * 1024); } TEST_F(ValueRangeTest, ConstantValuePred) { @@ -151,27 +152,28 @@ TEST_F(ValueRangeTest, ConstantValueWithConditional) { const HloInstruction* p0 = module->entry_computation()->parameter_instruction(0); absl::flat_hash_map fs; - fs.insert( - std::make_pair(p0, Range{ConstantValue::GetZero(32, /*is_signed=*/true), - ConstantValue::GetSigned(5, 32), - ConstantValue::GetOne(32, /*is_signed=*/false), - /*is_linear=*/true})); + // p0 has range min = 0, max = 32, step = 2. + fs.insert(std::make_pair( + p0, Range{/*min=*/ConstantValue::GetSigned(0, /*bitwidth=*/32), + /*max=*/ConstantValue::GetSigned(32, /*bitwidth=*/32), + /*step=*/ConstantValue::GetUnsigned(2, /*bitwidth=*/32), + /*is_linear=*/true})); auto add_range = RecursivelyIdentifyRange(add, fs, alias_analysis.get()); EXPECT_FALSE(add_range.IsEmpty()); EXPECT_FALSE(add_range.IsSingleValue()); EXPECT_TRUE(add_range.IsLinear()); EXPECT_EQ(add_range.min().GetSignedValue(), 1024); - EXPECT_EQ(add_range.max().GetSignedValue(), 1029); - EXPECT_EQ(add_range.step().GetSignedValue(), 1); + EXPECT_EQ(add_range.max().GetSignedValue(), 1024 + 32); + EXPECT_EQ(add_range.step().GetSignedValue(), 2); auto mult_range = RecursivelyIdentifyRange(mult, fs, alias_analysis.get()); EXPECT_FALSE(mult_range.IsEmpty()); EXPECT_FALSE(mult_range.IsSingleValue()); EXPECT_TRUE(mult_range.IsLinear()); EXPECT_EQ(mult_range.min().GetSignedValue(), 0); - EXPECT_EQ(mult_range.max().GetSignedValue(), 5120); - EXPECT_EQ(mult_range.step().GetSignedValue(), 1024); + EXPECT_EQ(mult_range.max().GetSignedValue(), 32 * 1024); + EXPECT_EQ(mult_range.step().GetSignedValue(), 2 * 1024); } TEST_F(ValueRangeTest, SelectValueWithCompareInConditional) { @@ -216,11 +218,12 @@ TEST_F(ValueRangeTest, SelectValueWithCompareInConditional) { const HloInstruction* p0 = module->entry_computation()->parameter_instruction(0); absl::flat_hash_map fs; - fs.insert( - std::make_pair(p0, Range{ConstantValue::GetZero(32, /*is_signed=*/true), - ConstantValue::GetSigned(5, 32), - ConstantValue::GetOne(32, /*is_signed=*/false), - /*is_linear=*/true})); + // p0 has range min = 0, max = 32, step = 2. + fs.insert(std::make_pair( + p0, Range{/*min=*/ConstantValue::GetSigned(0, /*bitwidth=*/32), + /*max=*/ConstantValue::GetSigned(32, /*bitwidth=*/32), + /*step=*/ConstantValue::GetUnsigned(2, /*bitwidth=*/32), + /*is_linear=*/true})); auto select1_range = RecursivelyIdentifyRange(select1, fs, alias_analysis.get());