From ba2cbb70d9f2dc3d31581222c3f95c09a017329c Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 1 Feb 2021 19:13:56 +0000 Subject: [PATCH] Add typing --- flash/text/seq2seq/core/data.py | 14 +++++++------- flash/text/seq2seq/summarization/data.py | 14 +++++++------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/flash/text/seq2seq/core/data.py b/flash/text/seq2seq/core/data.py index a7560570129..bf510f4b0aa 100644 --- a/flash/text/seq2seq/core/data.py +++ b/flash/text/seq2seq/core/data.py @@ -144,13 +144,13 @@ def default_pipeline(): @classmethod def from_files( cls, - train_file, + train_file: str, input: str = 'input', target: Optional[str] = None, - filetype="csv", - backbone="sshleifer/tiny-mbart", - valid_file=None, - test_file=None, + filetype: str = "csv", + backbone: str = "sshleifer/tiny-mbart", + valid_file: Optional[str] = None, + test_file: Optional[str] = None, max_source_length: int = 128, max_target_length: int = 128, padding: Union[str, bool] = 'max_length', @@ -217,8 +217,8 @@ def from_file( predict_file: str, input: str = 'input', target: Optional[str] = None, - backbone="sshleifer/tiny-mbart", - filetype="csv", + backbone: str = "sshleifer/tiny-mbart", + filetype: str = "csv", max_source_length: int = 128, max_target_length: int = 128, padding: Union[str, bool] = 'max_length', diff --git a/flash/text/seq2seq/summarization/data.py b/flash/text/seq2seq/summarization/data.py index 99bd541d4f3..20e0eb2ba2b 100644 --- a/flash/text/seq2seq/summarization/data.py +++ b/flash/text/seq2seq/summarization/data.py @@ -29,13 +29,13 @@ def default_pipeline(): @classmethod def from_files( cls, - train_file, + train_file: str, input: str = 'input', target: Optional[str] = None, - filetype="csv", - backbone="t5-small", - valid_file=None, - test_file=None, + filetype: str = "csv", + backbone: str = "t5-small", + valid_file: str = None, + test_file: str = None, max_source_length: int = 512, max_target_length: int = 128, padding: Union[str, bool] = 'max_length', @@ -91,8 +91,8 @@ def from_file( predict_file: str, input: str = 'src_text', target: Optional[str] = None, - backbone="t5-small", - filetype="csv", + backbone: str = "t5-small", + filetype: str = "csv", max_source_length: int = 512, max_target_length: int = 128, padding: Union[str, bool] = 'longest',