Skip to content

Commit

Permalink
Compare data loaders with original implementation. (#111)
Browse files Browse the repository at this point in the history
* comment test_load_colon_data and test_load_colon_label functions

* Add ori_data_loader

* fix linting for ori_data_loader

* Add tests to compare loaded data with original data
  • Loading branch information
PhilipMay authored Dec 13, 2023
1 parent f72d161 commit cf8ab38
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 2 deletions.
127 changes: 127 additions & 0 deletions tests/ori_data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright (c) 2021 Sigrun May, Helmholtz-Zentrum für Infektionsforschung GmbH (HZI)
# Copyright (c) 2021 Sigrun May, Ostfalia Hochschule für angewandte Wissenschaften
# Copyright (c) 2020 Philip May
# This software is distributed under the terms of the MIT license
# which is available at https://opensource.org/licenses/MIT

# this is the original implementation from
# https://github.com/sigrun-may/cv-pruner/blob/ac35eba88a824e6bb6a6435cda67224a4db69e65/examples/data_loader.py

"""Data loader module."""

from typing import Tuple

import numpy as np
import pandas as pd
import requests
from bs4 import BeautifulSoup


def load_colon_data() -> Tuple[pd.Series, pd.DataFrame]:
"""Load colon data.
The data is loaded and parsed from the internet.
Also see <http://genomics-pubs.princeton.edu/oncology/affydata/index.html>
Returns:
Tuple containing labels and data.
"""
html_data = "http://genomics-pubs.princeton.edu/oncology/affydata/I2000.html"

page = requests.get(html_data, timeout=10)

soup = BeautifulSoup(page.content, "html.parser")
colon_text_data = soup.get_text()

colon_text_data_lines = colon_text_data.splitlines()
colon_text_data_lines = [[float(s) for s in line.split()] for line in colon_text_data_lines if len(line) > 20]
assert len(colon_text_data_lines) == 2000
assert len(colon_text_data_lines[0]) == 62

data = np.array(colon_text_data_lines).T

html_label = "http://genomics-pubs.princeton.edu/oncology/affydata/tissues.html"
page = requests.get(html_label, timeout=10)
soup = BeautifulSoup(page.content, "html.parser")
colon_text_label = soup.get_text()
colon_text_label = colon_text_label.splitlines()

label = []

for line in colon_text_label:
try:
i = int(line)
label.append(0 if i > 0 else 1)
except: # noqa: S110, E722
pass

assert len(label) == 62

data_df = pd.DataFrame(data)

# generate feature names
column_names = []
for column_name in data_df.columns:
column_names.append("gene_" + str(column_name))

data_df.columns = column_names

return pd.Series(label), data_df


# TODO append random features and shuffle


def load_prostate_data() -> Tuple[pd.Series, pd.DataFrame]:
"""Load prostate data.
The data is loaded and parsed from <https://web.stanford.edu/~hastie/CASI_files/DATA/prostate.html>
Returns:
Tuple containing labels and data.
"""
df = pd.read_csv("https://web.stanford.edu/~hastie/CASI_files/DATA/prostmat.csv")
data = df.T

# labels
labels = []
for label in df.columns: # pylint:disable=no-member
if "control" in label:
labels.append(0)
elif "cancer" in label:
labels.append(1)
else:
assert False, "This must not happen!"

assert len(labels) == 102
assert data.shape == (102, 6033)

return pd.Series(labels), data


def load_leukemia_data() -> Tuple[pd.Series, pd.DataFrame]:
"""Load leukemia data.
The data is loaded and parsed from the internet.
Also see <https://web.stanford.edu/~hastie/CASI_files/DATA/leukemia.html>
Returns:
Tuple containing labels and data.
"""
df = pd.read_csv("https://web.stanford.edu/~hastie/CASI_files/DATA/leukemia_big.csv")
data = df.T

# labels
labels = []
for label in df.columns: # pylint:disable=no-member
if "ALL" in label:
labels.append(0)
elif "AML" in label:
labels.append(1)
else:
assert False, "This must not happen!"

assert len(labels) == 72
assert data.shape == (72, 7128)

return pd.Series(labels), data
34 changes: 32 additions & 2 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,22 @@
# which is available at https://opensource.org/licenses/MIT

import pandas as pd
from numpy.testing import assert_almost_equal

from mltb2.data import _load_colon_data, _load_colon_label, load_colon, load_leukemia_big, load_prostate

from .ori_data_loader import load_colon_data, load_leukemia_data, load_prostate_data


def test_load_colon_data():
result = _load_colon_data()
result = _load_colon_data() # only load data not labels
assert result is not None
assert isinstance(result, pd.DataFrame)
assert result.shape == (62, 2000)


def test_load_colon_label():
result = _load_colon_label()
result = _load_colon_label() # only load labels not data
assert result is not None
assert isinstance(result, pd.Series)
assert len(result) == 62
Expand All @@ -32,6 +35,15 @@ def test_load_colon():
assert result[1].shape == (62, 2000)


def test_load_colon_compare_original():
result = load_colon()
ori_result = load_colon_data()
assert result[0].shape == ori_result[0].shape
assert result[1].shape == ori_result[1].shape
assert_almost_equal(result[0].to_numpy(), ori_result[0].to_numpy())
assert_almost_equal(result[1].to_numpy(), ori_result[1].to_numpy())


def test_load_prostate():
result = load_prostate()
assert result is not None
Expand All @@ -43,6 +55,15 @@ def test_load_prostate():
assert result[1].shape == (102, 6033)


def test_load_prostate_compare_original():
result = load_prostate()
ori_result = load_prostate_data()
assert result[0].shape == ori_result[0].shape
assert result[1].shape == ori_result[1].shape
assert_almost_equal(result[0].to_numpy(), ori_result[0].to_numpy())
assert_almost_equal(result[1].to_numpy(), ori_result[1].to_numpy())


def test_load_leukemia_big():
result = load_leukemia_big()
assert result is not None
Expand All @@ -52,3 +73,12 @@ def test_load_leukemia_big():
assert isinstance(result[1], pd.DataFrame)
assert result[0].shape == (72,)
assert result[1].shape == (72, 7128)


def test_load_leukemia_big_compare_original():
result = load_leukemia_big()
ori_result = load_leukemia_data()
assert result[0].shape == ori_result[0].shape
assert result[1].shape == ori_result[1].shape
assert_almost_equal(result[0].to_numpy(), ori_result[0].to_numpy())
assert_almost_equal(result[1].to_numpy(), ori_result[1].to_numpy())

0 comments on commit cf8ab38

Please sign in to comment.