Skip to content

Commit

Permalink
address: fix null handling in postal UDFs
Browse files Browse the repository at this point in the history
  • Loading branch information
NickCrews committed Jul 17, 2024
1 parent df909cb commit 140756b
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 6 deletions.
56 changes: 56 additions & 0 deletions mismo/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import dataclasses
from typing import Callable, Iterable

import ibis
Expand Down Expand Up @@ -62,3 +63,58 @@ def t2(table_factory) -> ir.Table:
"array": [["b"], ["c"], ["d"], None],
}
)


@dataclasses.dataclass
class ToShape:
forward: Callable[[ir.Value], ir.Value]
revert: Callable[[ir.Value], ir.Value]

def call(self, f: Callable, *args, **kwargs) -> ir.Value:
first, *rest = args
args = (self.forward(first), *rest)
result = f(*args, **kwargs)
return self.revert(result)


@pytest.fixture(params=["scalar", "column"])
def to_shape(request) -> ToShape:
"""Fixture that allows you to test a function with both scalar and column inputs.
Say you had some function ``add_one(x: ir.Value) -> ir.Value``.
You already have a test like
```python
inp = literal(1)
result = add_one(inp)
assert result.execute() == 2
```
You can use this fixture to test add_one with a column input:
```
inp = literal(1)
result = to_shape.revert(add_one(to_shape.forward(inp)))
assert result.execute() == 2
```
Or, to do it in one step:
```
inp = literal(1)
result = to_shape.call(add_one, inp)
assert result.execute() == 2
```
"""
if request.param == "scalar":
return ToShape(
forward=lambda x: x,
revert=lambda x: x,
)
elif request.param == "column":
return ToShape(
forward=lambda x: ibis.array([x]).unnest(),
revert=lambda x: x.as_scalar(),
)
else:
assert False
10 changes: 8 additions & 2 deletions mismo/lib/geo/_address.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,10 @@ def postal_parse_address(address_string: ir.StringValue) -> ir.StructValue:
from postal.parser import parse_address as _parse_address

@ibis.udf.scalar.python(signature=((str,), ADDRESS_SCHEMA))
def udf(address_string: str) -> dict[str, str]:
def udf(address_string: str | None) -> dict[str, str] | None:
# remove once https://github.com/ibis-project/ibis/pull/9625 is fixed
if address_string is None:
return None
parsed_fields = _parse_address(address_string)
label_to_values = defaultdict(list)
for value, label in parsed_fields:
Expand Down Expand Up @@ -386,7 +389,10 @@ def postal_fingerprint_address(address: ir.StructValue) -> ir.ArrayValue:
from postal.near_dupe import near_dupe_hashes as _hash

@ibis.udf.scalar.python(signature=((ADDRESS_SCHEMA,), str))
def udf(address: dict[str, str]) -> list[str]:
def udf(address: dict[str, str] | None) -> list[str] | None:
# remove once https://github.com/ibis-project/ibis/pull/9625 is fixed
if address is None:
return None
# split street1 into house_number and road
street1 = address["street1"] or ""
house, *rest = street1.split(" ", 1)
Expand Down
11 changes: 7 additions & 4 deletions mismo/lib/geo/tests/test_address.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,10 @@ def test_address_tokens(address, expected):
),
],
)
def test_postal_parse_address(address, expected):
result = _address.postal_parse_address(address).execute()
def test_postal_parse_address(to_shape, address, expected):
address = ibis.literal(address, type=str)
e = to_shape.call(_address.postal_parse_address, address)
result = e.execute()
assert result == expected


Expand Down Expand Up @@ -216,10 +218,11 @@ def test_postal_parse_address(address, expected):
),
],
)
def test_postal_fingerprint_address(address, expected):
def test_postal_fingerprint_address(to_shape, address, expected):
a = ibis.literal(
address,
type="struct<street1: string, street2: string, city: string, state: string, postal_code: string, country: string>", # noqa
)
result = _address.postal_fingerprint_address(a).execute()
e = to_shape.call(_address.postal_fingerprint_address, a)
result = e.execute()
assert result == expected

0 comments on commit 140756b

Please sign in to comment.