diff --git a/src/safeds/data/tabular/containers/_column.py b/src/safeds/data/tabular/containers/_column.py index 1eb3b9d69..a7bc1bc9e 100644 --- a/src/safeds/data/tabular/containers/_column.py +++ b/src/safeds/data/tabular/containers/_column.py @@ -1,8 +1,7 @@ from __future__ import annotations -import typing from numbers import Number -from typing import Any, Callable, Iterator +from typing import Any, Callable, Iterable, Iterator, Optional import numpy as np import pandas as pd @@ -17,10 +16,32 @@ class Column: - def __init__(self, data: typing.Iterable, name: str) -> None: + """ + A column of data. + + Parameters + ---------- + data : Iterable + The data. + name : str + The name of the column. + type_ : Optional[ColumnType] + The type of the column. If not specified, the type will be inferred from the data. + """ + + def __init__( + self, + data: Iterable, + name: str, + type_: Optional[ColumnType] = None, + ) -> None: self._data: pd.Series = data if isinstance(data, pd.Series) else pd.Series(data) self._name: str = name - self._type: ColumnType = ColumnType.from_numpy_dtype(self._data.dtype) + self._type: ColumnType = ( + type_ + if type_ is not None + else ColumnType.from_numpy_dtype(self._data.dtype) + ) @property def name(self) -> str: @@ -135,7 +156,7 @@ def rename(self, new_name: str) -> Column: column : Column A new column with the new name. """ - return Column(self._data, new_name) + return Column(self._data, new_name, self._type) def all(self, predicate: Callable[[Any], bool]) -> bool: """ @@ -250,7 +271,7 @@ def correlation_with(self, other_column: Column) -> float: ) return self._data.corr(other_column._data) - def get_unique_values(self) -> list[typing.Any]: + def get_unique_values(self) -> list[Any]: """ Return a list of all unique values in the column. diff --git a/src/safeds/data/tabular/containers/_table.py b/src/safeds/data/tabular/containers/_table.py index 0f18ded8c..1682338ab 100644 --- a/src/safeds/data/tabular/containers/_table.py +++ b/src/safeds/data/tabular/containers/_table.py @@ -313,8 +313,8 @@ def get_column(self, column_name: str) -> Column: :, [self.schema._get_column_index_by_name(column_name)] ].squeeze(), column_name, + self.schema.get_type_of_column(column_name), ) - output_column._type = self.schema.get_type_of_column(column_name) return output_column raise UnknownColumnNameError([column_name])