Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

'AssertionError' when trying to return 'variant_allele' from biallelic_snp_calls() #516

Closed
tristanpwdennis opened this issue Mar 26, 2024 · 3 comments
Labels
bug Something isn't working

Comments

@tristanpwdennis
Copy link

tristanpwdennis commented Mar 26, 2024

When I try to retrieve 'variant_allele' data from the results of biallelic_snp_calls(), where I have specified n_snps(), I get an 'AssertionError'.

snp_calls() works fine:

import malariagen_data
recach='/Users/dennistpw/Projects/malariagen_results'
ag3=malariagen_data.Ag3(pre=True, 
                        results_cache=recach)

ds_snps = ag3.snp_calls(sample_sets='AG1000G-AO',
                        region='2L:1000000-1010000')

ds_snps['variant_allele'].compute()

Returns:

array([[b'T', b'A', b'C', b'G'],
       [b'G', b'A', b'C', b'T'],
       [b'T', b'A', b'C', b'G'],
       ...,
       [b'C', b'A', b'T', b'G'],
       [b'G', b'A', b'C', b'T'],
       [b'G', b'A', b'C', b'T']], dtype='|S1')

biallelic_snps_calls() also works fine:

ds_snps_bi = ag3.biallelic_snp_calls(sample_sets='AG1000G-AO',
                                     region='2L:1000000-1010000')

ds_snps_bi['variant_allele'].compute()

Returns:

array([[b'A', b'T'],
       [b'G', b'T'],
       [b'G', b'A'],
       ...,
       [b'C', b'T'],
       [b'G', b'A'],
       [b'G', b'T']], dtype='|S1')

biallelic_snp_calls() where I have specified n_snps()...

ds_snps_bi_sub = ag3.biallelic_snp_calls(sample_sets='AG1000G-AO',
                                     region='2L:1000000-5000000',
                                     n_snps=2000)
ds_snps_bi_sub['variant_allele'].compute()

Returns:

{
	"name": "AssertionError",
	"message": "",
	"stack": "---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[12], line 1
----> 1 ds_snps_bi_sub['variant_allele'].compute()

File ~/miniconda3/envs/malariagen_plink/lib/python3.8/site-packages/xarray/core/dataarray.py:1101, in DataArray.compute(self, **kwargs)
   1082 \"\"\"Manually trigger loading of this array's data from disk or a
   1083 remote source into memory and return a new array. The original is
   1084 left unaltered.
   (...)
   1098 dask.compute
   1099 \"\"\"
   1100 new = self.copy(deep=False)
-> 1101 return new.load(**kwargs)

File ~/miniconda3/envs/malariagen_plink/lib/python3.8/site-packages/xarray/core/dataarray.py:1075, in DataArray.load(self, **kwargs)
   1057 def load(self: T_DataArray, **kwargs) -> T_DataArray:
   1058     \"\"\"Manually trigger loading of this array's data from disk or a
   1059     remote source into memory and return this array.
   1060 
   (...)
   1073     dask.compute
   1074     \"\"\"
-> 1075     ds = self._to_temp_dataset().load(**kwargs)
   1076     new = self._from_temp_dataset(ds)
   1077     self._variable = new._variable

File ~/miniconda3/envs/malariagen_plink/lib/python3.8/site-packages/xarray/core/dataset.py:747, in Dataset.load(self, **kwargs)
    744 import dask.array as da
    746 # evaluate all the dask arrays simultaneously
--> 747 evaluated_data = da.compute(*lazy_data.values(), **kwargs)
    749 for k, data in zip(lazy_data, evaluated_data):
    750     self.variables[k].data = data

File ~/miniconda3/envs/malariagen_plink/lib/python3.8/site-packages/dask/base.py:599, in compute(traverse, optimize_graph, scheduler, get, *args, **kwargs)
    596     keys.append(x.__dask_keys__())
    597     postcomputes.append(x.__dask_postcompute__())
--> 599 results = schedule(dsk, keys, **kwargs)
    600 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])

File ~/miniconda3/envs/malariagen_plink/lib/python3.8/site-packages/dask/threaded.py:89, in get(dsk, keys, cache, num_workers, pool, **kwargs)
     86     elif isinstance(pool, multiprocessing.pool.Pool):
     87         pool = MultiprocessingPoolExecutor(pool)
---> 89 results = get_async(
     90     pool.submit,
     91     pool._max_workers,
     92     dsk,
     93     keys,
     94     cache=cache,
     95     get_id=_thread_get_id,
     96     pack_exception=pack_exception,
     97     **kwargs,
     98 )
    100 # Cleanup pools associated to dead threads
    101 with pools_lock:

File ~/miniconda3/envs/malariagen_plink/lib/python3.8/site-packages/dask/local.py:511, in get_async(submit, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, chunksize, **kwargs)
    509         _execute_task(task, data)  # Re-execute locally
    510     else:
