Skip to content

Commit

Permalink
[Prim] Filter tensor type for int_array and scalar input in composite…
Browse files Browse the repository at this point in the history
… rule (#51208)
  • Loading branch information
cyber-pioneer authored Mar 14, 2023
1 parent 60d04fa commit 775fb43
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

import paddle
from paddle.fluid import core


def fn(x, shape):
out = paddle.expand(x, shape=shape)
return out


class TestIntarrayInput(unittest.TestCase):
"""This case is set to test int_array input process during composite rule."""

def test_non_tensor_input(self):
core._set_prim_all_enabled(True)
np_data = np.random.random([3, 4]).astype("float32")
tensor_data = paddle.to_tensor(np_data)
net = paddle.jit.to_static(fn)

_ = net(tensor_data, shape=[2, 3, 4]).numpy()
core._set_prim_all_enabled(False)

def test_error_input(self):
"""In composite rules, tensor shape is not supported in int_array input"""
core._set_prim_all_enabled(True)
np_data = np.random.random([3, 4]).astype("float32")
tensor_data = paddle.to_tensor(np_data)
shape = paddle.to_tensor([2, 3, 4])
net = paddle.jit.to_static(fn)
with self.assertRaises(ValueError):
_ = net(tensor_data, shape).numpy()
core._set_prim_all_enabled(False)


if __name__ == '__main__':
unittest.main()
2 changes: 0 additions & 2 deletions python/paddle/incubate/autograd/composite_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import functools
import operator

import paddle.framework.dtype as dtypes
from paddle.fluid import core

from .primitives import * # noqa: F403
Expand Down Expand Up @@ -361,7 +360,6 @@ def fill_any_like(x, fill_value, dtype, place=None):
"""define composite rule of op full_like."""
"""op name: full_like op type name: fill_any_like."""
"""arg place is not used, add it here to keep same as python api."""
dtype = dtypes.dtype(dtype)
val = full(x.shape, fill_value, dtype)
return val

Expand Down
53 changes: 40 additions & 13 deletions python/paddle/incubate/autograd/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import typing

import paddle
import paddle.framework.dtype as dtypes
from paddle.fluid import framework as framework

from .phi_ops_map import op_info, op_map
Expand Down Expand Up @@ -159,15 +160,52 @@ def _solve_arg(item):
return arg_type.strip(), arg_name.strip()


def _get_attr_value(op, arg_type, arg_name):
op_content = op_map[op.type]
if "attrs" in op_content.keys() and arg_name in op_content["attrs"].keys():
arg_name = op_content["attrs"][arg_name]

# Note: in some cases, attrs may be optional , thus assign None. Such case must be recorded.

if arg_name not in op.attr_names:
return None
else:
if arg_type == "DataType":
return dtypes.dtype(op.attr(arg_name))
return op.attr(arg_name)


def _get_args_values(op, phi_name):
"get attrs' values for api args' values"
args = op_info[phi_name]
args_list = args["args"].split(",")
inputs = []
attrs = []

for item in args_list:
arg_type, arg_name = _solve_arg(item)
op_content = op_map[op.type]
# IntArray and Scalar are special cases which may cause dynamic shape. In these case, tensor-relative types are removed in composite op.
if arg_type in ("IntArray", "Scalar"):
tensor_key = "int_array" if arg_type == "IntArray" else "scalar"
if op_content.get(tensor_key):
tensor_content = op_content[tensor_key].get(arg_name)
if not tensor_content:
raise ValueError(
f'No value found for {arg_name} of {arg_type} type for operator {op.type}.'
)
for item in ("tensor_name", "tensors_name"):
# name of intarray may differ from operator arg_name
arg_name_new = tensor_content.get(item)
if (
arg_name_new is not None
and arg_name_new in op.input_names
and get_var_block(op.block, op.input(arg_name_new))
):
raise ValueError(
f"Tensor type of {arg_type} is not supported in composite op. Please set other type value of input arg {arg_name_new} for operator {op.type}."
)

if arg_type in ("Tensor", "Tensor[]"):
# assume Tensor type must belong to inputs
if (
Expand All @@ -178,19 +216,8 @@ def _get_args_values(op, phi_name):
else:
inputs.append(arg_name)
else:
op_content = op_map[op.type]
if (
"attrs" in op_content.keys()
and arg_name in op_content["attrs"].keys()
):
arg_name = op_content["attrs"][arg_name]

# Note: in some cases, attrs may be optional , thus assign None. Such case must be recorded.

if arg_name not in op.attr_names:
attrs.append(None)
else:
attrs.append(op.attr(arg_name))
attr_value = _get_attr_value(op, arg_type, arg_name)
attrs.append(attr_value)

return inputs, attrs

Expand Down

0 comments on commit 775fb43

Please sign in to comment.