Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TYPING: Added types for some tests #29205

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,4 @@ doc/build/html/index.html
doc/tmp.sv
env/
doc/source/savefig/
.dmypy.json
2 changes: 2 additions & 0 deletions pandas/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pandas.core.indexes.base import Index # noqa: F401
from pandas.core.series import Series # noqa: F401
from pandas.core.generic import NDFrame # noqa: F401
from pandas.core.base import IndexOpsMixin # noqa: F401


AnyArrayLike = TypeVar("AnyArrayLike", "ExtensionArray", "Index", "Series", np.ndarray)
Expand All @@ -32,6 +33,7 @@
FilePathOrBuffer = Union[str, Path, IO[AnyStr]]

FrameOrSeries = TypeVar("FrameOrSeries", bound="NDFrame")
IndexOrSeries = TypeVar("IndexOrSeries", bound="IndexOpsMixin")
Copy link
Contributor

@jreback jreback Oct 24, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this meant to include array likes, eg EA? as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, just Index and Series.

Scalar = Union[str, int, float, bool]
Axis = Union[str, int]
Ordered = Optional[bool]
Expand Down
12 changes: 12 additions & 0 deletions pandas/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import pandas as pd
from pandas import DataFrame
from pandas._typing import IndexOrSeries
from pandas.core import ops
import pandas.util.testing as tm

Expand Down Expand Up @@ -790,6 +791,17 @@ def tick_classes(request):
return request.param


index_or_series_params = [pd.Index, pd.Series] # type: IndexOrSeries
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# type: List[IndexOrSeries]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be List[Type[IndexOrSeries]]? Surprised this passed checks as is; seems like something wonky going on

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch.

agree that strange it doesn't fail



@pytest.fixture(params=index_or_series_params, ids=["series", "index"])
WillAyd marked this conversation as resolved.
Show resolved Hide resolved
def index_or_series(request) -> IndexOrSeries:
"""
Parametrized fixture providing the Index or Series class.
"""
return request.param


# ----------------------------------------------------------------
# Global setup for tests using Hypothesis

Expand Down
51 changes: 11 additions & 40 deletions pandas/tests/arithmetic/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from decimal import Decimal
from itertools import combinations
import operator
from typing import List, Type, Union

import numpy as np
import pytest
Expand Down Expand Up @@ -74,33 +75,22 @@ def test_compare_invalid(self):

# ------------------------------------------------------------------
# Numeric dtypes Arithmetic with Datetime/Timedelta Scalar
index_or_series_params = [
pd.Series,
pd.Index,
] # type: List[Union[Type[pd.Index], Type[pd.RangeIndex], Type[pd.Series]]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
] # type: List[Union[Type[pd.Index], Type[pd.RangeIndex], Type[pd.Series]]]
] # type: Sequence[Type[Union[pd.Index, pd.Series]]]

Can shorten this quite a bit if you put the Union inside of the Type. Also, if you decide to use Sequence instead of list it covariant, so can implicitly handle RangeIndex being a subclass of Index

left = [pd.RangeIndex(10, 40, 10)] # type: List[Union[Index, Series]]
for cls in index_or_series_params:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn’t this any_numeric_dtype ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to include float16. Aside from that, they look the same at a glance.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we barely support float16 would just remove and use the fixture

for dtype in ["i1", "i2", "i4", "i8", "u1", "u2", "u4", "u8", "f2", "f4", "f8"]:
left.append(cls([10, 20, 30], dtype=dtype))


class TestNumericArraylikeArithmeticWithDatetimeLike:

