diff --git a/pandas/_typing.py b/pandas/_typing.py index b237013ac7805..7aef5c02e290f 100644 --- a/pandas/_typing.py +++ b/pandas/_typing.py @@ -15,6 +15,7 @@ List, Mapping, Optional, + Sequence, Type, TypeVar, Union, @@ -82,6 +83,7 @@ Axis = Union[str, int] Label = Optional[Hashable] +IndexLabel = Optional[Union[Label, Sequence[Label]]] Level = Union[Label, int] Ordered = Optional[bool] JSONSerializable = Optional[Union[PythonScalar, List, Dict]] diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 40f0c6200e835..fffd2e068ebcf 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -40,6 +40,7 @@ CompressionOptions, FilePathOrBuffer, FrameOrSeries, + IndexLabel, JSONSerializable, Label, Level, @@ -3160,7 +3161,7 @@ def to_csv( columns: Optional[Sequence[Label]] = None, header: Union[bool_t, List[str]] = True, index: bool_t = True, - index_label: Optional[Union[bool_t, str, Sequence[Label]]] = None, + index_label: IndexLabel = None, mode: str = "w", encoding: Optional[str] = None, compression: CompressionOptions = "infer", diff --git a/pandas/io/common.py b/pandas/io/common.py index 3f130401558dd..f177e08ac0089 100644 --- a/pandas/io/common.py +++ b/pandas/io/common.py @@ -208,6 +208,21 @@ def get_filepath_or_buffer( # handle compression dict compression_method, compression = get_compression_method(compression) compression_method = infer_compression(filepath_or_buffer, compression_method) + + # GH21227 internal compression is not used for non-binary handles. + if ( + compression_method + and hasattr(filepath_or_buffer, "write") + and mode + and "b" not in mode + ): + warnings.warn( + "compression has no effect when passing a non-binary object as input.", + RuntimeWarning, + stacklevel=2, + ) + compression_method = None + compression = dict(compression, method=compression_method) # bz2 and xz do not write the byte order mark for utf-16 and utf-32 diff --git a/pandas/io/formats/csvs.py b/pandas/io/formats/csvs.py index 15cd5c026c6b6..90ab6f61f4d74 100644 --- a/pandas/io/formats/csvs.py +++ b/pandas/io/formats/csvs.py @@ -5,13 +5,18 @@ import csv as csvlib from io import StringIO, TextIOWrapper import os -from typing import Hashable, List, Optional, Sequence, Union -import warnings +from typing import Any, Dict, Hashable, Iterator, List, Optional, Sequence, Union import numpy as np from pandas._libs import writers as libwriters -from pandas._typing import CompressionOptions, FilePathOrBuffer, StorageOptions +from pandas._typing import ( + CompressionOptions, + FilePathOrBuffer, + IndexLabel, + Label, + StorageOptions, +) from pandas.core.dtypes.generic import ( ABCDatetimeIndex, @@ -21,6 +26,8 @@ ) from pandas.core.dtypes.missing import notna +from pandas.core.indexes.api import Index + from pandas.io.common import get_filepath_or_buffer, get_handle @@ -32,10 +39,10 @@ def __init__( sep: str = ",", na_rep: str = "", float_format: Optional[str] = None, - cols=None, + cols: Optional[Sequence[Label]] = None, header: Union[bool, Sequence[Hashable]] = True, index: bool = True, - index_label: Optional[Union[bool, Hashable, Sequence[Hashable]]] = None, + index_label: IndexLabel = None, mode: str = "w", encoding: Optional[str] = None, errors: str = "strict", @@ -43,7 +50,7 @@ def __init__( quoting: Optional[int] = None, line_terminator="\n", chunksize: Optional[int] = None, - quotechar='"', + quotechar: Optional[str] = '"', date_format: Optional[str] = None, doublequote: bool = True, escapechar: Optional[str] = None, @@ -52,16 +59,19 @@ def __init__( ): self.obj = obj + self.encoding = encoding or "utf-8" + if path_or_buf is None: path_or_buf = StringIO() ioargs = get_filepath_or_buffer( path_or_buf, - encoding=encoding, + encoding=self.encoding, compression=compression, mode=mode, storage_options=storage_options, ) + self.compression = ioargs.compression.pop("method") self.compression_args = ioargs.compression self.path_or_buf = ioargs.filepath_or_buffer @@ -72,46 +82,79 @@ def __init__( self.na_rep = na_rep self.float_format = float_format self.decimal = decimal - self.header = header self.index = index self.index_label = index_label - if encoding is None: - encoding = "utf-8" - self.encoding = encoding self.errors = errors + self.quoting = quoting or csvlib.QUOTE_MINIMAL + self.quotechar = quotechar + self.doublequote = doublequote + self.escapechar = escapechar + self.line_terminator = line_terminator or os.linesep + self.date_format = date_format + self.cols = cols # type: ignore[assignment] + self.chunksize = chunksize # type: ignore[assignment] + + @property + def index_label(self) -> IndexLabel: + return self._index_label + + @index_label.setter + def index_label(self, index_label: IndexLabel) -> None: + if index_label is not False: + if index_label is None: + index_label = self._get_index_label_from_obj() + elif not isinstance(index_label, (list, tuple, np.ndarray, ABCIndexClass)): + # given a string for a DF with Index + index_label = [index_label] + self._index_label = index_label + + def _get_index_label_from_obj(self) -> List[str]: + if isinstance(self.obj.index, ABCMultiIndex): + return self._get_index_label_multiindex() + else: + return self._get_index_label_flat() + + def _get_index_label_multiindex(self) -> List[str]: + return [name or "" for name in self.obj.index.names] - if quoting is None: - quoting = csvlib.QUOTE_MINIMAL - self.quoting = quoting + def _get_index_label_flat(self) -> List[str]: + index_label = self.obj.index.name + return [""] if index_label is None else [index_label] - if quoting == csvlib.QUOTE_NONE: + @property + def quotechar(self) -> Optional[str]: + if self.quoting != csvlib.QUOTE_NONE: # prevents crash in _csv - quotechar = None - self.quotechar = quotechar + return self._quotechar + return None - self.doublequote = doublequote - self.escapechar = escapechar + @quotechar.setter + def quotechar(self, quotechar: Optional[str]) -> None: + self._quotechar = quotechar - self.line_terminator = line_terminator or os.linesep + @property + def has_mi_columns(self) -> bool: + return bool(isinstance(self.obj.columns, ABCMultiIndex)) - self.date_format = date_format + @property + def cols(self) -> Sequence[Label]: + return self._cols - self.has_mi_columns = isinstance(obj.columns, ABCMultiIndex) + @cols.setter + def cols(self, cols: Optional[Sequence[Label]]) -> None: + self._cols = self._refine_cols(cols) + def _refine_cols(self, cols: Optional[Sequence[Label]]) -> Sequence[Label]: # validate mi options if self.has_mi_columns: if cols is not None: - raise TypeError("cannot specify cols with a MultiIndex on the columns") + msg = "cannot specify cols with a MultiIndex on the columns" + raise TypeError(msg) if cols is not None: if isinstance(cols, ABCIndexClass): - cols = cols.to_native_types( - na_rep=na_rep, - float_format=float_format, - date_format=date_format, - quoting=self.quoting, - ) + cols = cols.to_native_types(**self._number_format) else: cols = list(cols) self.obj = self.obj.loc[:, cols] @@ -120,58 +163,90 @@ def __init__( # and make sure sure cols is just a list of labels cols = self.obj.columns if isinstance(cols, ABCIndexClass): - cols = cols.to_native_types( - na_rep=na_rep, - float_format=float_format, - date_format=date_format, - quoting=self.quoting, - ) + return cols.to_native_types(**self._number_format) else: - cols = list(cols) + assert isinstance(cols, Sequence) + return list(cols) - # save it - self.cols = cols + @property + def _number_format(self) -> Dict[str, Any]: + """Dictionary used for storing number formatting settings.""" + return dict( + na_rep=self.na_rep, + float_format=self.float_format, + date_format=self.date_format, + quoting=self.quoting, + decimal=self.decimal, + ) - # preallocate data 2d list - ncols = self.obj.shape[-1] - self.data = [None] * ncols + @property + def chunksize(self) -> int: + return self._chunksize + @chunksize.setter + def chunksize(self, chunksize: Optional[int]) -> None: if chunksize is None: chunksize = (100000 // (len(self.cols) or 1)) or 1 - self.chunksize = int(chunksize) + assert chunksize is not None + self._chunksize = int(chunksize) - self.data_index = obj.index + @property + def data_index(self) -> Index: + data_index = self.obj.index if ( - isinstance(self.data_index, (ABCDatetimeIndex, ABCPeriodIndex)) - and date_format is not None + isinstance(data_index, (ABCDatetimeIndex, ABCPeriodIndex)) + and self.date_format is not None ): - from pandas import Index - - self.data_index = Index( - [x.strftime(date_format) if notna(x) else "" for x in self.data_index] + data_index = Index( + [x.strftime(self.date_format) if notna(x) else "" for x in data_index] ) + return data_index + + @property + def nlevels(self) -> int: + if self.index: + return getattr(self.data_index, "nlevels", 1) + else: + return 0 + + @property + def _has_aliases(self) -> bool: + return isinstance(self.header, (tuple, list, np.ndarray, ABCIndexClass)) + + @property + def _need_to_save_header(self) -> bool: + return bool(self._has_aliases or self.header) + + @property + def write_cols(self) -> Sequence[Label]: + if self._has_aliases: + assert not isinstance(self.header, bool) + if len(self.header) != len(self.cols): + raise ValueError( + f"Writing {len(self.cols)} cols but got {len(self.header)} aliases" + ) + else: + return self.header + else: + return self.cols + + @property + def encoded_labels(self) -> List[Label]: + encoded_labels: List[Label] = [] + + if self.index and self.index_label: + assert isinstance(self.index_label, Sequence) + encoded_labels = list(self.index_label) - self.nlevels = getattr(self.data_index, "nlevels", 1) - if not index: - self.nlevels = 0 + if not self.has_mi_columns or self._has_aliases: + encoded_labels += list(self.write_cols) + + return encoded_labels def save(self) -> None: """ Create the writer & save. """ - # GH21227 internal compression is not used for non-binary handles. - if ( - self.compression - and hasattr(self.path_or_buf, "write") - and "b" not in self.mode - ): - warnings.warn( - "compression has no effect when passing a non-binary object as input.", - RuntimeWarning, - stacklevel=2, - ) - self.compression = None - # get a handle or wrap an existing handle to take care of 1) compression and # 2) text -> byte conversion f, handles = get_handle( @@ -215,133 +290,63 @@ def save(self) -> None: for _fh in handles: _fh.close() - def _save_header(self): - writer = self.writer - obj = self.obj - index_label = self.index_label - cols = self.cols - has_mi_columns = self.has_mi_columns - header = self.header - encoded_labels: List[str] = [] - - has_aliases = isinstance(header, (tuple, list, np.ndarray, ABCIndexClass)) - if not (has_aliases or self.header): - return - if has_aliases: - if len(header) != len(cols): - raise ValueError( - f"Writing {len(cols)} cols but got {len(header)} aliases" - ) - else: - write_cols = header - else: - write_cols = cols - - if self.index: - # should write something for index label - if index_label is not False: - if index_label is None: - if isinstance(obj.index, ABCMultiIndex): - index_label = [] - for i, name in enumerate(obj.index.names): - if name is None: - name = "" - index_label.append(name) - else: - index_label = obj.index.name - if index_label is None: - index_label = [""] - else: - index_label = [index_label] - elif not isinstance( - index_label, (list, tuple, np.ndarray, ABCIndexClass) - ): - # given a string for a DF with Index - index_label = [index_label] - - encoded_labels = list(index_label) - else: - encoded_labels = [] - - if not has_mi_columns or has_aliases: - encoded_labels += list(write_cols) - writer.writerow(encoded_labels) - else: - # write out the mi - columns = obj.columns - - # write out the names for each level, then ALL of the values for - # each level - for i in range(columns.nlevels): - - # we need at least 1 index column to write our col names - col_line = [] - if self.index: - - # name is the first column - col_line.append(columns.names[i]) - - if isinstance(index_label, list) and len(index_label) > 1: - col_line.extend([""] * (len(index_label) - 1)) - - col_line.extend(columns._get_level_values(i)) - - writer.writerow(col_line) - - # Write out the index line if it's not empty. - # Otherwise, we will print out an extraneous - # blank line between the mi and the data rows. - if encoded_labels and set(encoded_labels) != {""}: - encoded_labels.extend([""] * len(columns)) - writer.writerow(encoded_labels) - def _save(self) -> None: - self._save_header() + if self._need_to_save_header: + self._save_header() + self._save_body() + def _save_header(self) -> None: + if not self.has_mi_columns or self._has_aliases: + self.writer.writerow(self.encoded_labels) + else: + for row in self._generate_multiindex_header_rows(): + self.writer.writerow(row) + + def _generate_multiindex_header_rows(self) -> Iterator[List[Label]]: + columns = self.obj.columns + for i in range(columns.nlevels): + # we need at least 1 index column to write our col names + col_line = [] + if self.index: + # name is the first column + col_line.append(columns.names[i]) + + if isinstance(self.index_label, list) and len(self.index_label) > 1: + col_line.extend([""] * (len(self.index_label) - 1)) + + col_line.extend(columns._get_level_values(i)) + yield col_line + + # Write out the index line if it's not empty. + # Otherwise, we will print out an extraneous + # blank line between the mi and the data rows. + if self.encoded_labels and set(self.encoded_labels) != {""}: + yield self.encoded_labels + [""] * len(columns) + + def _save_body(self) -> None: nrows = len(self.data_index) - - # write in chunksize bites - chunksize = self.chunksize - chunks = int(nrows / chunksize) + 1 - + chunks = int(nrows / self.chunksize) + 1 for i in range(chunks): - start_i = i * chunksize - end_i = min((i + 1) * chunksize, nrows) + start_i = i * self.chunksize + end_i = min(start_i + self.chunksize, nrows) if start_i >= end_i: break - self._save_chunk(start_i, end_i) def _save_chunk(self, start_i: int, end_i: int) -> None: - data_index = self.data_index + ncols = self.obj.shape[-1] + data = [None] * ncols # create the data for a chunk slicer = slice(start_i, end_i) df = self.obj.iloc[slicer] - blocks = df._mgr.blocks - - for i in range(len(blocks)): - b = blocks[i] - d = b.to_native_types( - na_rep=self.na_rep, - float_format=self.float_format, - decimal=self.decimal, - date_format=self.date_format, - quoting=self.quoting, - ) - for col_loc, col in zip(b.mgr_locs, d): - # self.data is a preallocated list - self.data[col_loc] = col + for block in df._mgr.blocks: + d = block.to_native_types(**self._number_format) - ix = data_index.to_native_types( - slicer=slicer, - na_rep=self.na_rep, - float_format=self.float_format, - decimal=self.decimal, - date_format=self.date_format, - quoting=self.quoting, - ) + for col_loc, col in zip(block.mgr_locs, d): + data[col_loc] = col - libwriters.write_csv_rows(self.data, ix, self.nlevels, self.cols, self.writer) + ix = self.data_index.to_native_types(slicer=slicer, **self._number_format) + libwriters.write_csv_rows(data, ix, self.nlevels, self.cols, self.writer)