Skip to content

Commit

Permalink
apacheGH-40644: [Python] Allow passing a mapping of column names to `…
Browse files Browse the repository at this point in the history
…rename_columns` (apache#40645)

### Rationale for this change

See apache#40644 

### What changes are included in this PR?

### Are these changes tested?

Yes.

Tests have been added.

### Are there any user-facing changes?

* GitHub Issue: apache#40644

Authored-by: Judah Rand <[email protected]>
Signed-off-by: AlenkaF <[email protected]>
  • Loading branch information
judahrand authored and vibhatha committed May 25, 2024
1 parent 81b2866 commit 36c87b6
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 8 deletions.
83 changes: 75 additions & 8 deletions python/pyarrow/table.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -2816,8 +2816,17 @@ cdef class RecordBatch(_Tabular):
Parameters
----------
names : list of str
List of new column names.
names : list[str] or dict[str, str]
List of new column names or mapping of old column names to new column names.
If a mapping of old to new column names is passed, then all columns which are
found to match a provided old column name will be renamed to the new column name.
If any column names are not found in the mapping, a KeyError will be raised.
Raises
------
KeyError
If any of the column names passed in the names mapping do not exist.
Returns
-------
Expand All @@ -2838,13 +2847,38 @@ cdef class RecordBatch(_Tabular):
----
n: [2,4,5,100]
name: ["Flamingo","Horse","Brittle stars","Centipede"]
>>> new_names = {"n_legs": "n", "animals": "name"}
>>> batch.rename_columns(new_names)
pyarrow.RecordBatch
n: int64
name: string
----
n: [2,4,5,100]
name: ["Flamingo","Horse","Brittle stars","Centipede"]
"""
cdef:
shared_ptr[CRecordBatch] c_batch
vector[c_string] c_names

for name in names:
c_names.push_back(tobytes(name))
if isinstance(names, list):
for name in names:
c_names.push_back(tobytes(name))
elif isinstance(names, dict):
idx_to_new_name = {}
for name, new_name in names.items():
indices = self.schema.get_all_field_indices(name)

if not indices:
raise KeyError("Column {!r} not found".format(name))

for index in indices:
idx_to_new_name[index] = new_name

for i in range(self.num_columns):
new_name = idx_to_new_name.get(i, self.column_names[i])
c_names.push_back(tobytes(new_name))
else:
raise TypeError(f"names must be a list or dict not {type(names)!r}")

with nogil:
c_batch = GetResultValue(self.batch.RenameColumns(move(c_names)))
Expand Down Expand Up @@ -5215,8 +5249,17 @@ cdef class Table(_Tabular):
Parameters
----------
names : list of str
List of new column names.
names : list[str] or dict[str, str]
List of new column names or mapping of old column names to new column names.
If a mapping of old to new column names is passed, then all columns which are
found to match a provided old column name will be renamed to the new column name.
If any column names are not found in the mapping, a KeyError will be raised.
Raises
------
KeyError
If any of the column names passed in the names mapping do not exist.
Returns
-------
Expand All @@ -5237,13 +5280,37 @@ cdef class Table(_Tabular):
----
n: [[2,4,5,100]]
name: [["Flamingo","Horse","Brittle stars","Centipede"]]
>>> new_names = {"n_legs": "n", "animals": "name"}
>>> table.rename_columns(new_names)
pyarrow.Table
n: int64
name: string
----
n: [[2,4,5,100]]
name: [["Flamingo","Horse","Brittle stars","Centipede"]]
"""
cdef:
shared_ptr[CTable] c_table
vector[c_string] c_names

for name in names:
c_names.push_back(tobytes(name))
if isinstance(names, list):
for name in names:
c_names.push_back(tobytes(name))
elif isinstance(names, dict):
idx_to_new_name = {}
for name, new_name in names.items():
indices = self.schema.get_all_field_indices(name)

if not indices:
raise KeyError("Column {!r} not found".format(name))

for index in indices:
idx_to_new_name[index] = new_name

for i in range(self.num_columns):
c_names.push_back(tobytes(idx_to_new_name.get(i, self.schema[i].name)))
else:
raise TypeError(f"names must be a list or dict not {type(names)!r}")

with nogil:
c_table = GetResultValue(self.table.RenameColumns(move(c_names)))
Expand Down
37 changes: 37 additions & 0 deletions python/pyarrow/tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1737,6 +1737,43 @@ def test_table_rename_columns(cls):
expected = cls.from_arrays(data, names=['eh', 'bee', 'sea'])
assert t2.equals(expected)

message = "names must be a list or dict not <class 'str'>"
with pytest.raises(TypeError, match=message):
table.rename_columns('not a list')


@pytest.mark.parametrize(
('cls'),
[
(pa.Table),
(pa.RecordBatch)
]
)
def test_table_rename_columns_mapping(cls):
data = [
pa.array(range(5)),
pa.array([-10, -5, 0, 5, 10]),
pa.array(range(5, 10))
]
table = cls.from_arrays(data, names=['a', 'b', 'c'])
assert table.column_names == ['a', 'b', 'c']

expected = cls.from_arrays(data, names=['eh', 'b', 'sea'])
t1 = table.rename_columns({'a': 'eh', 'c': 'sea'})
t1.validate()
assert t1 == expected

# Test renaming duplicate column names
table = cls.from_arrays(data, names=['a', 'a', 'c'])
expected = cls.from_arrays(data, names=['eh', 'eh', 'sea'])
t2 = table.rename_columns({'a': 'eh', 'c': 'sea'})
t2.validate()
assert t2 == expected

# Test column not found
with pytest.raises(KeyError, match=r"Column 'd' not found"):
table.rename_columns({'a': 'eh', 'd': 'sea'})


def test_table_flatten():
ty1 = pa.struct([pa.field('x', pa.int16()),
Expand Down

0 comments on commit 36c87b6

Please sign in to comment.