Skip to content

Commit

Permalink
Refactor dynamic factor model runner based off user input resolves #78
Browse files Browse the repository at this point in the history
  • Loading branch information
jvivian committed Jun 9, 2024
1 parent 2e6f6fa commit 401f2c4
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 76 deletions.
4 changes: 2 additions & 2 deletions covid19_drdfm/streamlit/Dashboard.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import yaml
import time
from pathlib import Path

import pandas as pd
import plotly.io as pio
import streamlit as st
import yaml

from covid19_drdfm.constants import FACTORS
from covid19_drdfm.covid19 import get_df, get_project_h5ad
Expand Down Expand Up @@ -32,7 +32,7 @@ def get_data():
var_df["Variables"] = var_df.index
ad.obs["Time"] = pd.to_datetime(ad.obs.index)

center_title("Dynamic Factor Model Runner")
center_title("Legacy Dynamic Factor Model Runner for Covid-19")

with st.expander("Variable correlations"):
st.write("Data is normalized between [0, 1] before calculating correlation")
Expand Down
228 changes: 154 additions & 74 deletions covid19_drdfm/streamlit/pages/0_Dynamic_Factor_Model.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import time
from pathlib import Path
from typing import Optional

import anndata as ann
import pandas as pd
import plotly.io as pio
import streamlit as st
import yaml

from covid19_drdfm.constants import FACTORS
from covid19_drdfm.dfm import ModelRunner
import anndata as ann

st.set_page_config(layout="wide")
pio.templates.default = "plotly_white"
Expand All @@ -18,82 +16,164 @@ def center_title(text):
return st.markdown(f"<h1 style='text-align: center; color: grey;'>{text}</h1>", unsafe_allow_html=True)


def load_data(file):
if "csv" in file.type:
return pd.read_csv(file, index_col=0)
elif "tsv" in file.type:
return pd.read_csv(file, index_col=0, sep="\t")
elif "xlsx" in file.type:
return pd.read_excel(file, index_col=0)
else:
return None


def create_anndata(df, factor_mappings, batch_col=None):
if batch_col:
adata = ann.AnnData(df.drop(columns=batch_col))
adata.obs[batch_col] = df[batch_col]
else:
adata = ann.AnnData(df)
adata.var["factor"] = [factor_mappings[x] for x in adata.var.index]
return adata


def file_uploader():
# File uploader
file = st.file_uploader("Upload a data file (CSV, TSV, XLSX)", type=["csv", "tsv", "xlsx"])
if file is None:
st.error("Please provide input file")
st.stop()
df = load_data(file)
with st.expander("Raw Input Data"):
st.dataframe(df)
if df is not None:
# Optional batch column
batch_col = st.selectbox("Select a batch column (optional):", ["None"] + list(df.columns))
if batch_col == "None":
batch_col = None
class DataHandler:
"""
Handles data loading and preprocessing for a Streamlit application.
"""

# Ask for non-batch variables and their factor mappings
non_batch_cols = [col for col in df.columns if col != batch_col]
factor_mappings = {}
for col in non_batch_cols:
factor = st.text_input(f"Enter factor for {col}:", key=col)
if factor:
# factor_cats = factor.split(",")
# factor_mappings[col] = pd.Categorical(df[col], categories=factor_cats, ordered=True)
factor_mappings[col] = factor
if len(factor_mappings) != len(non_batch_cols):
st.warning("Fill in a Factor label for all variables!")
def __init__(self):
self.df: Optional[pd.DataFrame] = None
self.ad: Optional[ann.AnnData] = None
self.batch_col: Optional[str] = None
self.non_batch_cols: Optional[list[str]] = None

def get_data(self) -> "DataHandler":
self.file_uploader().get_factor_mappings().apply_transforms().create_anndata()
return self

def file_uploader(self) -> "DataHandler":
"""
Uploads a file and reads it into a DataFrame. Supported file types are CSV, TSV, and XLSX.
Returns:
A pandas DataFrame loaded from the uploaded file.
Raises:
RuntimeError: If no file is uploaded.
"""
file = st.file_uploader("Upload a data file (CSV, TSV, XLSX)", type=["csv", "tsv", "xlsx"])
if file is None:
st.error("Please provide input file")
st.stop()
self.df = self.load_data(file)
with st.expander("Raw Input Data"):
st.dataframe(self.df)
if self.df is None:
st.error("DataFrame is empty! Check input data")
st.stop()
batch_col = st.sidebar.selectbox("Select a batch column (optional):", ["None", *list(self.df.columns)])
if batch_col == "None":
self.batch_col = None
self.non_batch_cols = [col for col in self.df.columns if col != batch_col]
return self