# TODO: also check name retentention
@pytest.mark.parametrize("box_cls", [np.array, pd.Index, pd.Series])
@pytest.mark.parametrize(
"left",
[pd.RangeIndex(10, 40, 10)]
+ [
WillAyd marked this conversation as resolved.
Show resolved Hide resolved
cls([10, 20, 30], dtype=dtype)
for dtype in [
"i1",
"i2",
"i4",
"i8",
"u1",
"u2",
"u4",
"u8",
"f2",
"f4",
"f8",
]
for cls in [pd.Series, pd.Index]
],
ids=lambda x: type(x).__name__ + str(x.dtype),
"left", left, ids=lambda x: type(x).__name__ + str(x.dtype)
)
def test_mul_td64arr(self, left, box_cls):
# GH#22390
Expand All @@ -120,26 +110,7 @@ def test_mul_td64arr(self, left, box_cls):
# TODO: also check name retentention
@pytest.mark.parametrize("box_cls", [np.array, pd.Index, pd.Series])
@pytest.mark.parametrize(
"left",
[pd.RangeIndex(10, 40, 10)]
+ [
cls([10, 20, 30], dtype=dtype)
for dtype in [
"i1",
"i2",
"i4",
"i8",
"u1",
"u2",
"u4",
"u8",
"f2",
"f4",
"f8",
]
for cls in [pd.Series, pd.Index]
],
ids=lambda x: type(x).__name__ + str(x.dtype),
"left", left, ids=lambda x: type(x).__name__ + str(x.dtype)
)
def test_div_td64arr(self, left, box_cls):
# GH#22390
Expand Down
5 changes: 2 additions & 3 deletions pandas/tests/arrays/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,8 @@ def _from_sequence(cls, scalars, dtype=None, copy=False):
return super()._from_sequence(scalars, dtype=dtype, copy=copy)


@pytest.mark.parametrize("box", [pd.Series, pd.Index])
def test_array_unboxes(box):
data = box([decimal.Decimal("1"), decimal.Decimal("2")])
def test_array_unboxes(index_or_series):
data = index_or_series([decimal.Decimal("1"), decimal.Decimal("2")])
# make sure it works
with pytest.raises(TypeError):
DecimalArray2._from_sequence(data)
Expand Down
7 changes: 3 additions & 4 deletions pandas/tests/dtypes/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pandas.core.dtypes.concat as _concat

from pandas import DatetimeIndex, Index, Period, PeriodIndex, Series, TimedeltaIndex
from pandas import DatetimeIndex, Period, PeriodIndex, Series, TimedeltaIndex


@pytest.mark.parametrize(
Expand Down Expand Up @@ -40,9 +40,8 @@
),
],
)
@pytest.mark.parametrize("klass", [Index, Series])
def test_get_dtype_kinds(klass, to_concat, expected):
to_concat_klass = [klass(c) for c in to_concat]
def test_get_dtype_kinds(index_or_series, to_concat, expected):
to_concat_klass = [index_or_series(c) for c in to_concat]
result = _concat.get_dtype_kinds(to_concat_klass)
assert result == set(expected)

