Skip to content

Commit

Permalink
Configure HloPjRtTestBase with new option structs.
Browse files Browse the repository at this point in the history
A struct instead of optional parameters make it easier for us to express
different test setups. In most applications we expect the default values to
suffice.

PiperOrigin-RevId: 706083556
  • Loading branch information
nvgrw authored and Google-ML-Automation committed Dec 14, 2024
1 parent 9beb0f6 commit b5e97bc
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
13 changes: 6 additions & 7 deletions xla/tests/hlo_pjrt_test_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,11 @@ std::unique_ptr<HloRunnerInterface> GetHloRunnerForReference() {

} // namespace

HloPjRtTestBase::HloPjRtTestBase(
bool verifier_layout_sensitive, bool allow_mixed_precision_in_hlo_verifier,
HloPredicate instruction_can_change_layout_func)
: HloRunnerAgnosticTestBase(
GetHloRunnerForTest(), GetHloRunnerForReference(),
verifier_layout_sensitive, allow_mixed_precision_in_hlo_verifier,
instruction_can_change_layout_func) {}
HloPjRtTestBase::HloPjRtTestBase(HloPjRtTestBaseOptions options)
: HloRunnerAgnosticTestBase(GetHloRunnerForTest(),
GetHloRunnerForReference(),
options.verifier_layout_sensitive,
options.allow_mixed_precision_in_hlo_verifier,
options.instruction_can_change_layout_func) {}

} // namespace xla
11 changes: 7 additions & 4 deletions xla/tests/hlo_pjrt_test_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,17 @@ limitations under the License.

namespace xla {

struct HloPjRtTestBaseOptions {
bool verifier_layout_sensitive = false;
bool allow_mixed_precision_in_hlo_verifier = true;
HloPredicate instruction_can_change_layout_func;
};

class HloPjRtTestBase : public HloRunnerAgnosticTestBase {
protected:
// This uses the SE interpreter backend for the reference backend and
// automatically finds a PjRt backend for the test backend.
explicit HloPjRtTestBase(
bool verifier_layout_sensitive = false,
bool allow_mixed_precision_in_hlo_verifier = true,
HloPredicate instruction_can_change_layout_func = {});
explicit HloPjRtTestBase(HloPjRtTestBaseOptions options = {});
};

} // namespace xla
Expand Down

0 comments on commit b5e97bc

Please sign in to comment.