Skip to content

Commit

Permalink
Merge branch 'main' into 169-set-c-parameter-for-regularization-of-a-…
Browse files Browse the repository at this point in the history
…supportvectormachine
  • Loading branch information
lars-reimann authored May 5, 2023
2 parents 7017360 + 5adadad commit 9f9ba6f
Show file tree
Hide file tree
Showing 19 changed files with 282 additions and 52 deletions.
31 changes: 29 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pandas = "^2.0.0"
pillow = "^9.5.0"
scikit-learn = "^1.2.0"
seaborn = "^0.12.2"
openpyxl = "^3.1.2"

[tool.poetry.group.dev.dependencies]
pytest = "^7.2.1"
Expand Down
49 changes: 49 additions & 0 deletions src/safeds/data/tabular/containers/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import matplotlib.pyplot as plt
import numpy as np
import openpyxl
import pandas as pd
import seaborn as sns
from pandas import DataFrame
Expand Down Expand Up @@ -84,6 +85,33 @@ def from_csv_file(path: str | Path) -> Table:
except FileNotFoundError as exception:
raise FileNotFoundError(f'File "{path}" does not exist') from exception

@staticmethod
def from_excel_file(path: str | Path) -> Table:
"""
Read data from an Excel file into a table.
Parameters
----------
path : str | Path
The path to the Excel file.
Returns
-------
table : Table
The table created from the Excel file.
Raises
------
FileNotFoundError
If the specified file does not exist.
ValueError
If the file could not be read.
"""
try:
return Table(pd.read_excel(path, engine="openpyxl", usecols=lambda colname: "Unnamed" not in colname))
except FileNotFoundError as exception:
raise FileNotFoundError(f'File "{path}" does not exist') from exception

