Skip to content

Commit

Permalink
Remove map() and filter() in favor of comprehensions
Browse files Browse the repository at this point in the history
Summary: These return views in Python 3 which would not do anything in a lot of usages currently present in Caffe2. This diff simply removes (almost) all usages of these two in Caffe2 and sub projects in favor of comprehensions which are also easier to read/understand

Reviewed By: akyrola

Differential Revision: D5142049

fbshipit-source-id: e800631d2df7d0823fed698cae46c486038007dc
  • Loading branch information
tomdz authored and facebook-github-bot committed May 30, 2017
1 parent 0deec5b commit 47e921b
Show file tree
Hide file tree
Showing 16 changed files with 133 additions and 93 deletions.
28 changes: 16 additions & 12 deletions caffe2/python/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,7 @@ def _MakeDenseSumOps(self, generators, out_base_name):

sum_ops = [CreateOperator(
"Sum",
map(BlobReference, sum_op_input),
[BlobReference(x) for x in sum_op_input],
BlobReference(out_base_name))]
return sum_ops, out_base_name

Expand Down Expand Up @@ -681,15 +681,16 @@ def _MakeSparseSumOps(self, generators, out_base_name):
sum_ops = [
CreateOperator(
"Concat",
map(BlobReference, indices_concat_input),
map(BlobReference,
[indices_concat_output, indices_concat_split]),
[BlobReference(x) for x in indices_concat_input],
[BlobReference(x) for x in
[indices_concat_output, indices_concat_split]],
axis=0
),
CreateOperator(
"Concat",
map(BlobReference, values_concat_input),
map(BlobReference, [values_concat_output, values_concat_split]),
[BlobReference(x) for x in values_concat_input],
[BlobReference(x) for x in
[values_concat_output, values_concat_split]],
axis=0
),
]
Expand Down Expand Up @@ -1164,8 +1165,9 @@ def current_prefix():

@staticmethod
def _get_next_net_name(basename):
name = basename = '/'.join(filter(
lambda x: x, (Net.current_prefix(), basename)))
name = basename = '/'.join(
x for x in [Net.current_prefix(), basename] if x
)
next_idx = 1
while name in Net._net_names_used:
name = basename + '_' + str(next_idx)
Expand Down Expand Up @@ -1651,11 +1653,11 @@ def AddScopedExternalOutputs(self, *outputs):

@property
def external_inputs(self):
return map(_get_blob_ref, self._net.external_input)
return [_get_blob_ref(x) for x in self._net.external_input]

@property
def external_outputs(self):
return map(_get_blob_ref, self._net.external_output)
return [_get_blob_ref(x) for x in self._net.external_output]

