diff --git a/rockfish/model/datasets.py b/rockfish/model/datasets.py index 6c43977..9015aa5 100644 --- a/rockfish/model/datasets.py +++ b/rockfish/model/datasets.py @@ -1,12 +1,12 @@ -from dataclasses import dataclass -from io import BufferedReader +import numpy as np import torch from torch.nn.utils.rnn import pad_sequence from torch.utils.data import Dataset, DataLoader -import math import pytorch_lightning as pl +from dataclasses import dataclass +from io import BufferedReader import struct import sys @@ -74,7 +74,7 @@ class Example: read_id: str ctg: int pos: int - signal: List[float] + signal: np.ndarray q_indices: List[int] lengths: List[int] bases: str diff --git a/rockfish/model/iterative_inference.py b/rockfish/model/iterative_inference.py index ae731c9..6feedb4 100644 --- a/rockfish/model/iterative_inference.py +++ b/rockfish/model/iterative_inference.py @@ -1,16 +1,21 @@ +from numpy import dtype import torch from torch.nn import DataParallel from torch.nn.utils.rnn import pad_sequence from torch.utils.data import IterableDataset, DataLoader from tqdm import tqdm +import numpy as np import math import struct import argparse +import collections from datasets import read_offsets, parse_ctgs, Example, MappingEncodings from model import Rockfish +from typing import * + ENCODING = {b: i for i, b in enumerate('ACGT')} @@ -46,7 +51,7 @@ def __init__(self, self.end = len(self.offsets) def __iter__(self): - bins = [list() for _ in range(9)] + bins = [list() for _ in range(10)] stored = 0 self.fd.seek(self.offsets[self.start]) @@ -75,8 +80,12 @@ def __iter__(self): stored -= len(bins[bin]) bins[bin].clear() elif stored >= 4 * self.batch_size: - for bin in bins: - for example in bin: + batch_processed = 0 + + for bin in reversed(bins): + while len(bin) > 0: + example = bin.pop() + signal = torch.tensor(example.signal, dtype=torch.half).unfold( -1, self.block_size, @@ -92,9 +101,34 @@ def __iter__(self): example. ctg], example.pos, signal, bases, r_pos_enc, q_indices - bin.clear() - stored = 0 - + batch_processed += 1 + stored -= 1 + + if batch_processed == self.batch_size: + break + + if batch_processed == self.batch_size: + break + + for bin in bins: + for example in bin: + signal = torch.tensor(example.signal, + dtype=torch.half).unfold( + -1, self.block_size, + self.block_size) + bases = torch.tensor( + [ENCODING[b] for b in example.bases]) + q_indices = torch.tensor(example.q_indices) + lengths = torch.tensor(example.lengths) + + r_pos_enc = self.mapping_encodings(lengths) + + yield example.read_id, self.ctgs[ + example. + ctg], example.pos, signal, bases, r_pos_enc, q_indices + + bin.clear() + def read_iter_example(self): read_id, ctg, pos, n_points, q_indices_len = struct.unpack( '=36sHIHH', self.fd.read(46)) @@ -105,7 +139,7 @@ def read_iter_example(self): self.fd.read(n_bytes)) event_len_start = n_points + q_indices_len - return Example(read_id.decode(), ctg, pos, data[:n_points], + return Example(read_id.decode(), ctg, pos, np.array(data[:n_points], dtype=np.half), data[n_points:event_len_start], data[event_len_start:-1], data[-1].decode())