# Create anndata
ad = create_anndata(df, factor_mappings, batch_col)
@staticmethod
def load_data(file) -> pd.DataFrame:
"""
Loads a DataFrame from an uploaded file based on its MIME type.
# Transformations
options = st.multiselect(
"Select columns to apply transformations:", non_batch_cols, format_func=lambda x: f"Transform {x}"
)
transforms = {}
for opt in options:
transform = st.radio(f"Select transform type for {opt}:", ("difference", "logdiff"), key=f"trans_{opt}")
transforms[opt] = transform
ad.var[transform] = None
ad.var.loc[opt, transform] = True
Args:
file: UploadedFile object from Streamlit.
# Show anndata and transforms
st.write("Anndata object:", ad)
st.dataframe(ad.var)
return ad
Returns:
A DataFrame containing the data from the file.
Raises:
ValueError: If the file type is unsupported.
"""
file_type = file.type.split("/")[-1]
read_function = {
"csv": lambda f: pd.read_csv(f, index_col=0),
"tsv": lambda f: pd.read_csv(f, index_col=0, sep="\t"),
"xlsx": lambda f: pd.read_excel(f, index_col=0),
}.get(file_type, lambda _: None)

ad = file_uploader()
if read_function is None:
raise ValueError(f"Unsupported file type: {file_type}")

global_multiplier = st.slider("Global Multiplier", min_value=0, max_value=4, value=0)
outdir = st.text_input("Location of output!", value=None)
if not outdir:
st.stop()
return read_function(file)

def apply_transforms(self) -> "DataHandler":
options = st.multiselect(
"Select columns to apply transformations:", self.non_batch_cols, format_func=lambda x: f"Transform {x}"
)
transforms = {}
for i, opt in enumerate(options):
if i % 2 == 0:
cols = st.columns(2)
transform = cols[i % 2].radio(
f"Select transform type for {opt}:", ("difference", "logdiff"), key=f"trans_{opt}"
)
transforms[opt] = transform
self.ad.var[transform] = None
self.ad.var.loc[opt, transform] = True
return self

def get_factor_mappings(self) -> "DataHandler":
factor_input = st.text_input("Enter all factor options separated by space:")
factor_options = factor_input.split()
if not factor_options:
st.warning("Enter at least one factor to assign to variables")
st.stop()
factor_mappings = {}
for i, col in enumerate(self.non_batch_cols):
if i % 2 == 0:
cols = st.columns(2)
col_factor = cols[i % 2].radio(
f"Select factor for {col}:",
options=factor_options,
key=col,
format_func=lambda x: f"{x}",
horizontal=True,
)
if col_factor:
factor_mappings[col] = col_factor

if len(factor_mappings) != len(self.non_batch_cols):
st.warning("Select a factor for each variable!")
st.stop()
self.factor_mappings = factor_mappings
return self

def create_anndata(self) -> ann.AnnData:
"""
Creates an AnnData object from the loaded DataFrame with optional batch column handling.
Args:
factor_mappings: A dictionary mapping column names to their respective factors.
batch_col: Optional; the name of the column to use as the batch category.
Returns:
An AnnData object with additional metadata.
"""
if self.batch_col and self.batch_col in self.df.columns:
ad = ann.AnnData(self.df.drop(columns=self.batch_col))
ad.obs[self.batch_col] = self.df[self.batch_col]
else:
ad = ann.AnnData(self.df)

ad.var["factor"] = [self.factor_mappings[x] for x in ad.var.index]
self.ad = ad
return ad


def additional_params():
global_multiplier = st.sidebar.slider("Global Multiplier", min_value=0, max_value=4, value=0)
out_dir = st.sidebar.text_input("Output Directory", value=None)
if not out_dir:
st.warning("Specify output directory (in sidebar) to continue")
st.stop()
return global_multiplier, out_dir


def run_model(ad, out_dir, batch, global_multiplier) -> ModelRunner:
dfm = ModelRunner(ad, Path(out_dir), batch=batch)
dfm.run(global_multiplier=global_multiplier)
st.subheader("Results")
for result in dfm.results:
if batch is not None:
st.subheader(result.name)
st.write(result.result.summary())
st.divider()
st.write(result.model.summary())
return dfm


center_title("Dynamic Factor Model Runner")
data = DataHandler().get_data()
ad = data.ad
global_multiplier, out_dir = additional_params()
batch = None if ad.obs.empty else ad.obs.columns[0]
dfm = ModelRunner(ad, Path(outdir), batch=batch)
dfm.run(global_multiplier=global_multiplier)
st.write(dfm.results)
dfm = run_model(ad, out_dir, batch, global_multiplier)
st.balloons()
st.stop()

0 comments on commit 401f2c4

Please sign in to comment.