Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DOC] Document How to Implement Custom Operators #788

Closed
deanljohnson opened this issue Nov 1, 2024 · 5 comments
Closed

[DOC] Document How to Implement Custom Operators #788

deanljohnson opened this issue Nov 1, 2024 · 5 comments

Comments

@deanljohnson
Copy link

deanljohnson commented Nov 1, 2024

Which documentation should be updated?

How to implement custom operators should be documented and the documentation should address things like:

  1. How to support broadcasting the operation over extra dimensions
  2. How to handle temporary workspaces (e.g. should they be allocated in the constructor, or somewhere else? and is there a caching mechanism for these workspaces that should be used?)
  3. How to integrate custom kernels into the execution (as opposed to just implementing operator())
  4. Things like using matxop = bool and using matxoplvalue = bool

Please provide any links to existing documentation where the updates are needed

Unknown

I have not been able to find any documentation on how to do this, so I apologize if it already exists.

@cliffburdick
Copy link
Collaborator

Thanks for opening the issue @deanljohnson. The best documentation we have for a custom operator is the black-scholes example:

class BlackScholes : public BaseOp<BlackScholes<O, I1>> {

  1. You do not need to do anything to support this. The MatX framework will handle this in the background. For example, if you multiply a 2D operator by a 4D operator and their dimensions are compatible, MatX will implicitly call the 2D operator multiple times for each 4D operator.

  2. It sounds like you want to make a transform rather than a regular custom operator. Transforms require quite a bit more work that is not documented. A transform is typically anything that's not element-wise and requires temporary storage and/or uses custom kernels. If this is what you'd like to do please let us know and we can write some documentation.

  3. See the previous one. Maybe you can describe what you're trying to add.

  4. Agreed, and this can be part of the docs.

@deanljohnson
Copy link
Author

deanljohnson commented Nov 1, 2024

There are many things we would like to implement that would likely involve many custom operators, but I can describe the first thing I implemented as an operator as an example.

We have the following simple kernel:

template <typename ComplexT, typename RealT>
__global__ void SoftQPSKDemod(const ComplexT* samples, RealT* bits, int num_samples)
{
    const int idx = threadIdx.x + blockDim.x * blockIdx.x;
    const int stride = blockDim.x * gridDim.x;

    for (int i = idx; i < num_samples; i += stride)
    {
        const auto bits_offset = i * 2;
        bits[bits_offset] = -samples[i].x;
        bits[bits_offset + 1] = -samples[i].y;
    }
}

And I would like to implement that as an operator so that it can be used in matx expressions. The call should be something like:

    auto samples_t = matx::make_tensor(samples, {num_samples}); // complex<float>
    auto bits_t = matx::make_tensor(bits, {num_samples * 2});          // float
    (bits_t = softQPSKDemodulate(samples_t)).run(stream); // real expressions would include more operations

I could implement that as an operator with something like:

template <typename TOp, typename RealT = typename TOp::value_type::value_type>
class SoftQpskDemodOp : public matx::BaseOp<SoftQpskDemodOp<TOp, RealT>>
{
   static_assert(matx::is_complex_v<typename TOp::value_type>);

public:
   using value_type = RealT;
   using matxop = bool;

   __MATX_INLINE__ SoftQpskDemodOp(const TOp& input) : m_input(input) {}

   __MATX_INLINE__ __host__ __device__ matx::index_t Size(int dim) const {
      return 2 * m_input.Size(m_input.Rank() - 1);
   }

   __MATX_INLINE__ __host__ __device__ RealT operator()(matx::index_t i) const
   {
      const auto input_idx = i / 2;
      const auto& sample = m_input(input_idx);
      if (i % 2 == 0)
         return RealT(-sample.real());
      else
         return RealT(-sample.imag());
   }

