From d8d7fc0846cad4382f8343f4340efe7e47eaf0fa Mon Sep 17 00:00:00 2001 From: sarahmish Date: Tue, 29 Dec 2020 16:22:19 -0500 Subject: [PATCH 1/5] initial early stop --- mlblocks/mlpipeline.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/mlblocks/mlpipeline.py b/mlblocks/mlpipeline.py index 8367b327..5ba0cbc8 100644 --- a/mlblocks/mlpipeline.py +++ b/mlblocks/mlpipeline.py @@ -93,6 +93,7 @@ def _get_tunable_hyperparameters(self): def _build_blocks(self): blocks = OrderedDict() + last_fit_block = None block_names_count = Counter() for primitive in self.primitives: @@ -115,11 +116,14 @@ def _build_blocks(self): block = MLBlock(primitive, **block_params) blocks[block_name] = block + if bool(block._fit): + last_fit_block = primitive + except Exception: LOGGER.exception("Exception caught building MLBlock %s", primitive) raise - return blocks + return blocks, last_fit_block @staticmethod def _get_pipeline_dict(pipeline, primitives): @@ -204,7 +208,7 @@ def __init__(self, pipeline=None, primitives=None, init_params=None, self.primitives = primitives or pipeline['primitives'] self.init_params = init_params or pipeline.get('init_params', dict()) - self.blocks = self._build_blocks() + self.blocks, self._last_fit_block = self._build_blocks() self._last_block_name = self._get_block_name(-1) self.input_names = input_names or pipeline.get('input_names', dict()) @@ -750,7 +754,7 @@ def fit(self, X=None, y=None, output_=None, start_=None, debug=False, **kwargs): self._fit_block(block, block_name, context, debug_info) - if (block_name != self._last_block_name) or (block_name in output_blocks): + if (block_name != self._last_fit_block) or (block_name in output_blocks): self._produce_block( block, block_name, context, output_variables, outputs, debug_info) From ba514f0d87d02c67645eb28c388d465988865ba6 Mon Sep 17 00:00:00 2001 From: sarahmish Date: Tue, 29 Dec 2020 17:11:14 -0500 Subject: [PATCH 2/5] change to stop after fitting the last block with attribute --- mlblocks/mlpipeline.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mlblocks/mlpipeline.py b/mlblocks/mlpipeline.py index 73d31790..a362053b 100644 --- a/mlblocks/mlpipeline.py +++ b/mlblocks/mlpipeline.py @@ -120,7 +120,7 @@ def _build_blocks(self): blocks[block_name] = block if bool(block._fit): - last_fit_block = primitive + last_fit_block = block_name except Exception: LOGGER.exception('Exception caught building MLBlock %s', primitive) @@ -771,6 +771,7 @@ def fit(self, X=None, y=None, output_=None, start_=None, debug=False, **kwargs): debug_info = defaultdict(dict) debug_info['debug'] = debug.lower() if isinstance(debug, str) else 'tmio' + early_stop = False for block_name, block in self.blocks.items(): if start_: if block_name == start_: @@ -781,7 +782,10 @@ def fit(self, X=None, y=None, output_=None, start_=None, debug=False, **kwargs): self._fit_block(block, block_name, context, debug_info) - if (block_name != self._last_fit_block) or (block_name in output_blocks): + if block_name == self._last_fit_block: + early_stop = True + + if (not early_stop) or (block_name in output_blocks): self._produce_block( block, block_name, context, output_variables, outputs, debug_info) From 03b2f710935b4d5a09c8278ccc8e65b3c7476202 Mon Sep 17 00:00:00 2001 From: sarahmish Date: Mon, 4 Jan 2021 16:39:00 -0500 Subject: [PATCH 3/5] test early-stop calls --- tests/test_mlpipeline.py | 63 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/tests/test_mlpipeline.py b/tests/test_mlpipeline.py index 97c59cd0..53087594 100644 --- a/tests/test_mlpipeline.py +++ b/tests/test_mlpipeline.py @@ -681,6 +681,69 @@ def test_get_inputs_no_fit(self): assert inputs == expected + @patch('mlblocks.mlpipeline.MLBlock', new=get_mlblock_mock) + def test_fit_no_early_stop(self): + block_1 = get_mlblock_mock() + block_2 = get_mlblock_mock() + blocks = OrderedDict(( + ('a.primitive.Name#1', block_1), + ('a.primitive.Name#2', block_2), + )) + + self_ = MagicMock(autospec=MLPipeline) + self_.blocks = blocks + self_._last_fit_block = 'a.primitive.Name#2' + + MLPipeline.fit(self_) + + expected = [ + call('a.primitive.Name#1'), + call('a.primitive.Name#2') + ] + self_._fit_block.call_args_list = expected + + expected = [ + call('a.primitive.Name#1'), + ] + self_._produce_block.call_args_list = expected + + @patch('mlblocks.mlpipeline.MLBlock', new=get_mlblock_mock) + def test_fit_early_stop(self): + block_1 = get_mlblock_mock() + block_2 = get_mlblock_mock() + blocks = OrderedDict(( + ('a.primitive.Name#1', block_1), + ('a.primitive.Name#2', block_2), + )) + + self_ = MagicMock(autospec=MLPipeline) + self_.blocks = blocks + self_._last_fit_block = 'a.primitive.Name#1' + + MLPipeline.fit(self_) + + expected = [ + call('a.primitive.Name#1'), + ] + self_._fit_block.call_args_list = expected + + assert not self_._produce_block.called + + # @patch('mlblocks.mlpipeline.MLBlock', new=get_mlblock_mock) + # # @patch('mlpipeline._produce_block') + # def test_fit_early_stop(self): + # pipeline = MLPipeline(['a_primitive', 'another_primitive']) + # pipeline._last_fit_block = 'a_primitive#1' + + # pipeline.fit() + + # expected_calls = [ + # call(get_mlblock_mock(), 'a_primitive#1'), + # call(get_mlblock_mock(), 'another_primitive#1') + # ] + + # assert pipeline.call_args_list == expected_calls + @patch('mlblocks.mlpipeline.MLBlock', new=get_mlblock_mock) def test_fit_no_debug(self): mlpipeline = MLPipeline(['a_primitive']) From e65078965853c74aa3aafab58503c5b1096fac46 Mon Sep 17 00:00:00 2001 From: sarahmish Date: Mon, 4 Jan 2021 16:42:14 -0500 Subject: [PATCH 4/5] remove comment --- tests/test_mlpipeline.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/tests/test_mlpipeline.py b/tests/test_mlpipeline.py index 53087594..b8c021b4 100644 --- a/tests/test_mlpipeline.py +++ b/tests/test_mlpipeline.py @@ -729,21 +729,6 @@ def test_fit_early_stop(self): assert not self_._produce_block.called - # @patch('mlblocks.mlpipeline.MLBlock', new=get_mlblock_mock) - # # @patch('mlpipeline._produce_block') - # def test_fit_early_stop(self): - # pipeline = MLPipeline(['a_primitive', 'another_primitive']) - # pipeline._last_fit_block = 'a_primitive#1' - - # pipeline.fit() - - # expected_calls = [ - # call(get_mlblock_mock(), 'a_primitive#1'), - # call(get_mlblock_mock(), 'another_primitive#1') - # ] - - # assert pipeline.call_args_list == expected_calls - @patch('mlblocks.mlpipeline.MLBlock', new=get_mlblock_mock) def test_fit_no_debug(self): mlpipeline = MLPipeline(['a_primitive']) From 382b89a7f5df9a20ff3ff19fd727957c68a41fee Mon Sep 17 00:00:00 2001 From: sarahmish Date: Thu, 7 Jan 2021 18:18:46 -0500 Subject: [PATCH 5/5] change to fit pending --- mlblocks/mlpipeline.py | 31 +++++++++++++++++++------------ tests/test_mlpipeline.py | 4 ++-- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/mlblocks/mlpipeline.py b/mlblocks/mlpipeline.py index a362053b..d7935757 100644 --- a/mlblocks/mlpipeline.py +++ b/mlblocks/mlpipeline.py @@ -771,8 +771,11 @@ def fit(self, X=None, y=None, output_=None, start_=None, debug=False, **kwargs): debug_info = defaultdict(dict) debug_info['debug'] = debug.lower() if isinstance(debug, str) else 'tmio' - early_stop = False + fit_pending = True for block_name, block in self.blocks.items(): + if block_name == self._last_fit_block: + fit_pending = False + if start_: if block_name == start_: start_ = False @@ -782,10 +785,7 @@ def fit(self, X=None, y=None, output_=None, start_=None, debug=False, **kwargs): self._fit_block(block, block_name, context, debug_info) - if block_name == self._last_fit_block: - early_stop = True - - if (not early_stop) or (block_name in output_blocks): + if fit_pending or output_blocks: self._produce_block( block, block_name, context, output_variables, outputs, debug_info) @@ -795,16 +795,23 @@ def fit(self, X=None, y=None, output_=None, start_=None, debug=False, **kwargs): # If there was an output_ but there are no pending # outputs we are done. - if output_variables is not None and not output_blocks: - if len(outputs) > 1: - result = tuple(outputs) - else: - result = outputs[0] + if output_variables: + if not output_blocks: + if len(outputs) > 1: + result = tuple(outputs) + else: + result = outputs[0] + + if debug: + return result, debug_info + return result + + elif not fit_pending: if debug: - return result, debug_info + return debug_info - return result + return if start_: # We skipped all the blocks up to the end diff --git a/tests/test_mlpipeline.py b/tests/test_mlpipeline.py index b8c021b4..0ee4cf2c 100644 --- a/tests/test_mlpipeline.py +++ b/tests/test_mlpipeline.py @@ -682,7 +682,7 @@ def test_get_inputs_no_fit(self): assert inputs == expected @patch('mlblocks.mlpipeline.MLBlock', new=get_mlblock_mock) - def test_fit_no_early_stop(self): + def test_fit_pending_all_primitives(self): block_1 = get_mlblock_mock() block_2 = get_mlblock_mock() blocks = OrderedDict(( @@ -708,7 +708,7 @@ def test_fit_no_early_stop(self): self_._produce_block.call_args_list = expected @patch('mlblocks.mlpipeline.MLBlock', new=get_mlblock_mock) - def test_fit_early_stop(self): + def test_fit_pending_one_primitive(self): block_1 = get_mlblock_mock() block_2 = get_mlblock_mock() blocks = OrderedDict((