-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ARROW-134 Cannot encode pandas NA objects #118
Changes from 4 commits
46fed4b
1116dc4
774653a
3e73b0c
a5e35d1
f9450f4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
import tempfile | ||
import unittest | ||
import unittest.mock as mock | ||
import warnings | ||
from test import client_context | ||
from test.utils import AllowListEventListener, TestNullsBase | ||
|
||
|
@@ -98,13 +99,21 @@ def test_aggregate_simple(self): | |
self.assertEqual(agg_cmd.command["pipeline"][0]["$project"], projection) | ||
self.assertEqual(agg_cmd.command["pipeline"][1]["$project"], {"_id": True, "data": True}) | ||
|
||
def _assert_frames_equal(self, incoming, outgoing): | ||
for name in incoming.columns: | ||
col = incoming[name] | ||
val = outgoing[name] | ||
if str(val.dtype) in ["object", "float64"]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A comment here or a method docstring would be helpful. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
val = val.astype(col.dtype) | ||
pd.testing.assert_series_equal(col, val) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is using "val" as the name of a column idomatic? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed |
||
|
||
def round_trip(self, data, schema, coll=None): | ||
if coll is None: | ||
coll = self.coll | ||
coll.drop() | ||
res = write(self.coll, data) | ||
self.assertEqual(len(data), res.raw_result["insertedCount"]) | ||
pd.testing.assert_frame_equal(data, find_pandas_all(coll, {}, schema=schema)) | ||
self._assert_frames_equal(data, find_pandas_all(coll, {}, schema=schema)) | ||
return res | ||
|
||
def test_write_error(self): | ||
|
@@ -129,23 +138,34 @@ def _create_data(self): | |
if k.__name__ not in ("ObjectId", "Decimal128") | ||
} | ||
schema = {k: v.to_pandas_dtype() for k, v in arrow_schema.items()} | ||
schema["Int64"] = pd.Int64Dtype() | ||
schema["int"] = pd.Int32Dtype() | ||
schema["str"] = "U8" | ||
schema["datetime"] = "datetime64[ns]" | ||
|
||
data = pd.DataFrame( | ||
data={ | ||
"Int64": [i for i in range(2)], | ||
"float": [i for i in range(2)], | ||
"int": [i for i in range(2)], | ||
"datetime": [datetime.datetime(1970 + i, 1, 1) for i in range(2)], | ||
"str": [f"a{i}" for i in range(2)], | ||
"bool": [True, False], | ||
"Int64": [i for i in range(2)] + [None], | ||
"float": [i for i in range(2)] + [None], | ||
"int": [i for i in range(2)] + [None], | ||
"datetime": [datetime.datetime(1970 + i, 1, 1) for i in range(2)] + [None], | ||
"str": [f"a{i}" for i in range(2)] + [None], | ||
"bool": [True, False, None], | ||
} | ||
).astype(schema) | ||
return arrow_schema, data | ||
|
||
def test_write_schema_validation(self): | ||
arrow_schema, data = self._create_data() | ||
|
||
# Work around https://github.com/pandas-dev/pandas/issues/11453. | ||
def new_replace(k): | ||
if k.value < 1: | ||
return datetime.datetime(1970, 1, 1) | ||
return k.replace(tzinfo=None) | ||
|
||
data["datetime"] = data.apply(lambda row: new_replace(row["datetime"]), axis=1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That pandas ticket seems to be about .time() not working on NA types but we don't use .time() anywhere. Can you explain? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, meant to link to pandas-dev/pandas#16248. Also added context. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm still confused since this code is changing the data before passing it to any pymongoarrow methods. What happens if the user actually calls write() with a NaT datetime? And why are we clamping the datetime to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They would see the same error as https://jira.mongodb.org/browse/FREE-165786, for which this is a workaround. I just chose a random valid date. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, 2 follow up questions.
if isinstance(k, Nat):
return None
return k
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we return |
||
|
||
self.round_trip( | ||
data, | ||
Schema(arrow_schema), | ||
|
@@ -280,14 +300,12 @@ def test_csv(self): | |
_, data = self._create_data() | ||
with tempfile.NamedTemporaryFile(suffix=".csv") as f: | ||
f.close() | ||
data.to_csv(f.name, index=False) | ||
# May give RuntimeWarning due to the nulls. | ||
with warnings.catch_warnings(): | ||
warnings.simplefilter("ignore", RuntimeWarning) | ||
data.to_csv(f.name, index=False, na_rep="") | ||
out = pd.read_csv(f.name) | ||
for name in data.columns: | ||
col = data[name] | ||
val = out[name] | ||
if str(val.dtype) == "object": | ||
val = val.astype(col.dtype) | ||
pd.testing.assert_series_equal(col, val) | ||
self._assert_frames_equal(data, out) | ||
|
||
|
||
class TestBSONTypes(PandasTestBase): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure it's ideal to completely replace the collection's type_registry. What if the app already configured a type_registry for other types they want to encode?
Would it be possible to keep the existing registry but add the _PandasNACodec?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could, but we'd have to use private APIs from
TypeRegistry
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah my mistake for assuming TypeRegistry would have the ability to add/edit a type. You know, like a registry. Could you open an ARROW ticket for this and backlog it? I'd say it's low priority.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://jira.mongodb.org/browse/ARROW-135