@staticmethod
def from_json_file(path: str | Path) -> Table:
"""
Expand Down Expand Up @@ -1242,6 +1270,27 @@ def to_csv_file(self, path: str | Path) -> None:
data_to_csv.columns = self._schema.column_names
data_to_csv.to_csv(path, index=False)

def to_excel_file(self, path: str | Path) -> None:
"""
Write the data from the table into an Excel file.
If the file and/or the directories do not exist, they will be created. If the file already exists, it will be
overwritten.
Parameters
----------
path : str | Path
The path to the output file.
"""
# Create Excel metadata in the file
tmp_table_file = openpyxl.Workbook()
tmp_table_file.save(path)

Path(path).parent.mkdir(parents=True, exist_ok=True)
data_to_excel = self._data.copy()
data_to_excel.columns = self._schema.column_names
data_to_excel.to_excel(path)

def to_json_file(self, path: str | Path) -> None:
"""
Write the data from the table into a JSON file.
Expand Down
44 changes: 41 additions & 3 deletions src/safeds/ml/classical/classification/_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,45 @@ def accuracy(self, validation_or_test_set: TaggedTable) -> float:
"""
if not isinstance(validation_or_test_set, TaggedTable) and isinstance(validation_or_test_set, Table):
raise UntaggedTableError
expected = validation_or_test_set.target
predicted = self.predict(validation_or_test_set.features).target

return sk_accuracy_score(expected._data, predicted._data)
expected_values = validation_or_test_set.target
predicted_values = self.predict(validation_or_test_set.features).target

return sk_accuracy_score(expected_values._data, predicted_values._data)

def precision(self, validation_or_test_set: TaggedTable, positive_class: int = 1) -> float:
"""
Compute the classifier's precision on the given data.
Parameters
----------
validation_or_test_set : TaggedTable
The validation or test set.
positive_class : int | str
The class to be considered positive. All other classes are considered negative.
Returns
-------
precision : float
The calculated precision score, i.e. the ratio of correctly predicted positives to all predicted positives.
Return 1 if no positive predictions are made.
"""
if not isinstance(validation_or_test_set, TaggedTable) and isinstance(validation_or_test_set, Table):
raise UntaggedTableError

expected_values = validation_or_test_set.target
predicted_values = self.predict(validation_or_test_set.features).target

n_true_positives = 0
n_false_positives = 0

for expected_value, predicted_value in zip(expected_values, predicted_values, strict=True):
if predicted_value == positive_class:
if expected_value == positive_class:
n_true_positives += 1
else:
n_false_positives += 1

if (n_true_positives + n_false_positives) == 0:
return 1.0
return n_true_positives / (n_true_positives + n_false_positives)
Binary file added tests/resources/dummy_excel_file.xlsx
Binary file not shown.
Binary file added tests/resources/image/snapshot_boxplot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/resources/image/snapshot_heatmap.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/resources/image/snapshot_histogram_str.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/resources/image/snapshot_lineplot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/resources/image/snapshot_scatterplot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
29 changes: 10 additions & 19 deletions tests/safeds/data/tabular/containers/_column/test_plot_boxplot.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,20 @@
import _pytest
import matplotlib.pyplot as plt
import pytest
from safeds.data.image.containers import Image
from safeds.data.tabular.containers import Table
from safeds.data.tabular.exceptions import NonNumericColumnError

from tests.helpers import resolve_resource_path

def test_plot_boxplot_complex() -> None:
with pytest.raises(NotImplementedError): # noqa: PT012
table = Table.from_dict({"A": [1, 2, complex(1, -2)]})
table.get_column("A").plot_boxplot()

def test_should_match_snapshot() -> None:
table = Table.from_dict({"A": [1, 2, 3]})
table.get_column("A").plot_boxplot()
current = table.get_column("A").plot_boxplot()
snapshot = Image.from_png_file(resolve_resource_path("./image/snapshot_boxplot.png"))
assert snapshot._image.tobytes() == current._image.tobytes()


def test_plot_boxplot_non_numeric() -> None:
def test_should_raise_if_column_contains_non_numerical_values() -> None:
table = Table.from_dict({"A": [1, 2, "A"]})
with pytest.raises(NonNumericColumnError):
table.get_column("A").plot_boxplot()


def test_plot_boxplot_float(monkeypatch: _pytest.monkeypatch) -> None:
monkeypatch.setattr(plt, "show", lambda: None)
table = Table.from_dict({"A": [1, 2, 3.5]})
table.get_column("A").plot_boxplot()


def test_plot_boxplot_int(monkeypatch: _pytest.monkeypatch) -> None:
monkeypatch.setattr(plt, "show", lambda: None)
table = Table.from_dict({"A": [1, 2, 3]})
table.get_column("A").plot_boxplot()
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
import _pytest
import matplotlib.pyplot as plt
from safeds.data.image.containers import Image
from safeds.data.tabular.containers import Table

from tests.helpers import resolve_resource_path

def test_plot_histogram(monkeypatch: _pytest.monkeypatch) -> None:
monkeypatch.setattr(plt, "show", lambda: None)

def test_should_match_snapshot_numeric() -> None:
table = Table.from_dict({"A": [1, 2, 3]})
table.get_column("A").plot_histogram()
current = table.get_column("A").plot_histogram()
snapshot = Image.from_png_file(resolve_resource_path("./image/snapshot_histogram_numeric.png"))
assert snapshot._image.tobytes() == current._image.tobytes()


def test_should_match_snapshot_str() -> None:
table = Table.from_dict({"A": ["A", "B", "Apple"]})
current = table.get_column("A").plot_histogram()
snapshot = Image.from_png_file(resolve_resource_path("./image/snapshot_histogram_str.png"))
assert snapshot._image.tobytes() == current._image.tobytes()
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from pathlib import Path

import pytest
from safeds.data.tabular.containers import Table

from tests.helpers import resolve_resource_path


@pytest.mark.parametrize(
("path", "expected"),
[
(
resolve_resource_path("./dummy_excel_file.xlsx"),
Table.from_dict(
{
"A": [1],
"B": [2],
},
),
),
(
Path(resolve_resource_path("./dummy_excel_file.xlsx")),
Table.from_dict(
{
"A": [1],
"B": [2],
},
),
),
],
ids=["string path", "object path"],
)
def test_should_create_table_from_excel_file(path: str | Path, expected: Table) -> None:
table = Table.from_excel_file(path)
assert table == expected


def test_should_raise_if_file_not_found() -> None:
with pytest.raises(FileNotFoundError):
Table.from_excel_file(resolve_resource_path("test_table_from_excel_file_invalid.xls"))
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
import _pytest
import matplotlib.pyplot as plt
from safeds.data.image.containers import Image
from safeds.data.tabular.containers import Table

from tests.helpers import resolve_resource_path

def test_plot_correlation_heatmap_non_numeric(monkeypatch: _pytest.monkeypatch) -> None:
monkeypatch.setattr(plt, "show", lambda: None)
table = Table.from_dict({"A": [1, 2, "A"], "B": [1, 2, 3]})
table.plot_correlation_heatmap()


def test_plot_correlation_heatmap(monkeypatch: _pytest.monkeypatch) -> None:
monkeypatch.setattr(plt, "show", lambda: None)
table = Table.from_dict({"A": [1, 2, 3.5], "B": [2, 4, 7]})
table.plot_correlation_heatmap()
def test_should_match_snapshot() -> None:
table = Table.from_dict({"A": [1, 2, 3.5], "B": [0.2, 4, 77]})
current = table.plot_correlation_heatmap()
legacy = Image.from_png_file(resolve_resource_path("./image/snapshot_heatmap.png"))
assert legacy._image.tobytes() == current._image.tobytes()
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import _pytest
import matplotlib.pyplot as plt
import pytest
from safeds.data.image.containers import Image
from safeds.data.tabular.containers import Table
from safeds.data.tabular.exceptions import UnknownColumnNameError

from tests.helpers import resolve_resource_path

def test_plot_lineplot(monkeypatch: _pytest.monkeypatch) -> None:
monkeypatch.setattr(plt, "show", lambda: None)

def test_should_match_snapshot() -> None:
table = Table.from_dict({"A": [1, 2, 3], "B": [2, 4, 7]})
table.plot_lineplot("A", "B")
current = table.plot_lineplot("A", "B")
snapshot = Image.from_png_file(resolve_resource_path("./image/snapshot_lineplot.png"))
assert snapshot._image.tobytes() == current._image.tobytes()


def test_plot_lineplot_wrong_column_name() -> None:
def test_should_raise_if_column_does_not_exist() -> None:
table = Table.from_dict({"A": [1, 2, 3], "B": [2, 4, 7]})
with pytest.raises(UnknownColumnNameError):
table.plot_lineplot("C", "A")
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import _pytest
import matplotlib.pyplot as plt
import pytest
from safeds.data.image.containers import Image
from safeds.data.tabular.containers import Table
from safeds.data.tabular.exceptions import UnknownColumnNameError

from tests.helpers import resolve_resource_path

def test_plot_scatterplot(monkeypatch: _pytest.monkeypatch) -> None:
monkeypatch.setattr(plt, "show", lambda: None)

def test_should_match_snapshot() -> None:
table = Table.from_dict({"A": [1, 2, 3], "B": [2, 4, 7]})
table.plot_scatterplot("A", "B")
current = table.plot_scatterplot("A", "B")
snapshot = Image.from_png_file(resolve_resource_path("./image/snapshot_scatterplot.png"))
assert snapshot._image.tobytes() == current._image.tobytes()


def test_plot_scatterplot_wrong_column_name() -> None:
def test_should_raise_if_column_does_not_exist() -> None:
table = Table.from_dict({"A": [1, 2, 3], "B": [2, 4, 7]})
with pytest.raises(UnknownColumnNameError):
table.plot_scatterplot("C", "A")
26 changes: 26 additions & 0 deletions tests/safeds/data/tabular/containers/_table/test_to_excel_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from pathlib import Path
from tempfile import NamedTemporaryFile

from safeds.data.tabular.containers import Table


def test_should_create_csv_file_from_table_by_str() -> None:
table = Table.from_dict({"col1": ["col1_1"], "col2": ["col2_1"]})
with NamedTemporaryFile(suffix=".xlsx") as tmp_table_file:
tmp_table_file.close()
with Path(tmp_table_file.name).open("w", encoding="utf-8") as tmp_file:
table.to_excel_file(tmp_file.name)
with Path(tmp_table_file.name).open("r", encoding="utf-8") as tmp_file:
table_r = Table.from_excel_file(tmp_file.name)
assert table == table_r


def test_should_create_csv_file_from_table_by_path() -> None:
table = Table.from_dict({"col1": ["col1_1"], "col2": ["col2_1"]})
with NamedTemporaryFile(suffix=".xlsx") as tmp_table_file:
tmp_table_file.close()
with Path(tmp_table_file.name).open("w", encoding="utf-8") as tmp_file:
table.to_excel_file(Path(tmp_file.name))
with Path(tmp_table_file.name).open("r", encoding="utf-8") as tmp_file:
table_r = Table.from_excel_file(Path(tmp_file.name))
assert table == table_r
Loading

0 comments on commit 9f9ba6f

Please sign in to comment.