Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Enable task to returne object as if it's returned by its parent #26774

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/src/ray/runtime/task/local_mode_task_submitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ ObjectID LocalModeTaskSubmitter::Submit(InvocationSpec &invocation,
0,
local_mode_ray_tuntime_.GetCurrentTaskId(),
address,
-1,
1,
required_resources,
required_placement_resources,
Expand Down
5 changes: 5 additions & 0 deletions java/api/src/main/java/io/ray/api/call/ActorTaskCaller.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ public ActorTaskCaller<R> setConcurrencyGroup(String name) {
return self();
}

public ActorTaskCaller<R> setForwardObjectToParentTask(boolean ifForward) {
builder.setForwardObjectToParentTask(ifForward);
return self();
}

private ActorTaskCaller<R> self() {
return this;
}
Expand Down
14 changes: 12 additions & 2 deletions java/api/src/main/java/io/ray/api/options/CallOptions.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,23 @@ public class CallOptions extends BaseTaskOptions {
public final int bundleIndex;
public final String concurrencyGroupName;
private final String serializedRuntimeEnvInfo;
public final boolean forwardObjectToParentTask;

private CallOptions(
String name,
Map<String, Double> resources,
PlacementGroup group,
int bundleIndex,
String concurrencyGroupName,
RuntimeEnv runtimeEnv) {
RuntimeEnv runtimeEnv,
boolean forwardObjectToParentTask) {
super(resources);
this.name = name;
this.group = group;
this.bundleIndex = bundleIndex;
this.concurrencyGroupName = concurrencyGroupName;
this.serializedRuntimeEnvInfo = runtimeEnv == null ? "" : runtimeEnv.toJsonBytes();
this.forwardObjectToParentTask = forwardObjectToParentTask;
}

/** This inner class for building CallOptions. */
Expand All @@ -38,6 +41,7 @@ public static class Builder {
private int bundleIndex;
private String concurrencyGroupName = "";
private RuntimeEnv runtimeEnv = null;
private boolean forwardObjectToParentTask = false;

/**
* Set a name for this task.
Expand Down Expand Up @@ -98,8 +102,14 @@ public Builder setRuntimeEnv(RuntimeEnv runtimeEnv) {
return this;
}

public Builder setForwardObjectToParentTask(boolean ifForward) {
this.forwardObjectToParentTask = ifForward;
return this;
}

public CallOptions build() {
return new CallOptions(name, resources, group, bundleIndex, concurrencyGroupName, runtimeEnv);
return new CallOptions(name, resources, group, bundleIndex, concurrencyGroupName,
runtimeEnv, forwardObjectToParentTask);
}
}
}
18 changes: 10 additions & 8 deletions java/build-jar-multiplatform.sh
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,12 @@ build_jars_multiplatform() {
return
fi
fi
if download_jars "ray-runtime-$version.jar"; then
prepare_native
build_jars multiplatform false
else
echo "download_jars failed, skip building multiplatform jars"
fi
# if download_jars "ray-runtime-$version.jar"; then
prepare_native
build_jars linux
# else
# echo "download_jars failed, skip building multiplatform jars"
# fi
}

# Download darwin/windows ray-related jar from s3
Expand Down Expand Up @@ -124,7 +124,8 @@ download_jars() {

# prepare native binaries and libraries.
prepare_native() {
for os in 'darwin' 'linux'; do
# for os in 'darwin' 'linux'; do
for os in 'linux'; do
cd "$JAR_BASE_DIR/$os"
jar xf "ray-runtime-$version.jar" "native/$os"
local native_dir="$WORKSPACE_DIR/java/runtime/native_dependencies/native/$os"
Expand All @@ -137,7 +138,8 @@ prepare_native() {
# Return 0 if native bianries and libraries exist and 1 if not.
native_files_exist() {
local os
for os in 'darwin' 'linux'; do
# for os in 'darwin' 'linux'; do
for os in 'linux'; do
native_dirs=()
native_dirs+=("$WORKSPACE_DIR/java/runtime/native_dependencies/native/$os")
for native_dir in "${native_dirs[@]}"; do
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,8 @@ private ObjectRef callNormalFunction(

ObjectRefImpl<?> impl = new ObjectRefImpl<>();
/// Mapping the object id to the object ref.
List<ObjectId> preparedReturnIds = getCurrentReturnIds(numReturns, ActorId.NIL);
List<ObjectId> preparedReturnIds = getCurrentReturnIds(numReturns, ActorId.NIL,
options.forwardObjectToParentTask);
if (rayConfig.runMode == RunMode.CLUSTER && numReturns > 0) {
ObjectRefImpl.registerObjectRefImpl(preparedReturnIds.get(0), impl);
}
Expand Down Expand Up @@ -354,7 +355,9 @@ private ObjectRef callActorFunction(

ObjectRefImpl<?> impl = new ObjectRefImpl<>();
/// Mapping the object id to the object ref.
List<ObjectId> preparedReturnIds = getCurrentReturnIds(numReturns, rayActor.getId());
System.err.println(options.forwardObjectToParentTask);
List<ObjectId> preparedReturnIds = getCurrentReturnIds(numReturns, rayActor.getId(),
options.forwardObjectToParentTask);
if (rayConfig.runMode == RunMode.CLUSTER && numReturns > 0) {
ObjectRefImpl.registerObjectRefImpl(preparedReturnIds.get(0), impl);
}
Expand Down Expand Up @@ -394,7 +397,7 @@ private BaseActorHandle createActorImpl(
return actor;
}

abstract List<ObjectId> getCurrentReturnIds(int numReturns, ActorId actorId);
abstract List<ObjectId> getCurrentReturnIds(int numReturns, ActorId actorId, boolean ifForward);

public WorkerContext getWorkerContext() {
return workerContext;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public Map<String, List<ResourceValue>> getAvailableResourceIds() {
}

@Override
List<ObjectId> getCurrentReturnIds(int numReturns, ActorId actorId) {
List<ObjectId> getCurrentReturnIds(int numReturns, ActorId actorId, boolean ifForward) {
return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ public void killActor(BaseActorHandle actor, boolean noRestart) {
}

@Override
List<ObjectId> getCurrentReturnIds(int numReturns, ActorId actorId) {
List<byte[]> ret = nativeGetCurrentReturnIds(numReturns, actorId.getBytes());
List<ObjectId> getCurrentReturnIds(int numReturns, ActorId actorId, boolean ifForward) {
List<byte[]> ret = nativeGetCurrentReturnIds(numReturns, actorId.getBytes(), ifForward);
return ret.stream().map(ObjectId::new).collect(Collectors.toList());
}

Expand Down Expand Up @@ -291,7 +291,7 @@ private static native void nativeInitialize(

private static native String nativeGetNamespace();

private static native List<byte[]> nativeGetCurrentReturnIds(int numReturns, byte[] actorId);
private static native List<byte[]> nativeGetCurrentReturnIds(int numReturns, byte[] actorId, boolean ifForward);

private static native byte[] nativeGetCurrentNodeId();
}
6 changes: 6 additions & 0 deletions src/mock/ray/core_worker/core_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,12 @@ class MockCoreWorker : public CoreWorker {
rpc::AssignObjectOwnerReply *reply,
rpc::SendReplyCallback send_reply_callback),
(override));
MOCK_METHOD(void,
HandleUpdateForwardedObject,
(const rpc::UpdateForwardedObjectRequest &request,
rpc::UpdateForwardedObjectReply *reply,
rpc::SendReplyCallback send_reply_callback),
(override));
};

} // namespace core
Expand Down
5 changes: 5 additions & 0 deletions src/mock/ray/rpc/worker/core_worker_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ class MockCoreWorkerClientInterface : public ray::pubsub::MockSubscriberClientIn
(const AssignObjectOwnerRequest &request,
const ClientCallback<AssignObjectOwnerReply> &callback),
(override));
MOCK_METHOD(void,
UpdateForwardedObject,
(const UpdateForwardedObjectRequest &request,
const ClientCallback<UpdateForwardedObjectReply> &callback),
(override));
MOCK_METHOD(int64_t, ClientProcessedUpToSeqno, (), (override));
};

Expand Down
11 changes: 10 additions & 1 deletion src/ray/common/task/task_spec.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,16 @@ size_t TaskSpecification::NumArgs() const { return message_->args_size(); }
size_t TaskSpecification::NumReturns() const { return message_->num_returns(); }

ObjectID TaskSpecification::ReturnId(size_t return_index) const {
return ObjectID::FromIndex(TaskId(), return_index + 1);
auto parent_num_returns = this->message_->parent_num_returns();
if (parent_num_returns < 0) {
return ObjectID::FromIndex(TaskId(), return_index + 1);
} else {
return ObjectID::FromIndex(ParentTaskId(), parent_num_returns + return_index + 1);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means we can only have one forward_to_parent task call inside one task right?

}
}

bool TaskSpecification::ForwardToParent() const {
return this->message_->parent_num_returns() >= 0;
}

bool TaskSpecification::ArgByRef(size_t arg_index) const {
Expand Down
2 changes: 2 additions & 0 deletions src/ray/common/task/task_spec.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ class TaskSpecification : public MessageWrapper<rpc::TaskSpec> {

ObjectID ReturnId(size_t return_index) const;

bool ForwardToParent() const;

const uint8_t *ArgData(size_t arg_index) const;

size_t ArgDataSize(size_t arg_index) const;
Expand Down
2 changes: 2 additions & 0 deletions src/ray/common/task/task_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class TaskSpecBuilder {
uint64_t parent_counter,
const TaskID &caller_id,
const rpc::Address &caller_address,
int parent_num_returns,
uint64_t num_returns,
const std::unordered_map<std::string, double> &required_resources,
const std::unordered_map<std::string, double> &required_placement_resources,
Expand All @@ -124,6 +125,7 @@ class TaskSpecBuilder {
message_->set_parent_counter(parent_counter);
message_->set_caller_id(caller_id.Binary());
message_->mutable_caller_address()->CopyFrom(caller_address);
message_->set_parent_num_returns(parent_num_returns);
message_->set_num_returns(num_returns);
message_->mutable_required_resources()->insert(required_resources.begin(),
required_resources.end());
Expand Down
8 changes: 6 additions & 2 deletions src/ray/core_worker/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,14 @@ struct TaskOptions {
int num_returns,
std::unordered_map<std::string, double> &resources,
const std::string &concurrency_group_name = "",
const std::string &serialized_runtime_env_info = "{}")
const std::string &serialized_runtime_env_info = "{}",
bool use_parent_task_id = false)
: name(name),
num_returns(num_returns),
resources(resources),
concurrency_group_name(concurrency_group_name),
serialized_runtime_env_info(serialized_runtime_env_info) {}
serialized_runtime_env_info(serialized_runtime_env_info),
use_parent_task_id(use_parent_task_id) {}

/// The name of this task.
std::string name;
Expand All @@ -79,6 +81,8 @@ struct TaskOptions {
/// fields which not contained in Runtime Env, such as eager_install.
/// Propagated to child actors and tasks.
std::string serialized_runtime_env_info;

bool use_parent_task_id = false;
};

/// Options for actor creation tasks.
Expand Down
Loading