--> 511         raise_exception(exc, tb)
    512 res, worker_id = loads(res_info)
    513 state[\"cache\"][key] = res

File ~/miniconda3/envs/malariagen_plink/lib/python3.8/site-packages/dask/local.py:319, in reraise(exc, tb)
    317 if exc.__traceback__ is not tb:
    318     raise exc.with_traceback(tb)
--> 319 raise exc

File ~/miniconda3/envs/malariagen_plink/lib/python3.8/site-packages/dask/local.py:224, in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
    222 try:
    223     task, data = loads(task_info)
--> 224     result = _execute_task(task, data)
    225     id = get_id()
    226     result = dumps((result, id))

File ~/miniconda3/envs/malariagen_plink/lib/python3.8/site-packages/dask/core.py:119, in _execute_task(arg, cache, dsk)
    115     func, args = arg[0], arg[1:]
    116     # Note: Don't assign the subtask results to a variable. numpy detects
    117     # temporaries by their reference count and can execute certain
    118     # operations in-place.
--> 119     return func(*(_execute_task(a, cache) for a in args))
    120 elif not ishashable(arg):
    121     return arg

File ~/miniconda3/envs/malariagen_plink/lib/python3.8/site-packages/dask/core.py:119, in <genexpr>(.0)
    115     func, args = arg[0], arg[1:]
    116     # Note: Don't assign the subtask results to a variable. numpy detects
    117     # temporaries by their reference count and can execute certain
    118     # operations in-place.
--> 119     return func(*(_execute_task(a, cache) for a in args))
    120 elif not ishashable(arg):
    121     return arg

File ~/miniconda3/envs/malariagen_plink/lib/python3.8/site-packages/dask/core.py:119, in _execute_task(arg, cache, dsk)
    115     func, args = arg[0], arg[1:]
    116     # Note: Don't assign the subtask results to a variable. numpy detects
    117     # temporaries by their reference count and can execute certain
    118     # operations in-place.
--> 119     return func(*(_execute_task(a, cache) for a in args))
    120 elif not ishashable(arg):
    121     return arg

File ~/miniconda3/envs/malariagen_plink/lib/python3.8/site-packages/dask/optimization.py:990, in SubgraphCallable.__call__(self, *args)
    988 if not len(args) == len(self.inkeys):
    989     raise ValueError(\"Expected %d args, got %d\" % (len(self.inkeys), len(args)))
--> 990 return core.get(self.dsk, self.outkey, dict(zip(self.inkeys, args)))

File ~/miniconda3/envs/malariagen_plink/lib/python3.8/site-packages/dask/core.py:149, in get(dsk, out, cache)
    147 for key in toposort(dsk):
    148     task = dsk[key]
--> 149     result = _execute_task(task, cache)
    150     cache[key] = result
    151 result = _execute_task(out, cache)

File ~/miniconda3/envs/malariagen_plink/lib/python3.8/site-packages/dask/core.py:119, in _execute_task(arg, cache, dsk)
    115     func, args = arg[0], arg[1:]
    116     # Note: Don't assign the subtask results to a variable. numpy detects
    117     # temporaries by their reference count and can execute certain
    118     # operations in-place.
--> 119     return func(*(_execute_task(a, cache) for a in args))
    120 elif not ishashable(arg):
    121     return arg

File ~/Projects/malariagen-data-python/malariagen_data/anoph/snp_data.py:1629, in AnophelesSnpData.biallelic_snp_calls.<locals>.<lambda>(block)
   1626 variant_allele = ds_bi[\"variant_allele\"].data
   1627 variant_allele = variant_allele.rechunk((variant_allele.chunks[0], -1))
   1628 variant_allele_out = da.map_blocks(
-> 1629     lambda block: apply_allele_mapping(block, allele_mapping, max_allele=1),
   1630     variant_allele,
   1631     dtype=variant_allele.dtype,
   1632     chunks=(variant_allele.chunks[0], [2]),
   1633 )
   1634 data_vars[\"variant_allele\"] = (\"variants\", \"alleles\"), variant_allele_out
   1636 # Store allele counts, transformed, so we don't have to recompute.

File ~/Projects/malariagen-data-python/malariagen_data/util.py:1281, in apply_allele_mapping()
   1279 n_sites = x.shape[0]
   1280 n_alleles = x.shape[1]
-> 1281 assert mapping.shape[0] == n_sites
   1282 assert (
   1283     mapping.shape[1] == n_alleles
   1284 )  # these are not the same, work out what's going on - try running code with debugger? or print statementsd
   1286 # Create output array.

AssertionError: "
}

This looks like some kind of mismatch in the expected size of arrays in the apply_allele_mapping()

@leehart leehart added the bug Something isn't working label Jul 29, 2024
@tristanpwdennis
Copy link
Author

tristanpwdennis commented Aug 8, 2024

Sorry it's taken me a while to come back to this!

To my (very inexperienced) eye - it looks like the source of the bug is when da.map_blocks tries to apply over the allele_mapping, and the variant_allele arrays. variant_allele is a chunked dask array, whereas allele_mapping is an in-memory numpy array. When map_blocks is run, it is passed the chunks of the variant_allele, and the entire allele_mapping array, so the test for size...

    n_sites = x.shape[0]
    n_alleles = x.shape[1]
    assert mapping.shape[0] == n_sites
    assert mapping.shape[1] == n_alleles

...fails, as it is comparing the size of a chunk of variant_allele to the entire allele_mapping array.
I've attempted to fix this in the my PR by chunking allele_mapping according to variant_allele (see here). Now it seems to be working OK.

@alimanfoo
Copy link
Member

Hi @tristanpwdennis, nice work getting to the bottom of this one, the fix you have in #515 LGTM, thanks so much!

@alimanfoo
Copy link
Member

Just to update, I rolled the fix for this into other work I was doing on improving the biallelic SNP calls functions, via #623. Thanks again for figuring it out!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants