Skip to content

Commit

Permalink
Fixed splitting when providing pre-split inputs (#2248)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Jul 11, 2022
1 parent 7c929d3 commit a4eb05c
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 2 deletions.
2 changes: 2 additions & 0 deletions ludwig/data/dataset/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def __init__(self, dataset, features, data_hdf5_fp):
self.features = features
self.data_hdf5_fp = data_hdf5_fp
self.size = len(dataset)
if self.size == 0:
raise ValueError("Dataset is empty following preprocessing")
self.dataset = to_numpy_dataset(dataset)

def to_df(self, features: Optional[Iterable[BaseFeature]] = None) -> DataFrame:
Expand Down
2 changes: 2 additions & 0 deletions ludwig/data/dataset/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def __init__(
):
self.df_engine = backend.df_engine
self.ds = self.df_engine.to_ray_dataset(df) if not isinstance(df, str) else read_remote_parquet(df)
if self.size == 0:
raise ValueError("Dataset is empty following preprocessing")
self.features = features
self.training_set_metadata = training_set_metadata
self.data_hdf5_fp = training_set_metadata.get(DATA_TRAIN_HDF5_FP)
Expand Down
16 changes: 16 additions & 0 deletions ludwig/data/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1716,6 +1716,22 @@ def _preprocess_df_for_training(
# needs preprocessing
logger.info("Using training dataframe")
dataset = concatenate_df(training_set, validation_set, test_set, backend)

# Data is pre-split, so we override whatever split policy the user specified
if preprocessing_params["split"]:
warnings.warn(
'Preprocessing "split" section provided, but pre-split dataset given as input. '
"Ignoring split configuration."
)

preprocessing_params = {
**preprocessing_params,
"split": {
"type": "fixed",
"column": SPLIT,
},
}

logger.info("Building dataset (it may take a while)")

data, training_set_metadata = build_dataset(
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ requests
tables
fsspec[http]
dataclasses-json
jsonschema>=4.5.0
jsonschema>=4.5.0,<4.7
marshmallow
marshmallow-jsonschema
marshmallow-dataclass==8.5.5
Expand Down
81 changes: 80 additions & 1 deletion tests/integration_tests/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def test_dask_known_divisions(feature_fn, csv_filename, tmpdir):
data_csv = generate_data(
input_features, output_features, os.path.join(tmpdir, csv_filename), num_examples=num_examples
)
data_df = dd.from_pandas(pd.read_csv(data_csv), npartitions=1)
data_df = dd.from_pandas(pd.read_csv(data_csv), npartitions=10)
assert data_df.known_divisions

config = {
Expand Down Expand Up @@ -232,3 +232,82 @@ def random_string():
# check that train_ds had invalid values replaced with the missing value
assert len(concatenated_df) == len(df)
assert np.all(concatenated_df[num_feat[PROC_COLUMN]] == 0.0)


@pytest.mark.parametrize("format", ["file", "df"])
def test_presplit_override(format, tmpdir):
"""Tests that provising a pre-split file or dataframe overrides the user's split config."""
num_feat = number_feature(normalization=None)
input_features = [num_feat, sequence_feature(reduce_output="sum")]
output_features = [category_feature(vocab_size=5, reduce_input="sum")]

data_csv = generate_data(input_features, output_features, os.path.join(tmpdir, "dataset.csv"), num_examples=25)
data_df = pd.read_csv(data_csv)

# Set the feature value equal to an ordinal index so we can ensure the splits are identical before and after
# preprocessing.
data_df[num_feat[COLUMN]] = data_df.index

train_df = data_df[:15]
val_df = data_df[15:20]
test_df = data_df[20:]

train_data = train_df
val_data = val_df
test_data = test_df

if format == "file":
train_data = os.path.join(tmpdir, "train.csv")
val_data = os.path.join(tmpdir, "val.csv")
test_data = os.path.join(tmpdir, "test.csv")

train_df.to_csv(train_data)
val_df.to_csv(val_data)
test_df.to_csv(test_data)

data_df.to_csv(data_csv, index=False)
config = {
"input_features": input_features,
"output_features": output_features,
"trainer": {
"epochs": 2,
},
"preprocessing": {"split": {"type": "random"}},
}

model = LudwigModel(config, backend=LocalTestBackend())
train_set, val_set, test_set, _ = model.preprocess(
training_set=train_data, validation_set=val_data, test_set=test_data
)

assert len(train_set) == len(train_df)
assert len(val_set) == len(val_df)
assert len(test_set) == len(test_df)

assert np.all(train_set.to_df()[num_feat[PROC_COLUMN]].values == train_df[num_feat[COLUMN]].values)
assert np.all(val_set.to_df()[num_feat[PROC_COLUMN]].values == val_df[num_feat[COLUMN]].values)
assert np.all(test_set.to_df()[num_feat[PROC_COLUMN]].values == test_df[num_feat[COLUMN]].values)


@pytest.mark.parametrize("backend", ["local", "ray"])
@pytest.mark.distributed
def test_empty_split_error(backend, tmpdir):
"""Tests that an error is raised if one or more of the splits is empty after preprocessing."""
data_csv_path = os.path.join(tmpdir, "data.csv")

out_feat = binary_feature()
input_features = [number_feature()]
output_features = [out_feat]
config = {"input_features": input_features, "output_features": output_features}

training_data_csv_path = generate_data(input_features, output_features, data_csv_path)
df = pd.read_csv(training_data_csv_path)

# Convert all the output features rows to null. Because the default missing value strategy is to drop empty output
# rows, this will result in the dataset being empty after preprocessing.
df[out_feat[COLUMN]] = None

with init_backend(backend):
ludwig_model = LudwigModel(config, backend=backend)
with pytest.raises(ValueError, match="Dataset is empty following preprocessing"):
ludwig_model.preprocess(dataset=df)

0 comments on commit a4eb05c

Please sign in to comment.