Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(dtypes): allow passing nullable kwarg to string parsed dtypes #10632

Merged
merged 1 commit into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions ibis/expr/datatypes/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def dtype(value: Any, nullable: bool = True) -> DataType:


@dtype.register(str)
def from_string(value):
return DataType.from_string(value)
def from_string(value, nullable: bool = True):
return DataType.from_string(value, nullable)


@dtype.register("numpy.dtype")
Expand Down Expand Up @@ -165,14 +165,18 @@ def castable(self, to, **kwargs) -> bool:
return castable(self, to, **kwargs)

@classmethod
def from_string(cls, value) -> Self:
def from_string(cls, value, nullable: bool = True) -> Self:
from ibis.expr.datatypes.parse import parse

try:
return parse(value)
typ = parse(value)
except SyntaxError:
raise TypeError(f"{value!r} cannot be parsed as a datatype")

if not nullable:
return typ.copy(nullable=nullable)
return typ

@classmethod
def from_typehint(cls, typ, nullable=True) -> Self:
origin_type = get_origin(typ)
Expand Down
5 changes: 3 additions & 2 deletions ibis/expr/datatypes/tests/test_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ibis.common.annotations import ValidationError


@pytest.mark.parametrize("nullable", [True, False])
@pytest.mark.parametrize(
("spec", "expected"),
[
Expand Down Expand Up @@ -43,8 +44,8 @@
("multipolygon", dt.multipolygon),
],
)
def test_primitive_from_string(spec, expected):
assert dt.dtype(spec) == expected
def test_primitive_from_string(nullable, spec, expected):
assert dt.dtype(spec, nullable=nullable) == expected(nullable=nullable)


@pytest.mark.parametrize(
Expand Down
Loading