Skip to content

Commit

Permalink
Implement rolling Gini graph; bump version to 0.16.2
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewfallan committed May 23, 2024
1 parent 4d8abcb commit e710d1e
Show file tree
Hide file tree
Showing 8 changed files with 134 additions and 13 deletions.
1 change: 1 addition & 0 deletions src/seismicrna/core/mu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .frame import *
from .nan import *
from .scale import *
from .trends import *
from .unbias import *

########################################################################
Expand Down
2 changes: 1 addition & 1 deletion src/seismicrna/core/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

logger = getLogger(__name__)

__version__ = "0.16.1"
__version__ = "0.16.2"


def parse_version(version: str = __version__):
Expand Down
8 changes: 2 additions & 6 deletions src/seismicrna/graph/aucroll.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
StructOneTableRunner,
StructOneTableWriter)
from .roc import PROFILE_NAME, rename_columns
from .roll import RollingGraph
from .roll import RollingGraph, RollingRunner
from .trace import iter_rolling_auc_traces
from ..core.arg import opt_window, opt_winmin

Expand Down Expand Up @@ -70,11 +70,7 @@ def get_graph(self, rels_group: str, **kwargs):
return RollingAUCGraph(table=self.table, rel=rels_group, **kwargs)


class RollingAUCRunner(StructOneTableRunner, PosGraphRunner):

@classmethod
def var_params(cls):
return super().var_params() + [opt_window, opt_winmin]
class RollingAUCRunner(RollingRunner, StructOneTableRunner, PosGraphRunner):

@classmethod
def get_writer_type(cls):
Expand Down
10 changes: 5 additions & 5 deletions src/seismicrna/graph/corroll.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from plotly import graph_objects as go

from .base import PosGraphWriter, PosGraphRunner
from .roll import RollingGraph
from .twotable import TwoTableMergedGraph, TwoTableRunner, TwoTableWriter
from .roll import RollingGraph, RollingRunner
from .trace import iter_seq_line_traces
from ..core.arg import opt_metric, opt_window, opt_winmin
from .twotable import TwoTableMergedGraph, TwoTableRunner, TwoTableWriter
from ..core.arg import opt_metric
from ..core.mu import compare_windows, get_comp_name

logger = getLogger(__name__)
Expand Down Expand Up @@ -66,11 +66,11 @@ def get_graph_type(cls):
return RollingCorrelationGraph


class RollingCorrelationRunner(TwoTableRunner, PosGraphRunner):
class RollingCorrelationRunner(RollingRunner, TwoTableRunner, PosGraphRunner):

@classmethod
def var_params(cls):
return super().var_params() + [opt_metric, opt_window, opt_winmin]
return super().var_params() + [opt_metric]

@classmethod
def get_writer_type(cls):
Expand Down
104 changes: 104 additions & 0 deletions src/seismicrna/graph/giniroll.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import os
from functools import cached_property
from logging import getLogger

import pandas as pd
from click import command
from plotly import graph_objects as go

from .base import PosGraphWriter, PosGraphRunner
from .onetable import OneTableGraph, OneTableRunner, OneTableWriter
from .rel import OneRelGraph
from .roll import RollingGraph, RollingRunner
from .trace import iter_rolling_gini_traces
from ..core.header import format_clust_name
from ..core.mu import calc_gini
from ..core.seq import iter_windows

logger = getLogger(__name__)

COMMAND = __name__.split(os.path.extsep)[-1]


class RollingGiniGraph(OneTableGraph, OneRelGraph, RollingGraph):

@classmethod
def graph_kind(cls):
return COMMAND

@classmethod
def what(cls):
return "Rolling Gini coefficient"

@property
def y_title(self):
return "Gini"

@cached_property
def data(self):
data = self._fetch_data(self.table,
order=self.order,
clust=self.clust)
gini = pd.DataFrame(index=data.index, dtype=float)
for cluster, cluster_data in data.items():
cluster_gini = pd.Series(index=gini.index, dtype=float)
for center, (window,) in iter_windows(cluster_data,
size=self._size,
min_count=self._min_count):
cluster_gini.loc[center] = calc_gini(window)
if isinstance(cluster, tuple):
_, order, clust = cluster
label = format_clust_name(order, clust)
else:
label = cluster
gini[label] = cluster_gini
return gini

def get_traces(self):
for row, trace in enumerate(iter_rolling_gini_traces(self.data),
start=1):
yield (row, 1), trace

def _figure_layout(self, fig: go.Figure):
super()._figure_layout(fig)
fig.update_yaxes(gridcolor="#d0d0d0")


class RollingGiniWriter(OneTableWriter, PosGraphWriter):

def get_graph(self, rels_group: str, **kwargs):
return RollingGiniGraph(table=self.table, rel=rels_group, **kwargs)


class RollingGiniRunner(RollingRunner, OneTableRunner, PosGraphRunner):

@classmethod
def get_writer_type(cls):
return RollingGiniWriter


@command(COMMAND, params=RollingGiniRunner.params())
def cli(*args, **kwargs):
""" Rolling Gini coefficient. """
return RollingGiniRunner.run(*args, **kwargs)

########################################################################
# #
# © Copyright 2024, the Rouskin Lab. #
# #
# This file is part of SEISMIC-RNA. #
# #
# SEISMIC-RNA is free software; you can redistribute it and/or modify #
# it under the terms of the GNU General Public License as published by #
# the Free Software Foundation; either version 3 of the License, or #
# (at your option) any later version. #
# #
# SEISMIC-RNA is distributed in the hope that it will be useful, but #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANT- #
# ABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General #
# Public License for more details. #
# #
# You should have received a copy of the GNU General Public License #
# along with SEISMIC-RNA; if not, see <https://www.gnu.org/licenses>. #
# #
########################################################################
2 changes: 2 additions & 0 deletions src/seismicrna/graph/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from . import (aucroll,
corroll,
delprof,
giniroll,
histpos,
histread,
profile,
Expand All @@ -21,6 +22,7 @@ def cli():
for module in (aucroll,
corroll,
delprof,
giniroll,
histpos,
histread,
profile,
Expand Down
10 changes: 9 additions & 1 deletion src/seismicrna/graph/roll.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from abc import ABC
from functools import cached_property

from .base import GraphBase
from .base import GraphBase, GraphRunner
from ..core.arg import opt_window, opt_winmin
from ..core.seq import POS_NAME


Expand All @@ -28,6 +29,13 @@ def predicate(self):
"-".join(map(str, [self._size, self._min_count]))]
)


class RollingRunner(GraphRunner, ABC):

@classmethod
def var_params(cls):
return super().var_params() + [opt_window, opt_winmin]

########################################################################
# #
# © Copyright 2024, the Rouskin Lab. #
Expand Down
10 changes: 10 additions & 0 deletions src/seismicrna/graph/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,16 @@ def iter_rolling_auc_traces(aucs: pd.DataFrame, profile: str):
yield get_rolling_auc_trace(auc, profile, str(struct))


def get_rolling_gini_trace(gini: pd.Series, cluster: str):
return go.Scatter(x=gini.index.get_level_values(POS_NAME),
y=gini,
name=cluster)


def iter_rolling_gini_traces(ginis: pd.DataFrame):
for cluster, gini in ginis.items():
yield get_rolling_gini_trace(gini, str(cluster))

########################################################################
# #
# © Copyright 2024, the Rouskin Lab. #
Expand Down

0 comments on commit e710d1e

Please sign in to comment.