From e5b1b33722b7329488e7550ab50b128999c095fd Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Mon, 30 Dec 2024 09:32:37 -0500 Subject: [PATCH] fix(dtypes): allow passing `nullable` kwarg to string parsed dtypes --- ibis/expr/datatypes/core.py | 12 ++++++++---- ibis/expr/datatypes/tests/test_parse.py | 5 +++-- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/ibis/expr/datatypes/core.py b/ibis/expr/datatypes/core.py index f7c4324ceb1a..e66464663c73 100644 --- a/ibis/expr/datatypes/core.py +++ b/ibis/expr/datatypes/core.py @@ -73,8 +73,8 @@ def dtype(value: Any, nullable: bool = True) -> DataType: @dtype.register(str) -def from_string(value): - return DataType.from_string(value) +def from_string(value, nullable: bool = True): + return DataType.from_string(value, nullable) @dtype.register("numpy.dtype") @@ -165,14 +165,18 @@ def castable(self, to, **kwargs) -> bool: return castable(self, to, **kwargs) @classmethod - def from_string(cls, value) -> Self: + def from_string(cls, value, nullable: bool = True) -> Self: from ibis.expr.datatypes.parse import parse try: - return parse(value) + typ = parse(value) except SyntaxError: raise TypeError(f"{value!r} cannot be parsed as a datatype") + if not nullable: + return typ.copy(nullable=nullable) + return typ + @classmethod def from_typehint(cls, typ, nullable=True) -> Self: origin_type = get_origin(typ) diff --git a/ibis/expr/datatypes/tests/test_parse.py b/ibis/expr/datatypes/tests/test_parse.py index b332b9110511..fef021ba7cc0 100644 --- a/ibis/expr/datatypes/tests/test_parse.py +++ b/ibis/expr/datatypes/tests/test_parse.py @@ -13,6 +13,7 @@ from ibis.common.annotations import ValidationError +@pytest.mark.parametrize("nullable", [True, False]) @pytest.mark.parametrize( ("spec", "expected"), [ @@ -43,8 +44,8 @@ ("multipolygon", dt.multipolygon), ], ) -def test_primitive_from_string(spec, expected): - assert dt.dtype(spec) == expected +def test_primitive_from_string(nullable, spec, expected): + assert dt.dtype(spec, nullable=nullable) == expected(nullable=nullable) @pytest.mark.parametrize(