-
Notifications
You must be signed in to change notification settings - Fork 88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DOC] Document How to Implement Custom Operators #788
Comments
Thanks for opening the issue @deanljohnson. The best documentation we have for a custom operator is the black-scholes example: MatX/examples/black_scholes.cu Line 57 in 31b784e
|
There are many things we would like to implement that would likely involve many custom operators, but I can describe the first thing I implemented as an operator as an example. We have the following simple kernel:
And I would like to implement that as an operator so that it can be used in matx expressions. The call should be something like:
I could implement that as an operator with something like:
And this works if input is a 1D. However, if the input tensor is multi dimensional, I want the operator to work over only the right-most dimension. I may have been using the term "broadcast" incorrectly here since that seems to be referring to combining operators whereas I just want to apply (in this case) a 1D operation to a multi dimensional input. In this example, I could just reshape the input into a single dimension, but there are obviously cases where that is not possible and I would think there would be a way to handle this "automatically" from the point of view of the callsite. I was able to get that working with the following:
Such that now the following seems to work:
I don't know if my operator implementation is the "right" way to do this, but it does seem to work. If this is the right way, I think examples/documentation like this would be useful to have available within the Regarding the transform questions and use of custom kernels in the operation, I don't have an example ready at this time. The situation I am in is that I am investigating the integration of matx into our project and am trying to make sure that once the whole team is let lose on using matx that they don't immediately run into a ton of questions. An example (even just a pointer to an existing operator) that uses temporary storage could definitely be helpful. I would imagine that the storage is not allocated/de-allocated every time the operator is constructed (or maybe the temporary storage is passed into the operator?). Minor note: that example seems to have a typo as it says |
Thanks for the description @deanljohnson. We will get this pushed out soon since it's been a long-standing item on our list. Very quickly on your broadcasting question: typically broadcasting would just mean you take an operator of a lower rank and perform an operation on it with an operator that is expecting a higher rank. As long as the sizes are compatible, MatX will indirectly broadcast. This is done by using the function get_value. This function takes an operator and indices as input, and will return the value broadcasted if needed. For example, if you have a 2D 5x6 tensor and you call get_value with indices 3,4,2,3, the 3, 4 will effectively be ignored and it will be indexes with 2,3. We will work on this early next week since you are not the first person to ask and it should help people unfamiliar with the library. |
Hi @deanljohnson, I added some documentation and it's going through our CI: #793 I did notice that there are some places we're missing broadcasting support in the operators, so I will update the code to fix that and add it to the docs guide. |
@cliffburdick Excellent write-up! Thank you! |
Which documentation should be updated?
How to implement custom operators should be documented and the documentation should address things like:
operator()
)using matxop = bool
andusing matxoplvalue = bool
Please provide any links to existing documentation where the updates are needed
Unknown
I have not been able to find any documentation on how to do this, so I apologize if it already exists.
The text was updated successfully, but these errors were encountered: