Skip to content

Commit

Permalink
ARROW-4324: [Python] Triage broken type inference logic in presence o…
Browse files Browse the repository at this point in the history
…f a mix of NumPy dtype-having objects and other scalar values

In investigating the innocuous bug report from ARROW-4324 I stumbled on a pile of hacks and flawed design around type inference

```
test_list = [np.dtype('int32').type(10), np.dtype('float32').type(0.5)]
test_array = pa.array(test_list)

# Expected
# test_array
# <pyarrow.lib.DoubleArray object at 0x7f009963bf48>
# [
#   10,
#   0.5
# ]

# Got
# test_array
# <pyarrow.lib.Int32Array object at 0x7f009963bf48>
# [
#   10,
#   0
# ]
```

It turns out there are several issues:

* There was a kludge around handling the `numpy.nan` value which is a PyFloat, not a NumPy float64 scalar
* Type inference assumed "NaN is null", which should not be hard coded, so I added a flag to switch between pandas semantics and non-pandas
* Mixing NumPy scalar values and non-NumPy scalars (like our evil friend numpy.nan) caused the output type to be simply incorrect. For example `[np.float16(1.5), 2.5]` would yield `pa.float16()` output type. Yuck

In inserted some hacks to force what I believe to be the correct behavior and fixed a couple unit tests that actually exhibited buggy behavior before (see within). I don't have time to do the "right thing" right now which is to more or less rewrite the hot path of `arrow/python/inference.cc`, so at least this gets the unit tests asserting what is correct so that refactoring will be more productive later.

Author: Wes McKinney <[email protected]>

Closes #4527 from wesm/ARROW-4324 and squashes the following commits:

e396958 <Wes McKinney> Add unit test for passing pandas Series with from_pandas=False
754468a <Wes McKinney> Set from_pandas to None by default in pyarrow.array so that user wishes can be respected
e1b8393 <Wes McKinney> Remove outdated unit test, add Python unit test that shows behavior from ARROW-2240 that's been changed
4bc8c81 <Wes McKinney> Triage type inference logic in presence of a mix of NumPy dtype-having objects and other typed values, pending more serious refactor in ARROW-5564
  • Loading branch information
wesm committed Jun 12, 2019
1 parent 4ea86ff commit 25b4a46
Show file tree
Hide file tree
Showing 9 changed files with 182 additions and 87 deletions.
14 changes: 3 additions & 11 deletions cpp/src/arrow/python/arrow_to_pandas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,19 +173,11 @@ inline void set_numpy_metadata(int type, DataType* datatype, PyArray_Descr* out)
}
}

static inline PyArray_Descr* GetSafeNumPyDtype(int type) {
if (type == NPY_DATETIME) {
// It is not safe to mutate the result of DescrFromType
return PyArray_DescrNewFromType(type);
} else {
return PyArray_DescrFromType(type);
}
}
static inline PyObject* NewArray1DFromType(DataType* arrow_type, int type, int64_t length,
void* data) {
npy_intp dims[1] = {length};

PyArray_Descr* descr = GetSafeNumPyDtype(type);
PyArray_Descr* descr = internal::GetSafeNumPyDtype(type);
if (descr == nullptr) {
// Error occurred, trust error state is set
return nullptr;
Expand Down Expand Up @@ -244,7 +236,7 @@ class PandasBlock {
Status AllocateNDArray(int npy_type, int ndim = 2) {
PyAcquireGIL lock;

PyArray_Descr* descr = GetSafeNumPyDtype(npy_type);
PyArray_Descr* descr = internal::GetSafeNumPyDtype(npy_type);

PyObject* block_arr;
if (ndim == 2) {
Expand Down Expand Up @@ -1220,7 +1212,7 @@ class CategoricalBlock : public PandasBlock {

PyAcquireGIL lock;

PyArray_Descr* descr = GetSafeNumPyDtype(npy_type);
PyArray_Descr* descr = internal::GetSafeNumPyDtype(npy_type);
if (descr == nullptr) {
// Error occurred, trust error state is set
return Status::OK();
Expand Down
134 changes: 96 additions & 38 deletions cpp/src/arrow/python/inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,27 @@ namespace py {

#define _NUMPY_UNIFY_NOOP(DTYPE) \
case NPY_##DTYPE: \
return NOOP;
return OK;

#define _NUMPY_UNIFY_PROMOTE(DTYPE) \
case NPY_##DTYPE: \
return PROMOTE;

// Form a consensus NumPy dtype to use for Arrow conversion for a collection of dtype
// objects observed one at a time
current_type_num_ = dtype; \
current_dtype_ = descr; \
return OK;

#define _NUMPY_UNIFY_PROMOTE_TO(DTYPE, NEW_TYPE) \
case NPY_##DTYPE: \
current_type_num_ = NPY_##NEW_TYPE; \
current_dtype_ = PyArray_DescrFromType(current_type_num_); \
return OK;

// Form a consensus NumPy dtype to use for Arrow conversion for a
// collection of dtype objects observed one at a time
class NumPyDtypeUnifier {
public:
enum Action { NOOP, PROMOTE, INVALID };
enum Action { OK, INVALID };

NumPyDtypeUnifier() : current_type_num_(-1), current_dtype_(NULLPTR) {}
NumPyDtypeUnifier() : current_type_num_(-1), current_dtype_(nullptr) {}

Status InvalidMix(int new_dtype) {
return Status::Invalid("Cannot mix NumPy dtypes ",
Expand Down Expand Up @@ -97,7 +105,7 @@ class NumPyDtypeUnifier {
_NUMPY_UNIFY_PROMOTE(INT64);
_NUMPY_UNIFY_NOOP(UINT8);
_NUMPY_UNIFY_NOOP(UINT16);
_NUMPY_UNIFY_PROMOTE(FLOAT32);
_NUMPY_UNIFY_PROMOTE_TO(FLOAT32, FLOAT64);
_NUMPY_UNIFY_PROMOTE(FLOAT64);
default:
return INVALID;
Expand All @@ -113,7 +121,7 @@ class NumPyDtypeUnifier {
_NUMPY_UNIFY_NOOP(UINT8);
_NUMPY_UNIFY_NOOP(UINT16);
_NUMPY_UNIFY_NOOP(UINT32);
_NUMPY_UNIFY_PROMOTE(FLOAT32);
_NUMPY_UNIFY_PROMOTE_TO(FLOAT32, FLOAT64);
_NUMPY_UNIFY_PROMOTE(FLOAT64);
default:
return INVALID;
Expand Down Expand Up @@ -149,7 +157,7 @@ class NumPyDtypeUnifier {
_NUMPY_UNIFY_NOOP(UINT8);
_NUMPY_UNIFY_NOOP(UINT16);
_NUMPY_UNIFY_PROMOTE(UINT64);
_NUMPY_UNIFY_PROMOTE(FLOAT32);
_NUMPY_UNIFY_PROMOTE_TO(FLOAT32, FLOAT64);
_NUMPY_UNIFY_PROMOTE(FLOAT64);
default:
return INVALID;
Expand All @@ -161,7 +169,7 @@ class NumPyDtypeUnifier {
_NUMPY_UNIFY_NOOP(UINT8);
_NUMPY_UNIFY_NOOP(UINT16);
_NUMPY_UNIFY_NOOP(UINT32);
_NUMPY_UNIFY_PROMOTE(FLOAT32);
_NUMPY_UNIFY_PROMOTE_TO(FLOAT32, FLOAT64);
_NUMPY_UNIFY_PROMOTE(FLOAT64);
default:
return INVALID;
Expand Down Expand Up @@ -210,12 +218,11 @@ class NumPyDtypeUnifier {

int Observe_DATETIME(PyArray_Descr* dtype_obj) {
// TODO: check that units are all the same
// current_dtype_ = dtype_obj->type_num;
return NOOP;
return OK;
}

Status Observe(PyArray_Descr* descr) {
const int dtype = fix_numpy_type_num(descr->type_num);
int dtype = fix_numpy_type_num(descr->type_num);

if (current_type_num_ == -1) {
current_dtype_ = descr;
Expand All @@ -230,7 +237,7 @@ class NumPyDtypeUnifier {
action = Observe_##DTYPE(descr, dtype); \
break;

int action = NOOP;
int action = OK;
switch (current_type_num_) {
OBSERVE_CASE(BOOL);
OBSERVE_CASE(INT8);
Expand All @@ -253,9 +260,6 @@ class NumPyDtypeUnifier {

if (action == INVALID) {
return InvalidMix(dtype);
} else if (action == PROMOTE) {
current_type_num_ = dtype;
current_dtype_ = descr;
}
return Status::OK();
}
Expand All @@ -264,6 +268,8 @@ class NumPyDtypeUnifier {

PyArray_Descr* current_dtype() const { return current_dtype_; }

int current_type_num() const { return current_type_num_; }

private:
int current_type_num_;
PyArray_Descr* current_dtype_;
Expand All @@ -278,8 +284,10 @@ class TypeInferrer {
// early with long sequences that may have problems up front
// \param make_unions permit mixed-type data by creating union types (not yet
// implemented)
explicit TypeInferrer(int64_t validate_interval = 100, bool make_unions = false)
: validate_interval_(validate_interval),
explicit TypeInferrer(bool pandas_null_sentinels = false,
int64_t validate_interval = 100, bool make_unions = false)
: pandas_null_sentinels_(pandas_null_sentinels),
validate_interval_(validate_interval),
make_unions_(make_unions),
total_count_(0),
none_count_(0),
Expand All @@ -297,6 +305,7 @@ class TypeInferrer {
decimal_count_(0),
list_count_(0),
struct_count_(0),
numpy_dtype_count_(0),
max_decimal_metadata_(std::numeric_limits<int32_t>::min(),
std::numeric_limits<int32_t>::min()),
decimal_type_() {
Expand All @@ -311,12 +320,12 @@ class TypeInferrer {
Status Visit(PyObject* obj, bool* keep_going) {
++total_count_;

if (obj == Py_None || internal::PyFloat_IsNaN(obj)) {
if (obj == Py_None || (pandas_null_sentinels_ && internal::PyFloat_IsNaN(obj))) {
++none_count_;
} else if (PyBool_Check(obj)) {
++bool_count_;
*keep_going = make_unions_;
} else if (internal::PyFloatScalar_Check(obj)) {
} else if (PyFloat_Check(obj)) {
++float_count_;
*keep_going = make_unions_;
} else if (internal::IsPyInteger(obj)) {
Expand Down Expand Up @@ -367,19 +376,56 @@ class TypeInferrer {
});
}

Status GetType(std::shared_ptr<DataType>* out) const {
Status GetType(std::shared_ptr<DataType>* out) {
// TODO(wesm): handling forming unions
if (make_unions_) {
return Status::NotImplemented("Creating union types not yet supported");
}

RETURN_NOT_OK(Validate());

if (numpy_unifier_.current_dtype() != nullptr) {
std::shared_ptr<DataType> type;
RETURN_NOT_OK(NumPyDtypeToArrow(numpy_unifier_.current_dtype(), &type));
*out = type;
} else if (list_count_) {
if (numpy_dtype_count_ > 0) {
// All NumPy scalars and Nones/nulls
if (numpy_dtype_count_ + none_count_ == total_count_) {
std::shared_ptr<DataType> type;
RETURN_NOT_OK(NumPyDtypeToArrow(numpy_unifier_.current_dtype(), &type));
*out = type;
return Status::OK();
}

// The "bad path": data contains a mix of NumPy scalars and
// other kinds of scalars. Note this can happen innocuously
// because numpy.nan is not a NumPy scalar (it's a built-in
// PyFloat)

// TODO(ARROW-5564): Merge together type unification so this
// hack is not necessary
switch (numpy_unifier_.current_type_num()) {
case NPY_BOOL:
bool_count_ += numpy_dtype_count_;
break;
case NPY_INT8:
case NPY_INT16:
case NPY_INT32:
case NPY_INT64:
case NPY_UINT8:
case NPY_UINT16:
case NPY_UINT32:
case NPY_UINT64:
int_count_ += numpy_dtype_count_;
break;
case NPY_FLOAT32:
case NPY_FLOAT64:
float_count_ += numpy_dtype_count_;
break;
case NPY_DATETIME:
return Status::Invalid(
"numpy.datetime64 scalars cannot be mixed "
"with other Python scalar values currently");
}
}

if (list_count_) {
std::shared_ptr<DataType> value_type;
RETURN_NOT_OK(list_inferrer_->GetType(&value_type));
*out = list(value_type);
Expand Down Expand Up @@ -439,13 +485,15 @@ class TypeInferrer {
Status VisitDType(PyArray_Descr* dtype, bool* keep_going) {
// Continue visiting dtypes for now.
// TODO(wesm): devise approach for unions
++numpy_dtype_count_;
*keep_going = true;
return numpy_unifier_.Observe(dtype);
}

Status VisitList(PyObject* obj, bool* keep_going /* unused */) {
if (!list_inferrer_) {
list_inferrer_.reset(new TypeInferrer(validate_interval_, make_unions_));
list_inferrer_.reset(
new TypeInferrer(pandas_null_sentinels_, validate_interval_, make_unions_));
}
++list_count_;
return list_inferrer_->VisitSequence(obj);
Expand All @@ -458,9 +506,15 @@ class TypeInferrer {
}
// Not an object array: infer child Arrow type from dtype
if (!list_inferrer_) {
list_inferrer_.reset(new TypeInferrer(validate_interval_, make_unions_));
list_inferrer_.reset(
new TypeInferrer(pandas_null_sentinels_, validate_interval_, make_unions_));
}
++list_count_;

// XXX(wesm): In ARROW-4324 I added accounting to check whether
// all of the non-null values have NumPy dtypes, but the
// total_count not not being properly incremented here
++(*list_inferrer_).total_count_;
return list_inferrer_->VisitDType(dtype, keep_going);
}

Expand All @@ -484,7 +538,8 @@ class TypeInferrer {
if (it == struct_inferrers_.end()) {
it = struct_inferrers_
.insert(
std::make_pair(key, TypeInferrer(validate_interval_, make_unions_)))
std::make_pair(key, TypeInferrer(pandas_null_sentinels_,
validate_interval_, make_unions_)))
.first;
}
TypeInferrer* visitor = &it->second;
Expand All @@ -503,9 +558,9 @@ class TypeInferrer {
return Status::OK();
}

Status GetStructType(std::shared_ptr<DataType>* out) const {
Status GetStructType(std::shared_ptr<DataType>* out) {
std::vector<std::shared_ptr<Field>> fields;
for (const auto& it : struct_inferrers_) {
for (auto&& it : struct_inferrers_) {
std::shared_ptr<DataType> field_type;
RETURN_NOT_OK(it.second.GetType(&field_type));
fields.emplace_back(field(it.first, field_type));
Expand All @@ -515,6 +570,7 @@ class TypeInferrer {
}

private:
bool pandas_null_sentinels_;
int64_t validate_interval_;
bool make_unions_;
int64_t total_count_;
Expand All @@ -532,8 +588,9 @@ class TypeInferrer {
int64_t unicode_count_;
int64_t decimal_count_;
int64_t list_count_;
std::unique_ptr<TypeInferrer> list_inferrer_;
int64_t struct_count_;
int64_t numpy_dtype_count_;
std::unique_ptr<TypeInferrer> list_inferrer_;
std::map<std::string, TypeInferrer> struct_inferrers_;

// If we observe a strongly-typed value in e.g. a NumPy array, we can store
Expand All @@ -548,9 +605,10 @@ class TypeInferrer {
};

// Non-exhaustive type inference
Status InferArrowType(PyObject* obj, std::shared_ptr<DataType>* out_type) {
Status InferArrowType(PyObject* obj, bool pandas_null_sentinels,
std::shared_ptr<DataType>* out_type) {
PyDateTime_IMPORT;
TypeInferrer inferrer;
TypeInferrer inferrer(pandas_null_sentinels);
RETURN_NOT_OK(inferrer.VisitSequence(obj));
RETURN_NOT_OK(inferrer.GetType(out_type));
if (*out_type == nullptr) {
Expand All @@ -560,7 +618,7 @@ Status InferArrowType(PyObject* obj, std::shared_ptr<DataType>* out_type) {
return Status::OK();
}

Status InferArrowTypeAndSize(PyObject* obj, int64_t* size,
Status InferArrowTypeAndSize(PyObject* obj, bool pandas_null_sentinels, int64_t* size,
std::shared_ptr<DataType>* out_type) {
if (!PySequence_Check(obj)) {
return Status::TypeError("Object is not a sequence");
Expand All @@ -572,7 +630,7 @@ Status InferArrowTypeAndSize(PyObject* obj, int64_t* size,
*out_type = null();
return Status::OK();
}
RETURN_NOT_OK(InferArrowType(obj, out_type));
RETURN_NOT_OK(InferArrowType(obj, pandas_null_sentinels, out_type));

return Status::OK();
}
Expand Down
6 changes: 4 additions & 2 deletions cpp/src/arrow/python/inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,12 @@ namespace py {

// These three functions take a sequence input, not arbitrary iterables
ARROW_PYTHON_EXPORT
arrow::Status InferArrowType(PyObject* obj, std::shared_ptr<arrow::DataType>* out_type);
arrow::Status InferArrowType(PyObject* obj, bool pandas_null_sentinels,
std::shared_ptr<arrow::DataType>* out_type);

ARROW_PYTHON_EXPORT
arrow::Status InferArrowTypeAndSize(PyObject* obj, int64_t* size,
arrow::Status InferArrowTypeAndSize(PyObject* obj, bool pandas_null_sentinels,
int64_t* size,
std::shared_ptr<arrow::DataType>* out_type);

/// Checks whether the passed Python object is a boolean scalar
Expand Down
9 changes: 9 additions & 0 deletions cpp/src/arrow/python/numpy-internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,15 @@ inline bool PyBoolScalar_Check(PyObject* obj) {
return PyBool_Check(obj) || PyArray_IsScalar(obj, Bool);
}

static inline PyArray_Descr* GetSafeNumPyDtype(int type) {
if (type == NPY_DATETIME) {
// It is not safe to mutate the result of DescrFromType
return PyArray_DescrNewFromType(type);
} else {
return PyArray_DescrFromType(type);
}
}

} // namespace internal

} // namespace py
Expand Down
Loading

0 comments on commit 25b4a46

Please sign in to comment.