Skip to content

Commit

Permalink
Merge branch 'improvements' of github.com:dominikstanojevic/no-tombo-…
Browse files Browse the repository at this point in the history
…mod into improvements
  • Loading branch information
dominikstanojevic committed Feb 20, 2022
2 parents dc911c5 + e90ec0c commit e78ed98
Showing 1 changed file with 48 additions and 12 deletions.
60 changes: 48 additions & 12 deletions rockfish/model/iterative_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@ def parse_gpus(string):

class RFDataset(IterableDataset):
def __init__(self,
path=str,
path: str,
batch_size: int,
window: int = 15,
block_size: int = 5) -> None:
super(IterableDataset, self).__init__()

self.batch_size = batch_size
self.seq_len = (2 * window) + 1
self.block_size = block_size

Expand All @@ -44,20 +46,54 @@ def __init__(self,
self.end = len(self.offsets)

def __iter__(self):
bins = [list() for _ in range(9)]
stored = 0

self.fd.seek(self.offsets[self.start])
for _ in range(self.start, self.end):
example = self.read_iter_example()
bin = len(example.q_indices) // 10 - 3
bins[bin].append(example)
stored += 1

if len(bins[bin]) >= self.batch_size: # bin is full, emit examples
for example in bins[bin]:
signal = torch.tensor(example.signal,
dtype=torch.half).unfold(
-1, self.block_size,
self.block_size)
bases = torch.tensor([ENCODING[b] for b in example.bases])
q_indices = torch.tensor(example.q_indices)
lengths = torch.tensor(example.lengths)

r_pos_enc = self.mapping_encodings(lengths)

yield example.read_id, self.ctgs[
example.
ctg], example.pos, signal, bases, r_pos_enc, q_indices

stored -= len(bins[bin])
bins[bin].clear()
elif stored >= 4 * self.batch_size:
for bin in bins:
for example in bin:
signal = torch.tensor(example.signal,
dtype=torch.half).unfold(
-1, self.block_size,
self.block_size)
bases = torch.tensor(
[ENCODING[b] for b in example.bases])
q_indices = torch.tensor(example.q_indices)
lengths = torch.tensor(example.lengths)

signal = torch.tensor(example.signal, dtype=torch.half).unfold(
-1, self.block_size, self.block_size)
bases = torch.tensor([ENCODING[b] for b in example.bases])
q_indices = torch.tensor(example.q_indices)
lengths = torch.tensor(example.lengths)
r_pos_enc = self.mapping_encodings(lengths)

r_pos_enc = self.mapping_encodings(lengths)
yield example.read_id, self.ctgs[
example.
ctg], example.pos, signal, bases, r_pos_enc, q_indices

yield example.read_id, self.ctgs[
example.ctg], example.pos, signal, bases, r_pos_enc, q_indices
bin.clear()
stored = 0

def read_iter_example(self):
read_id, ctg, pos, n_points, q_indices_len = struct.unpack(
Expand Down Expand Up @@ -115,9 +151,9 @@ def inference(args):
print(device)
model.to(device)

data = RFDataset(args.data_path)
data = RFDataset(args.data_path, args.batch_size)
loader = DataLoader(data,
args.batch,
args.batch_size,
False,
num_workers=args.workers,
collate_fn=collate_fn,
Expand Down Expand Up @@ -147,7 +183,7 @@ def get_arguments():
parser.add_argument('-o', '--output', type=str, default='preditions.tsv')
parser.add_argument('-d', '--gpus', default=None)
parser.add_argument('-t', '--workers', type=int, default=0)
parser.add_argument('-b', '--batch', type=int, default=1)
parser.add_argument('-b', '--batch_size', type=int, default=1)

return parser.parse_args()

Expand Down

0 comments on commit e78ed98

Please sign in to comment.