Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stop fitting pipeline after last fit block #132

Merged
merged 6 commits into from
Jan 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 25 additions & 10 deletions mlblocks/mlpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def _get_tunable_hyperparameters(self):

def _build_blocks(self):
blocks = OrderedDict()
last_fit_block = None

block_names_count = Counter()
for primitive in self.primitives:
Expand All @@ -118,11 +119,14 @@ def _build_blocks(self):
block = MLBlock(primitive, **block_params)
blocks[block_name] = block

if bool(block._fit):
last_fit_block = block_name

except Exception:
LOGGER.exception('Exception caught building MLBlock %s', primitive)
raise

return blocks
return blocks, last_fit_block

@staticmethod
def _get_pipeline_dict(pipeline, primitives):
Expand Down Expand Up @@ -207,7 +211,7 @@ def __init__(self, pipeline=None, primitives=None, init_params=None,

self.primitives = primitives or pipeline['primitives']
self.init_params = init_params or pipeline.get('init_params', dict())
self.blocks = self._build_blocks()
self.blocks, self._last_fit_block = self._build_blocks()
self._last_block_name = self._get_block_name(-1)

self.input_names = input_names or pipeline.get('input_names', dict())
Expand Down Expand Up @@ -767,7 +771,11 @@ def fit(self, X=None, y=None, output_=None, start_=None, debug=False, **kwargs):
debug_info = defaultdict(dict)
debug_info['debug'] = debug.lower() if isinstance(debug, str) else 'tmio'

fit_pending = True
for block_name, block in self.blocks.items():
if block_name == self._last_fit_block:
fit_pending = False

if start_:
if block_name == start_:
start_ = False
Expand All @@ -777,7 +785,7 @@ def fit(self, X=None, y=None, output_=None, start_=None, debug=False, **kwargs):

self._fit_block(block, block_name, context, debug_info)

if (block_name != self._last_block_name) or (block_name in output_blocks):
if fit_pending or output_blocks:
self._produce_block(
block, block_name, context, output_variables, outputs, debug_info)

Expand All @@ -787,16 +795,23 @@ def fit(self, X=None, y=None, output_=None, start_=None, debug=False, **kwargs):

# If there was an output_ but there are no pending
# outputs we are done.
if output_variables is not None and not output_blocks:
if len(outputs) > 1:
result = tuple(outputs)
else:
result = outputs[0]
if output_variables:
if not output_blocks:
if len(outputs) > 1:
result = tuple(outputs)
else:
result = outputs[0]

if debug:
return result, debug_info

return result

elif not fit_pending:
if debug:
return result, debug_info
return debug_info

return result
return

if start_:
# We skipped all the blocks up to the end
Expand Down
48 changes: 48 additions & 0 deletions tests/test_mlpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,54 @@ def test_get_inputs_no_fit(self):

assert inputs == expected

@patch('mlblocks.mlpipeline.MLBlock', new=get_mlblock_mock)
def test_fit_pending_all_primitives(self):
block_1 = get_mlblock_mock()
block_2 = get_mlblock_mock()
blocks = OrderedDict((
('a.primitive.Name#1', block_1),
('a.primitive.Name#2', block_2),
))

self_ = MagicMock(autospec=MLPipeline)
self_.blocks = blocks
self_._last_fit_block = 'a.primitive.Name#2'

MLPipeline.fit(self_)

expected = [
call('a.primitive.Name#1'),
call('a.primitive.Name#2')
]
self_._fit_block.call_args_list = expected

expected = [
call('a.primitive.Name#1'),
]
self_._produce_block.call_args_list = expected

@patch('mlblocks.mlpipeline.MLBlock', new=get_mlblock_mock)
def test_fit_pending_one_primitive(self):
block_1 = get_mlblock_mock()
block_2 = get_mlblock_mock()
blocks = OrderedDict((
('a.primitive.Name#1', block_1),
('a.primitive.Name#2', block_2),
))

self_ = MagicMock(autospec=MLPipeline)
self_.blocks = blocks
self_._last_fit_block = 'a.primitive.Name#1'

MLPipeline.fit(self_)

expected = [
call('a.primitive.Name#1'),
]
self_._fit_block.call_args_list = expected

assert not self_._produce_block.called

@patch('mlblocks.mlpipeline.MLBlock', new=get_mlblock_mock)
def test_fit_no_debug(self):
mlpipeline = MLPipeline(['a_primitive'])
Expand Down