Expand Down
81 changes: 37 additions & 44 deletions pandas/tests/indexing/test_coercion.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,55 +515,52 @@ def _assert_where_conversion(
res = target.where(cond, values)
self._assert(res, expected, expected_dtype)

@pytest.mark.parametrize("klass", [pd.Series, pd.Index], ids=["series", "index"])
@pytest.mark.parametrize(
"fill_val,exp_dtype",
[(1, np.object), (1.1, np.object), (1 + 1j, np.object), (True, np.object)],
)
def test_where_object(self, klass, fill_val, exp_dtype):
obj = klass(list("abcd"))
def test_where_object(self, index_or_series, fill_val, exp_dtype):
obj = index_or_series(list("abcd"))
assert obj.dtype == np.object
cond = klass([True, False, True, False])
cond = index_or_series([True, False, True, False])

if fill_val is True and klass is pd.Series:
if fill_val is True and index_or_series is pd.Series:
ret_val = 1
else:
ret_val = fill_val

exp = klass(["a", ret_val, "c", ret_val])
exp = index_or_series(["a", ret_val, "c", ret_val])
self._assert_where_conversion(obj, cond, fill_val, exp, exp_dtype)

if fill_val is True:
values = klass([True, False, True, True])
values = index_or_series([True, False, True, True])
else:
values = klass(fill_val * x for x in [5, 6, 7, 8])
values = index_or_series(fill_val * x for x in [5, 6, 7, 8])

exp = klass(["a", values[1], "c", values[3]])
exp = index_or_series(["a", values[1], "c", values[3]])
self._assert_where_conversion(obj, cond, values, exp, exp_dtype)

@pytest.mark.parametrize("klass", [pd.Series, pd.Index], ids=["series", "index"])
@pytest.mark.parametrize(
"fill_val,exp_dtype",
[(1, np.int64), (1.1, np.float64), (1 + 1j, np.complex128), (True, np.object)],
)
def test_where_int64(self, klass, fill_val, exp_dtype):
if klass is pd.Index and exp_dtype is np.complex128:
def test_where_int64(self, index_or_series, fill_val, exp_dtype):
if index_or_series is pd.Index and exp_dtype is np.complex128:
pytest.skip("Complex Index not supported")
obj = klass([1, 2, 3, 4])
obj = index_or_series([1, 2, 3, 4])
assert obj.dtype == np.int64
cond = klass([True, False, True, False])
cond = index_or_series([True, False, True, False])

exp = klass([1, fill_val, 3, fill_val])
exp = index_or_series([1, fill_val, 3, fill_val])
self._assert_where_conversion(obj, cond, fill_val, exp, exp_dtype)

if fill_val is True:
values = klass([True, False, True, True])
values = index_or_series([True, False, True, True])
else:
values = klass(x * fill_val for x in [5, 6, 7, 8])
exp = klass([1, values[1], 3, values[3]])
values = index_or_series(x * fill_val for x in [5, 6, 7, 8])
exp = index_or_series([1, values[1], 3, values[3]])
self._assert_where_conversion(obj, cond, values, exp, exp_dtype)

@pytest.mark.parametrize("klass", [pd.Series, pd.Index], ids=["series", "index"])
@pytest.mark.parametrize(
"fill_val, exp_dtype",
[
Expand All @@ -573,21 +570,21 @@ def test_where_int64(self, klass, fill_val, exp_dtype):
(True, np.object),
],
)
def test_where_float64(self, klass, fill_val, exp_dtype):
if klass is pd.Index and exp_dtype is np.complex128:
def test_where_float64(self, index_or_series, fill_val, exp_dtype):
if index_or_series is pd.Index and exp_dtype is np.complex128:
pytest.skip("Complex Index not supported")
obj = klass([1.1, 2.2, 3.3, 4.4])
obj = index_or_series([1.1, 2.2, 3.3, 4.4])
assert obj.dtype == np.float64
cond = klass([True, False, True, False])
cond = index_or_series([True, False, True, False])

exp = klass([1.1, fill_val, 3.3, fill_val])
exp = index_or_series([1.1, fill_val, 3.3, fill_val])
self._assert_where_conversion(obj, cond, fill_val, exp, exp_dtype)

if fill_val is True:
values = klass([True, False, True, True])
values = index_or_series([True, False, True, True])
else:
values = klass(x * fill_val for x in [5, 6, 7, 8])
exp = klass([1.1, values[1], 3.3, values[3]])
values = index_or_series(x * fill_val for x in [5, 6, 7, 8])
exp = index_or_series([1.1, values[1], 3.3, values[3]])
self._assert_where_conversion(obj, cond, values, exp, exp_dtype)

@pytest.mark.parametrize(
Expand Down Expand Up @@ -783,19 +780,17 @@ def _assert_fillna_conversion(self, original, value, expected, expected_dtype):
res = target.fillna(value)
self._assert(res, expected, expected_dtype)

@pytest.mark.parametrize("klass", [pd.Series, pd.Index], ids=["series", "index"])
@pytest.mark.parametrize(
"fill_val, fill_dtype",
[(1, np.object), (1.1, np.object), (1 + 1j, np.object), (True, np.object)],
)
def test_fillna_object(self, klass, fill_val, fill_dtype):
obj = klass(["a", np.nan, "c", "d"])
def test_fillna_object(self, index_or_series, fill_val, fill_dtype):
obj = index_or_series(["a", np.nan, "c", "d"])
assert obj.dtype == np.object

exp = klass(["a", fill_val, "c", "d"])
exp = index_or_series(["a", fill_val, "c", "d"])
self._assert_fillna_conversion(obj, fill_val, exp, fill_dtype)

@pytest.mark.parametrize("klass", [pd.Series, pd.Index], ids=["series", "index"])
@pytest.mark.parametrize(
"fill_val,fill_dtype",
[
Expand All @@ -805,15 +800,15 @@ def test_fillna_object(self, klass, fill_val, fill_dtype):
(True, np.object),
],
)
def test_fillna_float64(self, klass, fill_val, fill_dtype):
obj = klass([1.1, np.nan, 3.3, 4.4])
def test_fillna_float64(self, index_or_series, fill_val, fill_dtype):
obj = index_or_series([1.1, np.nan, 3.3, 4.4])
assert obj.dtype == np.float64

exp = klass([1.1, fill_val, 3.3, 4.4])
exp = index_or_series([1.1, fill_val, 3.3, 4.4])
# float + complex -> we don't support a complex Index
# complex for Series,
# object for Index
if fill_dtype == np.complex128 and klass == pd.Index:
if fill_dtype == np.complex128 and index_or_series == pd.Index:
fill_dtype = np.object
self._assert_fillna_conversion(obj, fill_val, exp, fill_dtype)

Expand All @@ -833,7 +828,6 @@ def test_fillna_series_complex128(self, fill_val, fill_dtype):
exp = pd.Series([1 + 1j, fill_val, 3 + 3j, 4 + 4j])
self._assert_fillna_conversion(obj, fill_val, exp, fill_dtype)

@pytest.mark.parametrize("klass", [pd.Series, pd.Index], ids=["series", "index"])
@pytest.mark.parametrize(
"fill_val,fill_dtype",
[
Expand All @@ -844,8 +838,8 @@ def test_fillna_series_complex128(self, fill_val, fill_dtype):
],
ids=["datetime64", "datetime64tz", "object", "object"],
)
def test_fillna_datetime(self, klass, fill_val, fill_dtype):
obj = klass(
def test_fillna_datetime(self, index_or_series, fill_val, fill_dtype):
obj = index_or_series(
[
pd.Timestamp("2011-01-01"),
pd.NaT,
Expand All @@ -855,7 +849,7 @@ def test_fillna_datetime(self, klass, fill_val, fill_dtype):
)
assert obj.dtype == "datetime64[ns]"

exp = klass(
exp = index_or_series(
[
pd.Timestamp("2011-01-01"),
fill_val,
Expand All @@ -865,7 +859,6 @@ def test_fillna_datetime(self, klass, fill_val, fill_dtype):
)
self._assert_fillna_conversion(obj, fill_val, exp, fill_dtype)

@pytest.mark.parametrize("klass", [pd.Series, pd.Index])
@pytest.mark.parametrize(
"fill_val,fill_dtype",
[
Expand All @@ -876,10 +869,10 @@ def test_fillna_datetime(self, klass, fill_val, fill_dtype):
("x", np.object),
],
)
def test_fillna_datetime64tz(self, klass, fill_val, fill_dtype):
def test_fillna_datetime64tz(self, index_or_series, fill_val, fill_dtype):
tz = "US/Eastern"

obj = klass(
obj = index_or_series(
[
pd.Timestamp("2011-01-01", tz=tz),
pd.NaT,
Expand All @@ -889,7 +882,7 @@ def test_fillna_datetime64tz(self, klass, fill_val, fill_dtype):
)
assert obj.dtype == "datetime64[ns, US/Eastern]"

exp = klass(
exp = index_or_series(
[
pd.Timestamp("2011-01-01", tz=tz),
fill_val,
Expand Down
10 changes: 4 additions & 6 deletions pandas/tests/io/json/test_json_table_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,17 +431,15 @@ def test_date_format_raises(self):
self.df.to_json(orient="table", date_format="iso")
self.df.to_json(orient="table")

@pytest.mark.parametrize("kind", [pd.Series, pd.Index])
def test_convert_pandas_type_to_json_field_int(self, kind):
def test_convert_pandas_type_to_json_field_int(self, index_or_series):
data = [1, 2, 3]
result = convert_pandas_type_to_json_field(kind(data, name="name"))
result = convert_pandas_type_to_json_field(index_or_series(data, name="name"))
expected = {"name": "name", "type": "integer"}
assert result == expected

@pytest.mark.parametrize("kind", [pd.Series, pd.Index])
def test_convert_pandas_type_to_json_field_float(self, kind):
def test_convert_pandas_type_to_json_field_float(self, index_or_series):
data = [1.0, 2.0, 3.0]
result = convert_pandas_type_to_json_field(kind(data, name="name"))
result = convert_pandas_type_to_json_field(index_or_series(data, name="name"))
expected = {"name": "name", "type": "number"}
assert result == expected

Expand Down
Loading