Skip to content

Commit

Permalink
Merge pull request #64 from pinellolab/codebase-nb-update
Browse files Browse the repository at this point in the history
Codebase Update via Noah's Refactored DDPM Notebook
  • Loading branch information
mateibejan1 authored Dec 23, 2022
2 parents 6a388ed + 3b53072 commit 7af7244
Show file tree
Hide file tree
Showing 14 changed files with 1,134 additions and 562 deletions.
93 changes: 56 additions & 37 deletions src/data/sequence_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,19 @@
import torch
import torchvision.transforms as T
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader


class SequenceDatasetBase(Dataset):
def __init__(self, data_path, sequence_length=200, sequence_encoding="polar", sequence_transform=None, cell_type_transform=None):
def __init__(
self,
data_path,
sequence_length: int = 200,
sequence_encoding: str = "polar",
sequence_transform=None,
cell_type_transform=None,
) -> None:
super().__init__()
self.data = pd.read_csv(data_path, sep="\t")
self.sequence_length = sequence_length
Expand All @@ -18,26 +26,26 @@ def __init__(self, data_path, sequence_length=200, sequence_encoding="polar", se
self.alphabet = ["A", "C", "T", "G"]
self.check_data_validity()

def __len__(self):
def __len__(self) -> int:
return len(self.data)

def __getitem__(self, index):
# Iterating through DNA sequences from dataset and one-hot encoding all nucleotides
current_seq = self.data["raw_sequence"][index]
if 'N' not in current_seq:
if "N" not in current_seq:
X_seq = self.encode_sequence(current_seq, encoding=self.sequence_encoding)

# Reading cell component at current index
X_cell_type = self.data["component"][index]

if self.sequence_transform is not None:
X_seq = self.sequence_transform(X_seq)
if self.cell_type_transform is not None:
X_cell_type = self.cell_type_transform(X_cell_type)

return X_seq, X_cell_type

def check_data_validity(self):
def check_data_validity(self) -> None:
"""
Checks if the data is valid.
"""
Expand All @@ -64,7 +72,7 @@ def encode_sequence(self, seq, encoding):
return seq

# Function for one hot encoding each line of the sequence dataset
def one_hot_encode(self, seq):
def one_hot_encode(self, seq) -> np.ndarray:
"""
One-hot encoding a sequence
"""
Expand All @@ -76,15 +84,17 @@ def one_hot_encode(self, seq):


class SequenceDatasetTrain(SequenceDatasetBase):
def __init__(self, data_path="", **kwargs):
def __init__(self, data_path="", **kwargs) -> None:
super().__init__(data_path=data_path, **kwargs)


class SequenceDatasetValidation(SequenceDatasetBase):
def __init__(self, data_path="", **kwargs):
def __init__(self, data_path="", **kwargs) -> None:
super().__init__(data_path=data_path, **kwargs)


class SequenceDatasetTest(SequenceDatasetBase):
def __init__(self, data_path="", **kwargs):
def __init__(self, data_path="", **kwargs) -> None:
super().__init__(data_path=data_path, **kwargs)


Expand All @@ -94,16 +104,20 @@ def __init__(
train_path=None,
val_path=None,
test_path=None,
sequence_length=200,
sequence_encoding="polar",
sequence_length: int = 200,
sequence_encoding: str = "polar",
sequence_transform=None,
cell_type_transform=None,
batch_size=None,
num_workers=1
):
num_workers: int = 1,
) -> None:
super().__init__()
self.datasets = dict()
self.train_dataloader, self.val_dataloader, self.test_dataloader = None, None, None
self.train_dataloader, self.val_dataloader, self.test_dataloader = (
None,
None,
None,
)

if train_path:
self.datasets["train"] = train_path
Expand Down Expand Up @@ -131,43 +145,48 @@ def setup(self):
sequence_length=self.sequence_length,
sequence_encoding=self.sequence_encoding,
sequence_transform=self.sequence_transform,
cell_type_transform=self.cell_type_transform
cell_type_transform=self.cell_type_transform,
)
if "validation" in self.datasets:
self.val_data = SequenceDatasetValidation(
data_path=self.datasets["validation"],
sequence_length=self.sequence_length,
sequence_encoding=self.sequence_encoding,
sequence_transform=self.sequence_transform,
cell_type_transform=self.cell_type_transform
cell_type_transform=self.cell_type_transform,
)
if "test" in self.datasets:
self.test_data = SequenceDatasetTest(
data_path=self.datasets["test"],
sequence_length=self.sequence_length,
sequence_encoding=self.sequence_encoding,
sequence_transform=self.sequence_transform,
cell_type_transform=self.cell_type_transform
cell_type_transform=self.cell_type_transform,
)

def _train_dataloader(self):
return DataLoader(self.train_data,
self.batch_size,
shuffle=True,
num_workers=self.num_workers,
pin_memory=True)
return DataLoader(
self.train_data,
self.batch_size,
shuffle=True,
num_workers=self.num_workers,
pin_memory=True,
)

def _val_dataloader(self):
return DataLoader(self.val_data,
self.batch_size,
shuffle=True,
num_workers=self.num_workers,
pin_memory=True)
return DataLoader(
self.val_data,
self.batch_size,
shuffle=True,
num_workers=self.num_workers,
pin_memory=True,
)

def _test_dataloader(self):
return DataLoader(self.test_data,
self.batch_size,
shuffle=True,
num_workers=self.num_workers,
pin_memory=True)

return DataLoader(
self.test_data,
self.batch_size,
shuffle=True,
num_workers=self.num_workers,
pin_memory=True,
)
137 changes: 0 additions & 137 deletions src/models/diffusion/ddim.py

This file was deleted.

Loading

0 comments on commit 7af7244

Please sign in to comment.