Skip to content

Commit

Permalink
remove commented code
Browse files Browse the repository at this point in the history
  • Loading branch information
aerubanov committed Nov 30, 2023
1 parent d4c88df commit b8fde47
Showing 1 changed file with 0 additions and 159 deletions.
159 changes: 0 additions & 159 deletions pymc/logprob/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,165 +294,6 @@ def transform(self, node):
return None


# def remove_check_parameter_from_scan(node):
# node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs
# local_fgraph_topo = io_toposort(node_inputs, node_outputs)
# op = node.op
# givens = {}
# to_remove_set = set()
# for nd in local_fgraph_topo:
# if nd not in to_remove_set:
# if isinstance(nd.op, CheckParameterValue):
# givens[nd.outputs[0]] = nd.inputs[0]
# to_remove_set.add(nd)
# elif isinstance(nd.op, Scan):
# new_node = remove_check_parameter_from_scan(nd)
# if new_node is not None:
# givens.update(zip(nd.outputs, new_node.owner.outputs))
# to_remove_set.add(nd)
# if len(to_remove_set) == 0:
# return None
# op_outs = clone_replace(node_outputs, replace=givens)

# nwScan = Scan(
# node_inputs,
# op_outs,
# op.info,
# mode=op.mode,
# profile=op.profile,
# truncate_gradient=op.truncate_gradient,
# name=op.name,
# allow_gc=op.allow_gc,
# )
# nw_node = nwScan(*(node.inputs))
# return nw_node


# def replace_check_parameter_by_ninf_in_scan(node):
# node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs
# local_fgraph_topo = io_toposort(node_inputs, node_outputs)
# op = node.op
# givens = {}
# to_remove_set = set()
# for nd in local_fgraph_topo:
# if nd not in to_remove_set:
# if isinstance(nd.op, CheckParameterValue):
# if nd.op.can_be_replaced_by_ninf:
# logp_expr, *logp_conds = nd.inputs
# if len(logp_conds) > 1:
# logp_cond = pt.all(logp_conds)
# else:
# (logp_cond,) = logp_conds
# new_node = pt.switch(logp_cond, logp_expr, -np.inf)

# if new_node.dtype != nd.outputs[0].dtype:
# new_node = pt.cast(new_node, nd.outputs[0].dtype)
# givens.update(zip(nd.outputs, new_node.owner.outputs))
# to_remove_set.add(nd)
# elif isinstance(nd.op, Scan):
# new_node = replace_check_parameter_by_ninf_in_scan(nd)
# if new_node is not None:
# givens.update(zip(nd.outputs, new_node.owner.outputs))
# to_remove_set.add(nd)
# if len(to_remove_set) == 0:
# return None
# op_outs = clone_replace(node_outputs, replace=givens)

# nwScan = Scan(
# node_inputs,
# op_outs,
# op.info,
# mode=op.mode,
# profile=op.profile,
# truncate_gradient=op.truncate_gradient,
# name=op.name,
# allow_gc=op.allow_gc,
# )
# nw_node = nwScan(*(node.inputs))
# return nw_node


# def remove_check_parameter_op_from_graph(node):
# node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs
# op = node.op

# local_fgraph_topo = io_toposort(node_inputs, node_outputs)
# op = node.op
# givens = {}
# to_remove_set = set()
# for nd in local_fgraph_topo:
# if nd not in to_remove_set:
# if isinstance(nd.op, CheckParameterValue):
# givens[nd.outputs[0]] = nd.inputs[0]
# to_remove_set.add(nd)
# elif isinstance(nd.op, OpFromGraph):
# new_node = remove_check_parameter_op_from_graph(nd)
# if new_node is not None:
# givens.update(zip(nd.outputs, new_node.owner.outputs))
# to_remove_set.add(nd)
# if len(to_remove_set) == 0:
# return None
# op_outs = clone_replace(node_outputs, replace=givens)

# nwOpFromGraph = OpFromGraph(
# node_inputs,
# op_outs,
# op.is_inline,
# op.lop_overrides,
# op.grad_overrides,
# op.rop_overrides,
# connection_pattern=op._connection_pattern,
# name=op.name,
# **op.kwargs,
# )
# nw_node = nwOpFromGraph(*(node.inputs))
# return nw_node


# def replace_check_parameters_by_ninf_in_op_from_graph(node):
# node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs
# op = node.op
# givens = {}
# to_remove_set = set()
# for nd in io_toposort(node_inputs, node_outputs):
# if nd not in to_remove_set:
# if isinstance(nd.op, CheckParameterValue):
# if nd.op.can_be_replaced_by_ninf:
# logp_expr, *logp_conds = nd.inputs
# if len(logp_conds) > 1:
# logp_cond = pt.all(logp_conds)
# else:
# (logp_cond,) = logp_conds
# new_node = pt.switch(logp_cond, logp_expr, -np.inf)

# if new_node.dtype != nd.outputs[0].dtype:
# new_node = pt.cast(new_node, nd.outputs[0].dtype)
# givens.update(zip(nd.outputs, new_node.owner.outputs))
# to_remove_set.add(nd)
# elif isinstance(nd.op, OpFromGraph):
# new_node = replace_check_parameters_by_ninf_in_op_from_graph(nd)
# if new_node is not None:
# givens.update(zip(nd.outputs, new_node.owner.outputs))
# to_remove_set.add(nd)
# if len(to_remove_set) == 0:
# return None
# op_outs = clone_replace(node_outputs, replace=givens)

# nwOpFromGraph = OpFromGraph(
# node_inputs,
# op_outs,
# op.is_inline,
# op.lop_overrides,
# op.grad_overrides,
# op.rop_overrides,
# connection_pattern=op._connection_pattern,
# name=op.name,
# **op.kwargs,
# )
# nw_node = nwOpFromGraph(*(node.inputs))
# return nw_node


@node_rewriter(tracks=[CheckParameterValue, Scan, OpFromGraph])
def local_remove_check_parameter(fgraph, node):
"""Rewrite that removes CheckParameterValue
Expand Down

0 comments on commit b8fde47

Please sign in to comment.