Skip to content

Commit

Permalink
Stop fitting pipeline after last fit block (#132)
Browse files Browse the repository at this point in the history
* initial early stop

* change  to stop after fitting the last block with  attribute

* test early-stop calls

* remove comment

* change to fit pending
  • Loading branch information
sarahmish authored Jan 8, 2021
1 parent 1af7b1b commit 8446048
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 10 deletions.
35 changes: 25 additions & 10 deletions mlblocks/mlpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def _get_tunable_hyperparameters(self):

def _build_blocks(self):
blocks = OrderedDict()
last_fit_block = None

block_names_count = Counter()
for primitive in self.primitives:
Expand All @@ -118,11 +119,14 @@ def _build_blocks(self):
block = MLBlock(primitive, **block_params)
blocks[block_name] = block

if bool(block._fit):
last_fit_block = block_name

except Exception:
LOGGER.exception('Exception caught building MLBlock %s', primitive)
raise

return blocks
return blocks, last_fit_block

@staticmethod
def _get_pipeline_dict(pipeline, primitives):
Expand Down Expand Up @@ -207,7 +211,7 @@ def __init__(self, pipeline=None, primitives=None, init_params=None,

self.primitives = primitives or pipeline['primitives']
self.init_params = init_params or pipeline.get('init_params', dict())
self.blocks = self._build_blocks()
self.blocks, self._last_fit_block = self._build_blocks()
self._last_block_name = self._get_block_name(-1)

self.input_names = input_names or pipeline.get('input_names', dict())
Expand Down Expand Up @@ -767,7 +771,11 @@ def fit(self, X=None, y=None, output_=None, start_=None, debug=False, **kwargs):
debug_info = defaultdict(dict)
debug_info['debug'] = debug.lower() if isinstance(debug, str) else 'tmio'

fit_pending = True
for block_name, block in self.blocks.items():
if block_name == self._last_fit_block:
fit_pending = False

if start_:
if block_name == start_:
start_ = False
Expand All @@ -777,7 +785,7 @@ def fit(self, X=None, y=None, output_=None, start_=None, debug=False, **kwargs):

self._fit_block(block, block_name, context, debug_info)

if (block_name != self._last_block_name) or (block_name in output_blocks):
if fit_pending or output_blocks:
self._produce_block(
block, block_name, context, output_variables, outputs, debug_info)

Expand All @@ -787,16 +795,23 @@ def fit(self, X=None, y=None, output_=None, start_=None, debug=False, **kwargs):

# If there was an output_ but there are no pending
# outputs we are done.
if output_variables is not None and not output_blocks:
if len(outputs) > 1:
result = tuple(outputs)
else:
result = outputs[0]
if output_variables:
if not output_blocks:
if len(outputs) > 1:
result = tuple(outputs)
else:
result = outputs[0]

if debug:
return result, debug_info

return result

elif not fit_pending:
if debug:
return result, debug_info
return debug_info

return result
return

if start_:
# We skipped all the blocks up to the end
Expand Down
48 changes: 48 additions & 0 deletions tests/test_mlpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,54 @@ def test_get_inputs_no_fit(self):

assert inputs == expected

@patch('mlblocks.mlpipeline.MLBlock', new=get_mlblock_mock)
def test_fit_pending_all_primitives(self):
block_1 = get_mlblock_mock()
block_2 = get_mlblock_mock()
blocks = OrderedDict((
('a.primitive.Name#1', block_1),
('a.primitive.Name#2', block_2),
))

self_ = MagicMock(autospec=MLPipeline)
self_.blocks = blocks
self_._last_fit_block = 'a.primitive.Name#2'

MLPipeline.fit(self_)

expected = [
call('a.primitive.Name#1'),
call('a.primitive.Name#2')
]
self_._fit_block.call_args_list = expected

expected = [
call('a.primitive.Name#1'),
]
self_._produce_block.call_args_list = expected

@patch('mlblocks.mlpipeline.MLBlock', new=get_mlblock_mock)
def test_fit_pending_one_primitive(self):
block_1 = get_mlblock_mock()
block_2 = get_mlblock_mock()
blocks = OrderedDict((
('a.primitive.Name#1', block_1),
('a.primitive.Name#2', block_2),
))

self_ = MagicMock(autospec=MLPipeline)
self_.blocks = blocks
self_._last_fit_block = 'a.primitive.Name#1'

MLPipeline.fit(self_)

expected = [
call('a.primitive.Name#1'),
]
self_._fit_block.call_args_list = expected

assert not self_._produce_block.called

@patch('mlblocks.mlpipeline.MLBlock', new=get_mlblock_mock)
def test_fit_no_debug(self):
mlpipeline = MLPipeline(['a_primitive'])
Expand Down

0 comments on commit 8446048

Please sign in to comment.