Skip to content

Commit

Permalink
added some simple changes regarding prediciton
Browse files Browse the repository at this point in the history
  • Loading branch information
EliHei2 committed Sep 26, 2024
1 parent 441434a commit aa200a3
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 66 deletions.
80 changes: 14 additions & 66 deletions scripts/predict_model_sample.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from segger.data.io import XeniumSample
from segger.training.train import LitSegger
from segger.training.segger_data_module import SeggerDataModule
from segger.prediction.predict_gpu import predict, load_model
from segger.prediction.predict import segment, load_model
from lightning.pytorch.loggers import CSVLogger
from pytorch_lightning import Trainer
from pathlib import Path
Expand All @@ -19,8 +19,8 @@

segger_data_dir = Path('./data_tidy/pyg_datasets/bc_embedding_0919')
models_dir = Path('./models/bc_embedding_0919')
benchmarks_path = Path('/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_tidy/benchmarks/xe_rep1_bc')

benchmarks_dir = Path('/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_tidy/benchmarks/xe_rep1_bc')
transcripts_file = 'data_raw/xenium/Xenium_FFPE_Human_Breast_Cancer_Rep1/transcripts.parquet'
# Initialize the Lightning data module
dm = SeggerDataModule(
data_dir=segger_data_dir,
Expand All @@ -40,67 +40,15 @@

receptive_field = {'k_bd': 4, 'dist_bd': 10,'k_tx': 5, 'dist_tx': 3}

# Perform segmentation (predictions)
segmentation_train = predict(
model,
dm.train_dataloader(),
score_cut=0.5,
receptive_field=receptive_field,
use_cc=True,
# device='cuda',
# num_workers=4
)

segmentation_val = predict(
segment(
model,
dm.val_dataloader(),
score_cut=0.5,
receptive_field=receptive_field,
use_cc=True,
# use_cc=False,
# device='cpu'
)

segmentation_test = predict(
model,
dm.test_dataloader(),
score_cut=0.5,
receptive_field=receptive_field,
use_cc=True,
# use_cc=False,
# device='cpu'
)



seg_combined = pd.concat([segmentation_train, segmentation_val, segmentation_test])
# Group by transcript_id and keep the row with the highest score for each transcript
seg_combined = pd.concat([segmentation_train, segmentation_val, segmentation_test]).reset_index()

# Group by transcript_id and keep the row with the highest score for each transcript
seg_final = seg_combined.loc[seg_combined.groupby('transcript_id')['score'].idxmax()]

# Drop rows where segger_cell_id is NaN
seg_final = seg_final.dropna(subset=['segger_cell_id'])

# Reset the index if needed
seg_final.reset_index(drop=True, inplace=True)

transcripts_df = dd.read_parquet('data_raw/xenium/Xenium_FFPE_Human_Breast_Cancer_Rep1/transcripts.parquet')

# # Assuming seg_final is already computed with pandas
# # Convert seg_final to a Dask DataFrame to enable efficient merging with Dask
seg_final_dd = dd.from_pandas(seg_final, npartitions=transcripts_df.npartitions)

# # Step 1: Merge segmentation with the transcripts on transcript_id
# # Use 'inner' join to keep only matching transcript_ids
transcripts_df_filtered = transcripts_df.merge(seg_final_dd, on='transcript_id', how='inner')

# Compute the result if needed
transcripts_df_filtered = transcripts_df_filtered.compute()


from segger.data.utils import create_anndata
segger_adata = create_anndata(transcripts_df_filtered, cell_id_col='segger_cell_id')
segger_adata.write(benchmarks_path / 'adata_segger_embedding_full.h5ad')

dm,
save_dir=benchmarks_dir,
seg_tag='test_segger_segment',
transcript_file=transcripts_file,
file_format='anndata',
receptive_field = receptive_field,
min_transcripts=10,
max_transcripts=1000
cell_id_col='segger_cell_id'
)
21 changes: 21 additions & 0 deletions src/segger/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def try_import(module_name):
print(f"Warning: cupy and/or cuvs are not installed. Please install them to use this functionality.")

import torch.utils.dlpack as dlpack
from datetime import timedelta



Expand Down Expand Up @@ -542,3 +543,23 @@ def coo_to_dense_adj(
nbr_idx[i, :len(nbrs)] = nbrs

return nbr_idx





def format_time(elapsed: float) -> str:
"""
Format elapsed time to h:m:s.
Parameters:
----------
elapsed : float
Elapsed time in seconds.
Returns:
-------
str
Formatted time in h:m:s.
"""
return str(timedelta(seconds=int(elapsed)))

0 comments on commit aa200a3

Please sign in to comment.