Skip to content

Commit

Permalink
fix(compilers): ensure that string constants are compiled as such
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Jun 6, 2024
1 parent 2c76d78 commit f4423b3
Show file tree
Hide file tree
Showing 10 changed files with 54 additions and 40 deletions.
8 changes: 3 additions & 5 deletions ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,12 +1008,10 @@ def list_tables(
.from_(sg.table("tables", db="information_schema"))
.distinct()
.where(
C.table_catalog.eq(catalog).or_(
C.table_catalog.eq(sge.convert("temp"))
),
C.table_schema.eq(database),
C.table_catalog.isin(sge.convert(catalog), sge.convert("temp")),
C.table_schema.eq(sge.convert(database)),
)
.sql(self.name, pretty=True)
.sql(self.dialect)
)
out = self.con.execute(sql).fetch_arrow_table()

Expand Down
20 changes: 14 additions & 6 deletions ibis/backends/duckdb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,40 +238,48 @@ def visit_MapMerge(self, op, *, left, right):

def visit_ToJSONMap(self, op, *, arg):
return self.if_(
self.f.json_type(arg).eq("OBJECT"),
self.f.json_type(arg).eq(sge.convert("OBJECT")),
self.cast(self.cast(arg, dt.json), op.dtype),
NULL,
)

def visit_ToJSONArray(self, op, *, arg):
return self.if_(
self.f.json_type(arg).eq("ARRAY"),
self.f.json_type(arg).eq(sge.convert("ARRAY")),
self.cast(self.cast(arg, dt.json), op.dtype),
NULL,
)

def visit_UnwrapJSONString(self, op, *, arg):
return self.if_(
self.f.json_type(arg).eq("VARCHAR"),
self.f.json_type(arg).eq(sge.convert("VARCHAR")),
self.f.json_extract_string(arg, "$"),
NULL,
)

def visit_UnwrapJSONInt64(self, op, *, arg):
arg_type = self.f.json_type(arg)
return self.if_(
arg_type.isin("UBIGINT", "BIGINT"), self.cast(arg, op.dtype), NULL
arg_type.isin(sge.convert("UBIGINT"), sge.convert("BIGINT")),
self.cast(arg, op.dtype),
NULL,
)

def visit_UnwrapJSONFloat64(self, op, *, arg):
arg_type = self.f.json_type(arg)
return self.if_(
arg_type.isin("UBIGINT", "BIGINT", "DOUBLE"), self.cast(arg, op.dtype), NULL
arg_type.isin(
sge.convert("UBIGINT"), sge.convert("BIGINT"), sge.convert("DOUBLE")
),
self.cast(arg, op.dtype),
NULL,
)

def visit_UnwrapJSONBoolean(self, op, *, arg):
return self.if_(
self.f.json_type(arg).eq("BOOLEAN"), self.cast(arg, op.dtype), NULL
self.f.json_type(arg).eq(sge.convert("BOOLEAN")),
self.cast(arg, op.dtype),
NULL,
)

