Skip to content

Commit

Permalink
BUG: Avoid RangeIndex conversion in read_csv if dtype is specified (#…
Browse files Browse the repository at this point in the history
…59316)

* BUG: Avoid RangeIndex conversion in read_csv if dtype is specified

* Undo change

* Typing
  • Loading branch information
mroeschke authored Jul 30, 2024
1 parent 12c8ec4 commit 7acd629
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 17 deletions.
39 changes: 27 additions & 12 deletions pandas/io/parsers/base_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from copy import copy
import csv
from enum import Enum
import itertools
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -271,7 +272,7 @@ def _maybe_make_multi_index_columns(

@final
def _make_index(
self, data, alldata, columns, indexnamerow: list[Scalar] | None = None
self, alldata, columns, indexnamerow: list[Scalar] | None = None
) -> tuple[Index | None, Sequence[Hashable] | MultiIndex]:
index: Index | None
if isinstance(self.index_col, list) and len(self.index_col):
Expand Down Expand Up @@ -326,7 +327,11 @@ def _agg_index(self, index) -> Index:
converters = self._clean_mapping(self.converters)
clean_dtypes = self._clean_mapping(self.dtype)

for i, arr in enumerate(index):
if self.index_names is not None:
names: Iterable = self.index_names
else:
names = itertools.cycle([None])
for i, (arr, name) in enumerate(zip(index, names)):
if self._should_parse_dates(i):
arr = date_converter(
arr,
Expand Down Expand Up @@ -369,12 +374,17 @@ def _agg_index(self, index) -> Index:
arr, _ = self._infer_types(
arr, col_na_values | col_na_fvalues, cast_type is None, try_num_bool
)
arrays.append(arr)

names = self.index_names
index = ensure_index_from_sequences(arrays, names)
if cast_type is not None:
# Don't perform RangeIndex inference
idx = Index(arr, name=name, dtype=cast_type)
else:
idx = ensure_index_from_sequences([arr], [name])
arrays.append(idx)

return index
if len(arrays) == 1:
return arrays[0]
else:
return MultiIndex.from_arrays(arrays)

@final
def _set_noconvert_dtype_columns(
Expand Down Expand Up @@ -704,12 +714,11 @@ def _get_empty_meta(
dtype_dict: defaultdict[Hashable, Any]
if not is_dict_like(dtype):
# if dtype == None, default will be object.
default_dtype = dtype or object
dtype_dict = defaultdict(lambda: default_dtype)
dtype_dict = defaultdict(lambda: dtype)
else:
dtype = cast(dict, dtype)
dtype_dict = defaultdict(
lambda: object,
lambda: None,
{columns[k] if is_integer(k) else k: v for k, v in dtype.items()},
)

Expand All @@ -726,8 +735,14 @@ def _get_empty_meta(
if (index_col is None or index_col is False) or index_names is None:
index = default_index(0)
else:
data = [Series([], dtype=dtype_dict[name]) for name in index_names]
index = ensure_index_from_sequences(data, names=index_names)
# TODO: We could return default_index(0) if dtype_dict[name] is None
data = [
Index([], name=name, dtype=dtype_dict[name]) for name in index_names
]
if len(data) == 1:
index = data[0]
else:
index = MultiIndex.from_arrays(data)
index_col.sort()

for i, n in enumerate(index_col):
Expand Down
2 changes: 1 addition & 1 deletion pandas/io/parsers/c_parser_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def read(
data = {k: v for k, (i, v) in zip(names, data_tups)}

date_data = self._do_date_conversions(names, data)
index, column_names = self._make_index(date_data, alldata, names)
index, column_names = self._make_index(alldata, names)

return index, column_names, date_data

Expand Down
4 changes: 1 addition & 3 deletions pandas/io/parsers/python_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,7 @@ def read(
conv_data = self._convert_data(data)
conv_data = self._do_date_conversions(columns, conv_data)

index, result_columns = self._make_index(
conv_data, alldata, columns, indexnamerow
)
index, result_columns = self._make_index(alldata, columns, indexnamerow)

return index, result_columns, conv_data

Expand Down
18 changes: 17 additions & 1 deletion pandas/tests/io/parser/dtypes/test_dtypes_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
"ignore:Passing a BlockManager to DataFrame:DeprecationWarning"
)

xfail_pyarrow = pytest.mark.usefixtures("pyarrow_xfail")


@pytest.mark.parametrize("dtype", [str, object])
@pytest.mark.parametrize("check_orig", [True, False])
Expand Down Expand Up @@ -614,6 +616,7 @@ def test_string_inference_object_dtype(all_parsers, dtype):
tm.assert_frame_equal(result, expected)


@xfail_pyarrow
def test_accurate_parsing_of_large_integers(all_parsers):
# GH#52505
data = """SYMBOL,MOMENT,ID,ID_DEAL
Expand All @@ -624,7 +627,7 @@ def test_accurate_parsing_of_large_integers(all_parsers):
AMZN,20230301181139587,2023552585717889759,2023552585717263360
MSFT,20230301181139587,2023552585717889863,2023552585717263361
NVDA,20230301181139587,2023552585717889827,2023552585717263361"""
orders = pd.read_csv(StringIO(data), dtype={"ID_DEAL": pd.Int64Dtype()})
orders = all_parsers.read_csv(StringIO(data), dtype={"ID_DEAL": pd.Int64Dtype()})
assert len(orders.loc[orders["ID_DEAL"] == 2023552585717263358, "ID_DEAL"]) == 1
assert len(orders.loc[orders["ID_DEAL"] == 2023552585717263359, "ID_DEAL"]) == 1
assert len(orders.loc[orders["ID_DEAL"] == 2023552585717263360, "ID_DEAL"]) == 2
Expand All @@ -646,3 +649,16 @@ def test_dtypes_with_usecols(all_parsers):
values = ["1", "4"]
expected = DataFrame({"a": pd.Series(values, dtype=object), "c": [3, 6]})
tm.assert_frame_equal(result, expected)


def test_index_col_with_dtype_no_rangeindex(all_parsers):
data = StringIO("345.5,519.5,0\n519.5,726.5,1")
result = all_parsers.read_csv(
data,
header=None,
names=["start", "stop", "bin_id"],
dtype={"start": np.float32, "stop": np.float32, "bin_id": np.uint32},
index_col="bin_id",
).index
expected = pd.Index([0, 1], dtype=np.uint32, name="bin_id")
tm.assert_index_equal(result, expected)

0 comments on commit 7acd629

Please sign in to comment.