Skip to content

Commit

Permalink
adds print message for mutate removing groupBy until rectified. adds …
Browse files Browse the repository at this point in the history
…mutate across support, improves rounding
  • Loading branch information
drizk1 committed Mar 28, 2024
1 parent 6baedb6 commit 240d2b2
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 104 deletions.
50 changes: 36 additions & 14 deletions src/TBD_macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,30 @@ macro arrange(sqlquery, columns...)
end
end



function process_mutate_expression(expr, sq, select_expressions)
if isa(expr, Expr) && expr.head == :(=) && isa(expr.args[1], Symbol)
col_name = string(expr.args[1])
col_expr = expr_to_sql(expr.args[2], sq) # Convert to SQL expression

# Determine whether the column already exists or needs to be added
if col_name in [col for col in sq.metadata[!, "name"]]
# Replace the existing column expression with the mutation
select_expr_index = findfirst(==(col_name), select_expressions)
select_expressions[select_expr_index] = string(col_expr, " AS ", col_name)
else
# Append the mutation as a new column expression
push!(select_expressions, string(col_expr, " AS ", col_name))
# Update metadata to include this new column
push!(sq.metadata, Dict("name" => col_name, "type" => "UNKNOWN", "current_selxn" => 1))
end
else
throw("Unsupported expression format in @mutate: $(expr)")
end
end


"""
$docstring_mutate
"""
Expand Down Expand Up @@ -185,22 +209,17 @@ macro mutate(sqlquery, mutations...)
all_columns = sq.metadata[sq.metadata.current_selxn .== 1, :name]
select_expressions = [col for col in all_columns] # Start with all currently selected columns

for expr in $(esc(mutations))
if isa(expr, Expr) && expr.head == :(=) && isa(expr.args[1], Symbol)
col_name = string(expr.args[1])
col_expr = expr_to_sql(expr.args[2], sq) # Ensure you have a function that can handle this conversion

if col_name in all_columns
# Replace the existing column expression with the mutation
select_expressions[findfirst(==(col_name), select_expressions)] = string(col_expr, " AS ", col_name)
else
# Append the mutation as a new column expression
push!(select_expressions, string(col_expr, " AS ", col_name))
# Update metadata to include this new column
push!(sq.metadata, Dict("name" => col_name, "type" => "UNKNOWN", "current_selxn" => 1))
for expr in $mutations
# Transform 'across' expressions first
if isa(expr, Expr) && expr.head == :call && expr.args[1] == :across
expr = parse_across(expr, $(esc(sqlquery)).metadata) # Assume expr_to_sql can handle 'across' and returns a tuple of expressions
end
if isa(expr, Expr) && expr.head == :tuple
for subexpr in expr.args
process_mutate_expression(subexpr, sq, select_expressions)
end
else
throw("Unsupported expression format in @mutate: $expr")
process_mutate_expression(expr, sq, select_expressions)
end
end
cte_sql = " " * join(select_expressions, ", ") * " FROM " * sq.from
Expand All @@ -224,6 +243,9 @@ macro mutate(sqlquery, mutations...)
sq.from = string(cte_name)

sq.select = "*" # This selects everything from the CTE without duplicating transformations
if !isempty(sq.groupBy)
println("@mutate removed grouping after applying mutations.")
end
sq.groupBy =""
else
error("Expected sqlquery to be an instance of SQLQuery")
Expand Down
1 change: 1 addition & 0 deletions src/TidierDB.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ include("structs.jl")
include("db_parsing.jl")
include("TBD_macros.jl")
include("postgresparsing.jl")
include("sqlite_parsing.jl")
include("joins_sq.jl")
include("slices_sq.jl")

Expand Down
90 changes: 0 additions & 90 deletions src/db_parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,96 +95,6 @@ function parse_tidy_db(exprs, metadata::DataFrame)

return included_columns
end
function expr_to_sql_lite(expr, sq; from_summarize::Bool)
expr = parse_char_matching(expr)
expr = exc_capture_bug(expr, names_to_modify)

MacroTools.postwalk(expr) do x

# Handle basic arithmetic and functions
if @capture(x, a_ + b_)
return :($a + $b)
elseif @capture(x, a_ - b_)
return :($a - $b)
elseif @capture(x, a_ * b_)
return :($a * $b)
elseif @capture(x, a_ / b_)
return :($a / $b)
elseif @capture(x, a_ ^ b_)
return :(POWER($a, $b))
elseif @capture(x, round(a_))
return :(ROUND($a))
elseif @capture(x, mean(a_))
if from_summarize
return :(AVG($a))
else
window_clause = construct_window_clause(sq)
return "AVG($(string(a))) $(window_clause)"
end
elseif @capture(x, minimum(a_))
if from_summarize
return :(MIN($a))
else
window_clause = construct_window_clause(sq)
return "MIN($(string(a))) $(window_clause)"
end
elseif @capture(x, maximum(a_))
if from_summarize
return :(MAX($a))
else
window_clause = construct_window_clause(sq)
return "MAX($(string(a))) $(window_clause)"
end
elseif @capture(x, sum(a_))
if from_summarize
return :(SUM($a))
else
window_clause = construct_window_clause(sq)
return "SUM($(string(a))) $(window_clause)"
end
elseif @capture(x, cumsum(a_))
if from_summarize
error("cumsum is only available through a windowed @mutate")
else
# sq.windowFrame = "ROWS UNBOUNDED PRECEDING "
window_clause = construct_window_clause(sq, from_cumsum = true)
return "SUM($(string(a))) $(window_clause)"
end
# exc_capture_bug used above to allow proper _ function name capturing
elseif @capture(x, replacemissing(column_, replacement_value_))
return :(COALESCE($column, $replacement_value))
elseif @capture(x, missingif(column_, value_to_replace_))
return :(NULLIF($column, $value_to_replace))
elseif @capture(x, ismissing(a_))
return "($(string(a)) IS NULL)"
elseif isa(x, Expr) && x.head == :call
if x.args[1] == :if_else && length(x.args) == 4
return parse_if_else(x)
elseif x.args[1] == :as_float && length(x.args) == 2
column = x.args[2]
# Return the SQL CAST statement directly as a string
return "CAST(" * string(column) * " AS DOUBLE)"
elseif x.args[1] == :as_integer && length(x.args) == 2
column = x.args[2]
return "CAST(" * string(column) * " AS INT)"
elseif x.args[1] == :as_string && length(x.args) == 2
column = x.args[2]
return "CAST(" * string(column) * " AS STRING)"
elseif x.args[1] == :case_when
return parse_case_when(x)
elseif isa(x, Expr) && x.head == :call && x.args[1] == :! && length(x.args) == 2
inner_expr = expr_to_sql_lite(x.args[2], sq) # Recursively transform the inner expression
return string("NOT (", inner_expr, ")")
elseif x.args[1] == :str_detect && length(x.args) == 3
column, pattern = x.args[2], x.args[3]
return string(column, " LIKE \'%", pattern, "%'")
elseif isa(x, Expr) && x.head == :call && x.args[1] == :n && length(x.args) == 1
return "COUNT(*)"
end
end
return x
end
end

