diff --git a/covid19_drdfm/streamlit/Dashboard.py b/covid19_drdfm/streamlit/Dashboard.py
index fa72324..f1af655 100644
--- a/covid19_drdfm/streamlit/Dashboard.py
+++ b/covid19_drdfm/streamlit/Dashboard.py
@@ -1,10 +1,10 @@
-import yaml
import time
from pathlib import Path
import pandas as pd
import plotly.io as pio
import streamlit as st
+import yaml
from covid19_drdfm.constants import FACTORS
from covid19_drdfm.covid19 import get_df, get_project_h5ad
@@ -32,7 +32,7 @@ def get_data():
var_df["Variables"] = var_df.index
ad.obs["Time"] = pd.to_datetime(ad.obs.index)
-center_title("Dynamic Factor Model Runner")
+center_title("Legacy Dynamic Factor Model Runner for Covid-19")
with st.expander("Variable correlations"):
st.write("Data is normalized between [0, 1] before calculating correlation")
diff --git a/covid19_drdfm/streamlit/pages/0_Dynamic_Factor_Model.py b/covid19_drdfm/streamlit/pages/0_Dynamic_Factor_Model.py
index c457335..e79c412 100644
--- a/covid19_drdfm/streamlit/pages/0_Dynamic_Factor_Model.py
+++ b/covid19_drdfm/streamlit/pages/0_Dynamic_Factor_Model.py
@@ -1,14 +1,12 @@
-import time
from pathlib import Path
+from typing import Optional
+import anndata as ann
import pandas as pd
import plotly.io as pio
import streamlit as st
-import yaml
-from covid19_drdfm.constants import FACTORS
from covid19_drdfm.dfm import ModelRunner
-import anndata as ann
st.set_page_config(layout="wide")
pio.templates.default = "plotly_white"
@@ -18,82 +16,164 @@ def center_title(text):
return st.markdown(f"
{text}
", unsafe_allow_html=True)
-def load_data(file):
- if "csv" in file.type:
- return pd.read_csv(file, index_col=0)
- elif "tsv" in file.type:
- return pd.read_csv(file, index_col=0, sep="\t")
- elif "xlsx" in file.type:
- return pd.read_excel(file, index_col=0)
- else:
- return None
-
-
-def create_anndata(df, factor_mappings, batch_col=None):
- if batch_col:
- adata = ann.AnnData(df.drop(columns=batch_col))
- adata.obs[batch_col] = df[batch_col]
- else:
- adata = ann.AnnData(df)
- adata.var["factor"] = [factor_mappings[x] for x in adata.var.index]
- return adata
-
-
-def file_uploader():
- # File uploader
- file = st.file_uploader("Upload a data file (CSV, TSV, XLSX)", type=["csv", "tsv", "xlsx"])
- if file is None:
- st.error("Please provide input file")
- st.stop()
- df = load_data(file)
- with st.expander("Raw Input Data"):
- st.dataframe(df)
- if df is not None:
- # Optional batch column
- batch_col = st.selectbox("Select a batch column (optional):", ["None"] + list(df.columns))
- if batch_col == "None":
- batch_col = None
+class DataHandler:
+ """
+ Handles data loading and preprocessing for a Streamlit application.
+ """
- # Ask for non-batch variables and their factor mappings
- non_batch_cols = [col for col in df.columns if col != batch_col]
- factor_mappings = {}
- for col in non_batch_cols:
- factor = st.text_input(f"Enter factor for {col}:", key=col)
- if factor:
- # factor_cats = factor.split(",")
- # factor_mappings[col] = pd.Categorical(df[col], categories=factor_cats, ordered=True)
- factor_mappings[col] = factor
- if len(factor_mappings) != len(non_batch_cols):
- st.warning("Fill in a Factor label for all variables!")
+ def __init__(self):
+ self.df: Optional[pd.DataFrame] = None
+ self.ad: Optional[ann.AnnData] = None
+ self.batch_col: Optional[str] = None
+ self.non_batch_cols: Optional[list[str]] = None
+
+ def get_data(self) -> "DataHandler":
+ self.file_uploader().get_factor_mappings().apply_transforms().create_anndata()
+ return self
+
+ def file_uploader(self) -> "DataHandler":
+ """
+ Uploads a file and reads it into a DataFrame. Supported file types are CSV, TSV, and XLSX.
+
+ Returns:
+ A pandas DataFrame loaded from the uploaded file.
+
+ Raises:
+ RuntimeError: If no file is uploaded.
+ """
+ file = st.file_uploader("Upload a data file (CSV, TSV, XLSX)", type=["csv", "tsv", "xlsx"])
+ if file is None:
+ st.error("Please provide input file")
st.stop()
+ self.df = self.load_data(file)
+ with st.expander("Raw Input Data"):
+ st.dataframe(self.df)
+ if self.df is None:
+ st.error("DataFrame is empty! Check input data")
+ st.stop()
+ batch_col = st.sidebar.selectbox("Select a batch column (optional):", ["None", *list(self.df.columns)])
+ if batch_col == "None":
+ self.batch_col = None
+ self.non_batch_cols = [col for col in self.df.columns if col != batch_col]
+ return self
- # Create anndata
- ad = create_anndata(df, factor_mappings, batch_col)
+ @staticmethod
+ def load_data(file) -> pd.DataFrame:
+ """
+ Loads a DataFrame from an uploaded file based on its MIME type.
- # Transformations
- options = st.multiselect(
- "Select columns to apply transformations:", non_batch_cols, format_func=lambda x: f"Transform {x}"
- )
- transforms = {}
- for opt in options:
- transform = st.radio(f"Select transform type for {opt}:", ("difference", "logdiff"), key=f"trans_{opt}")
- transforms[opt] = transform
- ad.var[transform] = None
- ad.var.loc[opt, transform] = True
+ Args:
+ file: UploadedFile object from Streamlit.
- # Show anndata and transforms
- st.write("Anndata object:", ad)
- st.dataframe(ad.var)
- return ad
+ Returns:
+ A DataFrame containing the data from the file.
+ Raises:
+ ValueError: If the file type is unsupported.
+ """
+ file_type = file.type.split("/")[-1]
+ read_function = {
+ "csv": lambda f: pd.read_csv(f, index_col=0),
+ "tsv": lambda f: pd.read_csv(f, index_col=0, sep="\t"),
+ "xlsx": lambda f: pd.read_excel(f, index_col=0),
+ }.get(file_type, lambda _: None)
-ad = file_uploader()
+ if read_function is None:
+ raise ValueError(f"Unsupported file type: {file_type}")
-global_multiplier = st.slider("Global Multiplier", min_value=0, max_value=4, value=0)
-outdir = st.text_input("Location of output!", value=None)
-if not outdir:
- st.stop()
+ return read_function(file)
+
+ def apply_transforms(self) -> "DataHandler":
+ options = st.multiselect(
+ "Select columns to apply transformations:", self.non_batch_cols, format_func=lambda x: f"Transform {x}"
+ )
+ transforms = {}
+ for i, opt in enumerate(options):
+ if i % 2 == 0:
+ cols = st.columns(2)
+ transform = cols[i % 2].radio(
+ f"Select transform type for {opt}:", ("difference", "logdiff"), key=f"trans_{opt}"
+ )
+ transforms[opt] = transform
+ self.ad.var[transform] = None
+ self.ad.var.loc[opt, transform] = True
+ return self
+
+ def get_factor_mappings(self) -> "DataHandler":
+ factor_input = st.text_input("Enter all factor options separated by space:")
+ factor_options = factor_input.split()
+ if not factor_options:
+ st.warning("Enter at least one factor to assign to variables")
+ st.stop()
+ factor_mappings = {}
+ for i, col in enumerate(self.non_batch_cols):
+ if i % 2 == 0:
+ cols = st.columns(2)
+ col_factor = cols[i % 2].radio(
+ f"Select factor for {col}:",
+ options=factor_options,
+ key=col,
+ format_func=lambda x: f"{x}",
+ horizontal=True,
+ )
+ if col_factor:
+ factor_mappings[col] = col_factor
+
+ if len(factor_mappings) != len(self.non_batch_cols):
+ st.warning("Select a factor for each variable!")
+ st.stop()
+ self.factor_mappings = factor_mappings
+ return self
+
+ def create_anndata(self) -> ann.AnnData:
+ """
+ Creates an AnnData object from the loaded DataFrame with optional batch column handling.
+
+ Args:
+ factor_mappings: A dictionary mapping column names to their respective factors.
+ batch_col: Optional; the name of the column to use as the batch category.
+
+ Returns:
+ An AnnData object with additional metadata.
+ """
+ if self.batch_col and self.batch_col in self.df.columns:
+ ad = ann.AnnData(self.df.drop(columns=self.batch_col))
+ ad.obs[self.batch_col] = self.df[self.batch_col]
+ else:
+ ad = ann.AnnData(self.df)
+
+ ad.var["factor"] = [self.factor_mappings[x] for x in ad.var.index]
+ self.ad = ad
+ return ad
+
+
+def additional_params():
+ global_multiplier = st.sidebar.slider("Global Multiplier", min_value=0, max_value=4, value=0)
+ out_dir = st.sidebar.text_input("Output Directory", value=None)
+ if not out_dir:
+ st.warning("Specify output directory (in sidebar) to continue")
+ st.stop()
+ return global_multiplier, out_dir
+
+
+def run_model(ad, out_dir, batch, global_multiplier) -> ModelRunner:
+ dfm = ModelRunner(ad, Path(out_dir), batch=batch)
+ dfm.run(global_multiplier=global_multiplier)
+ st.subheader("Results")
+ for result in dfm.results:
+ if batch is not None:
+ st.subheader(result.name)
+ st.write(result.result.summary())
+ st.divider()
+ st.write(result.model.summary())
+ return dfm
+
+
+center_title("Dynamic Factor Model Runner")
+data = DataHandler().get_data()
+ad = data.ad
+global_multiplier, out_dir = additional_params()
batch = None if ad.obs.empty else ad.obs.columns[0]
-dfm = ModelRunner(ad, Path(outdir), batch=batch)
-dfm.run(global_multiplier=global_multiplier)
-st.write(dfm.results)
+dfm = run_model(ad, out_dir, batch, global_multiplier)
+st.balloons()
+st.stop()