def set_input_record(self, input_record):
from caffe2.python import schema
Expand Down Expand Up @@ -2123,9 +2125,11 @@ def execution_step(default_name,
step.AddNet(steps_or_nets)
elif isinstance(steps_or_nets, list):
if all(isinstance(x, Net) for x in steps_or_nets):
map(step.AddNet, steps_or_nets)
for x in steps_or_nets:
step.AddNet(x)
else:
map(step.AddSubstep, map(to_execution_step, steps_or_nets))
for x in steps_or_nets:
step.AddSubstep(to_execution_step(x))
elif steps_or_nets:
raise ValueError(
'steps_or_nets must be a step, a net, or a list of nets or steps.')
Expand Down
2 changes: 1 addition & 1 deletion caffe2/python/layer_model_instantiator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def _filter_layers(layers, include_tags):
if include_tags is None:
return layers
include_tags = set(include_tags)
return filter(lambda l: not include_tags.isdisjoint(l.tags), layers)
return [l for l in layers if not include_tags.isdisjoint(l.tags)]


def shrink_output_schema(net, out_schema):
Expand Down
34 changes: 17 additions & 17 deletions caffe2/python/models/seq2seq/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def choose_state_per_hypo(state_config):
window=state_config.state_link.window,
)
)
state_configs = map(choose_state_per_hypo, state_configs)
state_configs = [choose_state_per_hypo(c) for c in state_configs]
initial_scores = self.model.param_init_net.ConstantFill(
[],
'initial_scores',
Expand Down Expand Up @@ -324,32 +324,32 @@ def choose_state_per_hypo(state_config):
link_internal, link_external, link_offset, link_window = (
zip(*forward_links)
)
all_outputs = map(
lambda s: str(s) + '_all',
[scores_t, tokens_t, hypo_t, attention_t],
)
all_outputs = [
str(s) + '_all'
for s in [scores_t, tokens_t, hypo_t, attention_t]
]
results = self.model.net.RecurrentNetwork(
all_inputs,
all_outputs + ['step_workspaces'],
param=map(all_inputs.index, self.step_model.params),
alias_src=map(
lambda s: str(s) + '_states',
[
param=[all_inputs.index(p) for p in self.step_model.params],
alias_src=[
str(s) + '_states'
for s in [
self.scores_t_prev,
self.tokens_t_prev,
self.hypo_t_prev,
self.attention_t_prev,
],
),
]
],
alias_dst=all_outputs,
alias_offset=[0] * 4,
recurrent_states=recurrent_states,
initial_recurrent_state_ids=map(
all_inputs.index,
[state_config.initial_value for state_config in state_configs],
),
link_internal=map(str, link_internal),
link_external=map(str, link_external),
initial_recurrent_state_ids=[
all_inputs.index(state_config.initial_value)
for state_config in state_configs
],
link_internal=[str(l) for l in link_internal],
link_external=[str(l) for l in link_external],
link_offset=link_offset,
link_window=link_window,
backward_link_internal=[],
Expand Down
2 changes: 1 addition & 1 deletion caffe2/python/models/seq2seq/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ def _init_model(self):
def create_net(net):
workspace.CreateNet(
net,
input_blobs=map(str, net.external_inputs),
input_blobs=[str(i) for i in net.external_inputs],
)

create_net(self.model.net)
Expand Down
10 changes: 5 additions & 5 deletions caffe2/python/models/seq2seq/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,11 +398,11 @@ def __init__(

workspace.CreateNet(
self.model.net,
input_blobs=map(str, [
self.encoder_inputs,
self.encoder_lengths,
self.max_output_seq_len,
]),
input_blobs=[
str(self.encoder_inputs),
str(self.encoder_lengths),
str(self.max_output_seq_len),
],
)

logger.info('Params created: ')
Expand Down
5 changes: 3 additions & 2 deletions caffe2/python/net_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ def __init__(self, name=None, _stop_blob_required=False,
_stop_blob=None, _fullname=None):
nb = NetBuilder.current(required=False)
assert not _fullname or not name, 'Cannot set both _fullname and name'
self.name = _fullname or '/'.join(filter(lambda x: x, (
nb.name if nb else None, name)))
self.name = _fullname or '/'.join(
n for n in (nb.name if nb else None, name) if n
)
self._frozen = False
self._current_net = None
self._children = []
Expand Down
5 changes: 4 additions & 1 deletion caffe2/python/net_builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,10 @@ def _actual_loop(self):
with c.Else():
ops.Add([total_tiny, val], [total_tiny])
ops.Add([total, val], total)
return map(final_output, (total, total_large, total_small, total_tiny))
return [
final_output(x)
for x in [total, total_large, total_small, total_tiny]
]

def test_loops(self):
with Task() as task:
Expand Down
31 changes: 20 additions & 11 deletions caffe2/python/net_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from collections import defaultdict
from contextlib import contextmanager
from copy import copy
from itertools import chain


class Visitor(object):
Expand Down Expand Up @@ -72,13 +73,16 @@ def need_blob(self, blob):

@Analyzer.register(OperatorDef)
def analyze_op(analyzer, op):
map(analyzer.need_blob, op.input)
map(analyzer.define_blob, op.output)
for x in op.input:
analyzer.need_blob(x)
for x in op.output:
analyzer.define_blob(x)


@Analyzer.register(Net)
def analyze_net(analyzer, net):
map(analyzer, net.Proto().op)
for x in net.Proto().op:
analyzer(x)


@Analyzer.register(ExecutionStep)
Expand All @@ -100,7 +104,8 @@ def analyze_step(analyzer, step):
'Error: Blobs created by multiple parallel steps: %s' % (
', '.join(all_new_blobs & new_blobs)))
all_new_blobs |= new_blobs
map(analyzer.define_blob, all_new_blobs)
for x in all_new_blobs:
analyzer.define_blob(x)


@Analyzer.register(Task)
Expand Down Expand Up @@ -209,7 +214,7 @@ def commonprefix(m):


def factor_prefix(vals, do_it):
vals = map(str, vals)
vals = [str(v) for v in vals]
prefix = commonprefix(vals) if len(vals) > 1 and do_it else ''
joined = ', '.join(v[len(prefix):] for v in vals)
return '%s[%s]' % (prefix, joined) if prefix else joined
Expand All @@ -221,10 +226,14 @@ def call(op, inputs=None, outputs=None, factor_prefixes=False):
else:
inputs_v = [a for a in inputs if not isinstance(a, tuple)]
inputs_kv = [a for a in inputs if isinstance(a, tuple)]
inputs = ', '.join(filter(
bool,
[factor_prefix(inputs_v, factor_prefixes)] +
['%s=%s' % kv for kv in inputs_kv]))
inputs = ', '.join(
x
for x in chain(
[factor_prefix(inputs_v, factor_prefixes)],
('%s=%s' % kv for kv in inputs_kv),
)
if x
)
call = '%s(%s)' % (op, inputs)
return call if not outputs else '%s = %s' % (
factor_prefix(outputs, factor_prefixes), call)
Expand Down Expand Up @@ -287,12 +296,12 @@ def print_step(text, step):

def _print_task_output(x):
assert isinstance(x, TaskOutput)
return 'Output[' + ', '.join(map(str, x.names)) + ']'
return 'Output[' + ', '.join(str(x) for x in x.names) + ']'


@Printer.register(Task)
def print_task(text, task):
outs = ', '.join(map(_print_task_output, task.outputs()))
outs = ', '.join(_print_task_output(o) for o in task.outputs())
context = [('node', task.node), ('name', task.name), ('outputs', outs)]
with text.context(call('Task', context)):
text(task.get_step())
Expand Down
3 changes: 2 additions & 1 deletion caffe2/python/net_printer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,9 @@ def test_undefined_blob(self):
with job:
with Task():
ops.Add(['a', 'b'])
with self.assertRaises(AssertionError):
with self.assertRaises(AssertionError) as e:
net_printer.analyze(job)
self.assertEqual("Blob undefined: a", str(e.exception))

def test_multiple_definition(self):
job = example_job()
Expand Down
18 changes: 9 additions & 9 deletions caffe2/python/operator_test/recurrent_network_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,19 +286,19 @@ def test_stateful_convolution_forward_only(
input_state_all, output_state_all, _ = model.net.RecurrentNetwork(
all_inputs,
all_outputs + ['step_workspaces'],
param=map(all_inputs.index, step_model.params),
param=[all_inputs.index(p) for p in step_model.params],
alias_src=recurrent_states,
alias_dst=all_outputs,
alias_offset=[conv_window - 1, 1],
recurrent_states=recurrent_states,
initial_recurrent_state_ids=map(
all_inputs.index,
initial_recurrent_states,
),
link_internal=map(
str,
[input_state_t_prev, input_state_t, output_state_t],
),
initial_recurrent_state_ids=[
all_inputs.index(s) for s in initial_recurrent_states
],
link_internal=[
str(input_state_t_prev),
str(input_state_t),
str(output_state_t),
],
link_external=['input_state', 'input_state', 'output_state'],
link_offset=[0, conv_window - 1, 1],
link_window=[conv_window, 1, 1],
Expand Down
28 changes: 18 additions & 10 deletions caffe2/python/operator_test/string_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@ def test_string_prefix(self, strings):
# an invalid utf-8 string. The goal here is just to avoid python
# complaining about the unicode -> str conversion.
strings = np.array(
map(lambda a: a.encode('utf-8'), strings), dtype=np.object)
[a.encode('utf-8') for a in strings], dtype=np.object
)

def string_prefix_ref(strings):
return (
np.array(map(lambda a: a[:length], strings), dtype=object), )
np.array([a[:length] for a in strings], dtype=object),
)

op = core.CreateOperator(
'StringPrefix',
Expand All @@ -46,11 +48,13 @@ def string_prefix_ref(strings):
def test_string_suffix(self, strings):
length = 3
strings = np.array(
map(lambda a: a.encode('utf-8'), strings), dtype=np.object)
[a.encode('utf-8') for a in strings], dtype=np.object
)

def string_suffix_ref(strings):
return (
np.array(map(lambda a: a[-length:], strings), dtype=object), )
np.array([a[-length:] for a in strings], dtype=object),
)

op = core.CreateOperator(
'StringSuffix',
Expand All @@ -67,11 +71,13 @@ def string_suffix_ref(strings):
def test_string_starts_with(self, strings):
prefix = 'a'
strings = np.array(
map(lambda a: str(strings), strings), dtype=np.object)
[str(a) for a in strings], dtype=np.object
)

def string_starts_with_ref(strings):
return (np.array(
map(lambda a: a.startswith(prefix), strings), dtype=bool), )
return (
np.array([a.startswith(prefix) for a in strings], dtype=bool),
)

op = core.CreateOperator(
'StringStartsWith',
Expand All @@ -88,11 +94,13 @@ def string_starts_with_ref(strings):
def test_string_ends_with(self, strings):
suffix = 'a'
strings = np.array(
map(lambda a: str(strings), strings), dtype=np.object)
[str(a) for a in strings], dtype=np.object
)

def string_ends_with_ref(strings):
return (np.array(
map(lambda a: a.endswith(suffix), strings), dtype=bool), )
return (
np.array([a.endswith(suffix) for a in strings], dtype=bool),
)

op = core.CreateOperator(
'StringEndsWith',
Expand Down
6 changes: 5 additions & 1 deletion caffe2/python/operator_test/text_file_reader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ def test_text_file_reader(self):
row_data = zip(*col_data)
txt_file = tempfile.NamedTemporaryFile(delete=False)
txt_file.write(
'\n'.join(['\t'.join(map(str, f)) for f in row_data]) + '\n')
'\n'.join(
'\t'.join(str(x) for x in f)
for f in row_data
) + '\n'
)
txt_file.close()

for num_passes in range(1, 3):
Expand Down
Loading

0 comments on commit 47e921b

Please sign in to comment.