Skip to content

Commit

Permalink
[Eager]Support dygraph2program for Eager Mode without program_desc_tr…
Browse files Browse the repository at this point in the history
  • Loading branch information
Aurelius84 authored May 7, 2022
1 parent bde51fd commit 2533aff
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 10 deletions.
93 changes: 83 additions & 10 deletions paddleslim/core/dygraph.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
import paddle
import collections
import logging
import numpy as np
from paddle.fluid.framework import _dygraph_tracer, dygraph_only, _dygraph_guard
from paddle.fluid.dygraph.base import program_desc_tracing_guard
from paddle.fluid.framework import _dygraph_tracer, dygraph_only, _dygraph_guard, program_guard
from paddle.fluid.dygraph.base import program_desc_tracing_guard, _switch_declarative_mode_guard_
from paddle.fluid.dygraph.layers import Layer
from paddle.fluid.framework import Block, ParamBase, Program, Variable
from ..common import get_logger
Expand All @@ -13,6 +14,22 @@
_logger = get_logger(__name__, level=logging.INFO)


class NameGenerator:
def __init__(self):
self.ids = collections.defaultdict(int)

def name(self, prefix):
assert isinstance(prefix, str)

name = "{}_{}".format(prefix, self.ids[prefix])
self.ids[prefix] += 1

return name


NG = NameGenerator()


def _is_shape(values):
if not isinstance(values, (list, tuple)):
return False
Expand All @@ -31,7 +48,7 @@ def _is_shapes(values):
return True


def _create_tensors(shapes, dtypes=None):
def _create_tensors(shapes, dtypes=None, is_static=False):
if dtypes is not None:
assert len(shapes) == len(
dtypes
Expand All @@ -41,8 +58,13 @@ def _create_tensors(shapes, dtypes=None):
dtypes = len(shapes) * ['float32']
tensors = []
for shape, dtype in zip(shapes, dtypes):
data = np.ones(tuple(shape)).astype(dtype)
tensors.append(paddle.to_tensor(data))
if is_static:
tensors.append(
paddle.static.data(
shape=shape, dtype=dtype, name=NG.name("feed")))
else:
data = np.ones(tuple(shape)).astype(dtype)
tensors.append(paddle.to_tensor(data))
return tensors


Expand Down Expand Up @@ -72,21 +94,35 @@ def extract_vars(inputs):
return vars


def to_variables(inputs):
def _to_var(x):
"""
Convert Variable or np.array into Placeholder.
"""
shape = x.shape
dtype = x.dtype
name = getattr(x, "name", None) or NG.name("feed")
return paddle.static.data(shape=shape, dtype=dtype, name=name)


def to_variables(inputs, is_static=False):
"""
Find and rename variables. Find np.ndarray and convert it to variable.
"""
if isinstance(inputs, Variable) or isinstance(inputs, np.ndarray):
return paddle.fluid.dygraph.to_variable(inputs)
if isinstance(inputs,
(Variable, paddle.Tensor)) or isinstance(inputs, np.ndarray):
if is_static:
return _to_var(inputs)
else:
return paddle.fluid.dygraph.to_variable(inputs)
elif isinstance(inputs, dict):
ret = {}
for _key in inputs:
ret[_key] = to_variables(inputs[_key])
ret[_key] = to_variables(inputs[_key], is_static)
return inputs
elif isinstance(inputs, list):
ret = []
for _value in inputs:
ret.append(to_variables(_value))
ret.append(to_variables(_value, is_static))
return ret


Expand All @@ -99,9 +135,15 @@ def dygraph2program(layer,
extract_inputs_fn=None,
extract_outputs_fn=None,
dtypes=None):
print(type(layer))
assert isinstance(layer, Layer)
extract_inputs_fn = extract_inputs_fn if extract_inputs_fn is not None else extract_vars
extract_outputs_fn = extract_outputs_fn if extract_outputs_fn is not None else extract_vars

if os.environ.get("FLAGS_enable_eager_mode") == "1":
return _dy2prog(layer, inputs, feed_prefix, fetch_prefix, tmp_prefix,
extract_inputs_fn, extract_outputs_fn, dtypes)

tracer = _dygraph_tracer()._get_program_desc_tracer()

with program_desc_tracing_guard(True):
Expand Down Expand Up @@ -131,3 +173,34 @@ def dygraph2program(layer,
program.blocks = [Block(program, 0)]
program._sync_with_cpp()
return program


def _dy2prog(layer,
inputs,
feed_prefix='feed_',
fetch_prefix='fetch_',
tmp_prefix='t_',
extract_inputs_fn=None,
extract_outputs_fn=None,
dtypes=None):
"""
Tracing program in Eager Mode.
"""
paddle.enable_static()

program = Program()
# convert ParamBase into Parameter automatically by _switch_declarative_mode_guard_
with program_guard(program), _switch_declarative_mode_guard_(True):
if _is_shape(inputs):
shapes = [inputs]
inputs = _create_tensors(shapes, dtypes=dtypes, is_static=True)
elif _is_shapes(inputs):
inputs = _create_tensors(inputs, dtypes=dtypes, is_static=True)
else:
inputs = to_variables(inputs, is_static=True)
inputs = extract_inputs_fn(inputs)
outputs = layer(*inputs)

paddle.disable_static()

return program
66 changes: 66 additions & 0 deletions tests/test_dy2prog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os
import sys
sys.path.append("../")
import paddle
import unittest
from paddleslim.core import dygraph2program


class Model(paddle.nn.Layer):
def __init__(self):
super(Model, self).__init__()
self.conv = paddle.nn.Conv2D(
in_channels=1, out_channels=256, kernel_size=3, stride=1, padding=1)
self.pool2d_avg = paddle.nn.AdaptiveAvgPool2D([1, 1])
self.out = paddle.nn.Linear(256, 10)

def forward(self, inputs):
inputs = paddle.reshape(inputs, shape=[0, 1, 28, 28])
y = self.conv(inputs)
y = self.pool2d_avg(y)
y = paddle.reshape(y, shape=[-1, 256])
y = self.out(y)
return y


class TestEagerDygraph2Program(unittest.TestCase):
def setUp(self):
os.environ['FLAGS_enable_eager_mode'] = "1"
self.prepare_inputs()
self.prepare_layer()

def prepare_inputs(self):
self.inputs = [3, 28, 28]

def prepare_layer(self):
self.layer = Model()

def test_dy2prog(self):
program = dygraph2program(self.layer, self.inputs)
self.assert_program(program)

def assert_program(self, program):
ops = [
'reshape2', 'conv2d', 'elementwise_add', 'pool2d', 'reshape2',
'matmul_v2', 'elementwise_add'
]
self.assertListEqual([op.type for op in program.block(0).ops], ops)


class TestEagerDygraph2Program2(TestEagerDygraph2Program):
def prepare_inputs(self):
self.inputs = [[3, 28, 28]]


class TestEagerDygraph2Program3(TestEagerDygraph2Program):
def prepare_inputs(self):
self.inputs = paddle.randn([3, 28, 28])


class TestEagerDygraph2Program4(TestEagerDygraph2Program):
def prepare_inputs(self):
self.inputs = [paddle.randn([3, 28, 28])]


if __name__ == "__main__":
unittest.main()

0 comments on commit 2533aff

Please sign in to comment.