Skip to content

Commit

Permalink
feat(duckdb): maps in progress...
Browse files Browse the repository at this point in the history
  • Loading branch information
gforsyth committed Sep 5, 2023
1 parent 16ce2a6 commit 3af79e1
Showing 1 changed file with 23 additions and 14 deletions.
37 changes: 23 additions & 14 deletions ibis/backends/duckdb/compiler/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,8 +542,8 @@ def _levenshtein(op, **kw):
ops.NullIf: "nullIf",
ops.MapContains: "mapContains", # TODO
ops.MapLength: "length",
ops.MapKeys: "mapKeys", # TODO
ops.MapValues: "mapValues", # TODO
ops.MapKeys: "map_keys",
ops.MapValues: "map_values",
ops.MapMerge: "mapUpdate", # TODO
ops.ArraySort: "list_sort",
ops.ArrayContains: "has",
Expand Down Expand Up @@ -650,7 +650,6 @@ def _in_column(op, **kw):
### LITERALLY


# TODO: need to go through this carefully
@translate_val.register(ops.Literal)
def _literal(op, **kw):
value = op.value
Expand Down Expand Up @@ -681,7 +680,14 @@ def _literal(op, **kw):
)

# TODO: handle if `value` is "Infinity"

# precision = sg.expressions.DataTypeParam(
# this=sg.expressions.Literal(this=f"{precision}", is_string=False)
# )
# scale = sg.expressions.DataTypeParam(
# this=sg.expressions.Literal(this=f"{scale}", is_string=False)
# )
# need sg.expressions.DataTypeParam to be available
# ...
return f"{value!s}::decimal({precision}, {scale})"
elif dtype.is_numeric():
if math.isinf(value):
Expand Down Expand Up @@ -734,7 +740,7 @@ def _literal(op, **kw):
elif dtype.is_map():
value_type = dtype.value_type
values = ", ".join(
f"{k!r}, {_literal(ops.Literal(v, dtype=value_type), **kw)}"
f"[{k!r}, {_literal(ops.Literal(v, dtype=value_type), **kw)}]"
for k, v in value.items()
)
return f"map({values})"
Expand Down Expand Up @@ -1198,6 +1204,10 @@ def _map(op, **kw):
keys = translate_val(op.keys, **kw)
values = translate_val(op.values, **kw)
typ = serialize(op.dtype)
breakpoint()
sg_expr = sg.expressions.Map(keys=keys, values=values)
breakpoint()
return sg_expr
return f"CAST(({keys}, {values}) AS {typ})"


Expand Down Expand Up @@ -1481,7 +1491,8 @@ def _array_map(op, **kw):
def _array_filter(op, **kw):
arg = translate_val(op.arg, **kw)
result = translate_val(op.result, **kw)
return sg.func("list_filter", arg, f"{op.parameter} -> {result}")
func = sg.func("list_filter", arg, f"{op.parameter} -> {result}")
return func


@translate_val.register(ops.ArrayPosition)
Expand All @@ -1501,13 +1512,11 @@ def _array_union(op, **kw):
return translate_val(ops.ArrayDistinct(ops.ArrayConcat(op.left, op.right)), **kw)


# TODO
# TODO: need to do this as a an array map + struct pack -- look at existing
# alchemy backend implementation
@translate_val.register(ops.ArrayZip)
def _array_zip(op: ops.ArrayZip, **kw: Any) -> str:
arglist = []
for arg in op.arg:
sql_arg = translate_val(arg, **kw)
with contextlib.suppress(AttributeError):
sql_arg = sql_arg.sql(dialect="duckdb")
arglist.append(sql_arg)
return f"arrayZip({', '.join(arglist)})"
zipped = sg.expressions.ArrayJoin().from_arg_list(
[translate_val(arg, **kw) for arg in op.arg]
)
return zipped

0 comments on commit 3af79e1

Please sign in to comment.