From 21d2d6faf11c1b1d34bc3f85b1cc286967f23afc Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Thu, 22 Nov 2018 22:48:57 +0100 Subject: [PATCH] Support for multi source in Transformer decoder (#272) --- CHANGELOG.md | 1 + config/models/multi_source_transformer.py | 28 +++++++ opennmt/decoders/decoder.py | 9 ++- opennmt/decoders/self_attention_decoder.py | 85 +++++++++++++--------- opennmt/encoders/encoder.py | 10 ++- opennmt/inputters/inputter.py | 22 ++++-- opennmt/inputters/text_inputter.py | 6 +- opennmt/models/transformer.py | 35 ++++++--- opennmt/tests/decoder_test.py | 80 +++++++++++++------- opennmt/tests/inputter_test.py | 2 + 10 files changed, 196 insertions(+), 82 deletions(-) create mode 100644 config/models/multi_source_transformer.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 229cbd2e9..129e3c5cd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ OpenNMT-tf follows [semantic versioning 2.0.0](https://semver.org/). The API cov ### New features +* Multi source Transformer architecture with serial attention layers (see the [example model configuration](https://github.com/OpenNMT/OpenNMT-tf/blob/master/config/models/multi_source_transformer.py)) * Inference now accepts the parameter `bucket_width`: if set, the data will be sorted by length to increase the translation efficiency. The predictions will still be outputted in order as they are available. (Enabled by default when using automatic configuration.) ### Fixes and improvements diff --git a/config/models/multi_source_transformer.py b/config/models/multi_source_transformer.py new file mode 100644 index 000000000..279dae889 --- /dev/null +++ b/config/models/multi_source_transformer.py @@ -0,0 +1,28 @@ +"""Defines a dual source Transformer architecture with serial attention layers +and parameter sharing between the encoders. + +See for example https://arxiv.org/pdf/1809.00188.pdf. +""" + +import opennmt as onmt + +def model(): + return onmt.models.SequenceToSequence( + source_inputter=onmt.inputters.ParallelInputter([ + onmt.inputters.WordEmbedder( + vocabulary_file_key="source_vocabulary_1", + embedding_size=512), + onmt.inputters.WordEmbedder( + vocabulary_file_key="source_vocabulary_2", + embedding_size=512)]), + target_inputter=onmt.inputters.WordEmbedder( + vocabulary_file_key="target_vocabulary", + embedding_size=512), + num_layers=6, + num_units=512, + num_heads=8, + ffn_inner_dim=2048, + dropout=0.1, + attention_dropout=0.1, + relu_dropout=0.1, + share_encoders=True) diff --git a/opennmt/decoders/decoder.py b/opennmt/decoders/decoder.py index 28bc71f77..dc556eacf 100644 --- a/opennmt/decoders/decoder.py +++ b/opennmt/decoders/decoder.py @@ -159,6 +159,13 @@ def decode(self, _ = embedding if sampling_probability is not None: raise ValueError("Scheduled sampling is not supported by this decoder") + if (not self.support_multi_source + and memory is not None + and tf.contrib.framework.nest.is_sequence(memory)): + raise ValueError("Multiple source encodings are passed to this decoder " + "but it does not support multi source context. You should " + "instead configure your encoder to merge the different " + "encodings.") returned_values = self.decode_from_inputs( inputs, @@ -329,7 +336,7 @@ def dynamic_decode_and_search(self, output_layer = build_output_layer(self.output_size, vocab_size, dtype=dtype) state = {"decoder": initial_state} - if self.support_alignment_history and not isinstance(memory, (tuple, list)): + if self.support_alignment_history and not tf.contrib.framework.nest.is_sequence(memory): state["attention"] = tf.zeros([batch_size, 0, tf.shape(memory)[1]], dtype=dtype) def _symbols_to_logits_fn(ids, step, state): diff --git a/opennmt/decoders/self_attention_decoder.py b/opennmt/decoders/self_attention_decoder.py index 97a22353d..eb9348737 100644 --- a/opennmt/decoders/self_attention_decoder.py +++ b/opennmt/decoders/self_attention_decoder.py @@ -66,15 +66,21 @@ def output_size(self): def support_alignment_history(self): return True - def _init_cache(self, batch_size, dtype=tf.float32): + @property + def support_multi_source(self): + return True + + def _init_cache(self, batch_size, dtype=tf.float32, num_sources=1): cache = {} for l in range(self.num_layers): proj_cache_shape = [batch_size, self.num_heads, 0, self.num_units // self.num_heads] - layer_cache = { - "memory_keys": tf.zeros(proj_cache_shape, dtype=dtype), - "memory_values": tf.zeros(proj_cache_shape, dtype=dtype), - } + layer_cache = {} + layer_cache["memory"] = [ + { + "memory_keys": tf.zeros(proj_cache_shape, dtype=dtype), + "memory_values": tf.zeros(proj_cache_shape, dtype=dtype) + } for _ in range(num_sources)] if self.self_attention_type == "scaled_dot": layer_cache["self_keys"] = tf.zeros(proj_cache_shape, dtype=dtype) layer_cache["self_values"] = tf.zeros(proj_cache_shape, dtype=dtype) @@ -121,11 +127,15 @@ def _self_attention_stack(self, decoder_mask = transformer.cumulative_average_mask( sequence_length, maximum_length=tf.shape(inputs)[1], dtype=inputs.dtype) - if memory is not None and memory_sequence_length is not None: - memory_mask = transformer.build_sequence_mask( - memory_sequence_length, - num_heads=self.num_heads, - maximum_length=tf.shape(memory)[1]) + if memory is not None and not tf.contrib.framework.nest.is_sequence(memory): + memory = (memory,) + if memory_sequence_length is not None: + if not tf.contrib.framework.nest.is_sequence(memory_sequence_length): + memory_sequence_length = (memory_sequence_length,) + memory_mask = [ + transformer.build_sequence_mask( + length, num_heads=self.num_heads, maximum_length=tf.shape(m)[1]) + for m, length in zip(memory, memory_sequence_length)] for l in range(self.num_layers): layer_name = "layer_{}".format(l) @@ -142,7 +152,7 @@ def _self_attention_stack(self, mask=decoder_mask, cache=layer_cache, dropout=self.attention_dropout) - encoded = transformer.drop_and_add( + last_context = transformer.drop_and_add( inputs, encoded, mode, @@ -160,36 +170,38 @@ def _self_attention_stack(self, z = tf.layers.dense(tf.concat([x, y], -1), self.num_units * 2) i, f = tf.split(z, 2, axis=-1) y = tf.sigmoid(i) * x + tf.sigmoid(f) * y - encoded = transformer.drop_and_add( + last_context = transformer.drop_and_add( inputs, y, mode, dropout=self.dropout) if memory is not None: - with tf.variable_scope("multi_head"): - context, last_attention = transformer.multi_head_attention( - self.num_heads, - transformer.norm(encoded), - memory, - mode, - mask=memory_mask, - cache=layer_cache, - dropout=self.attention_dropout, - return_attention=True) - context = transformer.drop_and_add( - encoded, - context, - mode, - dropout=self.dropout) - else: - context = encoded + for i, (mem, mask) in enumerate(zip(memory, memory_mask)): + memory_cache = layer_cache["memory"][i] if layer_cache is not None else None + with tf.variable_scope("multi_head" if i == 0 else "multi_head_%d" % i): + context, last_attention = transformer.multi_head_attention( + self.num_heads, + transformer.norm(last_context), + mem, + mode, + mask=mask, + cache=memory_cache, + dropout=self.attention_dropout, + return_attention=True) + last_context = transformer.drop_and_add( + last_context, + context, + mode, + dropout=self.dropout) + if i > 0: # Do not return attention in case of multi source. + last_attention = None with tf.variable_scope("ffn"): transformed = transformer.feed_forward( - transformer.norm(context), + transformer.norm(last_context), self.ffn_inner_dim, mode, dropout=self.relu_dropout) transformed = transformer.drop_and_add( - context, + last_context, transformed, mode, dropout=self.dropout) @@ -227,7 +239,13 @@ def step_fn(self, memory=None, memory_sequence_length=None, dtype=tf.float32): - cache = self._init_cache(batch_size, dtype=dtype) + if memory is None: + num_sources = 0 + elif tf.contrib.framework.nest.is_sequence(memory): + num_sources = len(memory) + else: + num_sources = 1 + cache = self._init_cache(batch_size, dtype=dtype, num_sources=num_sources) def _fn(step, inputs, cache, mode): inputs = tf.expand_dims(inputs, 1) outputs, attention = self._self_attention_stack( @@ -238,6 +256,7 @@ def _fn(step, inputs, cache, mode): memory_sequence_length=memory_sequence_length, step=step) outputs = tf.squeeze(outputs, axis=1) - attention = tf.squeeze(attention, axis=1) + if attention is not None: + attention = tf.squeeze(attention, axis=1) return outputs, cache, attention return _fn, cache diff --git a/opennmt/encoders/encoder.py b/opennmt/encoders/encoder.py index 5f2a8ed69..547357d8a 100644 --- a/opennmt/encoders/encoder.py +++ b/opennmt/encoders/encoder.py @@ -100,7 +100,8 @@ def __init__(self, outputs_reducer=ConcatReducer(axis=1), states_reducer=JoinReducer(), outputs_layer_fn=None, - combined_output_layer_fn=None): + combined_output_layer_fn=None, + share_parameters=False): """Initializes the parameters of the encoder. Args: @@ -117,6 +118,8 @@ def __init__(self, output. combined_output_layer_fn: A callable to apply on the combined output (i.e. the output of :obj:`outputs_reducer`). + share_parameters: If ``True``, share parameters between the parallel + encoders. Raises: ValueError: if :obj:`outputs_layer_fn` is a list with a size not equal @@ -132,6 +135,7 @@ def __init__(self, self.states_reducer = states_reducer if states_reducer is not None else JoinReducer() self.outputs_layer_fn = outputs_layer_fn self.combined_output_layer_fn = combined_output_layer_fn + self.share_parameters = share_parameters def encode(self, inputs, sequence_length=None, mode=tf.estimator.ModeKeys.TRAIN): all_outputs = [] @@ -142,7 +146,9 @@ def encode(self, inputs, sequence_length=None, mode=tf.estimator.ModeKeys.TRAIN) raise ValueError("ParallelEncoder expects as many inputs as parallel encoders") for i, encoder in enumerate(self.encoders): - with tf.variable_scope("encoder_{}".format(i)): + scope_name = "encoder_{}".format(i) if not self.share_parameters else "parallel_encoder" + reuse = self.share_parameters and i > 0 + with tf.variable_scope(scope_name, reuse=reuse): if tf.contrib.framework.nest.is_sequence(inputs): encoder_inputs = inputs[i] length = sequence_length[i] diff --git a/opennmt/inputters/inputter.py b/opennmt/inputters/inputter.py index 201228549..871ecbed9 100644 --- a/opennmt/inputters/inputter.py +++ b/opennmt/inputters/inputter.py @@ -5,7 +5,7 @@ import tensorflow as tf -from opennmt.layers.reducer import ConcatReducer +from opennmt.layers.reducer import ConcatReducer, JoinReducer from opennmt.utils.misc import extract_prefixed_keys @@ -18,6 +18,11 @@ def __init__(self, dtype=tf.float32): self.process_hooks = [] self.dtype = dtype + @property + def num_outputs(self): + """How many parallel outputs does this inputter produce.""" + return 1 + def add_process_hooks(self, hooks): """Adds processing hooks. @@ -229,7 +234,7 @@ def transform(self, inputs, mode): class MultiInputter(Inputter): """An inputter that gathers multiple inputters.""" - def __init__(self, inputters): + def __init__(self, inputters, reducer=None): if not isinstance(inputters, list) or not inputters: raise ValueError("inputters must be a non empty list") dtype = inputters[0].dtype @@ -238,6 +243,13 @@ def __init__(self, inputters): raise TypeError("All inputters must have the same dtype") super(MultiInputter, self).__init__(dtype=dtype) self.inputters = inputters + self.reducer = reducer + + @property + def num_outputs(self): + if self.reducer is None or isinstance(self.reducer, JoinReducer): + return len(self.inputters) + return 1 @abc.abstractmethod def make_dataset(self, data_file): @@ -282,8 +294,7 @@ def __init__(self, inputters, reducer=None): reducer: A :class:`opennmt.layers.reducer.Reducer` to merge all inputs. If set, parallel inputs are assumed to have the same length. """ - super(ParallelInputter, self).__init__(inputters) - self.reducer = reducer + super(ParallelInputter, self).__init__(inputters, reducer=reducer) def get_length(self, data): lengths = [] @@ -370,8 +381,7 @@ def __init__(self, reducer: A :class:`opennmt.layers.reducer.Reducer` to merge all inputs. dropout: The probability to drop units in the merged inputs. """ - super(MixedInputter, self).__init__(inputters) - self.reducer = reducer + super(MixedInputter, self).__init__(inputters, reducer=reducer) self.dropout = dropout def get_length(self, data): diff --git a/opennmt/inputters/text_inputter.py b/opennmt/inputters/text_inputter.py index e1ef02ca4..582c605e9 100644 --- a/opennmt/inputters/text_inputter.py +++ b/opennmt/inputters/text_inputter.py @@ -505,7 +505,7 @@ def __init__(self, dropout=dropout, tokenizer=tokenizer, dtype=dtype) - self.num_outputs = num_outputs + self.output_size = num_outputs self.kernel_size = kernel_size self.stride = stride self.num_oov_buckets = 1 @@ -522,7 +522,7 @@ def transform(self, inputs, mode): outputs = tf.layers.conv1d( outputs, - self.num_outputs, + self.output_size, self.kernel_size, strides=self.stride) @@ -530,7 +530,7 @@ def transform(self, inputs, mode): outputs = tf.reduce_max(outputs, axis=1) # Split batch and sequence timesteps dimensions. - outputs = tf.reshape(outputs, [-1, tf.shape(inputs)[1], self.num_outputs]) + outputs = tf.reshape(outputs, [-1, tf.shape(inputs)[1], self.output_size]) return outputs diff --git a/opennmt/models/transformer.py b/opennmt/models/transformer.py index 8af0f97b4..3ee5ee037 100644 --- a/opennmt/models/transformer.py +++ b/opennmt/models/transformer.py @@ -3,6 +3,7 @@ import tensorflow as tf from opennmt.models.sequence_to_sequence import SequenceToSequence, EmbeddingsSharingLevel +from opennmt.encoders.encoder import ParallelEncoder from opennmt.encoders.self_attention_encoder import SelfAttentionEncoder from opennmt.decoders.self_attention_decoder import SelfAttentionDecoder from opennmt.layers.position import SinusoidalPositionEncoder @@ -27,13 +28,15 @@ def __init__(self, position_encoder=SinusoidalPositionEncoder(), decoder_self_attention_type="scaled_dot", share_embeddings=EmbeddingsSharingLevel.NONE, + share_encoders=False, alignment_file_key="train_alignments", name="transformer"): """Initializes a Transformer model. Args: source_inputter: A :class:`opennmt.inputters.inputter.Inputter` to process - the source data. + the source data. If this inputter returns parallel inputs, a multi + source Transformer architecture will be constructed. target_inputter: A :class:`opennmt.inputters.inputter.Inputter` to process the target data. Currently, only the :class:`opennmt.inputters.text_inputter.WordEmbedder` is supported. @@ -52,19 +55,31 @@ def __init__(self, share_embeddings: Level of embeddings sharing, see :class:`opennmt.models.sequence_to_sequence.EmbeddingsSharingLevel` for possible values. + share_encoders: In case of multi source architecture, whether to share the + separate encoders parameters or not. alignment_file_key: The data configuration key of the training alignment file to support guided alignment. name: The name of this model. """ - encoder = SelfAttentionEncoder( - num_layers, - num_units=num_units, - num_heads=num_heads, - ffn_inner_dim=ffn_inner_dim, - dropout=dropout, - attention_dropout=attention_dropout, - relu_dropout=relu_dropout, - position_encoder=position_encoder) + encoders = [ + SelfAttentionEncoder( + num_layers, + num_units=num_units, + num_heads=num_heads, + ffn_inner_dim=ffn_inner_dim, + dropout=dropout, + attention_dropout=attention_dropout, + relu_dropout=relu_dropout, + position_encoder=position_encoder) + for _ in range(source_inputter.num_outputs)] + if len(encoders) > 1: + encoder = ParallelEncoder( + encoders, + outputs_reducer=None, + states_reducer=None, + share_parameters=share_encoders) + else: + encoder = encoders[0] decoder = SelfAttentionDecoder( num_layers, num_units=num_units, diff --git a/opennmt/tests/decoder_test.py b/opennmt/tests/decoder_test.py index ac915ad2e..4ac7b68bb 100644 --- a/opennmt/tests/decoder_test.py +++ b/opennmt/tests/decoder_test.py @@ -10,6 +10,29 @@ from opennmt.layers import bridge +def _generate_source_context(batch_size, + depth, + initial_state_fn=None, + num_sources=1, + dtype=tf.float32): + memory_sequence_length = [ + np.random.randint(1, high=20, size=batch_size) for _ in range(num_sources)] + memory_time = [np.amax(length) for length in memory_sequence_length] + memory = [ + tf.placeholder_with_default( + np.random.randn(batch_size, time, depth).astype(dtype.as_numpy_dtype()), + shape=(None, None, depth)) + for time in memory_time] + if initial_state_fn is not None: + initial_state = initial_state_fn(tf.shape(memory[0])[0], dtype) + else: + initial_state = None + if num_sources == 1: + memory_sequence_length = memory_sequence_length[0] + memory = memory[0] + return initial_state, memory, memory_sequence_length + + class DecoderTest(tf.test.TestCase): def testSamplingProbability(self): @@ -45,7 +68,7 @@ def testSamplingProbability(self): self.assertAlmostEqual( 1.0 - (1.0 / (1.0 + math.exp(5.0 / 1.0))), sess.run(inv_sig_sample_prob)) - def _testDecoderTraining(self, decoder, initial_state_fn=None, dtype=tf.float32): + def _testDecoderTraining(self, decoder, initial_state_fn=None, num_sources=1, dtype=tf.float32): batch_size = 4 vocab_size = 10 time_dim = 5 @@ -55,15 +78,12 @@ def _testDecoderTraining(self, decoder, initial_state_fn=None, dtype=tf.float32) shape=(None, None, depth)) # NOTE: max(sequence_length) may be less than time_dim when num_gpus > 1 sequence_length = [1, 3, 4, 2] - memory_sequence_length = [3, 7, 5, 4] - memory_time = max(memory_sequence_length) - memory = tf.placeholder_with_default( - np.random.randn(batch_size, memory_time, depth).astype(dtype.as_numpy_dtype()), - shape=(None, None, depth)) - if initial_state_fn is not None: - initial_state = initial_state_fn(tf.shape(memory)[0], dtype) - else: - initial_state = None + initial_state, memory, memory_sequence_length = _generate_source_context( + batch_size, + depth, + initial_state_fn=initial_state_fn, + num_sources=num_sources, + dtype=dtype) outputs, _, _, attention = decoder.decode( inputs, sequence_length, @@ -74,7 +94,7 @@ def _testDecoderTraining(self, decoder, initial_state_fn=None, dtype=tf.float32) return_alignment_history=True) self.assertEqual(outputs.dtype, dtype) output_time_dim = tf.shape(outputs)[1] - if decoder.support_alignment_history: + if decoder.support_alignment_history and num_sources == 1: self.assertIsNotNone(attention) else: self.assertIsNone(attention) @@ -84,14 +104,15 @@ def _testDecoderTraining(self, decoder, initial_state_fn=None, dtype=tf.float32) sess.run(tf.global_variables_initializer()) output_time_dim_val = sess.run(output_time_dim) self.assertEqual(time_dim, output_time_dim_val) - if decoder.support_alignment_history: - attention_val = sess.run(attention) + if decoder.support_alignment_history and num_sources == 1: + attention_val, memory_time = sess.run([attention, tf.shape(memory)[1]]) self.assertAllEqual([batch_size, time_dim, memory_time], attention_val.shape) return saver.save(sess, os.path.join(self.get_temp_dir(), "model.ckpt")) def _testDecoderInference(self, decoder, initial_state_fn=None, + num_sources=1, with_beam_search=False, with_alignment_history=False, dtype=tf.float32, @@ -103,19 +124,15 @@ def _testDecoderInference(self, depth = 6 end_token = 2 start_tokens = tf.placeholder_with_default([1] * batch_size, shape=[None]) - memory_sequence_length = [3, 7, 5, 4] - memory_time = max(memory_sequence_length) - memory = tf.placeholder_with_default( - np.random.randn(batch_size, memory_time, depth).astype(dtype.as_numpy_dtype()), - shape=(None, None, depth)) - memory_sequence_length = tf.placeholder_with_default(memory_sequence_length, shape=[None]) embedding = tf.placeholder_with_default( np.random.randn(vocab_size, depth).astype(dtype.as_numpy_dtype()), shape=(vocab_size, depth)) - if initial_state_fn is not None: - initial_state = initial_state_fn(tf.shape(memory)[0], dtype) - else: - initial_state = None + initial_state, memory, memory_sequence_length = _generate_source_context( + batch_size, + depth, + initial_state_fn=initial_state_fn, + num_sources=num_sources, + dtype=dtype) if with_beam_search: decode_fn = decoder.dynamic_decode_and_search @@ -145,7 +162,6 @@ def _testDecoderInference(self, log_probs = outputs[3] self.assertEqual(log_probs.dtype, tf.float32) - decode_time = tf.shape(ids)[-1] saver = tf.train.Saver(var_list=tf.global_variables()) with self.test_session(graph=tf.get_default_graph()) as sess: @@ -159,9 +175,10 @@ def _testDecoderInference(self, else: self.assertEqual(5, len(outputs)) alignment_history = outputs[4] - if decoder.support_alignment_history: + if decoder.support_alignment_history and num_sources == 1: self.assertIsInstance(alignment_history, tf.Tensor) - alignment_history, decode_time = sess.run([alignment_history, decode_time]) + alignment_history, decode_time, memory_time = sess.run( + [alignment_history, tf.shape(ids)[-1], tf.shape(memory)[1]]) self.assertAllEqual( [batch_size, num_hyps, decode_time, memory_time], alignment_history.shape) else: @@ -172,17 +189,19 @@ def _testDecoderInference(self, self.assertAllEqual([batch_size, num_hyps], lengths.shape) self.assertAllEqual([batch_size, num_hyps], log_probs.shape) - def _testDecoder(self, decoder, initial_state_fn=None, dtype=tf.float32): + def _testDecoder(self, decoder, initial_state_fn=None, num_sources=1, dtype=tf.float32): with tf.Graph().as_default() as g: checkpoint_path = self._testDecoderTraining( decoder, initial_state_fn=initial_state_fn, + num_sources=num_sources, dtype=dtype) with tf.Graph().as_default() as g: self._testDecoderInference( decoder, initial_state_fn=initial_state_fn, + num_sources=num_sources, with_beam_search=False, with_alignment_history=False, dtype=dtype, @@ -191,6 +210,7 @@ def _testDecoder(self, decoder, initial_state_fn=None, dtype=tf.float32): self._testDecoderInference( decoder, initial_state_fn=initial_state_fn, + num_sources=num_sources, with_beam_search=False, with_alignment_history=True, dtype=dtype, @@ -199,6 +219,7 @@ def _testDecoder(self, decoder, initial_state_fn=None, dtype=tf.float32): self._testDecoderInference( decoder, initial_state_fn=initial_state_fn, + num_sources=num_sources, with_beam_search=True, with_alignment_history=False, dtype=dtype, @@ -207,6 +228,7 @@ def _testDecoder(self, decoder, initial_state_fn=None, dtype=tf.float32): self._testDecoderInference( decoder, initial_state_fn=initial_state_fn, + num_sources=num_sources, with_beam_search=True, with_alignment_history=True, dtype=dtype, @@ -243,6 +265,10 @@ def testSelfAttentionDecoderFP16(self): decoder = decoders.SelfAttentionDecoder(2, num_units=6, num_heads=2, ffn_inner_dim=12) self._testDecoder(decoder, dtype=tf.float16) + def testSelfAttentionDecoderMultiSource(self): + decoder = decoders.SelfAttentionDecoder(2, num_units=6, num_heads=2, ffn_inner_dim=12) + self._testDecoder(decoder, num_sources=2) + def testPenalizeToken(self): log_probs = tf.zeros([4, 6]) token_id = 1 diff --git a/opennmt/tests/inputter_test.py b/opennmt/tests/inputter_test.py index 14fec97eb..a546f4775 100644 --- a/opennmt/tests/inputter_test.py +++ b/opennmt/tests/inputter_test.py @@ -307,6 +307,7 @@ def testParallelInputter(self): parallel_inputter = inputter.ParallelInputter([ text_inputter.WordEmbedder("vocabulary_file_1", embedding_size=10), text_inputter.WordEmbedder("vocabulary_file_2", embedding_size=5)]) + self.assertEqual(parallel_inputter.num_outputs, 2) features, transformed = self._makeDataset( parallel_inputter, data_files, @@ -349,6 +350,7 @@ def testMixedInputter(self): text_inputter.WordEmbedder("vocabulary_file_1", embedding_size=10), text_inputter.CharConvEmbedder("vocabulary_file_2", 10, 5)], reducer=reducer.ConcatReducer()) + self.assertEqual(mixed_inputter.num_outputs, 1) features, transformed = self._makeDataset( mixed_inputter, data_file,