Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fix Flash Zero datamodule kwargs (#994)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Nov 23, 2021
1 parent 504c4c2 commit 19cf911
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 8 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where translation metrics were not computed correctly ([#992](https://github.com/PyTorchLightning/lightning-flash/pull/992))

- Fixed a bug where additional `DataModule` keyword arguments could not be configured with Flash Zero for some tasks ([#994](https://github.com/PyTorchLightning/lightning-flash/pull/994))

### Removed

- Removed `OutputMapping` ([#939](https://github.com/PyTorchLightning/lightning-flash/pull/939))
Expand Down
31 changes: 25 additions & 6 deletions flash/core/utilities/flash_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ def wrapper(*args, **kwargs):
return wrapper


def get_kwarg_name(func) -> Optional[str]:
sig = signature(func)
var_kwargs = [p for p in sig.parameters.values() if p.kind == p.VAR_KEYWORD]
if len(var_kwargs) == 1:
return var_kwargs[0].name
return None


def make_args_optional(cls, args: Set[str]):
@wraps(cls)
def wrapper(*args, **kwargs):
Expand Down Expand Up @@ -211,22 +219,33 @@ def add_arguments_to_parser(self, parser) -> None:

def add_subcommand_from_function(self, subcommands, function, function_name=None):
subcommand = ArgumentParser()
datamodule_function = class_from_function(drop_kwargs(function), return_type=self.local_datamodule_class)
subcommand.add_class_arguments(datamodule_function, fail_untyped=False)
if self.legacy:
datamodule_function = class_from_function(drop_kwargs(function), return_type=self.local_datamodule_class)
subcommand.add_class_arguments(datamodule_function, fail_untyped=False)
input_transform_function = class_from_function(drop_kwargs(self.local_datamodule_class.input_transform_cls))
subcommand.add_class_arguments(
input_transform_function,
fail_untyped=False,
skip=get_overlapping_args(datamodule_function, input_transform_function),
)
else:
base_datamodule_function = class_from_function(drop_kwargs(self.local_datamodule_class))
elif get_kwarg_name(function) == "data_module_kwargs":
datamodule_function = class_from_function(function, return_type=self.local_datamodule_class)
subcommand.add_class_arguments(
base_datamodule_function,
datamodule_function,
fail_untyped=False,
skip=get_overlapping_args(datamodule_function, base_datamodule_function),
skip={
"self",
"train_dataset",
"val_dataset",
"test_dataset",
"predict_dataset",
"input",
"input_transform",
},
)
else:
datamodule_function = class_from_function(drop_kwargs(function), return_type=self.local_datamodule_class)
subcommand.add_class_arguments(datamodule_function, fail_untyped=False)
subcommand_name = function_name or function.__name__
subcommands.add_subcommand(subcommand_name, subcommand)
self._subcommand_builders[subcommand_name] = function
Expand Down
8 changes: 6 additions & 2 deletions flash/tabular/forecasting/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def from_synthetic_ar_data(
n_series: int = 100,
max_encoder_length: int = 60,
max_prediction_length: int = 20,
**data_module_kwargs,
batch_size: int = 4,
num_workers: int = 0,
**time_series_dataset_kwargs,
) -> TabularForecastingData:
"""Creates and loads a synthetic auto-regressive (AR) data set."""
data = generate_ar_data(seasonality=seasonality, timesteps=timesteps, n_series=n_series, seed=42)
Expand All @@ -51,7 +53,9 @@ def from_synthetic_ar_data(
max_prediction_length=max_prediction_length,
train_data_frame=data[lambda x: x.time_idx <= training_cutoff],
val_data_frame=data,
**data_module_kwargs,
batch_size=batch_size,
num_workers=num_workers,
**time_series_dataset_kwargs,
)


Expand Down

0 comments on commit 19cf911

Please sign in to comment.