From 65aa566c87d08c1d196b1c1e146ce16f2f0be43d Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 11 Jul 2022 13:35:45 -0700 Subject: [PATCH] Fixed splitting when providing pre-split inputs (#2248) --- ludwig/data/dataset/pandas.py | 2 + ludwig/data/dataset/ray.py | 2 + ludwig/data/preprocessing.py | 16 ++++ requirements.txt | 2 +- tests/integration_tests/test_preprocessing.py | 81 ++++++++++++++++++- 5 files changed, 101 insertions(+), 2 deletions(-) diff --git a/ludwig/data/dataset/pandas.py b/ludwig/data/dataset/pandas.py index c93b341f411..8f90292ada9 100644 --- a/ludwig/data/dataset/pandas.py +++ b/ludwig/data/dataset/pandas.py @@ -32,6 +32,8 @@ def __init__(self, dataset, features, data_hdf5_fp): self.features = features self.data_hdf5_fp = data_hdf5_fp self.size = len(dataset) + if self.size == 0: + raise ValueError("Dataset is empty following preprocessing") self.dataset = to_numpy_dataset(dataset) def get(self, proc_column, idx=None): diff --git a/ludwig/data/dataset/ray.py b/ludwig/data/dataset/ray.py index d8bb507e27a..88b4141e4ce 100644 --- a/ludwig/data/dataset/ray.py +++ b/ludwig/data/dataset/ray.py @@ -60,6 +60,8 @@ def __init__( ): self.df_engine = backend.df_engine self.ds = self.df_engine.to_ray_dataset(df) if not isinstance(df, str) else read_remote_parquet(df) + if self.size == 0: + raise ValueError("Dataset is empty following preprocessing") self.features = features self.training_set_metadata = training_set_metadata self.data_hdf5_fp = training_set_metadata.get(DATA_TRAIN_HDF5_FP) diff --git a/ludwig/data/preprocessing.py b/ludwig/data/preprocessing.py index c8fc435cd99..ed454c05970 100644 --- a/ludwig/data/preprocessing.py +++ b/ludwig/data/preprocessing.py @@ -1742,6 +1742,22 @@ def _preprocess_df_for_training( # needs preprocessing logger.info("Using training dataframe") dataset = concatenate_df(training_set, validation_set, test_set, backend) + + # Data is pre-split, so we override whatever split policy the user specified + if preprocessing_params["split"]: + warnings.warn( + 'Preprocessing "split" section provided, but pre-split dataset given as input. ' + "Ignoring split configuration." + ) + + preprocessing_params = { + **preprocessing_params, + "split": { + "type": "fixed", + "column": SPLIT, + }, + } + logger.info("Building dataset (it may take a while)") dataset, training_set_metadata = build_dataset( diff --git a/requirements.txt b/requirements.txt index 6143cc6bae4..03e114eae0e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,7 +19,7 @@ requests tables fsspec[http] dataclasses-json -jsonschema>=4.5.0 +jsonschema>=4.5.0,<4.7 marshmallow marshmallow-jsonschema marshmallow-dataclass==8.5.5 diff --git a/tests/integration_tests/test_preprocessing.py b/tests/integration_tests/test_preprocessing.py index 32befcfa57a..4e384d173b0 100644 --- a/tests/integration_tests/test_preprocessing.py +++ b/tests/integration_tests/test_preprocessing.py @@ -127,7 +127,7 @@ def test_dask_known_divisions(feature_fn, csv_filename, tmpdir): data_csv = generate_data( input_features, output_features, os.path.join(tmpdir, csv_filename), num_examples=num_examples ) - data_df = dd.from_pandas(pd.read_csv(data_csv), npartitions=1) + data_df = dd.from_pandas(pd.read_csv(data_csv), npartitions=10) assert data_df.known_divisions config = { @@ -231,3 +231,82 @@ def random_string(): # check that train_ds had invalid values replaced with the missing value assert len(concatenated_df) == len(df) assert np.all(concatenated_df[num_feat[PROC_COLUMN]] == 0.0) + + +@pytest.mark.parametrize("format", ["file", "df"]) +def test_presplit_override(format, tmpdir): + """Tests that provising a pre-split file or dataframe overrides the user's split config.""" + num_feat = number_feature(normalization=None) + input_features = [num_feat, sequence_feature(reduce_output="sum")] + output_features = [category_feature(vocab_size=5, reduce_input="sum")] + + data_csv = generate_data(input_features, output_features, os.path.join(tmpdir, "dataset.csv"), num_examples=25) + data_df = pd.read_csv(data_csv) + + # Set the feature value equal to an ordinal index so we can ensure the splits are identical before and after + # preprocessing. + data_df[num_feat[COLUMN]] = data_df.index + + train_df = data_df[:15] + val_df = data_df[15:20] + test_df = data_df[20:] + + train_data = train_df + val_data = val_df + test_data = test_df + + if format == "file": + train_data = os.path.join(tmpdir, "train.csv") + val_data = os.path.join(tmpdir, "val.csv") + test_data = os.path.join(tmpdir, "test.csv") + + train_df.to_csv(train_data) + val_df.to_csv(val_data) + test_df.to_csv(test_data) + + data_df.to_csv(data_csv, index=False) + config = { + "input_features": input_features, + "output_features": output_features, + "trainer": { + "epochs": 2, + }, + "preprocessing": {"split": {"type": "random"}}, + } + + model = LudwigModel(config, backend=LocalTestBackend()) + train_set, val_set, test_set, _ = model.preprocess( + training_set=train_data, validation_set=val_data, test_set=test_data + ) + + assert len(train_set) == len(train_df) + assert len(val_set) == len(val_df) + assert len(test_set) == len(test_df) + + assert np.all(train_set.to_df()[num_feat[PROC_COLUMN]].values == train_df[num_feat[COLUMN]].values) + assert np.all(val_set.to_df()[num_feat[PROC_COLUMN]].values == val_df[num_feat[COLUMN]].values) + assert np.all(test_set.to_df()[num_feat[PROC_COLUMN]].values == test_df[num_feat[COLUMN]].values) + + +@pytest.mark.parametrize("backend", ["local", "ray"]) +@pytest.mark.distributed +def test_empty_split_error(backend, tmpdir): + """Tests that an error is raised if one or more of the splits is empty after preprocessing.""" + data_csv_path = os.path.join(tmpdir, "data.csv") + + out_feat = binary_feature() + input_features = [number_feature()] + output_features = [out_feat] + config = {"input_features": input_features, "output_features": output_features} + + training_data_csv_path = generate_data(input_features, output_features, data_csv_path) + df = pd.read_csv(training_data_csv_path) + + # Convert all the output features rows to null. Because the default missing value strategy is to drop empty output + # rows, this will result in the dataset being empty after preprocessing. + df[out_feat[COLUMN]] = None + + with init_backend(backend): + ludwig_model = LudwigModel(config, backend=backend) + with pytest.raises(ValueError, match="Dataset is empty following preprocessing"): + ludwig_model.preprocess(dataset=df)