Skip to content

Commit

Permalink
Merge pull request #1780 from luotao1/in
Browse files Browse the repository at this point in the history
add field "prob" in paddle.infer
  • Loading branch information
luotao1 authored Apr 13, 2017
2 parents aa230bf + 4274883 commit b25c512
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
6 changes: 5 additions & 1 deletion paddle/py_paddle/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,17 @@ def __arguments_to_numpy__(i, arg):
assert isinstance(arg, swig_paddle.Arguments)
value = arg.getSlotValue(i)
ids = arg.getSlotIds(i)
prob = arg.getSlotIn(i)
if value is not None:
assert isinstance(value, swig_paddle.Matrix)
value = value.copyToNumpyMat()
if ids is not None:
assert isinstance(ids, swig_paddle.IVector)
ids = ids.copyToNumpyArray()
return {"value": value, "id": ids}
if prob is not None:
assert isinstance(prob, swig_paddle.Matrix)
prob = prob.copyToNumpyMat()
return {"value": value, "id": ids, "prob": prob}


def __monkeypatch_gradient_machine__():
Expand Down
8 changes: 5 additions & 3 deletions python/paddle/v2/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,11 @@ def infer(output_layer, parameters, input, feeding=None, field='value'):
:type input: collections.Iterable
:param feeding: Reader dictionary. Default could generate from input
value.
:param field: The prediction field. It should in [`value`, `ids`]. `value`
means return the prediction probabilities, `ids` means return
the prediction labels. Default is `value`
:param field: The prediction field. It should in [`value`, `id`, `prob`].
`value` and `prob` mean return the prediction probabilities,
`id` means return the prediction labels. Default is `value`.
Note that `prob` only used when output_layer is beam_search
or max_id.
:type field: str
:return: a numpy array
:rtype: numpy.ndarray
Expand Down

0 comments on commit b25c512

Please sign in to comment.