From effc980a289474ee8617e69dfba480869f0cc8a4 Mon Sep 17 00:00:00 2001 From: Luis Fabregas Date: Tue, 22 Nov 2022 08:24:14 +0100 Subject: [PATCH] bootstrap_analysis: fix error when analyzing scalar variables --- deerlab/bootstrap_analysis.py | 2 +- test/test_bootstrap_analysis.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/deerlab/bootstrap_analysis.py b/deerlab/bootstrap_analysis.py index 368571a3..f3c277ed 100644 --- a/deerlab/bootstrap_analysis.py +++ b/deerlab/bootstrap_analysis.py @@ -129,7 +129,7 @@ def sample(): # Assert that all outputs are strictly numerical for var in varargout: - if not all(isnumeric(x) for x in var): + if not all(isnumeric(x) for x in np.atleast_1d(var)): raise ValueError('Non-numeric output arguments by the analyzed function are not accepted.') # Check that the full bootstrap analysis will not exceed the memory limits diff --git a/test/test_bootstrap_analysis.py b/test/test_bootstrap_analysis.py index e0e6e552..e356bb67 100644 --- a/test/test_bootstrap_analysis.py +++ b/test/test_bootstrap_analysis.py @@ -1,6 +1,6 @@ import numpy as np -from deerlab import dipolarkernel, whitegaussnoise, bootstrap_analysis, snlls +from deerlab import whitegaussnoise, bootstrap_analysis, snlls from deerlab.dd_models import dd_gauss from deerlab.utils import assert_docstring @@ -25,7 +25,7 @@ def fitfcn_global(ys): def fitfcn_multiout(yexp): fit = snlls(yexp,model,[1,3],uq=False) - return fit.nonlin*fit.lin, fit.model + return fit.nonlin*fit.lin, fit.model, fit.nonlin[0] def fitfcn_complex(yexp): fit = snlls(yexp,model,[1,3],uq=False) @@ -58,10 +58,10 @@ def test_multiple_ouputs(): # ====================================================================== "Check that both bootstrap handles the correct number outputs" - parfit,yfit = fitfcn_multiout(yexp) + parfit,yfit,_ = fitfcn_multiout(yexp) paruq = bootstrap_analysis(fitfcn_multiout,yexp,model(parfit),10) - assert len(paruq)==2 and all(abs(paruq[0].mean - parfit)) and all(abs(paruq[1].mean - yfit)) + assert len(paruq)==3 and all(abs(paruq[0].mean - parfit)) and all(abs(paruq[1].mean - yfit)) # ====================================================================== def test_multiple_datasets():