Skip to content

Commit

Permalink
UPDATE with VALUES optimization for Postgres. Various macro refactori…
Browse files Browse the repository at this point in the history
…ng. (#2571)
  • Loading branch information
deusaquilus authored Aug 19, 2022
1 parent 38f7f97 commit 3e62f51
Show file tree
Hide file tree
Showing 55 changed files with 1,359 additions and 335 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -382,15 +382,15 @@ class CqlIdiomSpec extends Spec {

"ident" in {
val a: Ast = Ident("a")
translate(a, Quat.Unknown, ExecutionType.Unknown, TranspileConfig.Empty) mustBe ((a, stmt"a", ExecutionType.Unknown))
translate(a, Quat.Unknown, ExecutionType.Unknown, IdiomContext.Empty) mustBe ((a, stmt"a", ExecutionType.Unknown))
}
"assignment" in {
val a: Ast = Assignment(Ident("a"), Ident("b"), Ident("c"))
translate(a: Ast, Quat.Unknown, ExecutionType.Unknown, TranspileConfig.Empty) mustBe ((a, stmt"b = c", ExecutionType.Unknown))
translate(a: Ast, Quat.Unknown, ExecutionType.Unknown, IdiomContext.Empty) mustBe ((a, stmt"b = c", ExecutionType.Unknown))
}
"assignmentDual" in {
val a: Ast = AssignmentDual(Ident("a1"), Ident("a2"), Ident("b"), Ident("c"))
translate(a: Ast, Quat.Unknown, ExecutionType.Unknown, TranspileConfig.Empty) mustBe ((a, stmt"b = c", ExecutionType.Unknown))
translate(a: Ast, Quat.Unknown, ExecutionType.Unknown, IdiomContext.Empty) mustBe ((a, stmt"b = c", ExecutionType.Unknown))
}
"aggregation" in {
val t = implicitly[Tokenizer[AggregationOperator]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ import io.getquill.ast.Renameable.Fixed

import scala.language.implicitConversions
import scala.language.experimental.macros
import io.getquill.ast._
import io.getquill.ast.{ External, _ }
import io.getquill.quat._

import scala.reflect.macros.whitebox.{ Context => MacroContext }
import io.getquill.util.Messages._

import scala.util.DynamicVariable
import scala.reflect.ClassTag
import io.getquill.{ ActionReturning, Delete, EntityQuery, Insert, Ord, Query, Update, Action => DslAction, Quoted }
import io.getquill.{ ActionReturning, Delete, EntityQuery, Insert, Ord, Query, Quoted, Update, Action => DslAction }

import scala.annotation.tailrec

Expand Down Expand Up @@ -164,7 +164,7 @@ trait DynamicQueryDsl {
}

protected def spliceLift[O](o: O)(implicit enc: Encoder[O]) =
splice[O](ScalarValueLift("o", o, enc, Quat.Value))
splice[O](ScalarValueLift("o", External.Source.Parser, o, enc, Quat.Value))

object DynamicQuery {
def apply[T](p: Quoted[Query[T]]) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ trait DynamicQueryDsl {
}

protected def spliceLift[O](o: O)(implicit enc: Encoder[O]) =
splice[O](ScalarValueLift("o", o, enc, Quat.Value))
splice[O](ScalarValueLift("o", External.Source.Parser, o, enc, Quat.Value))

object DynamicQuery {
def apply[T](p: Quoted[Query[T]]) =
Expand Down
2 changes: 1 addition & 1 deletion quill-core/src/main/scala/io/getquill/MirrorContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import scala.util.{ Failure, Success, Try }
object mirrorContextWithQueryProbing
extends MirrorContext(MirrorIdiom, Literal) with QueryProbing

case class BatchActionMirrorGeneric[Row](groups: List[(String, List[Row])], info: ExecutionInfo)
case class BatchActionMirrorGeneric[A](groups: List[(String, List[A])], info: ExecutionInfo)
case class BatchActionReturningMirrorGeneric[T, PrepareRow, Extractor[_]](groups: List[(String, ReturnAction, List[PrepareRow])], extractor: Extractor[T], info: ExecutionInfo)

/**
Expand Down
102 changes: 67 additions & 35 deletions quill-core/src/main/scala/io/getquill/context/ActionMacro.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.getquill.context

import io.getquill.IdiomContext
import io.getquill.ast._
import io.getquill.norm.BetaReduction
import io.getquill.quat.Quat
Expand All @@ -20,18 +21,20 @@ class ActionMacro(val c: MacroContext)
def translateQuery(quoted: Tree): Tree =
translateQueryPrettyPrint(quoted, q"false")

def translateQueryPrettyPrint(quoted: Tree, prettyPrint: Tree): Tree =
def translateQueryPrettyPrint(quoted: Tree, prettyPrint: Tree): Tree = {
val expanded = expand(extractAst(quoted), inferQuat(quoted.tpe))
c.untypecheck {
q"""
..${EnableReflectiveCalls(c)}
val expanded = ${expand(extractAst(quoted), inferQuat(quoted.tpe))}
val (idiomContext, expanded) = $expanded
${c.prefix}.translateQuery(
expanded.string,
expanded.prepare,
prettyPrint = ${prettyPrint}
)(io.getquill.context.ExecutionInfo.unknown, ())
"""
}
}

def translateBatchQuery(quoted: Tree): Tree =
translateBatchQueryPrettyPrint(quoted, q"false")
Expand All @@ -43,7 +46,7 @@ class ActionMacro(val c: MacroContext)
..${EnableReflectiveCalls(c)}
${c.prefix}.translateBatchQuery(
$batch.map { $param =>
val expanded = $expanded
val (idiomContext, expanded) = $expanded
(expanded.string, expanded.prepare)
}.groupBy(_._1).map {
case (string, items) =>
Expand All @@ -54,23 +57,26 @@ class ActionMacro(val c: MacroContext)
"""
}

def runAction(quoted: Tree): Tree =
def runAction(quoted: Tree): Tree = {
val expanded = expand(extractAst(quoted), Quat.Value)
c.untypecheck {
q"""
..${EnableReflectiveCalls(c)}
val expanded = ${expand(extractAst(quoted), Quat.Value)}
val (idiomContext, expanded) = $expanded
${c.prefix}.executeAction(
expanded.string,
expanded.prepare
)(io.getquill.context.ExecutionInfo.unknown, ())
"""
}
}

def runActionReturning[T](quoted: Tree)(implicit t: WeakTypeTag[T]): Tree =
def runActionReturning[T](quoted: Tree)(implicit t: WeakTypeTag[T]): Tree = {
val expanded = expand(extractAst(quoted), inferQuat(t.tpe))
c.untypecheck {
q"""
..${EnableReflectiveCalls(c)}
val expanded = ${expand(extractAst(quoted), inferQuat(t.tpe))}
val (idiomContext, expanded) = $expanded
${c.prefix}.executeActionReturning(
expanded.string,
expanded.prepare,
Expand All @@ -79,12 +85,14 @@ class ActionMacro(val c: MacroContext)
)(io.getquill.context.ExecutionInfo.unknown, ())
"""
}
}

def runActionReturningMany[T](quoted: Tree)(implicit t: WeakTypeTag[T]): Tree =
c.untypecheck { // TODO return expanded.executionType since we now have this info
def runActionReturningMany[T](quoted: Tree)(implicit t: WeakTypeTag[T]): Tree = {
val expanded = expand(extractAst(quoted), inferQuat(t.tpe))
c.untypecheck {
q"""
..${EnableReflectiveCalls(c)}
val expanded = ${expand(extractAst(quoted), inferQuat(t.tpe))}
val (idiomContext, expanded) = $expanded
${c.prefix}.executeActionReturningMany(
expanded.string,
expanded.prepare,
Expand All @@ -93,6 +101,7 @@ class ActionMacro(val c: MacroContext)
)(io.getquill.context.ExecutionInfo.unknown, ())
"""
}
}

// Called from: run(BatchAction)
def runBatchAction(quoted: Tree): Tree = batchAction(quoted, "executeBatchAction")
Expand All @@ -109,7 +118,7 @@ class ActionMacro(val c: MacroContext)

def batchActionRows(quoted: Tree, method: String, numRows: Tree): Tree =
expandBatchActionNew(quoted) {
case (batch, param, expanded, injectableLiftList, idiomNamingOriginalAstVars) =>
case (batch, param, expanded, injectableLiftList, idiomNamingOriginalAstVars, idiomContext) =>
q"""
..${EnableReflectiveCalls(c)}
${c.prefix}.${TermName(method)}({
Expand All @@ -120,12 +129,13 @@ class ActionMacro(val c: MacroContext)
*/
import io.getquill.util.OrderedGroupByExt._
val originalAst = $idiomNamingOriginalAstVars
val idiomContext = $idiomContext
/* for liftQuery(people:List[Person]) `batch` is `people` */
/* TODO Need secondary check to see if context is actually capable of batch-values insert */
/* If there is a INSERT ... VALUES clause this will be cnoded as ValuesClauseToken(lifts) which we need to duplicate */
/* batches: List[List[Person]] */
val batches =
if (io.getquill.context.CanDoBatchedInsert(originalAst, $numRows, idiom, naming, false, ${ConfigLiftables.transpileConfigLiftable(transpileConfig)}) && $numRows != 1) {
if (io.getquill.context.CanDoBatchedInsert(originalAst, $numRows, idiom, naming, false, idiomContext) && $numRows != 1) {
$batch.toList.grouped($numRows).toList
} else {
$batch.toList.map(element => List(element))
Expand Down Expand Up @@ -164,14 +174,15 @@ class ActionMacro(val c: MacroContext)

def batchActionReturningRows[T](quoted: Tree, numRows: Tree)(implicit t: WeakTypeTag[T]): Tree =
expandBatchActionNew(quoted) {
case (batch, param, expanded, injectableLiftList, idiomNamingOriginalAstVars) =>
case (batch, param, expanded, injectableLiftList, idiomNamingOriginalAstVars, idiomContext) =>
q"""
..${EnableReflectiveCalls(c)}
${c.prefix}.executeBatchActionReturning({
import io.getquill.util.OrderedGroupByExt._
val originalAst = $idiomNamingOriginalAstVars
val idiomContext = $idiomContext
val batches =
if (io.getquill.context.CanDoBatchedInsert(originalAst, $numRows, idiom, naming, true, ${ConfigLiftables.transpileConfigLiftable(transpileConfig)}) && $numRows != 1) {
if (io.getquill.context.CanDoBatchedInsert(originalAst, $numRows, idiom, naming, true, idiomContext) && $numRows != 1) {
$batch.toList.grouped($numRows).toList
} else {
$batch.toList.map(element => List(element))
Expand All @@ -190,9 +201,9 @@ class ActionMacro(val c: MacroContext)
"""
}

def expandBatchActionNew(quoted: Tree)(call: (Tree, Tree, Tree, Tree, Tree) => Tree): Tree =
def expandBatchActionNew(quoted: Tree)(call: (Tree, Tree, Tree, Tree, Tree, Tree) => Tree): Tree =
BetaReduction(extractAst(quoted)) match {
case Foreach(lift: Lift, alias, body) =>
case totalAst @ Foreach(lift: Lift, alias, body) =>
// for liftQuery(people:List[Person]) this is: `people`
val batch = lift.value.asInstanceOf[Tree]
// This would be the Type[Person]
Expand All @@ -205,47 +216,61 @@ class ActionMacro(val c: MacroContext)
val nestedLift =
lift match {
case ScalarQueryLift(name, batch: Tree, encoder: Tree, quat) =>
ScalarValueLift("value", q"$values", encoder, quat)
ScalarValueLift("value", External.Source.UnparsedProperty("value"), q"$values", encoder, quat)
case CaseClassQueryLift(name, batch: Tree, quat) =>
CaseClassValueLift("value", q"$values", quat)
CaseClassValueLift("value", "value", q"$values", quat)
}

// So then on the AST-level we transform the alias `p` **
// from this: `foreach(people).map(p => insert(p.name, p.age))`
// into this: `foreach(people).map(p => insert(CaseClassValue("value", value:Person, encoder[Person], quatOf[Person]).name, CCV(...).age)))
// ReifyLiftings will then turn it
// into this: `foreach(people).map(p => insert(CaseClassValue("value", (value:Person).name, encoder[Person], quatOf[Person]), CCV(... (value:Person).age ...))))
// into this: `foreach(people).map(p => insert(CaseClassValue("value", (value:Person).name, encoder[String], quatOf[Person]), CCV(... (value:Person).age ...))))
//
// (** Note that I mixing the scala-api way of seeing this DSL i.e. foreach instead of ast.Foreach
// and the regular one i.e CaseClassValue. That's the only way to see what's going on without information-overload.
// also CCV:=CaseClassValue)
//
// Note that update cases are more complex:
// from this: `foreach(people).map(p => filter(pp => pp.id == p.id).update(p.name, p.age))`
// into this: `foreach(people).map(p => filter(pp => pp.id == CCV(value:Person,...)).update(CCV(value:Person,...).name, CCV(...).age)))
// ReifyLiftings will then turn it
// into this: `foreach(people).map(p => filter(pp => pp.id == CCV((value:Person).id,...).update(CCV((value:Person).name), CCV(... (value:Person).age ...))))
// in order to be able to do things like VALUES-clause inserts we need to preserve the original knowledge that the property was `Property(Id(p),"name").
val (valuePluggingAst, _) = reifyLiftings(BetaReduction(body, alias -> nestedLift))
// this is the ast with ScalarTag placeholders for the lifts
val (ast, valuePlugList) = ExtractLiftings.of(valuePluggingAst)
val liftUnlift = new { override val mctx: c.type = c } with TokenLift(ast.countQuatFields)
// List(id1 -> ((p: Person) => CCV(p.name), id2 -> ((p: Person) => CCV(p.age), ...)
// For regular lifts (e.g. liftQuery(people).foreach(p => query[Person].filter(pp => pp.name == lift("a regular lift").insert(...)))
// we can just do (p: Person) => "a regular lift" and nothing will be done with `p`
val injectableLiftListTrees =
valuePlugList.map {
case (id, valuePlugLift) =>
q"($id, ($param) => ${liftUnlift.astLiftable(valuePlugLift)})"
}
val injectableLiftList = q"$injectableLiftListTrees"
val queryType = IdiomContext.QueryType.discoverFromAst(totalAst, Some(alias.name))
val idiomContext = IdiomContext(transpileConfig, queryType)

// Splice into the code to tokenize the ast (i.e. the Expand class) and compile-time translate the AST if possible
val expanded =
q"""
val (ast, statement, executionType) = ${translate(ast, Quat.Unknown, transpileConfig)}
val (ast, statement, executionType, _) = ${translate(ast, Quat.Unknown, Some(alias.name))}
io.getquill.context.ExpandWithInjectables(${c.prefix}, ast, statement, idiom, naming, executionType, subBatch, $injectableLiftList)
"""

val idiomNamingOriginalAstVars =
q"""
val (idiom, naming) = ${idiomAndNamingDynamic}
val (idiom, naming) = ${idiomAndNamingDynamic};
${liftUnlift.astLiftable.apply((ast))}
"""

val transpileContextExpr =
ConfigLiftables.transpileContextLiftable(idiomContext)

c.untypecheck {
call(batch, param, expanded, injectableLiftList, idiomNamingOriginalAstVars)
call(batch, param, expanded, injectableLiftList, idiomNamingOriginalAstVars, transpileContextExpr)
}
}
case other =>
Expand All @@ -255,24 +280,30 @@ class ActionMacro(val c: MacroContext)
object ExtractLiftings {
def of(ast: Ast): (Ast, List[(String, ScalarLift)]) = {
val (outputAst, extracted) = ExtractLiftings(List())(ast)
(outputAst, extracted.state)
(outputAst, extracted.state.map { case (tag, lift) => (tag.uid, lift) })
}
}
case class ExtractLiftings(state: List[(String, ScalarLift)]) extends StatefulTransformer[List[(String, ScalarLift)]] {
override def apply(e: Action): (Action, StatefulTransformer[List[(String, ScalarLift)]]) =
case class ExtractLiftings(state: List[(ScalarTag, ScalarLift)]) extends StatefulTransformer[List[(ScalarTag, ScalarLift)]] {

override def apply(e: Action): (Action, StatefulTransformer[List[(ScalarTag, ScalarLift)]]) =
e match {
// TODO Can we absolutely assume that this insert will yield a Values clause?
case Insert(body, assignments) =>
val (newAssignments, assignmentMappings) = apply(assignments)(_.apply)
(Insert(body, newAssignments), assignmentMappings)
case _ =>
super.apply(e)
}
override def apply(e: Ast): (Ast, StatefulTransformer[List[(String, ScalarLift)]]) =

// Only extrace lifts that come from values-clauses:
// liftQuery(people).foreach(ps => query[Person].filter(_.name == lift("not this")).insertValue(_.name -> <these!>, ...))
override def apply(e: Ast): (Ast, StatefulTransformer[List[(ScalarTag, ScalarLift)]]) =
e match {
case lift: ScalarLift =>
case rawLift @ ScalarValueLift(_, rawSource @ External.Source.UnparsedProperty(rawSourceName), _, _, _) =>
val uuid = UUID.randomUUID().toString
(ScalarTag(uuid), ExtractLiftings((uuid -> lift) +: state))
val source = External.Source.UnparsedProperty(rawSourceName.stripPrefix("value.").replace(".", "_"))
val scalarTag = ScalarTag(uuid, source)
val lift = rawLift.copy(source = source)
(scalarTag, ExtractLiftings((scalarTag -> lift) +: state))
case _ => super.apply(e)
}
}
Expand All @@ -283,29 +314,30 @@ class ActionMacro(val c: MacroContext)
case ret: io.getquill.ast.ReturningAction =>
io.getquill.norm.ExpandReturning.applyMap(ret)(
(ast, statement) => io.getquill.context.Expand(${c.prefix}, ast, statement, idiom, naming, io.getquill.context.ExecutionType.Unknown).string
)(idiom, naming, ${ConfigLiftables.transpileConfigLiftable(transpileConfig)})
)(idiom, naming, idiomContext)
case ast =>
io.getquill.util.Messages.fail(s"Can't find returning column. Ast: '$$ast'")
})
"""

def expandBatchAction(quoted: Tree)(call: (Tree, Tree, Tree) => Tree): Tree =
BetaReduction(extractAst(quoted)) match {
case ast @ Foreach(lift: Lift, alias, body) =>
case totalAst @ Foreach(lift: Lift, alias, body) =>
val batch = lift.value.asInstanceOf[Tree]
val batchItemType = batch.tpe.typeArgs.head
c.typecheck(q"(value: $batchItemType) => value") match {
case q"($param) => $value" =>
val nestedLift =
lift match {
case ScalarQueryLift(name, batch: Tree, encoder: Tree, quat) =>
ScalarValueLift("value", value, encoder, quat)
ScalarValueLift("value", External.Source.UnparsedProperty("value"), value, encoder, quat)
case CaseClassQueryLift(name, batch: Tree, quat) =>
CaseClassValueLift("value", value, quat)
CaseClassValueLift("value", "value", value, quat)
}
val (ast, _) = reifyLiftings(BetaReduction(body, alias -> nestedLift))
val expanded = expand(ast, Quat.Unknown)
c.untypecheck {
call(batch, param, expand(ast, Quat.Unknown))
call(batch, param, expanded)
}
}
case other =>
Expand All @@ -316,7 +348,7 @@ class ActionMacro(val c: MacroContext)
c.untypecheck {
q"""
..${EnableReflectiveCalls(c)}
val expanded = ${expand(extractAst(quoted), Quat.Value)}
val (idiomContext, expanded) = ${expand(extractAst(quoted), Quat.Value)}
${c.prefix}.prepareAction(
expanded.string,
expanded.prepare
Expand Down
Loading

0 comments on commit 3e62f51

Please sign in to comment.