Skip to content

Commit

Permalink
Add device_count accessor to HloRunnerInterface.
Browse files Browse the repository at this point in the history
Also fixes hlo_runner_interface includes.

PiperOrigin-RevId: 713504316
  • Loading branch information
nvgrw authored and Google-ML-Automation committed Jan 9, 2025
1 parent aa1378b commit 29fadea
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 6 deletions.
11 changes: 9 additions & 2 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4613,14 +4613,21 @@ cc_library(
deps = [
":computation_placer",
":executable",
"//xla:status_macros",
"//xla:types",
":hlo_module_config",
"//xla:literal",
"//xla:shape_util",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/parser:hlo_parser",
"//xla/tsl/platform:env",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:statusor",
],
)

Expand Down
2 changes: 2 additions & 0 deletions xla/service/hlo_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ class HloRunner : public HloRunnerInterface {
return backend().compiler()->ShapeSizeBytesFunction();
}

int device_count() const override { return backend().device_count(); }

private:
absl::StatusOr<ExecutionOutput> ExecuteWithExecutionInputs(
Executable* executable, std::vector<ExecutionInput> arguments,
Expand Down
16 changes: 16 additions & 0 deletions xla/service/hlo_runner_interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,23 @@ limitations under the License.

#include "xla/service/hlo_runner_interface.h"

#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/parser/hlo_parser.h"
#include "xla/literal.h"
#include "xla/service/executable.h"
#include "xla/service/hlo_module_config.h"
#include "xla/tsl/platform/env.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
#include "tsl/platform/statusor.h"

namespace xla {

Expand Down
12 changes: 8 additions & 4 deletions xla/service/hlo_runner_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,21 @@ limitations under the License.

#include <cstdint>
#include <functional>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <vector>

#include "absl/log/log.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/literal.h"
#include "xla/service/computation_placer.h"
#include "xla/service/executable.h"
#include "xla/status_macros.h"
#include "xla/types.h"
#include "xla/shape.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"

Expand Down Expand Up @@ -226,6 +226,10 @@ class HloRunnerInterface {
// This function is used e.g. to create a VerifiedHloModule. It returns an
// integer representing the size of the shape in bytes as opposed to a Shape.
virtual DeviceShapeSizeFn device_shape_size_fn() const = 0;

// Returns the number of devices which are known. Not all of these devices may
// be usable by XLA.
virtual int device_count() const = 0;
};

} // namespace xla
Expand Down
2 changes: 2 additions & 0 deletions xla/service/hlo_runner_pjrt.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ class HloRunnerPjRt : public HloRunnerInterface {
return device_shape_size_fn_;
}

int device_count() const override { return pjrt_client_->device_count(); }

private:
absl::StatusOr<CompileOptions> GenerateDefaultCompileOptions(
HloModule* module, bool run_hlo_passes);
Expand Down

0 comments on commit 29fadea

Please sign in to comment.