Skip to content

Commit

Permalink
Check -1 shape for input_spec and program when prim or cinn enabled (#…
Browse files Browse the repository at this point in the history
…50473)

* Check -1 shape for input_spec and program when prim or cinn enabled

* Polish neg shape check

* Polish code

* Fix UT

* Fix UT in static
  • Loading branch information
0x45f authored Feb 21, 2023
1 parent 8ad635d commit 1e7dc9c
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 38 deletions.
28 changes: 28 additions & 0 deletions python/paddle/fluid/tests/unittests/test_input_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.framework import convert_np_dtype_to_dtype_
from paddle.jit.dy2static.utils import _compatible_non_tensor_spec
from paddle.static import InputSpec
Expand Down Expand Up @@ -331,5 +332,32 @@ def test_case(self):
)


class NegSpecNet(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.linear = paddle.nn.Linear(10, 5)

def forward(self, x):
return self.linear(x)


class TestNegSpecWithPrim(unittest.TestCase):
def setUp(self):
paddle.disable_static()
core._set_prim_all_enabled(True)

def tearDown(self):
core._set_prim_all_enabled(False)

def test_run(self):
net = NegSpecNet()
net = paddle.jit.to_static(
net, input_spec=[paddle.static.InputSpec(shape=[-1, 10])]
)
x = paddle.randn([2, 10])
out = net(x)
np.testing.assert_equal(out.shape, [2, 5])


if __name__ == '__main__':
unittest.main()
26 changes: 0 additions & 26 deletions python/paddle/jit/dy2static/partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,32 +1069,6 @@ def _valid_vars(self, vars):
return vars if vars else None


def _create_fake_var():
"""
Create a fake_var (force on CPU) to handle empty input or output
"""
if not framework.global_var._in_eager_mode_:
return [
core.VarBase(
core.VarDesc.VarType.FP32,
[],
"Fake_var",
core.VarDesc.VarType.RAW,
False,
)
]
else:
return [
core.eager.Tensor(
core.VarDesc.VarType.FP32,
[],
"Fake_var",
core.VarDesc.VarType.RAW,
False,
)
]


def partial_program_from(concrete_program):
inputs = concrete_program.inputs
if inputs and isinstance(inputs[0], layers.Layer):
Expand Down
32 changes: 21 additions & 11 deletions python/paddle/jit/dy2static/program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
func_to_source_code,
input_specs_compatible,
make_hashable,
prim_or_cinn_is_enabled,
type_name,
unwrap,
)
Expand Down Expand Up @@ -320,6 +321,17 @@ def __init__(self, function, input_spec=None, **kwargs):
self._dygraph_function = function
self._class_instance = None

if input_spec is not None and prim_or_cinn_is_enabled(
kwargs.get("build_strategy", None)
):
for spec in input_spec:
if spec is not None and -1 in spec.shape:
input_spec = None
warnings.warn(
'Now prim and cinn do not support -1 shape, but input_spec has -1 shape so we set it to None.'
)
break

self._input_spec = input_spec
self._function_spec = FunctionSpec(function, input_spec)
self._program_cache = ProgramCache()
Expand Down Expand Up @@ -1046,17 +1058,6 @@ def from_func_spec(
)


def _extract_indeed_params_buffers(class_instance):
"""
To filter not initialzed buffers.
"""
params = list(get_parameters(class_instance).values())
buffers = list(get_buffers(class_instance).values())
buffers = [buffer for buffer in buffers if len(buffer.shape) != 0]

return params + buffers


class ParametersRecorder:
def __init__(self):
self.params_dict = {}
Expand Down Expand Up @@ -1177,6 +1178,15 @@ def _build_once(self, cache_key):
else:
raise

if prim_or_cinn_is_enabled(cache_key.kwargs['build_strategy']):
for var in concrete_program.main_program.list_vars():
if -1 in var.shape:
warnings.warn(
"Now prim and cinn do not support -1 shape, but the shape of var {} is {}".format(
var.name, var.shape
)
)

concrete_program._to_prim()
return concrete_program, partial_program_from(concrete_program)

Expand Down
24 changes: 23 additions & 1 deletion python/paddle/jit/dy2static/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1493,7 +1493,7 @@ def _param_grad_names(program_desc, params):
Parse PARAM@GARD name from original train and infer program.
"""
names = []
# NOTE: `names` and `self._params` must be in the same order so that
# NOTE: `names` and `params` must be in the same order so that
# the param grad name can be set correctly in the run_program.
for param in params:
candidate = [
Expand Down Expand Up @@ -1523,3 +1523,25 @@ def _out_grad_names(program_desc, fwd_end_op_index, out_size):
var_name = op.output('Out')[0]
names.append(var_name)
return names


def prim_or_cinn_is_enabled(build_strategy):
if build_strategy is not None and build_strategy.build_cinn_pass:
return True

if core._is_bwd_prim_enabled() or core._is_fwd_prim_enabled():
return True

env_flags = [
'FLAGS_prim_forward',
'FLAGS_prim_backward',
'FLAGS_prim_all',
'FLAGS_use_cinn',
]
for flag in env_flags:
value = os.getenv(flag)
if value is None:
continue
elif value.lower() in ['true', '1']:
return True
return False

0 comments on commit 1e7dc9c

Please sign in to comment.