Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[WIP/RFC] Asynchronous native operators #672

Closed
wants to merge 1 commit into from

Conversation

vchuravy
Copy link
Contributor

This is by no means done yet, but I would like some early feedback :) I also don't expect this code to work in its current state. For me it is a rough draft of how I think things might work

Currently NativeOps creates a copy of the input/output data and then synchronizes the TBlobs afterwards. That is problematic for asynchronous operation, since in Julia ctx.async_on_complete needs to be called from the Julia side of the code and there can't be any operations afterwards.

Questions:

  • Is it allowed to create NDArrays like this [https://github.com/vchuravy/mxnet/blob/65409a1f6753f04239443c5011f4c40907a31d13/src/operator/native_op-inl.h#L57] ? Passing handel to NDArrays to Python and Julia makes things easier.
  • Currently ctx.async_on_complete is an object with a member function operator() getting a function pointer to a member function is quite tricky, any other ideas how we could handle this elegantly?

CC: @piiswrong
Ref: dmlc/MXNet.jl#16

@tqchen
Copy link
Member

tqchen commented Nov 21, 2015

Some direction of solution for the callback. Create another callback function in the NativeOperator, and pass the C style callback to the Julia side.

// C style callback
extern "C" 
static void CallNativeCallback(void *self) {
     static_cast<NativeCallback*>(self)->OnComplete();
}

class NativeCallback {
   private:
      engine::CallbackOnComplete callback_; 
      void OnComplete() {
           //other operations, such as copy data back
          callback_();
      }
};

@piiswrong
Copy link
Contributor

@vchuravy I think you can create ndarrays in this way. In fact that's what I was going to try. I'll try to refactor it. If it works I'll tell you.

@piiswrong
Copy link
Contributor

We probably should create another op for backward compatibility?

@pluskid
Copy link
Contributor

pluskid commented Nov 21, 2015

I guess backward compatibility of numpy op can be implemented on the python side using ndarray op.

@tqchen
Copy link
Member

tqchen commented Nov 21, 2015

Some caveat on NDArray ops which we need want to think a bit careful about when support them.

Blocking in Operator

def forward(self, src_nd):
    src_nd = src_nd + 1
    data = src_nd.asnumpy()

Here data calls src_nd.asnumpy(), which is an blocking operation. The plus operation will be dispatched to another worker thread when constraints are satisfied. If there is only one worker thread on say GPU, then this can cause a deadlock, as current thread(which is worker) wait for the plus operation to be completed, however, the plus operation is supposed to be executed by the worker thread.

Ideally, all operations in the operator definition should be executed by the same thread in serial manner, which makes it unit ops to be executed, this was not the case here.

Which NDArray to Expose

There are two ways of NDArrays that can be exposed, either expose the original NDArray that stores the data and result. Or, rewrap the pointer in the TBlob as a new NDArray. Currently the second way requires no changes in interface. Both have some things need to think about.

The semantics of engine means when we call the forward/backward function, the original NDArray is marked as reading and writing, and the state won't change until the on_complete callback is invoked.

To wrap another NDArray, we only need to make sure that when we get the result, we need to call on_complete, this can be achieved by an dummy engine call

Engine::Get()->push(reads={result_nd}, [on_compelte] () {
    on_complete();
})

This operation will be scheduled when result_nd is available, which in turn notifies the completion of this op.

Notification of Finish of Async Copy

def forward(self, src_nd, res_nd):
    src_nd = src_nd + 1
    src_nd.copyto(res_nd)

When the copyto returns, the copy operation did not necessarily finish. We need some way to notify the finish, like what was mentioned in last paragraph.

Summary

Putting async calls in an async call have some interesting problems we might want to think about, because current NDArray interface is async, these problems need to be thought about when implementing them. I think we can first go with the NativeOp as first phase, and supporting nd as second phase.

Also cc @hotpxl , can you also think about this to see if there is any possible issue on engine side?

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants