Skip to content

Commit

Permalink
Follow comments
Browse files Browse the repository at this point in the history
  • Loading branch information
reyoung committed Jan 22, 2018
1 parent 9f731a6 commit 87b424e
Showing 1 changed file with 9 additions and 18 deletions.
27 changes: 9 additions & 18 deletions python/paddle/v2/fluid/layers/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


def monkey_patch_variable():
def new_name():
def unique_tmp_name():
return unique_name("tmp")

def safe_get_dtype(var):
Expand All @@ -29,21 +29,9 @@ def safe_get_dtype(var):
raise ValueError("Cannot get data type from %s", var.name)
return dtype

def create_scalar(block, value, dtype):
value = float(value)
tmp_name = new_name()
var = block.create_var(name=tmp_name, shape=[1], dtype=dtype)
block.append_op(
type="fill",
outputs={"Out": [var]},
attrs={"value": [value],
"shape": [1],
"dtype": dtype})
return var

def create_tensor(block, value, dtype, shape):
value = float(value)
tmp_name = new_name()
tmp_name = unique_tmp_name()
var = block.create_var(name=tmp_name, shape=shape, dtype=dtype)
block.append_op(
type="fill_constant",
Expand All @@ -53,10 +41,13 @@ def create_tensor(block, value, dtype, shape):
'value': value})
return var

def create_scalar(block, value, dtype):
return create_tensor(block, value, dtype, shape=[1])

def create_tensor_with_batchsize(ref_var, value, dtype):
assert isinstance(ref_var, Variable)
value = float(value)
tmp_name = new_name()
tmp_name = unique_tmp_name()
var = ref_var.block.create_var(name=tmp_name, dtype=dtype)
ref_var.block.append_op(
type='fill_constant_batch_size_like',
Expand All @@ -68,7 +59,7 @@ def create_tensor_with_batchsize(ref_var, value, dtype):

def astype(self, dtype):
"""
Cast a variable to data type.
Cast a variable to a specified data type.
NOTE: The variable must be a Tensor
Args:
self(Variable): The source variable
Expand All @@ -77,7 +68,7 @@ def astype(self, dtype):
Returns:
Variable with new dtype
"""
tmp_name = new_name()
tmp_name = unique_tmp_name()
out = self.block.create_var(name=tmp_name, dtype=dtype)
self.block.append_op(
type="cast",
Expand Down Expand Up @@ -120,7 +111,7 @@ def __impl__(self, other_var):
self = other_var
other_var = tmp

tmp_name = new_name()
tmp_name = unique_tmp_name()
out = self.block.create_var(name=tmp_name, dtype=lhs_dtype)
self.block.append_op(
type=op_type,
Expand Down

0 comments on commit 87b424e

Please sign in to comment.