Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable unit tests using pyspark backend #55

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ jobs:

- name: Run tests
run: pdm run just test

- name: Run tests on pyspark backend
run: pdm run just test-pyspark mismo #needed to locate conftest.py
lint-and-docs:
name: Lint
runs-on: ubuntu-latest
Expand Down
4 changes: 4 additions & 0 deletions justfile
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ lint:
test *FILES:
pytest {{FILES}}

# run tests on the pyspark backend
test-pyspark *FILES:
pytest --backend=pyspark {{FILES}}

# include --dev-addr localhost:8001 to avoid conflicts with other mkdocs instances
# serve docs for live editing
docs:
Expand Down
9 changes: 8 additions & 1 deletion mismo/arrays/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,17 @@ def test_array_max(backend, dtype, inp, exp):
@pytest.mark.parametrize(
"inp,type,exp",
[
# duckdb maybe is inconsistent with how it takes the median of
# even-length arrays. Skip testing that for now.
# SELECT MEDIAN(x) from (SELECT unnest([0.0, 2.0]) as x); -> 0.0
# SELECT MEDIAN(x) from (SELECT unnest([0, 2]) as x); -> 1.0
# This sounds to me like the exact inverse of the docs,
# which say that ordinal values get floored and quantitative values get meaned.
# https://discord.com/channels/909674491309850675/921073327009853451/1278856039164411916
pytest.param([0, 1, 2], "int", 1.0, id="happy"),
pytest.param([0, 2], "int", 1.0, id="split"),
pytest.param([0, 1], "int", 0.5, id="frac"),
pytest.param([0.0, 1.0], "float", 0.5, id="frac_float"),
pytest.param([0.0, 1.0, 2.0], "float", 1.0, id="frac_float"),
pytest.param([0, 1, None, 2], "int", 1.0, id="with_null"),
pytest.param([], "int", None, id="empty"),
pytest.param(None, "int", None, id="null"),
Expand Down
2 changes: 1 addition & 1 deletion mismo/block/_key_blocker.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def pair_counts(
The total number of pairs that would be generated is easy to find:

>>> counts.n.sum().execute()
4
np.int64(4)
""" # noqa: E501
if task is None:
task = "dedupe" if left is right else "link"
Expand Down
12 changes: 6 additions & 6 deletions mismo/compare/_match_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __getitem__(self, key):
>>> NameMatchLevel[ibis.literal(1)].execute()
'NEAR'
>>> NameMatchLevel[ibis.literal("NEAR")].execute()
1
np.int8(1)
>>> NameMatchLevel[100]
Traceback (most recent call last):
...
Expand Down Expand Up @@ -147,7 +147,7 @@ class MatchLevel(metaclass=_LevelsMeta):
>>> NameMatchLevel[ibis.literal(1)].execute()
'NEAR'
>>> NameMatchLevel[ibis.literal("NEAR")].execute()
1
np.int8(1)

You can construct your own values, both from python literals...

Expand All @@ -171,25 +171,25 @@ class MatchLevel(metaclass=_LevelsMeta):
2 NEAR
3 None
Name: NameMatchLevel, dtype: object
>>> levels.as_integer().execute()
>>> levels.as_integer().name("levels").execute()
0 0
1 2
2 1
3 99
Name: Array(), dtype: int8
Name: levels, dtype: int8

Comparisons work as you expect:

>>> NameMatchLevel.NEAR == 1
True
>>> NameMatchLevel(1) == "NEAR"
True
>>> (levels_raw == NameMatchLevel.NEAR).execute()
>>> (levels_raw == NameMatchLevel.NEAR).name("eq").execute()
0 False
1 False
2 True
3 False
Name: Equals(Array(), 1), dtype: bool
Name: eq, dtype: bool

However, implicit ordering is not supported
(file an issue if you think it should be):
Expand Down
37 changes: 29 additions & 8 deletions mismo/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import dataclasses
import os
from typing import Any, Callable, Iterable, Protocol
import uuid
import warnings

import ibis
from ibis.expr import datatypes as dt
Expand All @@ -13,25 +15,44 @@
pytest.register_assert_rewrite("mismo.tests.util")


@pytest.fixture
def backend() -> ibis.BaseBackend:
return ibis.duckdb.connect()
def pytest_addoption(parser):
parser.addoption(
"--backend",
action="store",
default="duckdb",
help="Specify the backend to use: duckdb (default) or pyspark",
)


_count = 0
@pytest.fixture
def backend(request) -> ibis.BaseBackend:
backend_option = request.config.getoption("--backend")
if backend_option == "duckdb":
return ibis.duckdb.connect()
elif backend_option == "pyspark":
# Suppress warnings from PySpark
warnings.filterwarnings("ignore", category=UserWarning, module="pyspark")
warnings.filterwarnings("ignore", category=DeprecationWarning, module="pyspark")
return ibis.pyspark.connect()
else:
raise ValueError(f"Unsupported backend: {backend_option}")


@pytest.fixture
def table_factory(backend: ibis.BaseBackend) -> Callable[..., ir.Table]:
created_tables = []

def factory(data, schema=None, columns=None, **kwargs):
global _count
name = f"__mismo_test{_count}"
_count += 1
name = f"__mismo_test_{uuid.uuid4().hex}"
mt = ibis.memtable(data, schema=schema, columns=columns)
result = backend.create_table(name, mt, **kwargs)
created_tables.append(name)
return result

return factory
yield factory

for name in created_tables:
backend.drop_table(name, force=True)


class ColumnFactory(Protocol):
Expand Down
8 changes: 4 additions & 4 deletions mismo/tests/test_factorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@
),
pytest.param(
[(4, 4), (3, 3), (2, 2), (None, None)],
"uint64",
id="integers",
"int64",
id="int64",
),
pytest.param(
[(4, 2), (3, 1), (2, 0), (None, None)],
"int64",
id="integers",
"uint64",
id="uint64",
),
pytest.param(
[
Expand Down
8 changes: 4 additions & 4 deletions mismo/text/_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,13 @@ def levenshtein_ratio(s1: ir.StringValue, s2: ir.StringValue) -> ir.FloatingValu
--------
>>> from mismo.text import levenshtein_ratio
>>> levenshtein_ratio("mile", "mike").execute()
0.75
np.float64(0.75)
>>> levenshtein_ratio("mile", "mile").execute()
1.0
np.float64(1.0)
>>> levenshtein_ratio("mile", "").execute()
0.0
np.float64(0.0)
>>> levenshtein_ratio("", "").execute()
nan
np.float64(nan)
"""
return _dist_ratio(s1, s2, lambda a, b: a.levenshtein(b))

Expand Down
10 changes: 5 additions & 5 deletions mismo/vector/_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ def dot(a: T, b: T) -> ir.FloatingValue:
>>> v1 = ibis.array([1, 2])
>>> v2 = ibis.array([4, 5])
>>> dot(v1, v2).execute() # 1*4 + 2*5
14.0
np.float64(14.0)
>>> m1 = ibis.map({"a": 1, "b": 2})
>>> m2 = ibis.map({"b": 3, "c": 4})
>>> dot(m1, m2).execute() # 2*3
6.0
np.float64(6.0)
"""
a_vals, b_vals = _shared_vals(a, b)
return _array_dot_product(a_vals, b_vals)
Expand Down Expand Up @@ -94,7 +94,7 @@ def cosine_similarity(a: T, b: T) -> ir.FloatingValue:
Orthogonal vectors:

>>> cosine_similarity(ibis.array([1, 0]), ibis.array([0, 1])).execute()
0.0
np.float64(0.0)
""" # noqa: E501
a_vals, b_vals = _shared_vals(a, b)
return _array_cosine_similarity(a_vals, b_vals)
Expand Down Expand Up @@ -137,10 +137,10 @@ def norm(vec: T, *, metric: Literal["l1", "l2"] = "l2") -> ir.FloatingValue:
>>> from mismo.vector import norm
>>> v = ibis.array([-3, 4])
>>> norm(v).execute()
5.0
np.float64(5.0)
>>> m = ibis.map({"a": -3, "b": 4})
>>> norm(m, metric="l1").execute()
7.0
np.float64(7.0)
"""
if isinstance(vec, ir.ArrayValue):
vals = vec
Expand Down
Loading