Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does it need to enhance matmul_op to support 4-D inputs #7319

Closed
lcy-seso opened this issue Jan 8, 2018 · 1 comment · Fixed by #7656
Closed

Does it need to enhance matmul_op to support 4-D inputs #7319

lcy-seso opened this issue Jan 8, 2018 · 1 comment · Fixed by #7656

Comments

@lcy-seso
Copy link
Contributor

lcy-seso commented Jan 8, 2018

When checking the dot product attention in ConvS2S and Transformer. I found in multi-head (self) attention, both inputs of the batched matrix multiplication can potentially be a 4-D tensor.

It seems we can enhance the current matmul_op to support 4-D tensor as its inputs, however, I guess this is determined by how to batch the computation to accelerate the computation speed.

Or the multiple heads can be simply wrapped in a Python API by using a for loop.

@lcy-seso lcy-seso added the NMT label Jan 8, 2018
@lcy-seso lcy-seso changed the title Enhance matmul_op to support 4-D inputs Does it need to enhance matmul_op to support 4-D inputs Jan 8, 2018
@lcy-seso
Copy link
Contributor Author

lcy-seso commented Jan 8, 2018

A straightforward way is to use a for loop define multiple heads.

  • query: $Q$ is a 3-D tensor with shape $[\text{bs} \times N \times D]$, where
    • $\text{bs}$ is the batch size
    • $N$ is the max length of query sentence (all sentences in a mini-batch are padded to have the same length)
    • $D$ is the hidden size.
  • key: $K$ is a 3-D tensor with shape $[\text{bs} \times M \times D]$, where
    • $\text{bs}$ is the batch size
    • $M$ is the max length of key sentence (all sentences in a mini-batch are padded to have the same length)
    • $D$ is the hidden size.
  • value: $V$ is a 3-D tensor with shape $[\text{bs} \times M \times D]$, where
    • $\text{bs}$ is the batch size
    • $N$ is the max length of query sentence (all sentences in a mini-batch are padded to have the same length)
    • $D$ is the hidden size.

For self-attention, $M = N$. For attention between encoder and decoder, $M \ne N$.

for a single head $i$, compute:

  • $\widetilde{Q_i} = QW_i^Q$.

    • $W_i^Q$ is the projection matrix which is a 2-D Tensor with a shape $[D \times D']$ .
    • $D'$ is the size of the context vector computed by one head.
    • $D'$ is set to $D/\text{number of heads}$ .
  • $\widetilde{K_i} = KW_i^K$.

    • $W_i^K$ is the projection matrix which is a 2-D Tensor with a shape $[D \times D']$
    • $D'$ is the size of the context vector computed by one head.
    • $D'$ is set to $D/\text{number of heads}$ .
  • $\widetilde{V_i} = KW_i^V$.

    • $W_i^V$ is the projection matrix which is a 2-D Tensor with a shape $[D \times D']$
    • $D'$ is the size of the context vector computed by one head.
    • $D'$ is set to $D/\text{number of heads}$ .
  • Then use $\widetilde{Q}$, $\widetilde{K}$ and $\widetilde{V}$ to compute dot product attention as described in Support padding_idx in the lookup_table_op. #7309 .


Or, It seems that:

  1. $\widetilde{Q}$, $\widetilde{K}$ and $\widetilde{V}$ can be computed at the same time for all the heads by simply making $D' = D$. The resulted $\widetilde{Q}$, $\widetilde{K}$ and $\widetilde{V}$ are all 3-D tensors.

  2. then reshape the 3-D tensor $\widetilde{Q}$, $\widetilde{K}$ and $\widetilde{V}$ into a 4-D tensor by making the number of heads a dimension.

  3. finally, compute the multi-head attention in parallel. This will lead to two 4-D tensors to matmul_op as its inputs.


P.S. I use chrome with the GitHub with MathJax plugin to show Latex formula in the GitHub issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants