Skip to content

Commit

Permalink
[Dy2St] pir dy2st unittest verification - Part 11 (#59314)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: SigureMo <[email protected]>
  • Loading branch information
gouzil and SigureMo authored Nov 29, 2023
1 parent cae8de7 commit 1d859c5
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 102 deletions.
2 changes: 2 additions & 0 deletions test/dygraph_to_static/test_build_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from dygraph_to_static_utils import (
Dy2StTestBase,
test_ast_only,
test_legacy_and_pt_and_pir,
)
from test_resnet import ResNetHelper

Expand Down Expand Up @@ -87,6 +88,7 @@ def test_in_static_mode_mkldnn(self):


class TestError(Dy2StTestBase):
@test_legacy_and_pt_and_pir
def test_type_error(self):
def foo(x):
out = x + 1
Expand Down
7 changes: 6 additions & 1 deletion test/dygraph_to_static/test_load_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
import unittest

import numpy as np
from dygraph_to_static_utils import Dy2StTestBase
from dygraph_to_static_utils import (
Dy2StTestBase,
test_legacy_and_pt_and_pir,
)

import paddle

Expand Down Expand Up @@ -45,6 +48,7 @@ class TestFallback(Dy2StTestBase):
def setUp(self):
self.x = paddle.to_tensor(1.0).astype('int')

@test_legacy_and_pt_and_pir
def test_name_load(self):
net_dy = Net()
net_st = Net()
Expand All @@ -54,6 +58,7 @@ def test_name_load(self):


class TestLoad2(Dy2StTestBase):
@test_legacy_and_pt_and_pir
def test_name_load_nograd(self):
@paddle.no_grad()
def func(x):
Expand Down
75 changes: 40 additions & 35 deletions test/dygraph_to_static/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,40 +72,38 @@ def test_save_load_same_result(self):
x_data = np.random.randn(30, 10, 32).astype('float32')
batch_num = 3

with base.dygraph.guard(place):
paddle.jit.enable_to_static(True)
x = base.dygraph.to_variable(x_data)
net = Linear(32, 64)
adam = Adam(learning_rate=0.1, parameters=net.parameters())

for i in range(batch_num):
static_out, static_loss = net(x)
# Update parameters
static_loss.backward()
adam.minimize(static_loss)
net.clear_gradients()
# Save parameters

paddle.save(net.state_dict(), self.model_path + '.pdparams')
# minimize() will update parameter, call net() to get output and avg_loss.
# Switch into eval mode.
net.eval()
paddle.jit.enable_to_static(True)
x = base.dygraph.to_variable(x_data)
net = Linear(32, 64)
adam = Adam(learning_rate=0.1, parameters=net.parameters())

for i in range(batch_num):
static_out, static_loss = net(x)
# Update parameters
static_loss.backward()
adam.minimize(static_loss)
net.clear_gradients()
# Save parameters

paddle.save(net.state_dict(), self.model_path + '.pdparams')
# minimize() will update parameter, call net() to get output and avg_loss.
# Switch into eval mode.
net.eval()
static_out, static_loss = net(x)

# load parameters into dygraph
with base.dygraph.guard(place):
dygraph_net = Linear(32, 64)
dygraph_net = Linear(32, 64)

# Load parameters
model_dict = paddle.load(self.model_path + '.pdparams')
dygraph_net.set_dict(model_dict)
# Switch into eval mode.
dygraph_net.eval()
# Load parameters
model_dict = paddle.load(self.model_path + '.pdparams')
dygraph_net.set_dict(model_dict)
# Switch into eval mode.
dygraph_net.eval()

x = base.dygraph.to_variable(x_data)
# predict output
paddle.jit.enable_to_static(False)
dygraph_out, dygraph_loss = dygraph_net(x)
x = base.dygraph.to_variable(x_data)
# predict output
paddle.jit.enable_to_static(False)
dygraph_out, dygraph_loss = dygraph_net(x)

np.testing.assert_allclose(
dygraph_out.numpy(), static_out.numpy(), rtol=1e-05
Expand All @@ -114,6 +112,17 @@ def test_save_load_same_result(self):
dygraph_loss.numpy(), static_loss.numpy(), rtol=1e-05
)

def _compute_op_num(self, composite_program):
if paddle.framework.use_pir_api():
comp_op_type_list = [
op.name() for op in composite_program.program.global_block().ops
]
else:
comp_op_type_list = [
op.type for op in composite_program.block(0).ops
]
return comp_op_type_list

@test_ast_only
def test_save_load_prim(self):
with base.dygraph.guard(place):
Expand All @@ -127,9 +136,7 @@ def test_save_load_prim(self):
composite_program = static_net.forward.get_concrete_program(self.x)[
1
].train_program
comp_op_type_list = [
op.type for op in composite_program.block(0).ops
]
comp_op_type_list = self._compute_op_num(composite_program)
self.assertNotIn("batch_norm", comp_op_type_list)
self.assertNotIn("relu", comp_op_type_list)
self.assertNotIn("pow", comp_op_type_list)
Expand Down Expand Up @@ -169,9 +176,7 @@ def test_save_load_prim_with_hook(self):
composite_program = static_net.forward.get_concrete_program(self.x)[
1
].train_program
comp_op_type_list = [
op.type for op in composite_program.block(0).ops
]
comp_op_type_list = self._compute_op_num(composite_program)
self.assertNotIn("batch_norm", comp_op_type_list)
self.assertNotIn("relu", comp_op_type_list)
self.assertNotIn("pow", comp_op_type_list)
Expand Down
122 changes: 60 additions & 62 deletions test/dygraph_to_static/test_sentiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import paddle
from paddle import base
from paddle.base.dygraph import to_variable
from paddle.jit.api import to_static
from paddle.nn import Embedding, Linear

SEED = 2020
Expand Down Expand Up @@ -88,7 +87,6 @@ def __init__(self, dict_dim, batch_size, seq_len):
self._fc1_act = paddle.nn.Softmax()
self._fc_prediction = Linear(self.fc_hid_dim, self.class_dim)

@to_static
def forward(self, inputs, label=None):
emb = self.embedding(inputs)
o_np_mask = (paddle.reshape(inputs, [-1, 1]) != self.dict_dim).astype(
Expand Down Expand Up @@ -132,7 +130,6 @@ def __init__(self, dict_dim, batch_size, seq_len):
self._fc2 = Linear(self.hid_dim, self.fc_hid_dim)
self._fc_prediction = Linear(self.fc_hid_dim, self.class_dim)

@to_static
def forward(self, inputs, label=None):
emb = self.embedding(inputs)
o_np_mask = (paddle.reshape(inputs, [-1, 1]) != self.dict_dim).astype(
Expand Down Expand Up @@ -171,7 +168,7 @@ def __init__(self, dict_dim, batch_size, seq_len):
self.embedding = Embedding(
self.dict_dim + 1,
self.emb_dim,
weight_attr=base.ParamAttr(learning_rate=30),
weight_attr=paddle.ParamAttr(learning_rate=30),
sparse=False,
)
h_0 = np.zeros((self.batch_size, self.hid_dim), dtype="float32")
Expand All @@ -181,7 +178,6 @@ def __init__(self, dict_dim, batch_size, seq_len):
self._fc_prediction = Linear(self.fc_hid_dim, self.class_dim)
self._gru = DynamicGRU(size=self.hid_dim, h_0=h_0)

@to_static
def forward(self, inputs, label=None):
emb = self.embedding(inputs)
o_np_mask = (paddle.reshape(inputs, [-1, 1]) != self.dict_dim).astype(
Expand Down Expand Up @@ -219,7 +215,7 @@ def __init__(self, dict_dim, batch_size, seq_len):
self.embedding = Embedding(
self.dict_dim + 1,
self.emb_dim,
weight_attr=base.ParamAttr(learning_rate=30),
weight_attr=paddle.ParamAttr(learning_rate=30),
sparse=False,
)
h_0 = np.zeros((self.batch_size, self.hid_dim), dtype="float32")
Expand All @@ -234,7 +230,6 @@ def __init__(self, dict_dim, batch_size, seq_len):
size=self.hid_dim, h_0=h_0, is_reverse=True
)

@to_static
def forward(self, inputs, label=None):
emb = self.embedding(inputs)
o_np_mask = (paddle.reshape(inputs, [-1, 1]) != self.dict_dim).astype(
Expand Down Expand Up @@ -304,68 +299,71 @@ class Args:

def train(args, to_static):
paddle.jit.enable_to_static(to_static)
place = (
base.CUDAPlace(0) if base.is_compiled_with_cuda() else base.CPUPlace()
)
np.random.seed(SEED)
paddle.seed(SEED)
paddle.framework.random._manual_program_seed(SEED)

with base.dygraph.guard(place):
np.random.seed(SEED)
paddle.seed(SEED)
paddle.framework.random._manual_program_seed(SEED)
train_reader = fake_data_reader(
args.class_num, args.vocab_size, args.batch_size, args.padding_size
)
train_loader = base.io.DataLoader.from_generator(capacity=24)
train_loader.set_sample_list_generator(train_reader)

train_reader = fake_data_reader(
args.class_num, args.vocab_size, args.batch_size, args.padding_size
if args.model_type == 'cnn_net':
model = paddle.jit.to_static(
CNN(args.vocab_size, args.batch_size, args.padding_size)
)
elif args.model_type == 'bow_net':
model = paddle.jit.to_static(
BOW(args.vocab_size, args.batch_size, args.padding_size)
)
train_loader = base.io.DataLoader.from_generator(capacity=24)
train_loader.set_sample_list_generator(train_reader)

if args.model_type == 'cnn_net':
model = CNN(args.vocab_size, args.batch_size, args.padding_size)
elif args.model_type == 'bow_net':
model = BOW(args.vocab_size, args.batch_size, args.padding_size)
elif args.model_type == 'gru_net':
model = GRU(args.vocab_size, args.batch_size, args.padding_size)
elif args.model_type == 'bigru_net':
model = BiGRU(args.vocab_size, args.batch_size, args.padding_size)
sgd_optimizer = paddle.optimizer.Adagrad(
learning_rate=args.lr, parameters=model.parameters()
elif args.model_type == 'gru_net':
model = paddle.jit.to_static(
GRU(args.vocab_size, args.batch_size, args.padding_size)
)
elif args.model_type == 'bigru_net':
model = paddle.jit.to_static(
BiGRU(args.vocab_size, args.batch_size, args.padding_size)
)
sgd_optimizer = paddle.optimizer.Adagrad(
learning_rate=args.lr, parameters=model.parameters()
)

loss_data = []
for eop in range(args.epoch):
time_begin = time.time()
for batch_id, data in enumerate(train_loader()):
word_ids, labels, seq_lens = data
doc = to_variable(word_ids.numpy().reshape(-1)).astype('int64')
label = labels.astype('int64')

model.train()
avg_cost, prediction, acc = model(doc, label)
loss_data.append(float(avg_cost))

avg_cost.backward()
sgd_optimizer.minimize(avg_cost)
model.clear_gradients()

if batch_id % args.log_step == 0:
time_end = time.time()
used_time = time_end - time_begin
# used_time may be 0.0, cause zero division error
if used_time < 1e-5:
used_time = 1e-5
print(
"step: %d, ave loss: %f, speed: %f steps/s"
% (
batch_id,
float(avg_cost),
args.log_step / used_time,
)
loss_data = []
for eop in range(args.epoch):
time_begin = time.time()
for batch_id, data in enumerate(train_loader()):
word_ids, labels, seq_lens = data
doc = paddle.to_tensor(word_ids.numpy().reshape(-1), dtype="int64")
label = labels.astype('int64')

model.train()
avg_cost, prediction, acc = model(doc, label)
loss_data.append(float(avg_cost))

avg_cost.backward()
sgd_optimizer.minimize(avg_cost)
model.clear_gradients()

if batch_id % args.log_step == 0:
time_end = time.time()
used_time = time_end - time_begin
# used_time may be 0.0, cause zero division error
if used_time < 1e-5:
used_time = 1e-5
print(
"step: %d, ave loss: %f, speed: %f steps/s"
% (
batch_id,
float(avg_cost),
args.log_step / used_time,
)
time_begin = time.time()
)
time_begin = time.time()

if batch_id == args.train_step:
break
batch_id += 1
if batch_id == args.train_step:
break
batch_id += 1
return loss_data


Expand Down
5 changes: 2 additions & 3 deletions test/dygraph_to_static/test_simnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
import unittest

import numpy as np
from dygraph_to_static_utils import (
Dy2StTestBase,
)
from dygraph_to_static_utils import Dy2StTestBase, test_legacy_and_pt_and_pir
from simnet_dygraph_model import BOW, HingeLoss

import paddle
Expand Down Expand Up @@ -180,6 +178,7 @@ def train(conf_dict, to_static):


class TestSimnet(Dy2StTestBase):
@test_legacy_and_pt_and_pir
def test_dygraph_static_same_loss(self):
if base.is_compiled_with_cuda():
base.set_flags({"FLAGS_cudnn_deterministic": True})
Expand Down
3 changes: 2 additions & 1 deletion test/dygraph_to_static/test_simnet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import unittest

import numpy as np
from dygraph_to_static_utils import Dy2StTestBase
from dygraph_to_static_utils import Dy2StTestBase, test_legacy_and_pt_and_pir
from simnet_dygraph_model_v2 import BOW, HingeLoss

import paddle
Expand Down Expand Up @@ -177,6 +177,7 @@ def train(conf_dict, to_static):


class TestSimnet(Dy2StTestBase):
@test_legacy_and_pt_and_pir
def test_dygraph_static_same_loss(self):
if paddle.is_compiled_with_cuda():
paddle.base.set_flags({"FLAGS_cudnn_deterministic": True})
Expand Down

0 comments on commit 1d859c5

Please sign in to comment.