   __MATX_INLINE__ static constexpr int32_t Rank() { return 1; }

private:
   TOp m_input;
};

template <typename TOp>
auto softQPSKDemodulate(const TOp& op)
{
   return SoftQpskDemodOp(op);
}

And this works if input is a 1D. However, if the input tensor is multi dimensional, I want the operator to work over only the right-most dimension. I may have been using the term "broadcast" incorrectly here since that seems to be referring to combining operators whereas I just want to apply (in this case) a 1D operation to a multi dimensional input.

In this example, I could just reshape the input into a single dimension, but there are obviously cases where that is not possible and I would think there would be a way to handle this "automatically" from the point of view of the callsite.

I was able to get that working with the following:

template <typename TOp, typename RealT = typename TOp::value_type::value_type>
class SoftQpskDemodOp : public matx::BaseOp<SoftQpskDemodOp<TOp, RealT>>
{
   static_assert(matx::is_complex_v<typename TOp::value_type>);

public:
   using value_type = RealT;
   using matxop = bool;

   __MATX_INLINE__ SoftQpskDemodOp(const TOp& input) : m_input(input)
   {
   }

   __MATX_INLINE__ __host__ __device__ matx::index_t Size(int dim) const
   {
      if (dim == Rank() - 1)
         return m_input.Size(dim) * 2;
      else
         return m_input.Size(dim);
   }

   template <typename... Is>
   __MATX_INLINE__ __host__ __device__ RealT operator()(Is... indices) const
   {
      // inds = output indices
      auto inds = cuda::std::make_tuple(indices...);
      const auto output_idx = cuda::std::get<Rank() - 1>(inds);

      // change indices to input indices
      const auto input_idx = output_idx / 2;
      cuda::std::get<Rank() - 1>(inds) = input_idx;

      const auto& sample = cuda::std::apply(m_input, inds);
      if (output_idx % 2 == 0)
         return RealT(-sample.real());
      else
         return RealT(-sample.imag());
   }

   __MATX_INLINE__ static constexpr int32_t Rank()
   {
      return TOp::Rank();
   }

private:
   TOp m_input;
};

Such that now the following seems to work:

    auto samples_t = matx::make_tensor(samples, {2, num_samples}); // complex<float>
    auto bits_t = matx::make_tensor(bits, {2, num_samples * 2});          // float
    (bits_t = softQPSKDemodulate(samples_t)).run(stream);

I don't know if my operator implementation is the "right" way to do this, but it does seem to work. If this is the right way, I think examples/documentation like this would be useful to have available within the matx project.

Regarding the transform questions and use of custom kernels in the operation, I don't have an example ready at this time. The situation I am in is that I am investigating the integration of matx into our project and am trying to make sure that once the whole team is let lose on using matx that they don't immediately run into a ton of questions. An example (even just a pointer to an existing operator) that uses temporary storage could definitely be helpful. I would imagine that the storage is not allocated/de-allocated every time the operator is constructed (or maybe the temporary storage is passed into the operator?).

Minor note: that example seems to have a typo as it says Time non-operator version for both the operator and non-operator calls.

@cliffburdick
Copy link
Collaborator

Thanks for the description @deanljohnson. We will get this pushed out soon since it's been a long-standing item on our list. Very quickly on your broadcasting question: typically broadcasting would just mean you take an operator of a lower rank and perform an operation on it with an operator that is expecting a higher rank. As long as the sizes are compatible, MatX will indirectly broadcast. This is done by using the function get_value. This function takes an operator and indices as input, and will return the value broadcasted if needed. For example, if you have a 2D 5x6 tensor and you call get_value with indices 3,4,2,3, the 3, 4 will effectively be ignored and it will be indexes with 2,3.

We will work on this early next week since you are not the first person to ask and it should help people unfamiliar with the library.

@cliffburdick
Copy link
Collaborator

Hi @deanljohnson, I added some documentation and it's going through our CI: #793

I did notice that there are some places we're missing broadcasting support in the operators, so I will update the code to fix that and add it to the docs guide.

@deanljohnson
Copy link
Author

@cliffburdick Excellent write-up! Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants