Skip to content

Commit

Permalink
fine gray
Browse files Browse the repository at this point in the history
  • Loading branch information
juAlberge committed Jan 22, 2024
1 parent b66721b commit 7a698fe
Showing 1 changed file with 93 additions and 15 deletions.
108 changes: 93 additions & 15 deletions benchmark/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,17 @@
integrated_brier_score_incidence,
brier_score_incidence,
brier_score_incidence_oracle,
integrated_brier_score_incidence_oracle,
)
from hazardous.metrics._concordance import concordance_index_ipcw

from main import DATASET_GRID, SEER_PATH, SEED

sns.set_style("whitegrid")
sns.set_style(
style="white",
)
sns.set_context("paper")
sns.set_palette("colorblind")


def aggregate_result(path_session_dataset, estimator_names):
Expand Down Expand Up @@ -108,6 +113,9 @@ def __init__(self, path_session, estimator_names, dataset_name):
self.estimator_names = estimator_names
self.dataset_name = dataset_name
self.df = aggregate_result(self.path_session / dataset_name, estimator_names)
path_profile = Path(path_session) / f"{dataset_name}_plots/"
path_profile.mkdir(parents=True, exist_ok=True)
self.path_profile = path_profile

def plot_PSR(self, data_params):
df = self.load_cv_results(data_params)
Expand All @@ -124,6 +132,7 @@ def plot_PSR(self, data_params):
fig.suptitle(
f"Time-varying Brier score ({censoring_fraction:.1%} {kind} censoring)"
)
sns.despine(fig=fig)

for estimator_name in self.estimator_names:
estimator = _get_estimator(df, estimator_name)
Expand Down Expand Up @@ -168,6 +177,7 @@ def plot_PSR(self, data_params):

ax.set_title(f"event {event_id}")
axes[0].legend()
plt.savefig(self.path_profile / "PSR.pdf", format="pdf")

def plot_marginal_incidence(self, data_params):
df = self.load_cv_results(data_params)
Expand Down Expand Up @@ -225,6 +235,7 @@ def plot_marginal_incidence(self, data_params):

for ax in [axes[1], axes[2]]:
ax.legend().remove()
plt.savefig(self.path_profile / "marginal_incidence.pdf", format="pdf")

def plot_individual_incidence(self, data_params, sample_ids=2):
if isinstance(sample_ids, int):
Expand All @@ -247,9 +258,9 @@ def plot_individual_incidence(self, data_params, sample_ids=2):
for estimator_name in self.estimator_names:
estimator = _get_estimator(df, estimator_name)
y_pred = self.get_predictions(
X.loc[sample_ids], time_grid, estimator, estimator_name
X.iloc[sample_ids], time_grid, estimator, estimator_name
)
y_test = y.loc[sample_ids]
y_test = y.iloc[sample_ids]

for row_idx in range(len(sample_ids)):
y_sample = y_test.iloc[row_idx]
Expand All @@ -275,6 +286,7 @@ def plot_individual_incidence(self, data_params, sample_ids=2):
if col_idx == 0:
ax.legend()
plt.tight_layout()
plt.savefig(self.path_profile / "individual_incidence.pdf", format="pdf")

def print_table_metrics(self, data_params):
df = self.load_cv_results(data_params)
Expand Down Expand Up @@ -350,6 +362,58 @@ def print_table_metrics(self, data_params):
print("Ct-index")
display(results_ct_index)

def plot_performance_time(self, data_params, x_col="n_samples"):
# Plot performance vs time
df = self.load_cv_results(data_params, x_col)
fig, ax = plt.subplots(figsize=(8, 4))
fit_time = {
"mean_fit_time": [],
"std_fit_time": [],
"estimator_name": [],
"mean_ipsr": [],
}
for estimator_name in tqdm(self.estimator_names):
for x_col_param, df_group in df.groupby(x_col):
data_params[x_col] = x_col_param

X, y = self.load_dataset(data_params, return_X_y=True)
time_grid = make_time_grid(y["duration"])

estimator = _get_estimator(df_group, estimator_name)

y_train = estimator.y_train # hack for benchmarks
y_pred = self.get_predictions(X, time_grid, estimator, estimator_name)
print(y_pred.shape)
event_specific_ipsr = []
for idx in range(data_params["n_events"]):
event_specific_ipsr.append(
integrated_brier_score_incidence(
y_train=y_train,
y_test=y,
# TODO: remove when removing GBI.
y_pred=y_pred[idx + 1] if y_pred.ndim == 3 else y_pred,
times=time_grid,
event_of_interest=idx + 1,
)
)
fit_time["mean_fit_time"].append(df_group["mean_fit_time"].values[0])
fit_time["estimator_name"].append(estimator_name + f" {x_col_param} TS")
fit_time["mean_ipsr"].append(np.mean(event_specific_ipsr))
fit_time["std_fit_time"].append(df_group["std_fit_time"].values[0])
fit_time = pd.DataFrame(fit_time)
sns.scatterplot(
fit_time,
x="mean_fit_time",
y="mean_ipsr",
hue="estimator_name",
ax=ax,
)
ax.set(
xlabel="time(s) to fit",
ylabel="IPSR",
)
plt.savefig(self.path_profile / "performance_vs_time.pdf", format="pdf")


