Skip to content

Commit

Permalink
refactor(sql): automatically add simple ops implementations (ibis-pro…
Browse files Browse the repository at this point in the history
…ject#8349)

Follow up to ibis-project#8338 to clean up the SIMPLE_OPS boilerplate.
  • Loading branch information
cpcloud authored Feb 14, 2024
1 parent 8d3fe7f commit 2c64b3f
Show file tree
Hide file tree
Showing 18 changed files with 738 additions and 925 deletions.
403 changes: 254 additions & 149 deletions ibis/backends/base/sqlglot/compiler.py

Large diffs are not rendered by default.

130 changes: 56 additions & 74 deletions ibis/backends/bigquery/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,62 @@ class BigQueryCompiler(SQLGlotCompiler):
this=sge.convert("-Infinity"), to=sge.DataType(this=sge.DataType.Type.DOUBLE)
)

SIMPLE_OPS = {
ops.StringAscii: "ascii",
ops.BitAnd: "bit_and",
ops.BitOr: "bit_or",
ops.BitXor: "bit_xor",
ops.DateFromYMD: "date",
ops.Divide: "ieee_divide",
ops.EndsWith: "ends_with",
ops.GeoArea: "st_area",
ops.GeoAsBinary: "st_asbinary",
ops.GeoAsText: "st_astext",
ops.GeoAzimuth: "st_azimuth",
ops.GeoBuffer: "st_buffer",
ops.GeoCentroid: "st_centroid",
ops.GeoContains: "st_contains",
ops.GeoCoveredBy: "st_coveredby",
ops.GeoCovers: "st_covers",
ops.GeoDWithin: "st_dwithin",
ops.GeoDifference: "st_difference",
ops.GeoDisjoint: "st_disjoint",
ops.GeoDistance: "st_distance",
ops.GeoEndPoint: "st_endpoint",
ops.GeoEquals: "st_equals",
ops.GeoGeometryType: "st_geometrytype",
ops.GeoIntersection: "st_intersection",
ops.GeoIntersects: "st_intersects",
ops.GeoLength: "st_length",
ops.GeoMaxDistance: "st_maxdistance",
ops.GeoNPoints: "st_numpoints",
ops.GeoPerimeter: "st_perimeter",
ops.GeoPoint: "st_geogpoint",
ops.GeoPointN: "st_pointn",
ops.GeoStartPoint: "st_startpoint",
ops.GeoTouches: "st_touches",
ops.GeoUnaryUnion: "st_union_agg",
ops.GeoUnion: "st_union",
ops.GeoWithin: "st_within",
ops.GeoX: "st_x",
ops.GeoY: "st_y",
ops.Hash: "farm_fingerprint",
ops.IsInf: "is_inf",
ops.IsNan: "is_nan",
ops.Log10: "log10",
ops.LPad: "lpad",
ops.RPad: "rpad",
ops.Levenshtein: "edit_distance",
ops.Modulus: "mod",
ops.RandomScalar: "rand",
ops.RegexReplace: "regexp_replace",
ops.RegexSearch: "regexp_contains",
ops.Time: "time",
ops.TimeFromHMS: "time",
ops.TimestampFromYMDHMS: "datetime",
ops.TimestampNow: "current_timestamp",
}

def _aggregate(self, funcname: str, *args, where):
func = self.f[funcname]

Expand Down Expand Up @@ -665,77 +721,3 @@ def visit_CountDistinct(self, op, *, arg, where):
if where is not None:
arg = self.if_(where, arg, NULL)
return self.f.count(sge.Distinct(expressions=[arg]))


_SIMPLE_OPS = {
ops.StringAscii: "ascii",
ops.BitAnd: "bit_and",
ops.BitOr: "bit_or",
ops.BitXor: "bit_xor",
ops.DateFromYMD: "date",
ops.Divide: "ieee_divide",
ops.EndsWith: "ends_with",
ops.GeoArea: "st_area",
ops.GeoAsBinary: "st_asbinary",
ops.GeoAsText: "st_astext",
ops.GeoAzimuth: "st_azimuth",
ops.GeoBuffer: "st_buffer",
ops.GeoCentroid: "st_centroid",
ops.GeoContains: "st_contains",
ops.GeoCoveredBy: "st_coveredby",
ops.GeoCovers: "st_covers",
ops.GeoDWithin: "st_dwithin",
ops.GeoDifference: "st_difference",
ops.GeoDisjoint: "st_disjoint",
ops.GeoDistance: "st_distance",
ops.GeoEndPoint: "st_endpoint",
ops.GeoEquals: "st_equals",
ops.GeoGeometryType: "st_geometrytype",
ops.GeoIntersection: "st_intersection",
ops.GeoIntersects: "st_intersects",
ops.GeoLength: "st_length",
ops.GeoMaxDistance: "st_maxdistance",
ops.GeoNPoints: "st_numpoints",
ops.GeoPerimeter: "st_perimeter",
ops.GeoPoint: "st_geogpoint",
ops.GeoPointN: "st_pointn",
ops.GeoStartPoint: "st_startpoint",
ops.GeoTouches: "st_touches",
ops.GeoUnaryUnion: "st_union_agg",
ops.GeoUnion: "st_union",
ops.GeoWithin: "st_within",
ops.GeoX: "st_x",
ops.GeoY: "st_y",
ops.Hash: "farm_fingerprint",
ops.IsInf: "is_inf",
ops.IsNan: "is_nan",
ops.Log10: "log10",
ops.LPad: "lpad",
ops.RPad: "rpad",
ops.Levenshtein: "edit_distance",
ops.Modulus: "mod",
ops.RandomScalar: "rand",
ops.RegexReplace: "regexp_replace",
ops.RegexSearch: "regexp_contains",
ops.Time: "time",
ops.TimeFromHMS: "time",
ops.TimestampFromYMDHMS: "datetime",
ops.TimestampNow: "current_timestamp",
}


