diff --git a/docs/api/paddle/summary_cn.rst b/docs/api/paddle/summary_cn.rst index 0a72e66a812..b9001c2aaaa 100644 --- a/docs/api/paddle/summary_cn.rst +++ b/docs/api/paddle/summary_cn.rst @@ -78,3 +78,17 @@ summary # --------------------------------------------------------------------------- # {'total_params': 61610, 'trainable_params': 61610} + # multi input demo + class LeNetMultiInput(LeNet): + def forward(self, inputs, y): + x = self.features(inputs) + if self.num_classes > 0: + x = paddle.flatten(x, 1) + x = self.fc(x + y) + return x + + lenet_multi_input = LeNetMultiInput() + params_info = paddle.summary(lenet_multi_input, [(1, 1, 28, 28), (1, 400)], + ['float32', 'float32']) + print(params_info) +