Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make reg2bins, reg2intervals faster on whole-chromosome queries #1596

Merged
merged 2 commits into from
Apr 13, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 174 additions & 43 deletions hts.c
Original file line number Diff line number Diff line change
Expand Up @@ -2903,73 +2903,201 @@ uint64_t hts_idx_get_n_no_coor(const hts_idx_t* idx)
****************/

// Note: even with 32-bit hts_pos_t, end needs to be 64-bit here due to 1LL<<s.
static inline int reg2bins(int64_t beg, int64_t end, hts_itr_t *itr, int min_shift, int n_lvls)
static inline int reg2bins_narrow(int64_t beg, int64_t end, hts_itr_t *itr, int min_shift, int n_lvls, bidx_t *bidx)
{
int l, t, s = min_shift + (n_lvls<<1) + n_lvls;
if (beg >= end) return 0;
if (end >= 1LL<<s) end = 1LL<<s;
for (--end, l = 0, t = 0; l <= n_lvls; s -= 3, t += 1<<((l<<1)+l), ++l) {
hts_pos_t b, e;
int n, i;
b = t + (beg>>s); e = t + (end>>s); n = e - b + 1;
if (itr->bins.n + n > itr->bins.m) {
itr->bins.m = itr->bins.n + n;
kroundup32(itr->bins.m);
itr->bins.a = (int*)realloc(itr->bins.a, sizeof(int) * itr->bins.m);
int i;
b = t + (beg>>s); e = t + (end>>s);
for (i = b; i <= e; ++i) {
if (kh_get(bin, bidx, i) != kh_end(bidx)) {
assert(itr->bins.n < itr->bins.m);
itr->bins.a[itr->bins.n++] = i;
}
}
for (i = b; i <= e; ++i) itr->bins.a[itr->bins.n++] = i;
}
return itr->bins.n;
}

static inline int reg2bins_wide(int64_t beg, int64_t end, hts_itr_t *itr, int min_shift, int n_lvls, bidx_t *bidx)
{
khint_t i;
hts_pos_t max_shift = 3 * n_lvls + min_shift;
--end;
if (beg < 0) beg = 0;
for (i = kh_begin(bidx); i != kh_end(bidx); i++) {
if (!kh_exist(bidx, i)) continue;
hts_pos_t bin = (hts_pos_t) kh_key(bidx, i);
int level = hts_bin_level(bin);
if (level > n_lvls) continue; // Dodgy index?
hts_pos_t first = hts_bin_first(level);
hts_pos_t beg_at_level = first + (beg >> (max_shift - 3 * level));
hts_pos_t end_at_level = first + (end >> (max_shift - 3 * level));
if (beg_at_level <= bin && bin <= end_at_level) {
assert(itr->bins.n < itr->bins.m);
itr->bins.a[itr->bins.n++] = bin;
}
}
return itr->bins.n;
}

static inline int reg2bins(int64_t beg, int64_t end, hts_itr_t *itr, int min_shift, int n_lvls, bidx_t *bidx)
{
int l, t, s = min_shift + (n_lvls<<1) + n_lvls;
size_t reg_bin_count = 0, hash_bin_count = kh_n_buckets(bidx), max_bins;
hts_pos_t end1;
if (end >= 1LL<<s) end = 1LL<<s;
if (beg >= end) return 0;
end1 = end - 1;

// Count bins to see if it's faster to iterate through the hash table
// or the set of bins covering the region
for (l = 0, t = 0; l <= n_lvls; s -= 3, t += 1<<((l<<1)+l), ++l) {
reg_bin_count += (end1 >> s) - (beg >> s) + 1;
}
max_bins = reg_bin_count < kh_size(bidx) ? reg_bin_count : kh_size(bidx);
if (itr->bins.m - itr->bins.n < max_bins) {
// Worst-case memory usage. May be wasteful on very sparse
// data, but the bin list usually won't be too big anyway.
size_t new_m = max_bins + itr->bins.n;
if (new_m > INT_MAX || new_m > SIZE_MAX / sizeof(int)) {
errno = ENOMEM;
return -1;
}
int *new_a = realloc(itr->bins.a, new_m * sizeof(*new_a));
if (!new_a) return -1;
itr->bins.a = new_a;
itr->bins.m = new_m;
}
if (reg_bin_count < hash_bin_count) {
return reg2bins_narrow(beg, end, itr, min_shift, n_lvls, bidx);
} else {
return reg2bins_wide(beg, end, itr, min_shift, n_lvls, bidx);
}
}

static inline int add_to_interval(hts_itr_t *iter, bins_t *bin,
int tid, uint32_t interval,
uint64_t min_off, uint64_t max_off)
{
hts_pair64_max_t *off;
int j;

if (!bin->n)
return 0;
off = realloc(iter->off, (iter->n_off + bin->n) * sizeof(*off));
if (!off)
return -2;

iter->off = off;
for (j = 0; j < bin->n; ++j) {
if (bin->list[j].v > min_off && bin->list[j].u < max_off) {
iter->off[iter->n_off].u = min_off > bin->list[j].u
? min_off : bin->list[j].u;
iter->off[iter->n_off].v = max_off < bin->list[j].v
? max_off : bin->list[j].v;
// hts_pair64_max_t::max is now used to link
// file offsets to region list entries.
// The iterator can use this to decide if it
// can skip some file regions.
iter->off[iter->n_off].max = ((uint64_t) tid << 32) | interval;
iter->n_off++;
}
}
return 0;
}

static inline int reg2intervals_narrow(hts_itr_t *iter, const bidx_t *bidx,
int tid, int64_t beg, int64_t end,
uint32_t interval,
uint64_t min_off, uint64_t max_off,
int min_shift, int n_lvls)
{
int l, t, s = min_shift + n_lvls * 3;
hts_pos_t b, e, i;

for (--end, l = 0, t = 0; l <= n_lvls; s -= 3, t += 1<<((l<<1)+l), ++l) {
b = t + (beg>>s); e = t + (end>>s);
for (i = b; i <= e; ++i) {
khint_t k = kh_get(bin, bidx, i);
if (k != kh_end(bidx)) {
bins_t *bin = &kh_value(bidx, k);
int res = add_to_interval(iter, bin, tid, interval, min_off, max_off);
if (res < 0)
return res;
}
}
}
return 0;
}

static inline int reg2intervals_wide(hts_itr_t *iter, const bidx_t *bidx,
int tid, int64_t beg, int64_t end,
uint32_t interval,
uint64_t min_off, uint64_t max_off,
int min_shift, int n_lvls)
{
khint_t i;
hts_pos_t max_shift = 3 * n_lvls + min_shift;
--end;
if (beg < 0) beg = 0;
for (i = kh_begin(bidx); i != kh_end(bidx); i++) {
if (!kh_exist(bidx, i)) continue;
hts_pos_t bin = (hts_pos_t) kh_key(bidx, i);
int level = hts_bin_level(bin);
if (level > n_lvls) continue; // Dodgy index?
hts_pos_t first = hts_bin_first(level);
hts_pos_t beg_at_level = first + (beg >> (max_shift - 3 * level));
hts_pos_t end_at_level = first + (end >> (max_shift - 3 * level));
if (beg_at_level <= bin && bin <= end_at_level) {
bins_t *bin = &kh_value(bidx, i);
int res = add_to_interval(iter, bin, tid, interval, min_off, max_off);
if (res < 0)
return res;
}
}
return 0;
}

static inline int reg2intervals(hts_itr_t *iter, const hts_idx_t *idx, int tid, int64_t beg, int64_t end, uint32_t interval, uint64_t min_off, uint64_t max_off, int min_shift, int n_lvls)
{
int l, t, s;
int i, j;
hts_pos_t b, e;
hts_pair64_max_t *off;
hts_pos_t end1;
bidx_t *bidx;
khint_t k;
int start_n_off = iter->n_off;
int start_n_off;
size_t reg_bin_count = 0, hash_bin_count;
int res;

if (!iter || !idx || (bidx = idx->bidx[tid]) == NULL || beg >= end)
return -1;

hash_bin_count = kh_n_buckets(bidx);

s = min_shift + (n_lvls<<1) + n_lvls;
if (end >= 1LL<<s)
end = 1LL<<s;

for (--end, l = 0, t = 0; l <= n_lvls; s -= 3, t += 1<<((l<<1)+l), ++l) {
b = t + (beg>>s); e = t + (end>>s);

for (i = b; i <= e; ++i) {
if ((k = kh_get(bin, bidx, i)) != kh_end(bidx)) {
bins_t *p = &kh_value(bidx, k);
end1 = end - 1;
// Count bins to see if it's faster to iterate through the hash table
// or the set of bins covering the region
for (l = 0, t = 0; l <= n_lvls; s -= 3, t += 1<<((l<<1)+l), ++l) {
reg_bin_count += (end1 >> s) - (beg >> s) + 1;
}

if (p->n) {
off = realloc(iter->off, (iter->n_off + p->n) * sizeof(*off));
if (!off)
return -2;
start_n_off = iter->n_off;

iter->off = off;
for (j = 0; j < p->n; ++j) {
if (p->list[j].v > min_off && p->list[j].u < max_off) {
iter->off[iter->n_off].u = min_off > p->list[j].u
? min_off : p->list[j].u;
iter->off[iter->n_off].v = max_off < p->list[j].v
? max_off : p->list[j].v;
// hts_pair64_max_t::max is now used to link
// file offsets to region list entries.
// The iterator can use this to decide if it
// can skip some file regions.
iter->off[iter->n_off].max = ((uint64_t) tid << 32) | interval;
iter->n_off++;
}
}
}
}
}
// Populate iter->off with the intervals for this region
if (reg_bin_count < hash_bin_count) {
res = reg2intervals_narrow(iter, bidx, tid, beg, end, interval,
min_off, max_off, min_shift, n_lvls);
} else {
res = reg2intervals_wide(iter, bidx, tid, beg, end, interval,
min_off, max_off, min_shift, n_lvls);
}
if (res < 0)
return res;

if (iter->n_off - start_n_off > 1) {
ks_introsort(_off_max, iter->n_off - start_n_off, iter->off + start_n_off);
Expand Down Expand Up @@ -3159,7 +3287,10 @@ hts_itr_t *hts_itr_query(const hts_idx_t *idx, int tid, hts_pos_t beg, hts_pos_t
}

// retrieve bins
reg2bins(beg, end, iter, idx->min_shift, idx->n_lvls);
if (reg2bins(beg, end, iter, idx->min_shift, idx->n_lvls, bidx) < 0) {
hts_itr_destroy(iter);
return NULL;
}

for (i = n_off = 0; i < iter->bins.n; ++i)
if ((k = kh_get(bin, bidx, iter->bins.a[i])) != kh_end(bidx))
Expand Down