Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: raise if remove_colums is called with unknown column by default #852

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions src/safeds/data/tabular/containers/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,20 +642,23 @@ def remove_columns(
self,
names: str | list[str],
/,
*,
ignore_unknown_names: bool = False,
) -> Table:
"""
Return a new table without the specified columns.

**Notes:**

- The original table is not modified.
- This method does not raise if a column does not exist. You can use it to ensure that the resulting table does
not contain certain columns.

Parameters
lars-reimann marked this conversation as resolved.
Show resolved Hide resolved
----------
names:
The names of the columns to remove.
ignore_unknown_names:
If set to True, columns that are not present in the table will be ignored.
If set to False, an error will be raised if any of the specified columns do not exist.

Returns
-------
Expand All @@ -677,7 +680,7 @@ def remove_columns(
| 6 |
+-----+

>>> table.remove_columns(["c"])
>>> table.remove_columns(["c"], ignore_unknown_names=True)
+-----+-----+
| a | b |
| --- | --- |
Expand All @@ -691,6 +694,9 @@ def remove_columns(
if isinstance(names, str):
names = [names]

if not ignore_unknown_names:
_check_columns_exist(self, names)

return Table._from_polars_lazy_frame(
self._lazy_frame.drop(names),
)
Expand Down Expand Up @@ -931,7 +937,7 @@ def replace_column(
_check_columns_dont_exist(self, [column.name for column in new_columns], old_name=old_name)

if len(new_columns) == 0:
return self.remove_columns(old_name)
return self.remove_columns(old_name, ignore_unknown_names=True)

if len(new_columns) == 1:
new_column = new_columns[0]
Expand Down
2 changes: 1 addition & 1 deletion src/safeds/ml/classical/_supervised_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def _predict_with_sklearn_model(
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="X does not have valid feature names")
predicted_target_vector = model.predict(features._data_frame)
output = dataset.remove_columns(target_name).add_columns(
output = dataset.remove_columns(target_name, ignore_unknown_names=True).add_columns(
Column(target_name, predicted_target_vector),
)

Expand Down
2 changes: 1 addition & 1 deletion src/safeds/ml/classical/regression/_arima.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def predict(self, time_series: TimeSeriesDataset) -> Table:
# make a table without
forecast_horizon = len(time_series.target._series.to_numpy())
result_table = time_series.to_table()
result_table = result_table.remove_columns([time_series.target.name])
result_table = result_table.remove_columns([time_series.target.name], ignore_unknown_names=True)
# Validation
if not self.is_fitted or self._arima is None:
raise ModelNotFittedError
Expand Down
Original file line number Diff line number Diff line change
@@ -1,26 +1,67 @@
import pytest
from safeds.data.tabular.containers import Table
from safeds.exceptions import ColumnNotFoundError


# Test cases where no exception is expected
@pytest.mark.parametrize(
("table", "expected", "columns"),
("table", "expected", "columns", "ignore_unknown_names"),
[
(Table({"col1": [1, 2, 1], "col2": ["a", "b", "c"]}), Table({"col1": [1, 2, 1]}), ["col2"]),
(Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}), Table(), ["col1", "col2"]),
(Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}), Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}), []),
(Table(), Table(), []),
(Table(), Table(), ["col1"]),
(Table({"col1": [1, 2, 1], "col2": ["a", "b", "c"]}), Table({"col1": [1, 2, 1]}), ["col2"], True),
(Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}), Table(), ["col1", "col2"], True),
(Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}), Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}), [], True),
(Table(), Table(), [], True),
(Table(), Table(), ["col1"], True),
(Table({"col1": [1, 2, 1], "col2": ["a", "b", "c"]}), Table({"col1": [1, 2, 1]}), ["col2"], False),
(Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}), Table(), ["col1", "col2"], False),
(
Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}),
Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}),
[],
False,
),
(Table(), Table(), [], False),
],
ids=[
"one column, ignore unknown names",
"multiple columns, ignore unknown names",
"no columns, ignore unknown names",
"empty, ignore unknown names",
"missing columns, ignore unknown names",
"one column",
"multiple columns",
"no columns",
"empty",
"missing columns",
],
)
def test_should_remove_table_columns(table: Table, expected: Table, columns: list[str]) -> None:
table = table.remove_columns(columns)
def test_should_remove_table_columns_no_exception(
table: Table,
expected: Table,
columns: list[str],
ignore_unknown_names: bool,
) -> None:
table = table.remove_columns(columns, ignore_unknown_names=ignore_unknown_names)
assert table.schema == expected.schema
assert table == expected
assert table.row_count == expected.row_count


# Test cases where an exception is expected
@pytest.mark.parametrize(
("table", "columns", "ignore_unknown_names"),
[
(Table(), ["col1"], False),
(Table(), ["col12"], False),
],
ids=[
"missing columns",
"missing columns",
],
)
def test_should_raise_error_for_unknown_columns(
table: Table,
columns: list[str],
ignore_unknown_names: bool,
) -> None:
with pytest.raises(ColumnNotFoundError):
table.remove_columns(columns, ignore_unknown_names=ignore_unknown_names)
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def test_should_raise_if_not_fitted(self, classifier: Classifier, valid_data: Ta
def test_should_raise_if_dataset_misses_features(self, classifier: Classifier, valid_data: TabularDataset) -> None:
fitted_classifier = classifier.fit(valid_data)
with pytest.raises(DatasetMissesFeaturesError, match="[feat1, feat2]"):
fitted_classifier.predict(valid_data.features.remove_columns(["feat1", "feat2"]))
fitted_classifier.predict(valid_data.features.remove_columns(["feat1", "feat2"], ignore_unknown_names=True))

@pytest.mark.parametrize(
("invalid_data", "expected_error", "expected_error_msg"),
Expand Down
2 changes: 1 addition & 1 deletion tests/safeds/ml/classical/regression/test_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def test_should_raise_if_not_fitted(self, regressor: Regressor, valid_data: Tabu
def test_should_raise_if_dataset_misses_features(self, regressor: Regressor, valid_data: TabularDataset) -> None:
fitted_regressor = regressor.fit(valid_data)
with pytest.raises(DatasetMissesFeaturesError, match="[feat1, feat2]"):
fitted_regressor.predict(valid_data.features.remove_columns(["feat1", "feat2"]))
fitted_regressor.predict(valid_data.features.remove_columns(["feat1", "feat2"], ignore_unknown_names=True))

@pytest.mark.parametrize(
("invalid_data", "expected_error", "expected_error_msg"),
Expand Down
2 changes: 1 addition & 1 deletion tests/safeds/ml/nn/test_forward_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_forward_model(device: Device) -> None:
table_1 = Table.from_csv_file(
path=resolve_resource_path(_inflation_path),
)
table_1 = table_1.remove_columns(["date"])
table_1 = table_1.remove_columns(["date"], ignore_unknown_names=True)
table_2 = table_1.slice_rows(start=0, length=table_1.row_count - 14)
table_2 = table_2.add_columns([(table_1.slice_rows(start=14)).get_column("value").rename("target")])
train_table, test_table = table_2.split_rows(0.8)
Expand Down