diff --git a/mlblocks/mlpipeline.py b/mlblocks/mlpipeline.py index a4111bcb..d7935757 100644 --- a/mlblocks/mlpipeline.py +++ b/mlblocks/mlpipeline.py @@ -96,6 +96,7 @@ def _get_tunable_hyperparameters(self): def _build_blocks(self): blocks = OrderedDict() + last_fit_block = None block_names_count = Counter() for primitive in self.primitives: @@ -118,11 +119,14 @@ def _build_blocks(self): block = MLBlock(primitive, **block_params) blocks[block_name] = block + if bool(block._fit): + last_fit_block = block_name + except Exception: LOGGER.exception('Exception caught building MLBlock %s', primitive) raise - return blocks + return blocks, last_fit_block @staticmethod def _get_pipeline_dict(pipeline, primitives): @@ -207,7 +211,7 @@ def __init__(self, pipeline=None, primitives=None, init_params=None, self.primitives = primitives or pipeline['primitives'] self.init_params = init_params or pipeline.get('init_params', dict()) - self.blocks = self._build_blocks() + self.blocks, self._last_fit_block = self._build_blocks() self._last_block_name = self._get_block_name(-1) self.input_names = input_names or pipeline.get('input_names', dict()) @@ -767,7 +771,11 @@ def fit(self, X=None, y=None, output_=None, start_=None, debug=False, **kwargs): debug_info = defaultdict(dict) debug_info['debug'] = debug.lower() if isinstance(debug, str) else 'tmio' + fit_pending = True for block_name, block in self.blocks.items(): + if block_name == self._last_fit_block: + fit_pending = False + if start_: if block_name == start_: start_ = False @@ -777,7 +785,7 @@ def fit(self, X=None, y=None, output_=None, start_=None, debug=False, **kwargs): self._fit_block(block, block_name, context, debug_info) - if (block_name != self._last_block_name) or (block_name in output_blocks): + if fit_pending or output_blocks: self._produce_block( block, block_name, context, output_variables, outputs, debug_info) @@ -787,16 +795,23 @@ def fit(self, X=None, y=None, output_=None, start_=None, debug=False, **kwargs): # If there was an output_ but there are no pending # outputs we are done. - if output_variables is not None and not output_blocks: - if len(outputs) > 1: - result = tuple(outputs) - else: - result = outputs[0] + if output_variables: + if not output_blocks: + if len(outputs) > 1: + result = tuple(outputs) + else: + result = outputs[0] + + if debug: + return result, debug_info + + return result + elif not fit_pending: if debug: - return result, debug_info + return debug_info - return result + return if start_: # We skipped all the blocks up to the end diff --git a/tests/test_mlpipeline.py b/tests/test_mlpipeline.py index 97c59cd0..0ee4cf2c 100644 --- a/tests/test_mlpipeline.py +++ b/tests/test_mlpipeline.py @@ -681,6 +681,54 @@ def test_get_inputs_no_fit(self): assert inputs == expected + @patch('mlblocks.mlpipeline.MLBlock', new=get_mlblock_mock) + def test_fit_pending_all_primitives(self): + block_1 = get_mlblock_mock() + block_2 = get_mlblock_mock() + blocks = OrderedDict(( + ('a.primitive.Name#1', block_1), + ('a.primitive.Name#2', block_2), + )) + + self_ = MagicMock(autospec=MLPipeline) + self_.blocks = blocks + self_._last_fit_block = 'a.primitive.Name#2' + + MLPipeline.fit(self_) + + expected = [ + call('a.primitive.Name#1'), + call('a.primitive.Name#2') + ] + self_._fit_block.call_args_list = expected + + expected = [ + call('a.primitive.Name#1'), + ] + self_._produce_block.call_args_list = expected + + @patch('mlblocks.mlpipeline.MLBlock', new=get_mlblock_mock) + def test_fit_pending_one_primitive(self): + block_1 = get_mlblock_mock() + block_2 = get_mlblock_mock() + blocks = OrderedDict(( + ('a.primitive.Name#1', block_1), + ('a.primitive.Name#2', block_2), + )) + + self_ = MagicMock(autospec=MLPipeline) + self_.blocks = blocks + self_._last_fit_block = 'a.primitive.Name#1' + + MLPipeline.fit(self_) + + expected = [ + call('a.primitive.Name#1'), + ] + self_._fit_block.call_args_list = expected + + assert not self_._produce_block.called + @patch('mlblocks.mlpipeline.MLBlock', new=get_mlblock_mock) def test_fit_no_debug(self): mlpipeline = MLPipeline(['a_primitive'])