Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support static graph code-gen for temporal_shift #52686

Merged
merged 6 commits into from
Apr 13, 2023

Conversation

sanbuphy
Copy link
Contributor

@sanbuphy sanbuphy commented Apr 8, 2023

PR types

Others

PR changes

Others

Describe

#51842

@paddle-bot
Copy link

paddle-bot bot commented Apr 8, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added contributor External developers status: proposed labels Apr 8, 2023
Copy link
Contributor

@heavyrain-lzy heavyrain-lzy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请rerun Window-infer CI

@luotao1
Copy link
Contributor

luotao1 commented Apr 10, 2023

@heavyrain-lzy

请rerun Window-infer CI

CI 已经通过

@heavyrain-lzy
Copy link
Contributor

根据PR-CI-Static-Check的输出,这里需要把原始.cc文件中的comment拷贝到Python端。

  1. python/paddle/nn/functional/extension.py中的temporal_shift,删除@templatedoc(),替换${comment}
def temporal_shift(x, seg_num, shift_ratio=0.25, name=None, data_format="NCHW"):
    r"""

    **Temporal Shift Operator**

    Calculate the temporal shifting features for Input(X).

    Input(X) should be in shape of [N*T, C, H, W] or [N*T, H, W, C], while
    N is the batch size, T is the temporal segment number specified by
    :attr:`seg_num`, C is the channel number, H and W is the height and
    width of features.

    Temporal Shifting is calculated as follows when data format is NCHW:

    Step 1: Reshape Input(X) to [N, T, C, H, W].

    Step 2: Pad 0 to reshaping result in the 2nd(T) dimension with
    padding width as 1 on each side, padding result will be in shape
    of [N, T+2, C, H, W].

    Step 3: Assume :attr:`shift_ratio` is :math:`1/4`, slice padding
    result as follows:

    $$
    slice1 = x[:, :T, :C/4, :, :]
    $$
    $$
    slice2 = x[:, 2:T+2, C/4:C/2, :, :]
    $$
    $$
    slice3 = x[:, 1:T+1, C/2:, :, :]
    $$

    Step 4: Concatenate three slices along the 3rd(C) dimension and
    reshape result to [N*T, C, H, W].

    For details of temporal shifting, please refer to paper:
    `Temporal Shift Module <http://arxiv.org/abs/1811.08383>`_ .

    Args:
        x(Tensor): The input tensor of temporal shift operator.
                   This is a 4-D tensor with shape of [N*T, C, H, W] or
                   [N*T, H, W, C]. While N is the batch size,
                   T is the temporal segment number, C is the channel number,
                   H is the height of features and W is the width of features.
                   The data type is float16, float32 and float64
        seg_num(int): The temporal segment number, this should be a positive integer.
        shift_ratio(float): The shift ratio of the channels, the first :attr:`shift_ratio` part
                            of channels will be shifted by -1 along the temporal dimension,
                            and the second :attr:`shift_ratio` part of channels will be shifted
                            by 1 along the temporal dimension. :attr:`shift_ratio` should be in
                            range [0, 0.5]. Default 0.25.
        name(str, optional): For detailed information, please refer
                             to :ref:`api_guide_Name`. Usually name is no need to set and
                             None by default.
        data_format(str, optional): Data format that specifies the layout of input.
            It can be "NCHW" or "NHWC". Default: "NCHW".

    Returns:
        out(Tensor): The temporal shifting result is a tensor with the
        same shape and same data type as the input.

    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn.functional as F

            input = paddle.randn([6, 4, 2, 2])
            out = F.temporal_shift(x=input, seg_num=2, shift_ratio=0.2)
    """

heavyrain-lzy
heavyrain-lzy previously approved these changes Apr 11, 2023
Copy link
Contributor

@heavyrain-lzy heavyrain-lzy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@heavyrain-lzy
Copy link
Contributor

@sunzhongkai 麻烦review一下PR

@heavyrain-lzy
Copy link
Contributor

是不是没有安装pre-commit,代码格式错误,ci-converage可能是随机挂

@sanbuphy
Copy link
Contributor Author

是不是没有安装pre-commit,代码格式错误,ci-converage可能是随机挂

commit过了,我再commit一次吧

@sanbuphy
Copy link
Contributor Author

是不是没有安装pre-commit,代码格式错误,ci-converage可能是随机挂

precommit的话 是提示没有可更新的,明天我再看看具体出在哪

Copy link
Contributor

@heavyrain-lzy heavyrain-lzy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@luotao1
Copy link
Contributor

luotao1 commented Apr 13, 2023

@sunzhongkai588 审核下文档~

Copy link
Contributor

@sunzhongkai588 sunzhongkai588 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@luotao1 luotao1 merged commit 9246b93 into PaddlePaddle:develop Apr 13, 2023
@sanbuphy sanbuphy deleted the temporalop branch April 13, 2023 07:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants