Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARROW-134 Cannot encode pandas NA objects #118

Merged
merged 6 commits into from
Jan 11, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions bindings/python/pymongoarrow/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import numpy as np
import pymongo.errors
from bson import encode
from bson.codec_options import TypeEncoder, TypeRegistry
from bson.raw_bson import RawBSONDocument
from pyarrow import Schema as ArrowSchema
from pyarrow import Table
Expand All @@ -26,9 +27,10 @@
ndarray = None

try:
from pandas import DataFrame
from pandas import NA, DataFrame
except ImportError:
DataFrame = None
NA = None

from pymongo.bulk import BulkWriteError
from pymongo.common import MAX_WRITE_BATCH_SIZE
Expand Down Expand Up @@ -316,6 +318,16 @@ def _tabular_generator(tabular):
return


class _PandasNACodec(TypeEncoder):
"""A custom type codec for Pandas NA objects."""

python_type = NA.__class__ # type:ignore[assignment]

def transform_python(self, _):
"""Transform an NA object into 'None'"""
return None


def write(collection, tabular):
"""Write data from `tabular` into the given MongoDB `collection`.

Expand Down Expand Up @@ -352,6 +364,13 @@ def write(collection, tabular):
)

tabular_gen = _tabular_generator(tabular)

# Handle Pandas NA objects.
codec_options = collection.codec_options
if DataFrame is not None:
type_registry = TypeRegistry([_PandasNACodec()])
codec_options = codec_options.with_options(type_registry=type_registry)
Copy link
Collaborator

@ShaneHarvey ShaneHarvey Dec 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure it's ideal to completely replace the collection's type_registry. What if the app already configured a type_registry for other types they want to encode?

Would it be possible to keep the existing registry but add the _PandasNACodec?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could, but we'd have to use private APIs from TypeRegistry.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah my mistake for assuming TypeRegistry would have the ability to add/edit a type. You know, like a registry. Could you open an ARROW ticket for this and backlog it? I'd say it's low priority.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


while cur_offset < tab_size:
cur_size = 0
cur_batch = []
Expand All @@ -361,9 +380,7 @@ def write(collection, tabular):
and len(cur_batch) <= _MAX_WRITE_BATCH_SIZE
and cur_offset + i < tab_size
):
enc_tab = RawBSONDocument(
encode(next(tabular_gen), codec_options=collection.codec_options)
)
enc_tab = RawBSONDocument(encode(next(tabular_gen), codec_options=codec_options))
cur_batch.append(enc_tab)
cur_size += len(enc_tab.raw)
i += 1
Expand Down
46 changes: 32 additions & 14 deletions bindings/python/test/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import tempfile
import unittest
import unittest.mock as mock
import warnings
from test import client_context
from test.utils import AllowListEventListener, TestNullsBase

Expand Down Expand Up @@ -98,13 +99,21 @@ def test_aggregate_simple(self):
self.assertEqual(agg_cmd.command["pipeline"][0]["$project"], projection)
self.assertEqual(agg_cmd.command["pipeline"][1]["$project"], {"_id": True, "data": True})

def _assert_frames_equal(self, incoming, outgoing):
for name in incoming.columns:
col = incoming[name]
val = outgoing[name]
if str(val.dtype) in ["object", "float64"]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment here or a method docstring would be helpful.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

val = val.astype(col.dtype)
pd.testing.assert_series_equal(col, val)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is using "val" as the name of a column idomatic?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed


def round_trip(self, data, schema, coll=None):
if coll is None:
coll = self.coll
coll.drop()
res = write(self.coll, data)
self.assertEqual(len(data), res.raw_result["insertedCount"])
pd.testing.assert_frame_equal(data, find_pandas_all(coll, {}, schema=schema))
self._assert_frames_equal(data, find_pandas_all(coll, {}, schema=schema))
return res

def test_write_error(self):
Expand All @@ -129,23 +138,34 @@ def _create_data(self):
if k.__name__ not in ("ObjectId", "Decimal128")
}
schema = {k: v.to_pandas_dtype() for k, v in arrow_schema.items()}
schema["Int64"] = pd.Int64Dtype()
schema["int"] = pd.Int32Dtype()
schema["str"] = "U8"
schema["datetime"] = "datetime64[ns]"

data = pd.DataFrame(
data={
"Int64": [i for i in range(2)],
"float": [i for i in range(2)],
"int": [i for i in range(2)],
"datetime": [datetime.datetime(1970 + i, 1, 1) for i in range(2)],
"str": [f"a{i}" for i in range(2)],
"bool": [True, False],
"Int64": [i for i in range(2)] + [None],
"float": [i for i in range(2)] + [None],
"int": [i for i in range(2)] + [None],
"datetime": [datetime.datetime(1970 + i, 1, 1) for i in range(2)] + [None],
"str": [f"a{i}" for i in range(2)] + [None],
"bool": [True, False, None],
}
).astype(schema)
return arrow_schema, data

def test_write_schema_validation(self):
arrow_schema, data = self._create_data()

# Work around https://github.com/pandas-dev/pandas/issues/11453.
def new_replace(k):
if k.value < 1:
return datetime.datetime(1970, 1, 1)
return k.replace(tzinfo=None)

data["datetime"] = data.apply(lambda row: new_replace(row["datetime"]), axis=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That pandas ticket seems to be about .time() not working on NA types but we don't use .time() anywhere. Can you explain?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, meant to link to pandas-dev/pandas#16248. Also added context.

Copy link
Collaborator

@ShaneHarvey ShaneHarvey Dec 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still confused since this code is changing the data before passing it to any pymongoarrow methods. What happens if the user actually calls write() with a NaT datetime? And why are we clamping the datetime to datetime.datetime(1970, 1, 1)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They would see the same error as https://jira.mongodb.org/browse/FREE-165786, for which this is a workaround. I just chose a random valid date.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, 2 follow up questions.

  1. Can we change this to something like?:
if isinstance(k, Nat):
   return None
return k
  1. this looks like something we should ideally be working around in write() itself. Can you open a ticket for making write() encode NaT datetimes as BSON null?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we return None we get the original error, so I had to return the workaround datetime object. I filed https://jira.mongodb.org/browse/ARROW-136.


self.round_trip(
data,
Schema(arrow_schema),
Expand Down Expand Up @@ -280,14 +300,12 @@ def test_csv(self):
_, data = self._create_data()
with tempfile.NamedTemporaryFile(suffix=".csv") as f:
f.close()
data.to_csv(f.name, index=False)
# May give RuntimeWarning due to the nulls.
with warnings.catch_warnings():
warnings.simplefilter("ignore", RuntimeWarning)
data.to_csv(f.name, index=False, na_rep="")
out = pd.read_csv(f.name)
for name in data.columns:
col = data[name]
val = out[name]
if str(val.dtype) == "object":
val = val.astype(col.dtype)
pd.testing.assert_series_equal(col, val)
self._assert_frames_equal(data, out)


class TestBSONTypes(PandasTestBase):
Expand Down