Skip to content

Commit

Permalink
BUG: Set index when reading stata file
Browse files Browse the repository at this point in the history
Ensures index is set when requested when reading state dta file

closes #16342
  • Loading branch information
bashtage committed Aug 24, 2017
1 parent 96f92eb commit b8e36ac
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 3 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.21.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ I/O
- Bug in :func:`read_csv` when called with ``low_memory=False`` in which a CSV with at least one column > 2GB in size would incorrectly raise a ``MemoryError`` (:issue:`16798`).
- Bug in :func:`read_csv` when called with a single-element list ``header`` would return a ``DataFrame`` of all NaN values (:issue:`7757`)
- Bug in :func:`read_stata` where value labels could not be read when using an iterator (:issue:`16923`)
- Bug in :func:`read_stata` where the index was not set (:issue:`16342`)
- Bug in :func:`read_html` where import check fails when run in multiple threads (:issue:`16928`)

Plotting
Expand Down
11 changes: 8 additions & 3 deletions pandas/io/stata.py
Original file line number Diff line number Diff line change
Expand Up @@ -1486,6 +1486,8 @@ def read(self, nrows=None, convert_dates=None,
columns = self._columns
if order_categoricals is None:
order_categoricals = self._order_categoricals
if index is None:
index = self._index

if nrows is None:
nrows = self.nobs
Expand Down Expand Up @@ -1526,7 +1528,7 @@ def read(self, nrows=None, convert_dates=None,
if len(data) == 0:
data = DataFrame(columns=self.varlist, index=index)
else:
data = DataFrame.from_records(data, index=index)
data = DataFrame.from_records(data)
data.columns = self.varlist

# If index is not specified, use actual row number rather than
Expand All @@ -1553,7 +1555,7 @@ def read(self, nrows=None, convert_dates=None,
cols_ = np.where(self.dtyplist)[0]

# Convert columns (if needed) to match input type
index = data.index
ix = data.index
requires_type_conversion = False
data_formatted = []
for i in cols_:
Expand All @@ -1563,7 +1565,7 @@ def read(self, nrows=None, convert_dates=None,
if dtype != np.dtype(object) and dtype != self.dtyplist[i]:
requires_type_conversion = True
data_formatted.append(
(col, Series(data[col], index, self.dtyplist[i])))
(col, Series(data[col], ix, self.dtyplist[i])))
else:
data_formatted.append((col, data[col]))
if requires_type_conversion:
Expand Down Expand Up @@ -1606,6 +1608,9 @@ def read(self, nrows=None, convert_dates=None,
if convert:
data = DataFrame.from_items(retyped_data)

if index is not None:
data = data.set_index(data.pop(index))

return data

def _do_convert_missing(self, data, convert_missing):
Expand Down
8 changes: 8 additions & 0 deletions pandas/tests/io/test_stata.py
Original file line number Diff line number Diff line change
Expand Up @@ -1309,3 +1309,11 @@ def test_value_labels_iterator(self, write_index):
dta_iter = pd.read_stata(path, iterator=True)
value_labels = dta_iter.value_labels()
assert value_labels == {'A': {0: 'A', 1: 'B', 2: 'C', 3: 'E'}}

def test_set_index(self):
df = tm.makeDataFrame()
df.index.name = 'index'
with tm.ensure_clean() as path:
df.to_stata(path)
reread = pd.read_stata(path, index='index')
tm.assert_frame_equal(df, reread)

0 comments on commit b8e36ac

Please sign in to comment.