class WeibullDisplayer(BaseDisplayer):
def __init__(self, path_session, estimator_names):
Expand Down Expand Up @@ -389,13 +453,15 @@ def plot_memory_time(self, data_params, x_col="n_samples"):
title="Time to test",
ylabel=None,
)
plt.savefig(self.path_profile / "memory_time.pdf", format="pdf")

def plot_IPSR(self, data_params):
x_cols = ["n_samples", "censoring_relative_scale"]
fig, axes = plt.subplots(ncols=2)

fig, axes = plt.subplots(ncols=2, figsize=(8, 4), sharey=True)
sns.despine(fig=fig)
for x_col, ax in zip(x_cols, axes):
self._plot_IPSR(data_params, x_col, ax)
plt.savefig(self.path_profile / "IPSR.pdf", format="pdf")

def _plot_IPSR(self, data_params, x_col, ax):
df = self.load_cv_results(data_params, x_col)
Expand All @@ -405,7 +471,12 @@ def _plot_IPSR(self, data_params, x_col, ax):
for x_col_param, df_group in df.groupby(x_col):
data_params[x_col] = x_col_param

X, y = self.load_dataset(data_params, return_X_y=True)
bunch = self.load_dataset(data_params, return_X_y=False)
X, y = bunch.X, bunch.y
scale_censoring, shape_censoring = (
bunch.scale_censoring,
bunch.shape_censoring,
)
time_grid = make_time_grid(y["duration"])

estimator = _get_estimator(df_group, estimator_name)
Expand All @@ -415,13 +486,14 @@ def _plot_IPSR(self, data_params, x_col, ax):
event_specific_ipsr = []
for idx in range(data_params["n_events"]):
event_specific_ipsr.append(
integrated_brier_score_incidence(
integrated_brier_score_incidence_oracle(
y_train=y_train,
y_test=y,
# TODO: remove when removing GBI.
y_pred=y_pred[idx + 1] if y_pred.ndim == 3 else y_pred,
y_pred=y_pred[idx + 1],
times=time_grid,
event_of_interest=idx + 1,
shape_censoring=shape_censoring,
scale_censoring=scale_censoring,
)
)
x_col_params.append(x_col_param)
Expand All @@ -446,7 +518,9 @@ def load_cv_results(self, data_params, x_col=None):

def load_dataset(self, data_params, return_X_y=False, use_cache=True):
del use_cache
return make_synthetic_competing_weibull(**data_params, return_X_y=return_X_y)
return make_synthetic_competing_weibull(
**data_params, return_X_y=return_X_y, random_state=1345
)

def get_predictions(self, X, times, estimator, estimator_name):
"""TODO: implement cache if some estimators take long to predict."""
Expand Down Expand Up @@ -492,6 +566,7 @@ def plot_memory_time(self, data_params, x_col=None):
title="Time to test",
ylabel=None,
)
plt.savefig(self.path_profile / "memory_time.pdf", format="pdf")

def load_cv_results(self, data_params, x_col=None):
del data_params, x_col
Expand Down Expand Up @@ -535,8 +610,8 @@ def get_predictions(self, X, times, estimator, estimator_name):

# %%

path_session = "2024-01-17"
estimator_names = ["gbmi_10", "gbmi_20"]
path_session = "2024-01-20"
estimator_names = ["gbmi_competing_loss"]
displayer = SEERDisplayer(path_session, estimator_names)

data_params = {}
Expand All @@ -560,8 +635,8 @@ def get_predictions(self, X, times, estimator, estimator_name):

# %%

path_session = "2024-01-15"
estimator_names = ["gbmi_10", "gbmi_20"]
path_session = "2024-01-22"
estimator_names = ["fine_and_gray"]
displayer = WeibullDisplayer(path_session, estimator_names)

data_params = {
Expand All @@ -572,11 +647,14 @@ def get_predictions(self, X, times, estimator, estimator_name):
}
displayer.plot_memory_time(data_params)

# %%
displayer.plot_performance_time(data_params)
# %%

data_params["n_samples"] = 10_000
displayer.plot_IPSR(data_params)

# %%
data_params
# %%
data_params.update(
{
Expand Down

0 comments on commit 7a698fe

Please sign in to comment.