Skip to content

Commit

Permalink
add Prostate data loader
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilipMay committed Dec 8, 2023
1 parent 5ff61f2 commit 7a9ff83
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 3 deletions.
44 changes: 44 additions & 0 deletions mltb2/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import os
from hashlib import sha256
from io import StringIO
from typing import Tuple

import joblib
Expand Down Expand Up @@ -113,3 +114,46 @@ def load_colon() -> Tuple[pd.Series, pd.DataFrame]:
result = joblib.load(full_path)

return result


def load_prostate() -> Tuple[pd.Series, pd.DataFrame]:
"""Load prostate data.
The data is loaded and parsed from `prostate data
<https://web.stanford.edu/~hastie/CASI_files/DATA/prostate.html>`_.
Returns:
Tuple containing labels and data.
"""
filename = "prostate.pkl.gz"
mltb2_data_home = get_and_create_mltb2_data_dir()
full_path = os.path.join(mltb2_data_home, filename)
if not os.path.exists(full_path):
# download data file
url = "https://web.stanford.edu/~hastie/CASI_files/DATA/prostmat.csv"
page = requests.get(url, timeout=10)
page_str = page.text

# check checksum of data file
page_hash = sha256(page_str.encode("utf-8")).hexdigest()
assert page_hash == "f1ccfd3c9a837c002ec5d6489ab139c231739c3611189be14d15ca5541b92036", page_hash

data_df = pd.read_csv(StringIO(page_str))
data_df = data_df.T

labels = []
for label in data_df.index:
if "control" in label:
labels.append(0)
elif "cancer" in label:
labels.append(1)
else:
assert False, "This must not happen!"

data_df = data_df.reset_index(drop=True) # reset the index to default integer index
label_series = pd.Series(labels)
result = (label_series, data_df)
joblib.dump(result, full_path, compress=("gzip", 3))
else:
result = joblib.load(full_path)
return result
2 changes: 1 addition & 1 deletion mltb2/fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def get_model_path_and_download() -> str:
url=model_url,
sha256_checksum=sha256_checksum,
)
assert fetch_remote_file_path == model_full_path # noqa: S101
assert fetch_remote_file_path == model_full_path

return model_full_path

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ ignore = [
"S106", # Possible hardcoded password assigned to argument: "{}"
"COM812", # Trailing comma missing
"S101", # Use of `assert` detected
"PLR2004", # Magic value used in comparison,
"PLR2004", # Magic value used in comparison
"B011", # Do not `assert False`
]

[tool.ruff.per-file-ignores]
Expand Down
13 changes: 12 additions & 1 deletion tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pandas as pd

from mltb2.data import _load_colon_data, _load_colon_label, load_colon
from mltb2.data import _load_colon_data, _load_colon_label, load_colon, load_prostate


def test_load_colon_data():
Expand All @@ -30,3 +30,14 @@ def test_load_colon():
assert isinstance(result[1], pd.DataFrame)
assert result[0].shape == (62,)
assert result[1].shape == (62, 2000)


def test_load_prostate_data():
result = load_prostate()
assert result is not None
assert isinstance(result, tuple)
assert len(result) == 2
assert isinstance(result[0], pd.Series)
assert isinstance(result[1], pd.DataFrame)
assert result[0].shape == (102,)
assert result[1].shape == (102, 6033)

0 comments on commit 7a9ff83

Please sign in to comment.