def visit_ArrayConcat(self, op, *, arg):
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/exasol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def version(self) -> str:
query = (
sg.select("param_value")
.from_(sg.table("EXA_METADATA", catalog="SYS"))
.where(C.param_name.eq("databaseProductVersion"))
.where(C.param_name.eq(sge.convert("databaseProductVersion")))
)
with self._safe_raw_sql(query) as result:
[(version,)] = result.fetchall()
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/mysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,6 @@ def list_tables(
sg_db.args["quoted"] = False
conditions = [C.table_schema.eq(sge.convert(table_loc.sql(self.name)))]

# conditions.append(C.table_schema.eq(table_loc))

col = "table_name"
sql = (
sg.select(col)
Expand Down
14 changes: 10 additions & 4 deletions ibis/backends/mysql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,22 +336,28 @@ def visit_TimestampAdd(self, op, *, left, right):

def visit_UnwrapJSONString(self, op, *, arg):
return self.if_(
self.f.json_type(arg).eq("STRING"), self.f.json_unquote(arg), NULL
self.f.json_type(arg).eq(sge.convert("STRING")),
self.f.json_unquote(arg),
NULL,
)

def visit_UnwrapJSONInt64(self, op, *, arg):
return self.if_(
self.f.json_type(arg).eq("INTEGER"), self.cast(arg, op.dtype), NULL
self.f.json_type(arg).eq(sge.convert("INTEGER")),
self.cast(arg, op.dtype),
NULL,
)

def visit_UnwrapJSONFloat64(self, op, *, arg):
return self.if_(
self.f.json_type(arg).isin("DOUBLE", "INTEGER"),
self.f.json_type(arg).isin(sge.convert("DOUBLE"), sge.convert("INTEGER")),
self.cast(arg, op.dtype),
NULL,
)

def visit_UnwrapJSONBoolean(self, op, *, arg):
return self.if_(
self.f.json_type(arg).eq("BOOLEAN"), self.if_(arg.eq("true"), 1, 0), NULL
self.f.json_type(arg).eq(sge.convert("BOOLEAN")),
self.if_(arg.eq(sge.convert("true")), 1, 0),
NULL,
)
4 changes: 2 additions & 2 deletions ibis/backends/oracle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,10 +550,10 @@ def transformer(node):
C.data_type,
C.data_precision,
C.data_scale,
C.nullable.eq("Y"),
C.nullable.eq(sge.convert("Y")),
)
.from_("all_tab_columns")
.where(C.table_name.eq(name))
.where(C.table_name.eq(sge.convert(name)))
.order_by(C.column_id)
.sql(dialect)
)
Expand Down
12 changes: 6 additions & 6 deletions ibis/backends/postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,11 +352,11 @@ def list_tables(
if (db := table_loc.args["db"]) is not None:
db.args["quoted"] = False
db = db.sql(dialect=self.name)
conditions.append(C.table_schema.eq(db))
conditions.append(C.table_schema.eq(sge.convert(db)))
if (catalog := table_loc.args["catalog"]) is not None:
catalog.args["quoted"] = False
catalog = catalog.sql(dialect=self.name)
conditions.append(C.table_catalog.eq(catalog))
conditions.append(C.table_catalog.eq(sge.convert(catalog)))

sql = (
sg.select("table_name")
Expand Down Expand Up @@ -385,7 +385,7 @@ def _fetch_temp_tables(self):
sg.select("table_name")
.from_(sg.table("tables", db="information_schema"))
.distinct()
.where(C.table_type.eq("LOCAL TEMPORARY"))
.where(C.table_type.eq(sge.convert("LOCAL TEMPORARY")))
.sql(self.dialect)
)

Expand Down Expand Up @@ -434,7 +434,7 @@ def function(self, name: str, *, database: str | None = None) -> Callable:
p = ColGen(table="p")
f = self.compiler.f

predicates = [p.proname.eq(name)]
predicates = [p.proname.eq(sge.convert(name))]

if database is not None:
predicates.append(n.nspname.rlike(sge.convert(f"^({database})$")))
Expand Down Expand Up @@ -585,8 +585,8 @@ def get_schema(
.where(
a.attnum > 0,
sg.not_(a.attisdropped),
n.nspname.eq(database) if database is not None else TRUE,
c.relname.eq(name),
n.nspname.eq(sge.convert(database)) if database is not None else TRUE,
c.relname.eq(sge.convert(name)),
)
.order_by(a.attnum)
)
Expand Down
10 changes: 6 additions & 4 deletions ibis/backends/postgres/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def visit_StructField(self, op, *, arg, field):

def visit_UnwrapJSONString(self, op, *, arg):
return self.if_(
self.f.json_typeof(arg).eq("string"),
self.f.json_typeof(arg).eq(sge.convert("string")),
self.f.json_extract_path_text(
arg,
# this is apparently how you pass in no additional arguments to
Expand All @@ -336,7 +336,7 @@ def visit_UnwrapJSONInt64(self, op, *, arg):
arg, sge.Var(this="VARIADIC ARRAY[]::TEXT[]")
)
return self.if_(
self.f.json_typeof(arg).eq("number"),
self.f.json_typeof(arg).eq(sge.convert("number")),
self.cast(
self.if_(self.f.regexp_like(text, r"^\d+$", "g"), text, NULL),
op.dtype,
Expand All @@ -349,12 +349,14 @@ def visit_UnwrapJSONFloat64(self, op, *, arg):
arg, sge.Var(this="VARIADIC ARRAY[]::TEXT[]")
)
return self.if_(
self.f.json_typeof(arg).eq("number"), self.cast(text, op.dtype), NULL
self.f.json_typeof(arg).eq(sge.convert("number")),
self.cast(text, op.dtype),
NULL,
)

def visit_UnwrapJSONBoolean(self, op, *, arg):
return self.if_(
self.f.json_typeof(arg).eq("boolean"),
self.f.json_typeof(arg).eq(sge.convert("boolean")),
self.cast(
self.f.json_extract_path_text(
arg, sge.Var(this="VARIADIC ARRAY[]::TEXT[]")
Expand Down
12 changes: 6 additions & 6 deletions ibis/backends/sqlite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,14 +145,14 @@ def list_tables(
sg.select("name")
.from_(F.pragma_table_list())
.where(
C.schema.eq(database),
C.type.isin("table", "view"),
C.schema.eq(sge.convert(database)),
C.type.isin(sge.convert("table"), sge.convert("view")),
~(
C.name.isin(
"sqlite_schema",
"sqlite_master",
"sqlite_temp_schema",
"sqlite_temp_master",
sge.convert("sqlite_schema"),
sge.convert("sqlite_master"),
sge.convert("sqlite_temp_schema"),
sge.convert("sqlite_temp_master"),
)
),
)
Expand Down
10 changes: 6 additions & 4 deletions ibis/backends/sqlite/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,19 +223,21 @@ def _visit_arg_reduction(self, func, op, *, arg, key, where):

def visit_UnwrapJSONString(self, op, *, arg):
return self.if_(
self.f.json_type(arg).eq("text"), self.f.json_extract_scalar(arg, "$"), NULL
self.f.json_type(arg).eq(sge.convert("text")),
self.f.json_extract_scalar(arg, "$"),
NULL,
)

def visit_UnwrapJSONInt64(self, op, *, arg):
return self.if_(
self.f.json_type(arg).eq("integer"),
self.f.json_type(arg).eq(sge.convert("integer")),
self.cast(self.f.json_extract_scalar(arg, "$"), op.dtype),
NULL,
)

def visit_UnwrapJSONFloat64(self, op, *, arg):
return self.if_(
self.f.json_type(arg).isin("integer", "real"),
self.f.json_type(arg).isin(sge.convert("integer"), sge.convert("real")),
self.cast(self.f.json_extract_scalar(arg, "$"), op.dtype),
NULL,
)
Expand All @@ -244,7 +246,7 @@ def visit_UnwrapJSONBoolean(self, op, *, arg):
return self.if_(
# isin doesn't work here, with a strange error from sqlite about a
# misused row value
self.f.json_type(arg).isin("true", "false"),
self.f.json_type(arg).isin(sge.convert("true"), sge.convert("false")),
self.cast(self.f.json_extract_scalar(arg, "$"), dt.int64),
NULL,
)
Expand Down

0 comments on commit f4423b3

Please sign in to comment.