Skip to content

Commit

Permalink
feat: optionally pass type to column (#79)
Browse files Browse the repository at this point in the history
Closes #78.

### Summary of Changes

For the sake of consistency it is now possible to pass the type of a
`Column` in the constructor. This also improves performance, for example
when we call `get_column` on a `Table`. In that case we already know the
type of the column anyway, so there's no reason to infer it again.

---------

Co-authored-by: lars-reimann <[email protected]>
  • Loading branch information
lars-reimann and lars-reimann authored Mar 25, 2023
1 parent bc63693 commit 64aa429
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 7 deletions.
33 changes: 27 additions & 6 deletions src/safeds/data/tabular/containers/_column.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from __future__ import annotations

import typing
from numbers import Number
from typing import Any, Callable, Iterator
from typing import Any, Callable, Iterable, Iterator, Optional

import numpy as np
import pandas as pd
Expand All @@ -17,10 +16,32 @@


class Column:
def __init__(self, data: typing.Iterable, name: str) -> None:
"""
A column of data.
Parameters
----------
data : Iterable
The data.
name : str
The name of the column.
type_ : Optional[ColumnType]
The type of the column. If not specified, the type will be inferred from the data.
"""

def __init__(
self,
data: Iterable,
name: str,
type_: Optional[ColumnType] = None,
) -> None:
self._data: pd.Series = data if isinstance(data, pd.Series) else pd.Series(data)
self._name: str = name
self._type: ColumnType = ColumnType.from_numpy_dtype(self._data.dtype)
self._type: ColumnType = (
type_
if type_ is not None
else ColumnType.from_numpy_dtype(self._data.dtype)
)

@property
def name(self) -> str:
Expand Down Expand Up @@ -135,7 +156,7 @@ def rename(self, new_name: str) -> Column:
column : Column
A new column with the new name.
"""
return Column(self._data, new_name)
return Column(self._data, new_name, self._type)

def all(self, predicate: Callable[[Any], bool]) -> bool:
"""
Expand Down Expand Up @@ -250,7 +271,7 @@ def correlation_with(self, other_column: Column) -> float:
)
return self._data.corr(other_column._data)

def get_unique_values(self) -> list[typing.Any]:
def get_unique_values(self) -> list[Any]:
"""
Return a list of all unique values in the column.
Expand Down
2 changes: 1 addition & 1 deletion src/safeds/data/tabular/containers/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,8 @@ def get_column(self, column_name: str) -> Column:
:, [self.schema._get_column_index_by_name(column_name)]
].squeeze(),
column_name,
self.schema.get_type_of_column(column_name),
)
output_column._type = self.schema.get_type_of_column(column_name)
return output_column

raise UnknownColumnNameError([column_name])
Expand Down

0 comments on commit 64aa429

Please sign in to comment.