Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: table.keep_only_columns now maps column names to correct data #194

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions src/safeds/data/tabular/containers/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,9 @@ def from_columns(columns: list[Column]) -> Table:

Raises
------
MissingDataError
If an empty list is given.
ColumnLengthMismatchError
If any of the column sizes does not match with the others.
"""
if len(columns) == 0:
raise MissingDataError("This function requires at least one column.")

dataframe: DataFrame = pd.DataFrame()

for column in columns:
Expand Down Expand Up @@ -566,7 +561,7 @@ def keep_only_columns(self, column_names: list[str]) -> Table:
Raises
------
ColumnNameError
If any of the given columns do not exist.
If any of the given columns does not exist.
"""
invalid_columns = []
column_indices = []
Expand All @@ -578,7 +573,7 @@ def keep_only_columns(self, column_names: list[str]) -> Table:
if len(invalid_columns) != 0:
raise UnknownColumnNameError(invalid_columns)
transformed_data = self._data[column_indices]
transformed_data.columns = [name for name in self._schema.get_column_names() if name in column_names]
transformed_data.columns = column_names
return Table(transformed_data)

def remove_columns(self, column_names: list[str]) -> Table:
Expand All @@ -598,7 +593,7 @@ def remove_columns(self, column_names: list[str]) -> Table:
Raises
------
ColumnNameError
If any of the given columns do not exist.
If any of the given columns does not exist.
"""
invalid_columns = []
column_indices = []
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import pandas as pd
import pytest
from safeds.data.tabular.containers import Column, Table
from safeds.data.tabular.exceptions import MissingDataError

from tests.helpers import resolve_resource_path

Expand All @@ -15,8 +13,3 @@ def test_from_columns() -> None:
table_restored: Table = Table.from_columns(columns_table)

assert table_restored == table_expected


def test_from_columns_invalid() -> None:
with pytest.raises(MissingDataError):
Table.from_columns([])
lars-reimann marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,18 +1,45 @@
import pytest
from safeds.data.tabular.containers import Table
from safeds.data.tabular.containers import Column, Table
from safeds.data.tabular.exceptions import UnknownColumnNameError

from tests.helpers import resolve_resource_path

class TestKeepOnlyColumns:
@pytest.mark.parametrize(
("table", "column_names", "expected"),
[
(
Table.from_columns([Column("A", [1]), Column("B", [2])]),
[],
Table.from_columns([]),
),
(
Table.from_columns([Column("A", [1]), Column("B", [2])]),
["A"],
Table.from_columns([Column("A", [1])]),
),
(
Table.from_columns([Column("A", [1]), Column("B", [2])]),
["B"],
Table.from_columns([Column("B", [2])]),
),
(
Table.from_columns([Column("A", [1]), Column("B", [2])]),
["A", "B"],
Table.from_columns([Column("A", [1]), Column("B", [2])]),
),
# Related to https://github.com/Safe-DS/Stdlib/issues/115
(
Table.from_columns([Column("A", [1]), Column("B", [2]), Column("C", [3])]),
["C", "A"],
Table.from_columns([Column("C", [3]), Column("A", [1])]),
),
],
)
def test_should_keep_only_listed_columns(self, table: Table, column_names: list[str], expected: Table) -> None:
transformed_table = table.keep_only_columns(column_names)
assert transformed_table == expected

def test_keep_columns() -> None:
table = Table.from_csv_file(resolve_resource_path("test_table_from_csv_file.csv"))
transformed_table = table.keep_only_columns(["A"])
assert transformed_table.schema.has_column("A")
assert not transformed_table.schema.has_column("B")


def test_keep_columns_warning() -> None:
table = Table.from_csv_file(resolve_resource_path("test_table_from_csv_file.csv"))
with pytest.raises(UnknownColumnNameError):
table.keep_only_columns(["C"])
def test_should_raise_if_column_does_no_exist(self) -> None:
table = Table.from_columns([Column("A", [1]), Column("B", [2])])
with pytest.raises(UnknownColumnNameError):
table.keep_only_columns(["C"])