for _op, _name in _SIMPLE_OPS.items():
assert isinstance(type(_op), type), type(_op)
if issubclass(_op, ops.Reduction):

def _fmt(self, op, *, _name: str = _name, where, **kw):
return self.agg[_name](*kw.values(), where=where)

else:

def _fmt(self, op, *, _name: str = _name, **kw):
return self.f[_name](*kw.values())

setattr(BigQueryCompiler, f"visit_{_op.__name__}", _fmt)

del _op, _name, _fmt
157 changes: 70 additions & 87 deletions ibis/backends/clickhouse/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,76 @@ class ClickHouseCompiler(SQLGlotCompiler):
)
)

SIMPLE_OPS = {
ops.All: "min",
ops.Any: "max",
ops.ApproxCountDistinct: "uniqHLL12",
ops.ApproxMedian: "median",
ops.ArgMax: "argMax",
ops.ArgMin: "argMin",
ops.ArrayCollect: "groupArray",
ops.ArrayContains: "has",
ops.ArrayFlatten: "arrayFlatten",
ops.ArrayIntersect: "arrayIntersect",
ops.ArrayPosition: "indexOf",
ops.BitwiseAnd: "bitAnd",
ops.BitwiseLeftShift: "bitShiftLeft",
ops.BitwiseNot: "bitNot",
ops.BitwiseOr: "bitOr",
ops.BitwiseRightShift: "bitShiftRight",
ops.BitwiseXor: "bitXor",
ops.Capitalize: "initcap",
ops.CountDistinct: "uniq",
ops.Date: "toDate",
ops.E: "e",
ops.EndsWith: "endsWith",
ops.ExtractAuthority: "netloc",
ops.ExtractDay: "toDayOfMonth",
ops.ExtractDayOfYear: "toDayOfYear",
ops.ExtractEpochSeconds: "toRelativeSecondNum",
ops.ExtractFragment: "fragment",
ops.ExtractHost: "domain",
ops.ExtractHour: "toHour",
ops.ExtractMinute: "toMinute",
ops.ExtractMonth: "toMonth",
ops.ExtractPath: "path",
ops.ExtractProtocol: "protocol",
ops.ExtractQuarter: "toQuarter",
ops.ExtractSecond: "toSecond",
ops.ExtractWeekOfYear: "toISOWeek",
ops.ExtractYear: "toYear",
ops.First: "any",
ops.IntegerRange: "range",
ops.IsInf: "isInfinite",
ops.IsNan: "isNaN",
ops.IsNull: "isNull",
ops.LStrip: "trimLeft",
ops.Last: "anyLast",
ops.Ln: "log",
ops.Log10: "log10",
ops.MapContains: "mapContains",
ops.MapKeys: "mapKeys",
ops.MapLength: "length",
ops.MapMerge: "mapUpdate",
ops.MapValues: "mapValues",
ops.Median: "quantileExactExclusive",
ops.NotNull: "isNotNull",
ops.NullIf: "nullIf",
ops.RStrip: "trimRight",
ops.RandomScalar: "randCanonical",
ops.RegexReplace: "replaceRegexpAll",
ops.RowNumber: "row_number",
ops.StartsWith: "startsWith",
ops.StrRight: "right",
ops.Strftime: "formatDateTime",
ops.StringLength: "length",
ops.StringReplace: "replaceAll",
ops.Strip: "trimBoth",
ops.TimestampNow: "now",
ops.TypeOf: "toTypeName",
ops.Unnest: "arrayJoin",
}

def _aggregate(self, funcname: str, *args, where):
has_filter = where is not None
func = self.f[funcname + "If" * has_filter]
Expand Down Expand Up @@ -589,90 +659,3 @@ def visit_RegexSplit(self, op, *, arg, pattern):
@staticmethod
def _generate_groups(groups):
return groups


