diff --git a/asv_bench/benchmarks/io/stata.py b/asv_bench/benchmarks/io/stata.py index a7f854a853f50c..454eef4e282935 100644 --- a/asv_bench/benchmarks/io/stata.py +++ b/asv_bench/benchmarks/io/stata.py @@ -12,12 +12,12 @@ class Stata(BaseIO): def setup(self, convert_dates): self.fname = '__test__.dta' - N = 100000 - C = 5 + N = self.N = 100000 + C = self.C = 5 self.df = DataFrame(np.random.randn(N, C), columns=['float{}'.format(i) for i in range(C)], index=date_range('20000101', periods=N, freq='H')) - self.df['object'] = tm.makeStringIndex(N) + self.df['object'] = tm.makeStringIndex(self.N) self.df['int8_'] = np.random.randint(np.iinfo(np.int8).min, np.iinfo(np.int8).max - 27, N) self.df['int16_'] = np.random.randint(np.iinfo(np.int16).min, @@ -36,4 +36,14 @@ def time_write_stata(self, convert_dates): self.df.to_stata(self.fname, self.convert_dates) +class StataMissing(Stata): + def setup(self, convert_dates): + super(StataMissing, self).setup(convert_dates) + for i in range(25): + missing_data = np.random.randn(self.N) + missing_data[missing_data < 0] = np.nan + self.df['missing_{0}'.format(i)] = missing_data + self.df.to_stata(self.fname, self.convert_dates) + + from ..pandas_vb_common import setup # noqa: F401 diff --git a/doc/source/whatsnew/v0.25.0.rst b/doc/source/whatsnew/v0.25.0.rst index 8e72ce83ac0280..b298acdc57cf3c 100644 --- a/doc/source/whatsnew/v0.25.0.rst +++ b/doc/source/whatsnew/v0.25.0.rst @@ -235,7 +235,7 @@ I/O - Bug in :func:`json_normalize` for ``errors='ignore'`` where missing values in the input data, were filled in resulting ``DataFrame`` with the string "nan" instead of ``numpy.nan`` (:issue:`25468`) - :meth:`DataFrame.to_html` now raises ``TypeError`` when using an invalid type for the ``classes`` parameter instead of ``AsseertionError`` (:issue:`25608`) - Bug in :meth:`DataFrame.to_string` and :meth:`DataFrame.to_latex` that would lead to incorrect output when the ``header`` keyword is used (:issue:`16718`) -- +- Improved performance in :meth:`pandas.read_stata` and :class:`pandas.io.stata.StataReader` when converting columns that have missing values (:issue:`25772`) Plotting diff --git a/pandas/io/stata.py b/pandas/io/stata.py index 62a9dbdc4657ea..221e6c25022bc1 100644 --- a/pandas/io/stata.py +++ b/pandas/io/stata.py @@ -31,7 +31,8 @@ from pandas.core.dtypes.common import ( ensure_object, is_categorical_dtype, is_datetime64_dtype) -from pandas import DatetimeIndex, compat, isna, to_datetime, to_timedelta +from pandas import (DatetimeIndex, compat, isna, to_datetime, to_timedelta, + concat) from pandas.core.arrays import Categorical from pandas.core.base import StringMixin from pandas.core.frame import DataFrame @@ -1572,7 +1573,7 @@ def read(self, nrows=None, convert_dates=None, data = DataFrame.from_dict(OrderedDict(data_formatted)) del data_formatted - self._do_convert_missing(data, convert_missing) + data = self._do_convert_missing(data, convert_missing) if convert_dates: cols = np.where(lmap(lambda x: any(x.startswith(fmt) @@ -1616,7 +1617,7 @@ def read(self, nrows=None, convert_dates=None, def _do_convert_missing(self, data, convert_missing): # Check for missing values, and replace if found - + replacements = {} for i, colname in enumerate(data): fmt = self.typlist[i] if fmt not in self.VALID_RANGE: @@ -1646,8 +1647,14 @@ def _do_convert_missing(self, data, convert_missing): dtype = np.float64 replacement = Series(series, dtype=dtype) replacement[missing] = np.nan - - data[colname] = replacement + replacements[colname] = replacement + if replacements: + columns = data.columns + replacements = DataFrame(replacements) + data = concat([data.drop(replacements.columns, 1), + replacements], 1) + data = data[columns] + return data def _insert_strls(self, data): if not hasattr(self, 'GSO') or len(self.GSO) == 0: @@ -1712,7 +1719,7 @@ def _do_convert_categoricals(self, data, value_label_dict, lbllist, except ValueError: vc = Series(categories).value_counts() repeats = list(vc.index[vc > 1]) - repeats = '\n' + '-' * 80 + '\n'.join(repeats) + repeats = '\n' + '-' * 80 + '\n' + '\n'.join(repeats) raise ValueError('Value labels for column {col} are not ' 'unique. The repeated labels are:\n' '{repeats}'