Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Rename valid_ to val_
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Mar 29, 2021
1 parent ba34bf4 commit e72d96d
Show file tree
Hide file tree
Showing 24 changed files with 156 additions and 158 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'da
# 2. Load the data
datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/",
valid_folder="data/hymenoptera_data/val/",
val_folder="data/hymenoptera_data/val/",
test_folder="data/hymenoptera_data/test/",
)

Expand Down Expand Up @@ -205,11 +205,11 @@ download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", 'data/')

# 2. Load the data
datamodule = SummarizationData.from_files(
train_file="data/xsum/train.csv",
valid_file="data/xsum/valid.csv",
test_file="data/xsum/test.csv",
input="input",
target="target"
train_file="data/xsum/train.csv",
val_file="data/xsum/valid.csv",
test_file="data/xsum/test.csv",
input="input",
target="target"
)

# 3. Build the model
Expand Down
46 changes: 23 additions & 23 deletions flash/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class DataModule(pl.LightningDataModule):
Args:
train_dataset: Dataset for training. Defaults to None.
valid_dataset: Dataset for validating model performance during training. Defaults to None.
val_dataset: Dataset for validating model performance during training. Defaults to None.
test_dataset: Dataset to test model performance. Defaults to None.
predict_dataset: Dataset to predict model performance. Defaults to None.
num_workers: The number of workers to use for parallelized loading. Defaults to None.
Expand All @@ -49,7 +49,7 @@ class DataModule(pl.LightningDataModule):
def __init__(
self,
train_dataset: Optional[Dataset] = None,
valid_dataset: Optional[Dataset] = None,
val_dataset: Optional[Dataset] = None,
test_dataset: Optional[Dataset] = None,
predict_dataset: Optional[Dataset] = None,
batch_size: int = 1,
Expand All @@ -58,14 +58,14 @@ def __init__(

super().__init__()
self._train_ds = train_dataset
self._valid_ds = valid_dataset
self._val_ds = val_dataset
self._test_ds = test_dataset
self._predict_ds = predict_dataset

if self._train_ds:
self.train_dataloader = self._train_dataloader

if self._valid_ds:
if self._val_ds:
self.val_dataloader = self._val_dataloader

if self._test_ds:
Expand Down Expand Up @@ -104,8 +104,8 @@ def set_running_stages(self):
if self._train_ds:
self.set_dataset_attribute(self._train_ds, 'running_stage', RunningStage.TRAINING)

if self._valid_ds:
self.set_dataset_attribute(self._valid_ds, 'running_stage', RunningStage.VALIDATING)
if self._val_ds:
self.set_dataset_attribute(self._val_ds, 'running_stage', RunningStage.VALIDATING)

if self._test_ds:
self.set_dataset_attribute(self._test_ds, 'running_stage', RunningStage.TESTING)
Expand All @@ -130,13 +130,13 @@ def _train_dataloader(self) -> DataLoader:
)

def _val_dataloader(self) -> DataLoader:
valid_ds: Dataset = self._valid_ds() if isinstance(self._valid_ds, Callable) else self._valid_ds
val_ds: Dataset = self._val_ds() if isinstance(self._val_ds, Callable) else self._val_ds
return DataLoader(
valid_ds,
val_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
collate_fn=self._resolve_collate_fn(valid_ds, RunningStage.VALIDATING)
collate_fn=self._resolve_collate_fn(val_ds, RunningStage.VALIDATING)
)

def _test_dataloader(self) -> DataLoader:
Expand Down Expand Up @@ -214,10 +214,10 @@ def autogenerate_dataset(
return AutoDataset(data, whole_data_load_fn, per_sample_load_fn, data_pipeline, running_stage=running_stage)

@staticmethod
def train_valid_test_split(
def train_val_test_split(
dataset: torch.utils.data.Dataset,
train_split: Optional[Union[float, int]] = None,
valid_split: Optional[Union[float, int]] = None,
val_split: Optional[Union[float, int]] = None,
test_split: Optional[Union[float, int]] = None,
seed: Optional[int] = 1234,
) -> Tuple[Dataset, Optional[Dataset], Optional[Dataset]]:
Expand All @@ -227,11 +227,11 @@ def train_valid_test_split(
dataset: Dataset to be split.
train_split: If Float, ratio of data to be contained within the train dataset. If Int,
number of samples to be contained within train dataset.
valid_split: If Float, ratio of data to be contained within the validation dataset. If Int,
val_split: If Float, ratio of data to be contained within the validation dataset. If Int,
number of samples to be contained within test dataset.
test_split: If Float, ratio of data to be contained within the test dataset. If Int,
number of samples to be contained within test dataset.
seed: Used for the train/val splits when valid_split is not None.
seed: Used for the train/val splits when val_split is not None.
"""
n = len(dataset)
Expand All @@ -243,12 +243,12 @@ def train_valid_test_split(
else:
_test_length = test_split

if valid_split is None:
if val_split is None:
_val_length = 0
elif isinstance(valid_split, float):
_val_length = int(n * valid_split)
elif isinstance(val_split, float):
_val_length = int(n * val_split)
else:
_val_length = valid_split
_val_length = val_split

if train_split is None:
_train_length = n - _val_length - _test_length
Expand All @@ -265,7 +265,7 @@ def train_valid_test_split(
train_ds, val_ds, test_ds = torch.utils.data.random_split(
dataset, [_train_length, _val_length, _test_length], generator
)
if valid_split is None:
if val_split is None:
val_ds = None
if test_split is None:
test_ds = None
Expand Down Expand Up @@ -293,7 +293,7 @@ def _generate_dataset_if_possible(
def from_load_data_inputs(
cls,
train_load_data_input: Optional[Any] = None,
valid_load_data_input: Optional[Any] = None,
val_load_data_input: Optional[Any] = None,
test_load_data_input: Optional[Any] = None,
predict_load_data_input: Optional[Any] = None,
preprocess: Optional[Preprocess] = None,
Expand All @@ -306,7 +306,7 @@ def from_load_data_inputs(
Args:
cls: ``DataModule`` subclass
train_load_data_input: Data to be received by the ``train_load_data`` function from this ``Preprocess``
valid_load_data_input: Data to be received by the ``val_load_data`` function from this ``Preprocess``
val_load_data_input: Data to be received by the ``val_load_data`` function from this ``Preprocess``
test_load_data_input: Data to be received by the ``test_load_data`` function from this ``Preprocess``
predict_load_data_input: Data to be received by the ``predict_load_data`` function from this ``Preprocess``
kwargs: Any extra arguments to instantiate the provided ``DataModule``
Expand All @@ -322,8 +322,8 @@ def from_load_data_inputs(
train_dataset = cls._generate_dataset_if_possible(
train_load_data_input, running_stage=RunningStage.TRAINING, data_pipeline=data_pipeline
)
valid_dataset = cls._generate_dataset_if_possible(
valid_load_data_input, running_stage=RunningStage.VALIDATING, data_pipeline=data_pipeline
val_dataset = cls._generate_dataset_if_possible(
val_load_data_input, running_stage=RunningStage.VALIDATING, data_pipeline=data_pipeline
)
test_dataset = cls._generate_dataset_if_possible(
test_load_data_input, running_stage=RunningStage.TESTING, data_pipeline=data_pipeline
Expand All @@ -333,7 +333,7 @@ def from_load_data_inputs(
)
datamodule = cls(
train_dataset=train_dataset,
valid_dataset=valid_dataset,
val_dataset=val_dataset,
test_dataset=test_dataset,
predict_dataset=predict_dataset,
**kwargs
Expand Down
4 changes: 2 additions & 2 deletions flash/data/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,13 @@ class Preprocess(Properties, torch.nn.Module):
def __init__(
self,
train_transform: Optional[Union[Callable, Module, Dict[str, Callable]]] = None,
valid_transform: Optional[Union[Callable, Module, Dict[str, Callable]]] = None,
val_transform: Optional[Union[Callable, Module, Dict[str, Callable]]] = None,
test_transform: Optional[Union[Callable, Module, Dict[str, Callable]]] = None,
predict_transform: Optional[Union[Callable, Module, Dict[str, Callable]]] = None,
):
super().__init__()
self.train_transform = convert_to_modules(train_transform)
self.valid_transform = convert_to_modules(valid_transform)
self.val_transform = convert_to_modules(val_transform)
self.test_transform = convert_to_modules(test_transform)
self.predict_transform = convert_to_modules(predict_transform)

Expand Down
34 changes: 17 additions & 17 deletions flash/tabular/classification/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def state(self) -> TabularState:
@staticmethod
def generate_state(
train_df: DataFrame,
valid_df: Optional[DataFrame],
val_df: Optional[DataFrame],
test_df: Optional[DataFrame],
predict_df: Optional[DataFrame],
target_col: str,
Expand All @@ -100,8 +100,8 @@ def generate_state(

dfs = [train_df]

if valid_df is not None:
dfs += [valid_df]
if val_df is not None:
dfs += [val_df]

if test_df is not None:
dfs += [test_df]
Expand Down Expand Up @@ -197,7 +197,7 @@ def from_csv(
train_csv: Optional[str] = None,
categorical_cols: Optional[List] = None,
numerical_cols: Optional[List] = None,
valid_csv: Optional[str] = None,
val_csv: Optional[str] = None,
test_csv: Optional[str] = None,
predict_csv: Optional[str] = None,
batch_size: int = 8,
Expand All @@ -215,7 +215,7 @@ def from_csv(
target_col: The column containing the class id.
categorical_cols: The list of categorical columns.
numerical_cols: The list of numerical columns.
valid_csv: Validation data csv file.
val_csv: Validation data csv file.
test_csv: Test data csv file.
batch_size: The batchsize to use for parallel loading. Defaults to 64.
num_workers: The number of workers to use for parallelized loading.
Expand All @@ -234,7 +234,7 @@ def from_csv(
text_data = TabularData.from_files("train.csv", label_field="class", text_field="sentence")
"""
train_df = pd.read_csv(train_csv, **pandas_kwargs)
valid_df = pd.read_csv(valid_csv, **pandas_kwargs) if valid_csv else None
val_df = pd.read_csv(val_csv, **pandas_kwargs) if val_csv else None
test_df = pd.read_csv(test_csv, **pandas_kwargs) if test_csv else None
predict_df = pd.read_csv(predict_csv, **pandas_kwargs) if predict_csv else None

Expand All @@ -243,7 +243,7 @@ def from_csv(
target_col,
categorical_cols,
numerical_cols,
valid_df,
val_df,
test_df,
predict_df,
batch_size,
Expand All @@ -268,21 +268,21 @@ def emb_sizes(self) -> list:
@staticmethod
def _split_dataframe(
train_df: DataFrame,
valid_df: Optional[DataFrame] = None,
val_df: Optional[DataFrame] = None,
test_df: Optional[DataFrame] = None,
val_size: float = None,
test_size: float = None,
):
if valid_df is None and isinstance(val_size, float) and isinstance(test_size, float):
if val_df is None and isinstance(val_size, float) and isinstance(test_size, float):
assert 0 < val_size < 1
assert 0 < test_size < 1
train_df, valid_df = train_test_split(train_df, test_size=(val_size + test_size))
train_df, val_df = train_test_split(train_df, test_size=(val_size + test_size))

if test_df is None and isinstance(test_size, float):
assert 0 < test_size < 1
valid_df, test_df = train_test_split(valid_df, test_size=test_size)
val_df, test_df = train_test_split(val_df, test_size=test_size)

return train_df, valid_df, test_df
return train_df, val_df, test_df

@staticmethod
def _sanetize_cols(cat_cols: Optional[List], num_cols: Optional[List]):
Expand All @@ -298,7 +298,7 @@ def from_df(
target_col: str,
categorical_cols: Optional[List] = None,
numerical_cols: Optional[List] = None,
valid_df: Optional[DataFrame] = None,
val_df: Optional[DataFrame] = None,
test_df: Optional[DataFrame] = None,
predict_df: Optional[DataFrame] = None,
batch_size: int = 8,
Expand All @@ -316,7 +316,7 @@ def from_df(
target_col: The column containing the class id.
categorical_cols: The list of categorical columns.
numerical_cols: The list of numerical columns.
valid_df: Validation data DataFrame.
val_df: Validation data DataFrame.
test_df: Test data DataFrame.
batch_size: The batchsize to use for parallel loading. Defaults to 64.
num_workers: The number of workers to use for parallelized loading.
Expand All @@ -334,13 +334,13 @@ def from_df(
"""
categorical_cols, numerical_cols = cls._sanetize_cols(categorical_cols, numerical_cols)

train_df, valid_df, test_df = cls._split_dataframe(train_df, valid_df, test_df, val_size, test_size)
train_df, val_df, test_df = cls._split_dataframe(train_df, val_df, test_df, val_size, test_size)

preprocess_cls = preprocess_cls or cls.preprocess_cls

preprocess_state = preprocess_cls.generate_state(
train_df,
valid_df,
val_df,
test_df,
predict_df,
target_col,
Expand All @@ -353,7 +353,7 @@ def from_df(

return cls.from_load_data_inputs(
train_load_data_input=train_df,
valid_load_data_input=valid_df,
val_load_data_input=val_df,
test_load_data_input=test_df,
predict_load_data_input=predict_df,
batch_size=batch_size,
Expand Down
8 changes: 4 additions & 4 deletions flash/text/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def from_files(
target: Optional[str] = 'labels',
filetype: str = "csv",
backbone: str = "prajjwal1/bert-tiny",
valid_file: Optional[str] = None,
val_file: Optional[str] = None,
test_file: Optional[str] = None,
predict_file: Optional[str] = None,
max_length: int = 128,
Expand All @@ -255,7 +255,7 @@ def from_files(
target: The field storing the class id of the associated text.
filetype: .csv or .json
backbone: Tokenizer backbone to use, can use any HuggingFace tokenizer.
valid_file: Path to validation data.
val_file: Path to validation data.
test_file: Path to test data.
batch_size: the batchsize to use for parallel loading. Defaults to 64.
num_workers: The number of workers to use for parallelized loading.
Expand Down Expand Up @@ -287,7 +287,7 @@ def from_files(

return cls.from_load_data_inputs(
train_load_data_input=train_file,
valid_load_data_input=valid_file,
val_load_data_input=val_file,
test_load_data_input=test_file,
predict_load_data_input=predict_file,
batch_size=batch_size,
Expand Down Expand Up @@ -327,7 +327,7 @@ def from_file(
target=None,
filetype=filetype,
backbone=backbone,
valid_file=None,
val_file=None,
test_file=None,
predict_file=predict_file,
max_length=max_length,
Expand Down
6 changes: 3 additions & 3 deletions flash/text/seq2seq/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def from_files(
target: Optional[str] = None,
filetype: str = "csv",
backbone: str = "sshleifer/tiny-mbart",
valid_file: Optional[str] = None,
val_file: Optional[str] = None,
test_file: Optional[str] = None,
predict_file: Optional[str] = None,
max_source_length: int = 128,
Expand All @@ -185,7 +185,7 @@ def from_files(
target: The field storing the target translation text.
filetype: ``csv`` or ``json`` File
backbone: Tokenizer backbone to use, can use any HuggingFace tokenizer.
valid_file: Path to validation data.
val_file: Path to validation data.
test_file: Path to test data.
max_source_length: Maximum length of the source text. Any text longer will be truncated.
max_target_length: Maximum length of the target text. Any text longer will be truncated.
Expand Down Expand Up @@ -217,7 +217,7 @@ def from_files(

return cls.from_load_data_inputs(
train_load_data_input=train_file,
valid_load_data_input=valid_file,
val_load_data_input=val_file,
test_load_data_input=test_file,
predict_load_data_input=predict_file,
batch_size=batch_size,
Expand Down
Loading

0 comments on commit e72d96d

Please sign in to comment.