_SIMPLE_OPS = {
ops.All: "min",
ops.Any: "max",
ops.ApproxCountDistinct: "uniqHLL12",
ops.ApproxMedian: "median",
ops.ArgMax: "argMax",
ops.ArgMin: "argMin",
ops.ArrayCollect: "groupArray",
ops.ArrayContains: "has",
ops.ArrayFlatten: "arrayFlatten",
ops.ArrayIntersect: "arrayIntersect",
ops.ArrayPosition: "indexOf",
ops.BitwiseAnd: "bitAnd",
ops.BitwiseLeftShift: "bitShiftLeft",
ops.BitwiseNot: "bitNot",
ops.BitwiseOr: "bitOr",
ops.BitwiseRightShift: "bitShiftRight",
ops.BitwiseXor: "bitXor",
ops.Capitalize: "initcap",
ops.CountDistinct: "uniq",
ops.Date: "toDate",
ops.E: "e",
ops.EndsWith: "endsWith",
ops.ExtractAuthority: "netloc",
ops.ExtractDay: "toDayOfMonth",
ops.ExtractDayOfYear: "toDayOfYear",
ops.ExtractEpochSeconds: "toRelativeSecondNum",
ops.ExtractFragment: "fragment",
ops.ExtractHost: "domain",
ops.ExtractHour: "toHour",
ops.ExtractMinute: "toMinute",
ops.ExtractMonth: "toMonth",
ops.ExtractPath: "path",
ops.ExtractProtocol: "protocol",
ops.ExtractQuarter: "toQuarter",
ops.ExtractSecond: "toSecond",
ops.ExtractWeekOfYear: "toISOWeek",
ops.ExtractYear: "toYear",
ops.First: "any",
ops.IntegerRange: "range",
ops.IsInf: "isInfinite",
ops.IsNan: "isNaN",
ops.IsNull: "isNull",
ops.LStrip: "trimLeft",
ops.Last: "anyLast",
ops.Ln: "log",
ops.Log10: "log10",
ops.MapContains: "mapContains",
ops.MapKeys: "mapKeys",
ops.MapLength: "length",
ops.MapMerge: "mapUpdate",
ops.MapValues: "mapValues",
ops.Median: "quantileExactExclusive",
ops.NotNull: "isNotNull",
ops.NullIf: "nullIf",
ops.RStrip: "trimRight",
ops.RandomScalar: "randCanonical",
ops.RegexReplace: "replaceRegexpAll",
ops.RowNumber: "row_number",
ops.StartsWith: "startsWith",
ops.StrRight: "right",
ops.Strftime: "formatDateTime",
ops.StringLength: "length",
ops.StringReplace: "replaceAll",
ops.Strip: "trimBoth",
ops.TimestampNow: "now",
ops.TypeOf: "toTypeName",
ops.Unnest: "arrayJoin",
}

for _op, _name in _SIMPLE_OPS.items():
assert isinstance(type(_op), type), type(_op)
if issubclass(_op, ops.Reduction):

def _fmt(self, op, *, _name: str = _name, where, **kw):
return self.agg[_name](*kw.values(), where=where)

else:

def _fmt(self, op, *, _name: str = _name, **kw):
return self.f[_name](*kw.values())

setattr(ClickHouseCompiler, f"visit_{_op.__name__}", _fmt)

del _op, _name, _fmt
45 changes: 15 additions & 30 deletions ibis/backends/datafusion/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,21 @@ class DataFusionCompiler(SQLGlotCompiler):
)
)

SIMPLE_OPS = {
ops.ApproxMedian: "approx_median",
ops.ArrayRemove: "array_remove_all",
ops.BitAnd: "bit_and",
ops.BitOr: "bit_or",
ops.BitXor: "bit_xor",
ops.Cot: "cot",
ops.ExtractMicrosecond: "extract_microsecond",
ops.First: "first_value",
ops.Last: "last_value",
ops.Median: "median",
ops.StringLength: "character_length",
ops.RegexSplit: "regex_split",
}

def _aggregate(self, funcname: str, *args, where):
expr = self.f[funcname](*args)
if where is not None:
Expand Down Expand Up @@ -465,33 +480,3 @@ def visit_Aggregate(self, op, *, parent, groups, metrics):
sel = sel.group_by(*by_names_quoted)

return sel


_SIMPLE_OPS = {
ops.ApproxMedian: "approx_median",
ops.ArrayRemove: "array_remove_all",
ops.BitAnd: "bit_and",
ops.BitOr: "bit_or",
ops.BitXor: "bit_xor",
ops.Cot: "cot",
ops.ExtractMicrosecond: "extract_microsecond",
ops.First: "first_value",
ops.Last: "last_value",
ops.Median: "median",
ops.StringLength: "character_length",
ops.RegexSplit: "regex_split",
}

for _op, _name in _SIMPLE_OPS.items():
assert isinstance(type(_op), type), type(_op)
if issubclass(_op, ops.Reduction):

def _fmt(self, op, *, _name: str = _name, where, **kw):
return self.agg[_name](*kw.values(), where=where)

else:

def _fmt(self, op, *, _name: str = _name, **kw):
return self.f[_name](*kw.values())

setattr(DataFusionCompiler, f"visit_{_op.__name__}", _fmt)
Loading

0 comments on commit 2c64b3f

Please sign in to comment.