From 7d7ea4d3635422267251c75b9a090c30fec054b5 Mon Sep 17 00:00:00 2001 From: alixdamman Date: Fri, 14 Jun 2024 16:44:30 +0200 Subject: [PATCH] PYIODE: (Sample) - renamed get_list_periods() as get_period_list() - replaced argument as_float by astype --- doc/source/api.rst | 2 +- pyiode/time/sample.pyx | 26 ++++++++++++++++++++------ 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/doc/source/api.rst b/doc/source/api.rst index 9d6db4697..9d30252fb 100644 --- a/doc/source/api.rst +++ b/doc/source/api.rst @@ -282,7 +282,7 @@ TIME Sample.end Sample.nb_periods Sample.index - Sample.get_list_periods + Sample.get_period_list Sample.intersection diff --git a/pyiode/time/sample.pyx b/pyiode/time/sample.pyx index 734eaae99..32089f516 100644 --- a/pyiode/time/sample.pyx +++ b/pyiode/time/sample.pyx @@ -21,6 +21,9 @@ except ImportError: la = None Axis = Any +# TODO : add Period ? +_ALLOWED_TYPES_FOR_PERIOD = {'str': str, 'float': float} + # Sample wrapper class # see https://cython.readthedocs.io/en/latest/src/userguide/wrapping_CPlusPlus.html#create-cython-wrapper-class @@ -104,10 +107,16 @@ cdef class Sample: cdef string str_period = period.encode() return self.c_sample.get_period_position(str_period) - def get_list_periods(self, as_float: bool = False) -> Union[List[str], List[float]]: + def get_period_list(self, astype: Union[type(Any), str] = str) -> List[Any]: """ List of all periods of the sample. - Periods are exported as string (default) or as float + Periods are exported as string (default) or as float. + + Parameters + ---------- + astype: type or str + Allowed returned type for periods are str and float. + Default to str. Returns ------- @@ -117,17 +126,22 @@ cdef class Sample: -------- >>> from iode import variables, SAMPLE_DATA_DIR >>> variables.load(f"{SAMPLE_DATA_DIR}/fun.var") - >>> variables.sample.get_list_periods() #doctest: +ELLIPSIS + >>> variables.sample.get_period_list() #doctest: +ELLIPSIS ['1960Y1', '1961Y1', ..., '2014Y1', '2015Y1'] - >>> variables.sample.get_list_periods(as_float=True) #doctest: +ELLIPSIS + >>> variables.sample.get_period_list(astype=float) #doctest: +ELLIPSIS [1960.0, 1961.0, ..., 2014.0, 2015.0] """ if self.c_sample is NULL: raise RuntimeError("'sample' is not defined") - if as_float: + if isinstance(astype, str): + astype = _ALLOWED_TYPES_FOR_PERIOD[astype] + + if astype == float: return self.c_sample.get_list_periods_as_float() - else: + elif astype == str: return [period.decode() for period in self.c_sample.get_list_periods()] + else: + raise ValueError(f"'astype': type {astype.__name__} is not allowed.") def intersection(self, other_sample: Sample) -> Sample: """