From 8f4d40a41b7035227dee5b082c74deaa42dc7617 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 23 Nov 2021 16:17:41 +0000 Subject: [PATCH 1/3] Fix Flash Zero datamodule kwargs --- CHANGELOG.md | 2 ++ flash/core/utilities/flash_cli.py | 12 ++++-------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 45ebfa166e..65511a2851 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -48,6 +48,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where translation metrics were not computed correctly ([#992](https://github.com/PyTorchLightning/lightning-flash/pull/992)) +- Fixed a bug where additional `DataModule` keyword arguments could not be configured with Flash Zero for some tasks + ### Removed - Removed `OutputMapping` ([#939](https://github.com/PyTorchLightning/lightning-flash/pull/939)) diff --git a/flash/core/utilities/flash_cli.py b/flash/core/utilities/flash_cli.py index bd0992afba..673c07196c 100644 --- a/flash/core/utilities/flash_cli.py +++ b/flash/core/utilities/flash_cli.py @@ -211,9 +211,9 @@ def add_arguments_to_parser(self, parser) -> None: def add_subcommand_from_function(self, subcommands, function, function_name=None): subcommand = ArgumentParser() - datamodule_function = class_from_function(drop_kwargs(function), return_type=self.local_datamodule_class) - subcommand.add_class_arguments(datamodule_function, fail_untyped=False) if self.legacy: + datamodule_function = class_from_function(drop_kwargs(function), return_type=self.local_datamodule_class) + subcommand.add_class_arguments(datamodule_function, fail_untyped=False) input_transform_function = class_from_function(drop_kwargs(self.local_datamodule_class.input_transform_cls)) subcommand.add_class_arguments( input_transform_function, @@ -221,12 +221,8 @@ def add_subcommand_from_function(self, subcommands, function, function_name=None skip=get_overlapping_args(datamodule_function, input_transform_function), ) else: - base_datamodule_function = class_from_function(drop_kwargs(self.local_datamodule_class)) - subcommand.add_class_arguments( - base_datamodule_function, - fail_untyped=False, - skip=get_overlapping_args(datamodule_function, base_datamodule_function), - ) + datamodule_function = class_from_function(function, return_type=self.local_datamodule_class) + subcommand.add_class_arguments(datamodule_function, fail_untyped=False) subcommand_name = function_name or function.__name__ subcommands.add_subcommand(subcommand_name, subcommand) self._subcommand_builders[subcommand_name] = function From 0727fc29e656efd0eeabb662dffeb57145566fb7 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 23 Nov 2021 16:18:54 +0000 Subject: [PATCH 2/3] Update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 65511a2851..e47d72828e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -48,7 +48,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where translation metrics were not computed correctly ([#992](https://github.com/PyTorchLightning/lightning-flash/pull/992)) -- Fixed a bug where additional `DataModule` keyword arguments could not be configured with Flash Zero for some tasks +- Fixed a bug where additional `DataModule` keyword arguments could not be configured with Flash Zero for some tasks ([#994](https://github.com/PyTorchLightning/lightning-flash/pull/994)) ### Removed From ce617a71ee9e8345649cb3aa5b6c5e218563e1f5 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 23 Nov 2021 16:59:13 +0000 Subject: [PATCH 3/3] Fixes --- flash/core/utilities/flash_cli.py | 25 ++++++++++++++++++++++++- flash/tabular/forecasting/cli.py | 8 ++++++-- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/flash/core/utilities/flash_cli.py b/flash/core/utilities/flash_cli.py index 673c07196c..77814dad6b 100644 --- a/flash/core/utilities/flash_cli.py +++ b/flash/core/utilities/flash_cli.py @@ -76,6 +76,14 @@ def wrapper(*args, **kwargs): return wrapper +def get_kwarg_name(func) -> Optional[str]: + sig = signature(func) + var_kwargs = [p for p in sig.parameters.values() if p.kind == p.VAR_KEYWORD] + if len(var_kwargs) == 1: + return var_kwargs[0].name + return None + + def make_args_optional(cls, args: Set[str]): @wraps(cls) def wrapper(*args, **kwargs): @@ -220,8 +228,23 @@ def add_subcommand_from_function(self, subcommands, function, function_name=None fail_untyped=False, skip=get_overlapping_args(datamodule_function, input_transform_function), ) - else: + elif get_kwarg_name(function) == "data_module_kwargs": datamodule_function = class_from_function(function, return_type=self.local_datamodule_class) + subcommand.add_class_arguments( + datamodule_function, + fail_untyped=False, + skip={ + "self", + "train_dataset", + "val_dataset", + "test_dataset", + "predict_dataset", + "input", + "input_transform", + }, + ) + else: + datamodule_function = class_from_function(drop_kwargs(function), return_type=self.local_datamodule_class) subcommand.add_class_arguments(datamodule_function, fail_untyped=False) subcommand_name = function_name or function.__name__ subcommands.add_subcommand(subcommand_name, subcommand) diff --git a/flash/tabular/forecasting/cli.py b/flash/tabular/forecasting/cli.py index 44548587b8..f2260a68bf 100644 --- a/flash/tabular/forecasting/cli.py +++ b/flash/tabular/forecasting/cli.py @@ -32,7 +32,9 @@ def from_synthetic_ar_data( n_series: int = 100, max_encoder_length: int = 60, max_prediction_length: int = 20, - **data_module_kwargs, + batch_size: int = 4, + num_workers: int = 0, + **time_series_dataset_kwargs, ) -> TabularForecastingData: """Creates and loads a synthetic auto-regressive (AR) data set.""" data = generate_ar_data(seasonality=seasonality, timesteps=timesteps, n_series=n_series, seed=42) @@ -51,7 +53,9 @@ def from_synthetic_ar_data( max_prediction_length=max_prediction_length, train_data_frame=data[lambda x: x.time_idx <= training_cutoff], val_data_frame=data, - **data_module_kwargs, + batch_size=batch_size, + num_workers=num_workers, + **time_series_dataset_kwargs, )