diff --git a/PopPUNK/mandrake.py b/PopPUNK/mandrake.py index 2d14a350..fc9e81f5 100644 --- a/PopPUNK/mandrake.py +++ b/PopPUNK/mandrake.py @@ -70,23 +70,29 @@ def generate_embedding(seqLabels, accMat, perplexity, outPrefix, overwrite, kNN weights = np.ones((len(seqLabels))) random.Random() seed = random.randint(0, 2**32) - if use_gpu and gpu_fn_available: - sys.stderr.write("Running on GPU\n") - n_workers = 65536 - maxIter = round(maxIter / n_workers) - wtsne_call = partial(wtsne_gpu_fp64, - perplexity=perplexity, - maxIter=maxIter, - blockSize=128, - n_workers=n_workers, - nRepuSamp=5, - eta0=1, - bInit=0, - animated=False, - cpu_threads=n_threads, - device_id=device_id, - seed=seed) - else: + gpu_analysis_complete = False + try: + if use_gpu and gpu_fn_available: + sys.stderr.write("Running on GPU\n") + n_workers = 65536 + maxIter = round(maxIter / n_workers) + wtsne_call = partial(wtsne_gpu_fp64, + perplexity=perplexity, + maxIter=maxIter, + blockSize=128, + n_workers=n_workers, + nRepuSamp=5, + eta0=1, + bInit=0, + animated=False, + cpu_threads=n_threads, + device_id=device_id, + seed=seed) + gpu_analysis_complete = True + except: + # If installed through conda/mamba mandrake is not GPU-enabled by default + sys.stderr.write('Mandrake analysis with GPU failed; trying with CPU\n') + if not gpu_analysis_complete: sys.stderr.write("Running on CPU\n") maxIter = round(maxIter / n_threads) wtsne_call = partial(wtsne,