function parse_if_else(expr)
transformed_expr = MacroTools.postwalk(expr) do x
Expand Down
2 changes: 2 additions & 0 deletions src/postgresparsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ function expr_to_sql_postgres(expr, sq; from_summarize::Bool)
return :(POWER($a, $b))
elseif @capture(x, round(a_))
return :(ROUND($a))
elseif @capture(x, round(a_, b_))
return :(ROUND($a, $b))
elseif @capture(x, mean(a_))
if from_summarize
return :(AVG($a))
Expand Down
92 changes: 92 additions & 0 deletions src/sqlite_parsing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
function expr_to_sql_lite(expr, sq; from_summarize::Bool)
expr = parse_char_matching(expr)
expr = exc_capture_bug(expr, names_to_modify)

MacroTools.postwalk(expr) do x

# Handle basic arithmetic and functions
if @capture(x, a_ + b_)
return :($a + $b)
elseif @capture(x, a_ - b_)
return :($a - $b)
elseif @capture(x, a_ * b_)
return :($a * $b)
elseif @capture(x, a_ / b_)
return :($a / $b)
elseif @capture(x, a_ ^ b_)
return :(POWER($a, $b))
elseif @capture(x, round(a_))
return :(ROUND($a))
elseif @capture(x, round(a_, b_))
return :(ROUND($a, $b))
elseif @capture(x, mean(a_))
if from_summarize
return :(AVG($a))
else
window_clause = construct_window_clause(sq)
return "AVG($(string(a))) $(window_clause)"
end
elseif @capture(x, minimum(a_))
if from_summarize
return :(MIN($a))
else
window_clause = construct_window_clause(sq)
return "MIN($(string(a))) $(window_clause)"
end
elseif @capture(x, maximum(a_))
if from_summarize
return :(MAX($a))
else
window_clause = construct_window_clause(sq)
return "MAX($(string(a))) $(window_clause)"
end
elseif @capture(x, sum(a_))
if from_summarize
return :(SUM($a))
else
window_clause = construct_window_clause(sq)
return "SUM($(string(a))) $(window_clause)"
end
elseif @capture(x, cumsum(a_))
if from_summarize
error("cumsum is only available through a windowed @mutate")
else
# sq.windowFrame = "ROWS UNBOUNDED PRECEDING "
window_clause = construct_window_clause(sq, from_cumsum = true)
return "SUM($(string(a))) $(window_clause)"
end
# exc_capture_bug used above to allow proper _ function name capturing
elseif @capture(x, replacemissing(column_, replacement_value_))
return :(COALESCE($column, $replacement_value))
elseif @capture(x, missingif(column_, value_to_replace_))
return :(NULLIF($column, $value_to_replace))
elseif @capture(x, ismissing(a_))
return "($(string(a)) IS NULL)"
elseif isa(x, Expr) && x.head == :call
if x.args[1] == :if_else && length(x.args) == 4
return parse_if_else(x)
elseif x.args[1] == :as_float && length(x.args) == 2
column = x.args[2]
# Return the SQL CAST statement directly as a string
return "CAST(" * string(column) * " AS DOUBLE)"
elseif x.args[1] == :as_integer && length(x.args) == 2
column = x.args[2]
return "CAST(" * string(column) * " AS INT)"
elseif x.args[1] == :as_string && length(x.args) == 2
column = x.args[2]
return "CAST(" * string(column) * " AS STRING)"
elseif x.args[1] == :case_when
return parse_case_when(x)
elseif isa(x, Expr) && x.head == :call && x.args[1] == :! && length(x.args) == 2
inner_expr = expr_to_sql_lite(x.args[2], sq) # Recursively transform the inner expression
return string("NOT (", inner_expr, ")")
elseif x.args[1] == :str_detect && length(x.args) == 3
column, pattern = x.args[2], x.args[3]
return string(column, " LIKE \'%", pattern, "%'")
elseif isa(x, Expr) && x.head == :call && x.args[1] == :n && length(x.args) == 1
return "COUNT(*)"
end
end
return x
end
end

0 comments on commit 240d2b2

Please sign in to comment.