diff --git a/quill-cassandra/src/test/scala/io/getquill/context/cassandra/CqlIdiomSpec.scala b/quill-cassandra/src/test/scala/io/getquill/context/cassandra/CqlIdiomSpec.scala index d4a74cfbd4..e38e55010c 100644 --- a/quill-cassandra/src/test/scala/io/getquill/context/cassandra/CqlIdiomSpec.scala +++ b/quill-cassandra/src/test/scala/io/getquill/context/cassandra/CqlIdiomSpec.scala @@ -382,15 +382,15 @@ class CqlIdiomSpec extends Spec { "ident" in { val a: Ast = Ident("a") - translate(a, Quat.Unknown, ExecutionType.Unknown, TranspileConfig.Empty) mustBe ((a, stmt"a", ExecutionType.Unknown)) + translate(a, Quat.Unknown, ExecutionType.Unknown, IdiomContext.Empty) mustBe ((a, stmt"a", ExecutionType.Unknown)) } "assignment" in { val a: Ast = Assignment(Ident("a"), Ident("b"), Ident("c")) - translate(a: Ast, Quat.Unknown, ExecutionType.Unknown, TranspileConfig.Empty) mustBe ((a, stmt"b = c", ExecutionType.Unknown)) + translate(a: Ast, Quat.Unknown, ExecutionType.Unknown, IdiomContext.Empty) mustBe ((a, stmt"b = c", ExecutionType.Unknown)) } "assignmentDual" in { val a: Ast = AssignmentDual(Ident("a1"), Ident("a2"), Ident("b"), Ident("c")) - translate(a: Ast, Quat.Unknown, ExecutionType.Unknown, TranspileConfig.Empty) mustBe ((a, stmt"b = c", ExecutionType.Unknown)) + translate(a: Ast, Quat.Unknown, ExecutionType.Unknown, IdiomContext.Empty) mustBe ((a, stmt"b = c", ExecutionType.Unknown)) } "aggregation" in { val t = implicitly[Tokenizer[AggregationOperator]] diff --git a/quill-core/js/src/main/scala/io/getquill/dsl/DynamicQueryDSL.scala b/quill-core/js/src/main/scala/io/getquill/dsl/DynamicQueryDSL.scala index bae3ba89fb..354c564869 100644 --- a/quill-core/js/src/main/scala/io/getquill/dsl/DynamicQueryDSL.scala +++ b/quill-core/js/src/main/scala/io/getquill/dsl/DynamicQueryDSL.scala @@ -4,7 +4,7 @@ import io.getquill.ast.Renameable.Fixed import scala.language.implicitConversions import scala.language.experimental.macros -import io.getquill.ast._ +import io.getquill.ast.{ External, _ } import io.getquill.quat._ import scala.reflect.macros.whitebox.{ Context => MacroContext } @@ -12,7 +12,7 @@ import io.getquill.util.Messages._ import scala.util.DynamicVariable import scala.reflect.ClassTag -import io.getquill.{ ActionReturning, Delete, EntityQuery, Insert, Ord, Query, Update, Action => DslAction, Quoted } +import io.getquill.{ ActionReturning, Delete, EntityQuery, Insert, Ord, Query, Quoted, Update, Action => DslAction } import scala.annotation.tailrec @@ -164,7 +164,7 @@ trait DynamicQueryDsl { } protected def spliceLift[O](o: O)(implicit enc: Encoder[O]) = - splice[O](ScalarValueLift("o", o, enc, Quat.Value)) + splice[O](ScalarValueLift("o", External.Source.Parser, o, enc, Quat.Value)) object DynamicQuery { def apply[T](p: Quoted[Query[T]]) = diff --git a/quill-core/jvm/src/main/scala/io/getquill/dsl/DynamicQueryDSL.scala b/quill-core/jvm/src/main/scala/io/getquill/dsl/DynamicQueryDSL.scala index 2f10d21a4c..025b3a90ff 100644 --- a/quill-core/jvm/src/main/scala/io/getquill/dsl/DynamicQueryDSL.scala +++ b/quill-core/jvm/src/main/scala/io/getquill/dsl/DynamicQueryDSL.scala @@ -185,7 +185,7 @@ trait DynamicQueryDsl { } protected def spliceLift[O](o: O)(implicit enc: Encoder[O]) = - splice[O](ScalarValueLift("o", o, enc, Quat.Value)) + splice[O](ScalarValueLift("o", External.Source.Parser, o, enc, Quat.Value)) object DynamicQuery { def apply[T](p: Quoted[Query[T]]) = diff --git a/quill-core/src/main/scala/io/getquill/MirrorContext.scala b/quill-core/src/main/scala/io/getquill/MirrorContext.scala index 55333a64c3..3e58c497d3 100644 --- a/quill-core/src/main/scala/io/getquill/MirrorContext.scala +++ b/quill-core/src/main/scala/io/getquill/MirrorContext.scala @@ -11,7 +11,7 @@ import scala.util.{ Failure, Success, Try } object mirrorContextWithQueryProbing extends MirrorContext(MirrorIdiom, Literal) with QueryProbing -case class BatchActionMirrorGeneric[Row](groups: List[(String, List[Row])], info: ExecutionInfo) +case class BatchActionMirrorGeneric[A](groups: List[(String, List[A])], info: ExecutionInfo) case class BatchActionReturningMirrorGeneric[T, PrepareRow, Extractor[_]](groups: List[(String, ReturnAction, List[PrepareRow])], extractor: Extractor[T], info: ExecutionInfo) /** diff --git a/quill-core/src/main/scala/io/getquill/context/ActionMacro.scala b/quill-core/src/main/scala/io/getquill/context/ActionMacro.scala index 34e14e5cd7..3a94763d19 100644 --- a/quill-core/src/main/scala/io/getquill/context/ActionMacro.scala +++ b/quill-core/src/main/scala/io/getquill/context/ActionMacro.scala @@ -1,5 +1,6 @@ package io.getquill.context +import io.getquill.IdiomContext import io.getquill.ast._ import io.getquill.norm.BetaReduction import io.getquill.quat.Quat @@ -20,11 +21,12 @@ class ActionMacro(val c: MacroContext) def translateQuery(quoted: Tree): Tree = translateQueryPrettyPrint(quoted, q"false") - def translateQueryPrettyPrint(quoted: Tree, prettyPrint: Tree): Tree = + def translateQueryPrettyPrint(quoted: Tree, prettyPrint: Tree): Tree = { + val expanded = expand(extractAst(quoted), inferQuat(quoted.tpe)) c.untypecheck { q""" ..${EnableReflectiveCalls(c)} - val expanded = ${expand(extractAst(quoted), inferQuat(quoted.tpe))} + val (idiomContext, expanded) = $expanded ${c.prefix}.translateQuery( expanded.string, expanded.prepare, @@ -32,6 +34,7 @@ class ActionMacro(val c: MacroContext) )(io.getquill.context.ExecutionInfo.unknown, ()) """ } + } def translateBatchQuery(quoted: Tree): Tree = translateBatchQueryPrettyPrint(quoted, q"false") @@ -43,7 +46,7 @@ class ActionMacro(val c: MacroContext) ..${EnableReflectiveCalls(c)} ${c.prefix}.translateBatchQuery( $batch.map { $param => - val expanded = $expanded + val (idiomContext, expanded) = $expanded (expanded.string, expanded.prepare) }.groupBy(_._1).map { case (string, items) => @@ -54,23 +57,26 @@ class ActionMacro(val c: MacroContext) """ } - def runAction(quoted: Tree): Tree = + def runAction(quoted: Tree): Tree = { + val expanded = expand(extractAst(quoted), Quat.Value) c.untypecheck { q""" ..${EnableReflectiveCalls(c)} - val expanded = ${expand(extractAst(quoted), Quat.Value)} + val (idiomContext, expanded) = $expanded ${c.prefix}.executeAction( expanded.string, expanded.prepare )(io.getquill.context.ExecutionInfo.unknown, ()) """ } + } - def runActionReturning[T](quoted: Tree)(implicit t: WeakTypeTag[T]): Tree = + def runActionReturning[T](quoted: Tree)(implicit t: WeakTypeTag[T]): Tree = { + val expanded = expand(extractAst(quoted), inferQuat(t.tpe)) c.untypecheck { q""" ..${EnableReflectiveCalls(c)} - val expanded = ${expand(extractAst(quoted), inferQuat(t.tpe))} + val (idiomContext, expanded) = $expanded ${c.prefix}.executeActionReturning( expanded.string, expanded.prepare, @@ -79,12 +85,14 @@ class ActionMacro(val c: MacroContext) )(io.getquill.context.ExecutionInfo.unknown, ()) """ } + } - def runActionReturningMany[T](quoted: Tree)(implicit t: WeakTypeTag[T]): Tree = - c.untypecheck { // TODO return expanded.executionType since we now have this info + def runActionReturningMany[T](quoted: Tree)(implicit t: WeakTypeTag[T]): Tree = { + val expanded = expand(extractAst(quoted), inferQuat(t.tpe)) + c.untypecheck { q""" ..${EnableReflectiveCalls(c)} - val expanded = ${expand(extractAst(quoted), inferQuat(t.tpe))} + val (idiomContext, expanded) = $expanded ${c.prefix}.executeActionReturningMany( expanded.string, expanded.prepare, @@ -93,6 +101,7 @@ class ActionMacro(val c: MacroContext) )(io.getquill.context.ExecutionInfo.unknown, ()) """ } + } // Called from: run(BatchAction) def runBatchAction(quoted: Tree): Tree = batchAction(quoted, "executeBatchAction") @@ -109,7 +118,7 @@ class ActionMacro(val c: MacroContext) def batchActionRows(quoted: Tree, method: String, numRows: Tree): Tree = expandBatchActionNew(quoted) { - case (batch, param, expanded, injectableLiftList, idiomNamingOriginalAstVars) => + case (batch, param, expanded, injectableLiftList, idiomNamingOriginalAstVars, idiomContext) => q""" ..${EnableReflectiveCalls(c)} ${c.prefix}.${TermName(method)}({ @@ -120,12 +129,13 @@ class ActionMacro(val c: MacroContext) */ import io.getquill.util.OrderedGroupByExt._ val originalAst = $idiomNamingOriginalAstVars + val idiomContext = $idiomContext /* for liftQuery(people:List[Person]) `batch` is `people` */ /* TODO Need secondary check to see if context is actually capable of batch-values insert */ /* If there is a INSERT ... VALUES clause this will be cnoded as ValuesClauseToken(lifts) which we need to duplicate */ /* batches: List[List[Person]] */ val batches = - if (io.getquill.context.CanDoBatchedInsert(originalAst, $numRows, idiom, naming, false, ${ConfigLiftables.transpileConfigLiftable(transpileConfig)}) && $numRows != 1) { + if (io.getquill.context.CanDoBatchedInsert(originalAst, $numRows, idiom, naming, false, idiomContext) && $numRows != 1) { $batch.toList.grouped($numRows).toList } else { $batch.toList.map(element => List(element)) @@ -164,14 +174,15 @@ class ActionMacro(val c: MacroContext) def batchActionReturningRows[T](quoted: Tree, numRows: Tree)(implicit t: WeakTypeTag[T]): Tree = expandBatchActionNew(quoted) { - case (batch, param, expanded, injectableLiftList, idiomNamingOriginalAstVars) => + case (batch, param, expanded, injectableLiftList, idiomNamingOriginalAstVars, idiomContext) => q""" ..${EnableReflectiveCalls(c)} ${c.prefix}.executeBatchActionReturning({ import io.getquill.util.OrderedGroupByExt._ val originalAst = $idiomNamingOriginalAstVars + val idiomContext = $idiomContext val batches = - if (io.getquill.context.CanDoBatchedInsert(originalAst, $numRows, idiom, naming, true, ${ConfigLiftables.transpileConfigLiftable(transpileConfig)}) && $numRows != 1) { + if (io.getquill.context.CanDoBatchedInsert(originalAst, $numRows, idiom, naming, true, idiomContext) && $numRows != 1) { $batch.toList.grouped($numRows).toList } else { $batch.toList.map(element => List(element)) @@ -190,9 +201,9 @@ class ActionMacro(val c: MacroContext) """ } - def expandBatchActionNew(quoted: Tree)(call: (Tree, Tree, Tree, Tree, Tree) => Tree): Tree = + def expandBatchActionNew(quoted: Tree)(call: (Tree, Tree, Tree, Tree, Tree, Tree) => Tree): Tree = BetaReduction(extractAst(quoted)) match { - case Foreach(lift: Lift, alias, body) => + case totalAst @ Foreach(lift: Lift, alias, body) => // for liftQuery(people:List[Person]) this is: `people` val batch = lift.value.asInstanceOf[Tree] // This would be the Type[Person] @@ -205,47 +216,61 @@ class ActionMacro(val c: MacroContext) val nestedLift = lift match { case ScalarQueryLift(name, batch: Tree, encoder: Tree, quat) => - ScalarValueLift("value", q"$values", encoder, quat) + ScalarValueLift("value", External.Source.UnparsedProperty("value"), q"$values", encoder, quat) case CaseClassQueryLift(name, batch: Tree, quat) => - CaseClassValueLift("value", q"$values", quat) + CaseClassValueLift("value", "value", q"$values", quat) } // So then on the AST-level we transform the alias `p` ** // from this: `foreach(people).map(p => insert(p.name, p.age))` // into this: `foreach(people).map(p => insert(CaseClassValue("value", value:Person, encoder[Person], quatOf[Person]).name, CCV(...).age))) // ReifyLiftings will then turn it - // into this: `foreach(people).map(p => insert(CaseClassValue("value", (value:Person).name, encoder[Person], quatOf[Person]), CCV(... (value:Person).age ...)))) + // into this: `foreach(people).map(p => insert(CaseClassValue("value", (value:Person).name, encoder[String], quatOf[Person]), CCV(... (value:Person).age ...)))) // // (** Note that I mixing the scala-api way of seeing this DSL i.e. foreach instead of ast.Foreach // and the regular one i.e CaseClassValue. That's the only way to see what's going on without information-overload. // also CCV:=CaseClassValue) + // + // Note that update cases are more complex: + // from this: `foreach(people).map(p => filter(pp => pp.id == p.id).update(p.name, p.age))` + // into this: `foreach(people).map(p => filter(pp => pp.id == CCV(value:Person,...)).update(CCV(value:Person,...).name, CCV(...).age))) + // ReifyLiftings will then turn it + // into this: `foreach(people).map(p => filter(pp => pp.id == CCV((value:Person).id,...).update(CCV((value:Person).name), CCV(... (value:Person).age ...)))) + // in order to be able to do things like VALUES-clause inserts we need to preserve the original knowledge that the property was `Property(Id(p),"name"). val (valuePluggingAst, _) = reifyLiftings(BetaReduction(body, alias -> nestedLift)) // this is the ast with ScalarTag placeholders for the lifts val (ast, valuePlugList) = ExtractLiftings.of(valuePluggingAst) val liftUnlift = new { override val mctx: c.type = c } with TokenLift(ast.countQuatFields) // List(id1 -> ((p: Person) => CCV(p.name), id2 -> ((p: Person) => CCV(p.age), ...) + // For regular lifts (e.g. liftQuery(people).foreach(p => query[Person].filter(pp => pp.name == lift("a regular lift").insert(...))) + // we can just do (p: Person) => "a regular lift" and nothing will be done with `p` val injectableLiftListTrees = valuePlugList.map { case (id, valuePlugLift) => q"($id, ($param) => ${liftUnlift.astLiftable(valuePlugLift)})" } val injectableLiftList = q"$injectableLiftListTrees" + val queryType = IdiomContext.QueryType.discoverFromAst(totalAst, Some(alias.name)) + val idiomContext = IdiomContext(transpileConfig, queryType) // Splice into the code to tokenize the ast (i.e. the Expand class) and compile-time translate the AST if possible val expanded = q""" - val (ast, statement, executionType) = ${translate(ast, Quat.Unknown, transpileConfig)} + val (ast, statement, executionType, _) = ${translate(ast, Quat.Unknown, Some(alias.name))} io.getquill.context.ExpandWithInjectables(${c.prefix}, ast, statement, idiom, naming, executionType, subBatch, $injectableLiftList) """ val idiomNamingOriginalAstVars = q""" - val (idiom, naming) = ${idiomAndNamingDynamic} + val (idiom, naming) = ${idiomAndNamingDynamic}; ${liftUnlift.astLiftable.apply((ast))} """ + val transpileContextExpr = + ConfigLiftables.transpileContextLiftable(idiomContext) + c.untypecheck { - call(batch, param, expanded, injectableLiftList, idiomNamingOriginalAstVars) + call(batch, param, expanded, injectableLiftList, idiomNamingOriginalAstVars, transpileContextExpr) } } case other => @@ -255,24 +280,30 @@ class ActionMacro(val c: MacroContext) object ExtractLiftings { def of(ast: Ast): (Ast, List[(String, ScalarLift)]) = { val (outputAst, extracted) = ExtractLiftings(List())(ast) - (outputAst, extracted.state) + (outputAst, extracted.state.map { case (tag, lift) => (tag.uid, lift) }) } } - case class ExtractLiftings(state: List[(String, ScalarLift)]) extends StatefulTransformer[List[(String, ScalarLift)]] { - override def apply(e: Action): (Action, StatefulTransformer[List[(String, ScalarLift)]]) = + case class ExtractLiftings(state: List[(ScalarTag, ScalarLift)]) extends StatefulTransformer[List[(ScalarTag, ScalarLift)]] { + + override def apply(e: Action): (Action, StatefulTransformer[List[(ScalarTag, ScalarLift)]]) = e match { - // TODO Can we absolutely assume that this insert will yield a Values clause? case Insert(body, assignments) => val (newAssignments, assignmentMappings) = apply(assignments)(_.apply) (Insert(body, newAssignments), assignmentMappings) case _ => super.apply(e) } - override def apply(e: Ast): (Ast, StatefulTransformer[List[(String, ScalarLift)]]) = + + // Only extrace lifts that come from values-clauses: + // liftQuery(people).foreach(ps => query[Person].filter(_.name == lift("not this")).insertValue(_.name -> , ...)) + override def apply(e: Ast): (Ast, StatefulTransformer[List[(ScalarTag, ScalarLift)]]) = e match { - case lift: ScalarLift => + case rawLift @ ScalarValueLift(_, rawSource @ External.Source.UnparsedProperty(rawSourceName), _, _, _) => val uuid = UUID.randomUUID().toString - (ScalarTag(uuid), ExtractLiftings((uuid -> lift) +: state)) + val source = External.Source.UnparsedProperty(rawSourceName.stripPrefix("value.").replace(".", "_")) + val scalarTag = ScalarTag(uuid, source) + val lift = rawLift.copy(source = source) + (scalarTag, ExtractLiftings((scalarTag -> lift) +: state)) case _ => super.apply(e) } } @@ -283,7 +314,7 @@ class ActionMacro(val c: MacroContext) case ret: io.getquill.ast.ReturningAction => io.getquill.norm.ExpandReturning.applyMap(ret)( (ast, statement) => io.getquill.context.Expand(${c.prefix}, ast, statement, idiom, naming, io.getquill.context.ExecutionType.Unknown).string - )(idiom, naming, ${ConfigLiftables.transpileConfigLiftable(transpileConfig)}) + )(idiom, naming, idiomContext) case ast => io.getquill.util.Messages.fail(s"Can't find returning column. Ast: '$$ast'") }) @@ -291,7 +322,7 @@ class ActionMacro(val c: MacroContext) def expandBatchAction(quoted: Tree)(call: (Tree, Tree, Tree) => Tree): Tree = BetaReduction(extractAst(quoted)) match { - case ast @ Foreach(lift: Lift, alias, body) => + case totalAst @ Foreach(lift: Lift, alias, body) => val batch = lift.value.asInstanceOf[Tree] val batchItemType = batch.tpe.typeArgs.head c.typecheck(q"(value: $batchItemType) => value") match { @@ -299,13 +330,14 @@ class ActionMacro(val c: MacroContext) val nestedLift = lift match { case ScalarQueryLift(name, batch: Tree, encoder: Tree, quat) => - ScalarValueLift("value", value, encoder, quat) + ScalarValueLift("value", External.Source.UnparsedProperty("value"), value, encoder, quat) case CaseClassQueryLift(name, batch: Tree, quat) => - CaseClassValueLift("value", value, quat) + CaseClassValueLift("value", "value", value, quat) } val (ast, _) = reifyLiftings(BetaReduction(body, alias -> nestedLift)) + val expanded = expand(ast, Quat.Unknown) c.untypecheck { - call(batch, param, expand(ast, Quat.Unknown)) + call(batch, param, expanded) } } case other => @@ -316,7 +348,7 @@ class ActionMacro(val c: MacroContext) c.untypecheck { q""" ..${EnableReflectiveCalls(c)} - val expanded = ${expand(extractAst(quoted), Quat.Value)} + val (idiomContext, expanded) = ${expand(extractAst(quoted), Quat.Value)} ${c.prefix}.prepareAction( expanded.string, expanded.prepare diff --git a/quill-core/src/main/scala/io/getquill/context/ContextMacro.scala b/quill-core/src/main/scala/io/getquill/context/ContextMacro.scala index 5310ac26f0..e368e3728c 100644 --- a/quill-core/src/main/scala/io/getquill/context/ContextMacro.scala +++ b/quill-core/src/main/scala/io/getquill/context/ContextMacro.scala @@ -5,9 +5,8 @@ import io.getquill.ast.{ Ast, Dynamic, Lift, Tag } import io.getquill.quotation.{ IsDynamic, LiftUnlift, Quotation } import io.getquill.util.LoadObject import io.getquill.util.MacroContextExt._ -import io.getquill.NamingStrategy +import io.getquill.{ NamingStrategy, IdiomContext } import io.getquill.idiom._ -import io.getquill.norm.TranspileConfig import io.getquill.quat.Quat import scala.util.Success @@ -21,22 +20,23 @@ trait ContextMacro extends Quotation { summonPhaseDisable() q""" val (idiom, naming) = ${idiomAndNamingDynamic} - val (ast, statement, executionType) = ${translate(ast, topLevelQuat, transpileConfig)} - io.getquill.context.Expand(${c.prefix}, ast, statement, idiom, naming, executionType) + val (ast, statement, executionType, idiomContext) = ${translate(ast, topLevelQuat, None)} + (idiomContext, io.getquill.context.Expand(${c.prefix}, ast, statement, idiom, naming, executionType)) """ } - protected def extractAst[T](quoted: Tree): Ast = + protected def extractAst[T](quoted: Tree): Ast = { unquote[Ast](c.typecheck(q"quote($quoted)")) .map(VerifyFreeVariables(c)) .getOrElse { Dynamic(quoted) } + } - def translate(ast: Ast, topLevelQuat: Quat, transpileConfig: TranspileConfig): Tree = + def translate(ast: Ast, topLevelQuat: Quat, batchAlias: Option[String]): Tree = IsDynamic(ast) match { - case false => translateStatic(ast, topLevelQuat, transpileConfig) - case true => translateDynamic(ast, topLevelQuat, transpileConfig) + case false => translateStatic(ast, topLevelQuat, batchAlias) + case true => translateDynamic(ast, topLevelQuat, batchAlias) } abstract class TokenLift(numQuatFields: Int) extends LiftUnlift(numQuatFields) { @@ -57,13 +57,16 @@ trait ContextMacro extends Quotation { } } - private def translateStatic(ast: Ast, topLevelQuat: Quat, transpileConfig: TranspileConfig): Tree = { + private def translateStatic(ast: Ast, topLevelQuat: Quat, batchAlias: Option[String]): Tree = { val liftUnlift = new { override val mctx: c.type = c } with TokenLift(ast.countQuatFields) import liftUnlift._ + val transpileConfig = summonTranspileConfig() + val queryType = IdiomContext.QueryType.discoverFromAst(ast, batchAlias) + val idiomContext = IdiomContext(transpileConfig, queryType) idiomAndNamingStatic match { case Success((idiom, naming)) => - val (normalizedAst, statement, _) = idiom.translate(ast, topLevelQuat, ExecutionType.Static, transpileConfig)(naming) + val (normalizedAst, statement, _) = idiom.translate(ast, topLevelQuat, ExecutionType.Static, idiomContext)(naming) val (string, _) = ReifyStatement( @@ -77,26 +80,32 @@ trait ContextMacro extends Quotation { c.query(string, idiom) - q"($normalizedAst, ${statement: Token}, io.getquill.context.ExecutionType.Static)" + q"($normalizedAst, ${statement: Token}, io.getquill.context.ExecutionType.Static, ${ConfigLiftables.transpileContextLiftable(idiomContext)})" case Failure(ex) => c.info(s"Can't translate query at compile time because the idiom and/or the naming strategy aren't known at this point.") - translateDynamic(ast, topLevelQuat, transpileConfig) + translateDynamic(ast, topLevelQuat, batchAlias) } } - private def translateDynamic(ast: Ast, topLevelQuat: Quat, transpileConfig: TranspileConfig): Tree = { - // TODO Need to build Liftables for transpileConfig + private def translateDynamic(ast: Ast, topLevelQuat: Quat, batchAlias: Option[String]): Tree = { val liftUnlift = new { override val mctx: c.type = c } with TokenLift(ast.countQuatFields) import liftUnlift._ val liftQuat: Liftable[Quat] = liftUnlift.quatLiftable + val transpileConfig = summonTranspileConfig() + val transpileConfigExpr = ConfigLiftables.transpileConfigLiftable(transpileConfig) + // Compile-time AST might have Dynamic parts, we need those resoved (i.e. at runtime to be able to get the query type) + val queryTypeExpr = q"_root_.io.getquill.IdiomContext.QueryType.discoverFromAst($ast, $batchAlias)" c.info("Dynamic query") val translateMethod = if (io.getquill.util.Messages.cacheDynamicQueries) { q"idiom.translateCached" } else q"idiom.translate" - // The `transpileConfig` variable uses scala's provided list-liftable and the optionalPhaseLiftable + // The `idiomContext` variable uses scala's provided list-liftable and the optionalPhaseLiftable q""" val (idiom, naming) = ${idiomAndNamingDynamic} - $translateMethod(new _root_.io.getquill.norm.RepropagateQuats(${ConfigLiftables.transpileConfigLiftable(transpileConfig)}.traceConfig)($ast), ${liftQuat(topLevelQuat)}, io.getquill.context.ExecutionType.Dynamic, ${ConfigLiftables.transpileConfigLiftable(transpileConfig)})(naming) + val traceConfig = ${ConfigLiftables.traceConfigLiftable(transpileConfig.traceConfig)} + val idiomContext = _root_.io.getquill.IdiomContext($transpileConfigExpr, $queryTypeExpr) + val (ast, statement, executionType) = $translateMethod(new _root_.io.getquill.norm.RepropagateQuats(traceConfig)($ast), ${liftQuat(topLevelQuat)}, io.getquill.context.ExecutionType.Dynamic, idiomContext)(naming) + (ast, statement, executionType, idiomContext) """ } diff --git a/quill-core/src/main/scala/io/getquill/context/Expand.scala b/quill-core/src/main/scala/io/getquill/context/Expand.scala index 2be4cd10d6..6ebcdf271a 100644 --- a/quill-core/src/main/scala/io/getquill/context/Expand.scala +++ b/quill-core/src/main/scala/io/getquill/context/Expand.scala @@ -3,11 +3,11 @@ package io.getquill.context import io.getquill.ast._ import io.getquill.NamingStrategy import io.getquill.idiom._ -import io.getquill.norm.TranspileConfig +import io.getquill.IdiomContext import io.getquill.quat.Quat object CanDoBatchedInsert { - def apply(ast: Ast, numRows: Int, idiom: Idiom, naming: NamingStrategy, isReturning: Boolean, transpileConfig: TranspileConfig): Boolean = { + def apply(ast: Ast, numRows: Int, idiom: Idiom, naming: NamingStrategy, isReturning: Boolean, idiomContext: IdiomContext): Boolean = { // find any actions that could have a VALUES clause. Right now just ast.Insert, // in the future might be Update and Dlete val actions = CollectAst.byType[Action](ast) @@ -19,7 +19,7 @@ object CanDoBatchedInsert { else { // In order to see if there's a VALUES-clause in the action, we don't need to tokenize the entire AST, // just the ast.Insert (or ast.Update or ast.Delete) - val statement = idiom.translate(actions.head, Quat.Unknown, ExecutionType.Unknown, transpileConfig)(naming)._2 + val statement = idiom.translate(actions.head, Quat.Unknown, ExecutionType.Unknown, idiomContext)(naming)._2 val validations = for { diff --git a/quill-core/src/main/scala/io/getquill/context/QueryMacro.scala b/quill-core/src/main/scala/io/getquill/context/QueryMacro.scala index 56399b414f..e54ce84640 100644 --- a/quill-core/src/main/scala/io/getquill/context/QueryMacro.scala +++ b/quill-core/src/main/scala/io/getquill/context/QueryMacro.scala @@ -126,7 +126,7 @@ class QueryMacro(val c: MacroContext) extends ContextMacro { q""" ..${EnableReflectiveCalls(c)} val staticTopLevelQuat = ${if (Messages.attachTopLevelQuats) liftQuat(topLevelQuat) else q"io.getquill.quat.Quat.Unknown"} - val expanded = ${expand(ast, topLevelQuat)} + val (idiomContext, expanded) = ${expand(ast, topLevelQuat)} ${invocation} """ } @@ -205,7 +205,7 @@ class QueryMacro(val c: MacroContext) extends ContextMacro { q""" ..${EnableReflectiveCalls(c)} val staticTopLevelQuat = ${if (Messages.attachTopLevelQuats) liftQuat(topLevelQuat) else q"io.getquill.quat.Quat.Unknown"} - val expanded = ${expand(ast, topLevelQuat)} + val (idiomContext, expanded) = ${expand(ast, topLevelQuat)} ${invocation} """ } diff --git a/quill-core/src/main/scala/io/getquill/quat/QuatMaking.scala b/quill-core/src/main/scala/io/getquill/quat/QuatMaking.scala index 8133c03542..dbd0cc71d4 100644 --- a/quill-core/src/main/scala/io/getquill/quat/QuatMaking.scala +++ b/quill-core/src/main/scala/io/getquill/quat/QuatMaking.scala @@ -1,9 +1,10 @@ package io.getquill.quat +import io.getquill.quotation.{ MacroUtilUniverse, MacroUtilBase } + import java.lang.reflect.Method -import io.getquill.Quoted -import io.getquill.util.{ Messages, OptionalTypecheck } import io.getquill.{ Embedded, Udt } +import io.getquill.util.{ Messages, OptionalTypecheck } import scala.annotation.tailrec import scala.reflect.ClassTag @@ -15,11 +16,11 @@ object QuatMaking { case object IgnoreDecoders extends IgnoreDecoders } -trait QuatMaking extends QuatMakingBase { +trait QuatMaking extends QuatMakingBase with MacroUtilBase { val c: Context - type Uni = c.universe.type + override type Uni = c.universe.type // NOTE: u needs to be lazy otherwise sets value from c before c can be initialized by higher level classes - lazy val u: Uni = c.universe + override lazy val u: Uni = c.universe import u.{ Block => _, Constant => _, Function => _, Ident => _, If => _, _ } import collection.mutable.HashMap; @@ -139,7 +140,7 @@ abstract class TypeTaggedQuatMaking extends QuatMakingBase { def quatValueTypes: List[universe.Type] } -trait QuatMakingBase { +trait QuatMakingBase extends MacroUtilUniverse { type Uni <: Universe val u: Uni import u.{ Block => _, Constant => _, Function => _, Ident => _, If => _, _ } @@ -344,48 +345,6 @@ trait QuatMakingBase { parseTopLevelType(tpe) } - object QuotedType { - def unapply(tpe: Type) = - paramOf(tpe, typeOf[Quoted[Any]]) - } - - object QueryType { - def unapply(tpe: Type) = - paramOf(tpe, typeOf[io.getquill.Query[Any]]) - } - - object TypeSigParam { - def unapply(tpe: Type): Option[Type] = - tpe.typeSymbol.typeSignature.typeParams match { - case head :: tail => Some(head.typeSignature) - case Nil => None - } - } - - def paramOf(tpe: Type, of: Type, maxDepth: Int = 10): Option[Type] = { - //println(s"### Attempting to check paramOf ${tpe} assuming it is a ${of}") - tpe match { - case _ if (maxDepth == 0) => - throw new IllegalArgumentException(s"Max Depth reached with type: ${tpe}") - case _ if (!(tpe <:< of)) => - //println(s"### ${tpe} is not a ${of}") - None - case _ if (tpe =:= typeOf[Nothing] || tpe =:= typeOf[Any]) => - //println(s"### ${tpe} is Nothing or Any") - None - case TypeRef(_, cls, List(arg)) => - //println(s"### ${tpe} is a TypeRef whose arg is ${arg}") - Some(arg) - case TypeSigParam(param) => - //println(s"### ${tpe} is a type signature whose type is ${param}") - Some(param) - case _ => - val base = tpe.baseType(of.typeSymbol) - //println(s"### Going to base type for ${tpe} for expected base type ${of}") - paramOf(base, of, maxDepth - 1) - } - } - @tailrec private[getquill] final def innerOptionParam(tpe: Type, maxDepth: Option[Int]): Type = tpe match { case TypeRef(_, cls, List(arg)) if (cls.isClass && cls.asClass.fullName == "scala.Option") && maxDepth.forall(_ > 0) => diff --git a/quill-core/src/main/scala/io/getquill/quotation/Liftables.scala b/quill-core/src/main/scala/io/getquill/quotation/Liftables.scala index 96638bd937..175df081c0 100644 --- a/quill-core/src/main/scala/io/getquill/quotation/Liftables.scala +++ b/quill-core/src/main/scala/io/getquill/quotation/Liftables.scala @@ -11,6 +11,11 @@ trait Liftables extends QuatLiftable { private val pack = q"io.getquill.ast" + implicit val stringOptionLiftable: Liftable[Option[String]] = Liftable[Option[String]] { + case None => q"scala.None" + case Some(value) => q"scala.Some[String]($value)" + } + implicit val astLiftable: Liftable[Ast] = Liftable[Ast] { case ast: Query => queryLiftable(ast) case ast: Action => actionLiftable(ast) @@ -19,7 +24,7 @@ trait Liftables extends QuatLiftable { case ast: ExternalIdent => externalIdentLiftable(ast) case ast: Ordering => orderingLiftable(ast) case ast: Lift => liftLiftable(ast) - case ScalarTag(uid) => q"$pack.ScalarTag($uid)" + case ast: Tag => tagLiftable(ast) case ast: Assignment => assignmentLiftable(ast) case ast: AssignmentDual => assignmentDualLiftable(ast) case ast: OptionOperation => optionOperationLiftable(ast) @@ -207,13 +212,17 @@ trait Liftables extends QuatLiftable { } implicit val liftLiftable: Liftable[Lift] = Liftable[Lift] { - case ScalarValueLift(a, b: Tree, c: Tree, quat: Quat) => q"$pack.ScalarValueLift($a, $b, $c, $quat)" - case CaseClassValueLift(a, b: Tree, quat: Quat) => q"$pack.CaseClassValueLift($a, $b, $quat)" - case ScalarQueryLift(a, b: Tree, c: Tree, quat: Quat) => q"$pack.ScalarQueryLift($a, $b, $c, $quat)" - case CaseClassQueryLift(a, b: Tree, quat: Quat) => q"$pack.CaseClassQueryLift($a, $b, $quat)" + case ScalarValueLift(a, a1, b: Tree, c: Tree, quat: Quat) => q"$pack.ScalarValueLift($a, $a1, $b, $c, $quat)" + case CaseClassValueLift(a, a1, b: Tree, quat: Quat) => q"$pack.CaseClassValueLift($a, $a1, $b, $quat)" + case ScalarQueryLift(a, b: Tree, c: Tree, quat: Quat) => q"$pack.ScalarQueryLift($a, $b, $c, $quat)" + case CaseClassQueryLift(a, b: Tree, quat: Quat) => q"$pack.CaseClassQueryLift($a, $b, $quat)" } implicit val tagLiftable: Liftable[Tag] = Liftable[Tag] { - case ScalarTag(uid) => q"$pack.ScalarTag($uid)" - case QuotationTag(uid) => q"$pack.QuotationTag($uid)" + case ScalarTag(uid, originalName) => q"$pack.ScalarTag($uid, $originalName)" + case QuotationTag(uid) => q"$pack.QuotationTag($uid)" + } + implicit val sourceLiftable: Liftable[External.Source] = Liftable[External.Source] { + case External.Source.Parser => q"$pack.External.Source.Parser" + case External.Source.UnparsedProperty(prop) => q"$pack.External.Source.UnparsedProperty($prop)" } } diff --git a/quill-core/src/main/scala/io/getquill/quotation/MacroUtilUniverse.scala b/quill-core/src/main/scala/io/getquill/quotation/MacroUtilUniverse.scala new file mode 100644 index 0000000000..4708144807 --- /dev/null +++ b/quill-core/src/main/scala/io/getquill/quotation/MacroUtilUniverse.scala @@ -0,0 +1,96 @@ +package io.getquill.quotation + +import io.getquill.{ IdiomContext, Quoted } + +import scala.reflect.api.Universe +import scala.reflect.macros.whitebox.Context + +trait MacroUtilBase extends MacroUtilUniverse { + val c: Context + type Uni = c.universe.type + // NOTE: u needs to be lazy otherwise sets value from c before c can be initialized by higher level classes + lazy val u: Uni = c.universe +} + +trait MacroUtilUniverse { + type Uni <: Universe + val u: Uni + import u.{ Block => _, Constant => _, Function => _, Ident => _, If => _, _ } + + object QuotedType { + def unapply(tpe: Type) = + paramOf(tpe, typeOf[Quoted[Any]]) + } + + object QueryType { + def unapply(tpe: Type) = + paramOf(tpe, typeOf[io.getquill.Query[Any]]) + } + + object BatchType { + def unapply(tpe: Type) = + paramOf(tpe, typeOf[io.getquill.BatchAction[_]]) + } + + // Note: These will not match if they are not existential + object ActionType { + object Insert { + def unapply(tpe: Type) = + paramOf(tpe, typeOf[io.getquill.Insert[_]]) + } + object Update { + def unapply(tpe: Type) = + paramOf(tpe, typeOf[io.getquill.Update[_]]) + } + object Delete { + def unapply(tpe: Type) = + paramOf(tpe, typeOf[io.getquill.Delete[_]]) + } + } + + object TypeSigParam { + def unapply(tpe: Type): Option[Type] = + tpe.typeSymbol.typeSignature.typeParams match { + case head :: tail => Some(head.typeSignature) + case Nil => None + } + } + + def parseQueryType(tpe: Type): Option[IdiomContext.QueryType] = { + println(s"Trying to match: ${show(tpe)}") + tpe match { + case QuotedType(tpe) => parseQueryType(tpe) + case BatchType(tpe) => parseQueryType(tpe) + case QueryType(tpe) => Some(IdiomContext.QueryType.Select) + case ActionType.Insert(tpe) => Some(IdiomContext.QueryType.Insert) + case ActionType.Update(tpe) => Some(IdiomContext.QueryType.Update) + case ActionType.Delete(tpe) => Some(IdiomContext.QueryType.Delete) + case _ => None + } + } + + def paramOf(tpe: Type, of: Type, maxDepth: Int = 10): Option[Type] = { + //println(s"### Attempting to check paramOf ${tpe} assuming it is a ${of}") + tpe match { + case _ if (maxDepth == 0) => + throw new IllegalArgumentException(s"Max Depth reached with type: ${tpe}") + case _ if (!(tpe <:< of)) => + //println(s"### ${tpe} is not a ${of}") + None + case _ if (tpe =:= typeOf[Nothing] || tpe =:= typeOf[Any]) => + //println(s"### ${tpe} is Nothing or Any") + None + case TypeRef(_, cls, List(arg)) => + //println(s"### ${tpe} is a TypeRef whose arg is ${arg}") + Some(arg) + case TypeSigParam(param) => + //println(s"### ${tpe} is a type signature whose type is ${param}") + Some(param) + case _ => + val base = tpe.baseType(of.typeSymbol) + //println(s"### Going to base type for ${tpe} for expected base type ${of}") + paramOf(base, of, maxDepth - 1) + } + } + +} diff --git a/quill-core/src/main/scala/io/getquill/quotation/Parsing.scala b/quill-core/src/main/scala/io/getquill/quotation/Parsing.scala index b2c3a2b137..926670a0e7 100644 --- a/quill-core/src/main/scala/io/getquill/quotation/Parsing.scala +++ b/quill-core/src/main/scala/io/getquill/quotation/Parsing.scala @@ -20,7 +20,7 @@ import io.getquill.util.Messages.TraceType import io.getquill.util.{ Interleave, Interpolator, Messages } import io.getquill.{ Quoted, Delete => DslDelete, Insert => DslInsert, Query => DslQuery, Update => DslUpdate } -trait Parsing extends ValueComputation with QuatMaking { +trait Parsing extends ValueComputation with QuatMaking with MacroUtilBase { this: Quotation => import c.universe.{ Ident => _, Constant => _, Function => _, If => _, Block => _, _ } @@ -133,14 +133,14 @@ trait Parsing extends ValueComputation with QuatMaking { val liftParser: Parser[Lift] = Parser[Lift] { - case q"$pack.liftScalar[$t]($value)($encoder)" => ScalarValueLift(value.toString, value, encoder, inferQuat(q"$t".tpe)) - case q"$pack.liftCaseClass[$t]($value)" => CaseClassValueLift(value.toString, value, inferQuat(q"$t".tpe)) + case q"$pack.liftScalar[$t]($value)($encoder)" => ScalarValueLift(value.toString, External.Source.Parser, value, encoder, inferQuat(q"$t".tpe)) + case q"$pack.liftCaseClass[$t]($value)" => CaseClassValueLift(value.toString, value.toString, value, inferQuat(q"$t".tpe)) case q"$pack.liftQueryScalar[$u, $t]($value)($encoder)" => ScalarQueryLift(value.toString, value, encoder, inferQuat(q"$t".tpe)) case q"$pack.liftQueryCaseClass[$u, $t]($value)" => CaseClassQueryLift(value.toString, value, inferQuat(q"$t".tpe)) // Unused, it's here only to make eclipse's presentation compiler happy :( - case q"$pack.lift[$t]($value)" => ScalarValueLift(value.toString, value, q"null", inferQuat(q"$t".tpe)) + case q"$pack.lift[$t]($value)" => ScalarValueLift(value.toString, External.Source.Parser, value, q"null", inferQuat(q"$t".tpe)) case q"$pack.liftQuery[$t, $u]($value)" => ScalarQueryLift(value.toString, value, q"null", inferQuat(q"$t".tpe)) } diff --git a/quill-core/src/main/scala/io/getquill/quotation/ReifyLiftings.scala b/quill-core/src/main/scala/io/getquill/quotation/ReifyLiftings.scala index ad18e5367e..1201565c44 100644 --- a/quill-core/src/main/scala/io/getquill/quotation/ReifyLiftings.scala +++ b/quill-core/src/main/scala/io/getquill/quotation/ReifyLiftings.scala @@ -32,34 +32,50 @@ trait ReifyLiftings extends QuatMaking with TranspileConfigSummoning { private case class ReifyLiftings(state: Map[TermName, Reified]) extends StatefulTransformer[Map[TermName, Reified]] { + case class Unparsed(tree: Tree, name: String) + private def reify(lift: Lift) = lift match { - case ScalarValueLift(name, value: Tree, encoder: Tree, _) => Reified(value, Some(encoder)) - case CaseClassValueLift(name, value: Tree, _) => Reified(value, None) - case ScalarQueryLift(name, value: Tree, encoder: Tree, _) => Reified(value, Some(encoder)) - case CaseClassQueryLift(name, value: Tree, _) => Reified(value, None) + case ScalarValueLift(name, simpleName, value: Tree, encoder: Tree, _) => Reified(value, Some(encoder)) + case CaseClassValueLift(name, simpleName, value: Tree, _) => Reified(value, None) + case ScalarQueryLift(name, value: Tree, encoder: Tree, _) => Reified(value, Some(encoder)) + case CaseClassQueryLift(name, value: Tree, _) => Reified(value, None) } - private def unparse(ast: Ast): Tree = + private def unparse(ast: Ast): Unparsed = ast match { - case Property(Ident(alias, _), name) => q"${TermName(alias)}.${TermName(name)}" - case Property(nested, name) => q"${unparse(nested)}.${TermName(name)}" + case Property(Ident(alias, _), name) => + Unparsed(q"${TermName(alias)}.${TermName(name)}", name) + + case Property(nested, name) => + val Unparsed(nestedTree, nestedName) = unparse(nested) + Unparsed(q"${nestedTree}.${TermName(name)}", s"$nestedName.$name") + case OptionTableMap(ast2, Ident(alias, _), body) => - q"${unparse(ast2)}.map((${TermName(alias)}: ${tq""}) => ${unparse(body)})" + val Unparsed(ast2Tree, ast2Name) = unparse(ast2) + val Unparsed(bodyTree, bodyName) = unparse(body) + Unparsed(q"${ast2Tree}.map((${TermName(alias)}: ${tq""}) => ${bodyTree})", s"$ast2Name.$bodyName") + case OptionMap(ast2, Ident(alias, _), body) => - q"${unparse(ast2)}.map((${TermName(alias)}: ${tq""}) => ${unparse(body)})" - case CaseClassValueLift(_, v: Tree, _) => v - case other => c.fail(s"Unsupported AST: $other") + val Unparsed(ast2Tree, ast2Name) = unparse(ast2) + val Unparsed(bodyTree, bodyName) = unparse(body) + Unparsed(q"${ast2Tree}.map((${TermName(alias)}: ${tq""}) => ${bodyTree})", s"$ast2Name.$bodyName") + + case CaseClassValueLift(_, simpleName, v: Tree, _) => + Unparsed(v, simpleName) + + case other => c.fail(s"Unsupported AST: $other") } - private def lift(v: Tree): Lift = { + private def lift(value: Unparsed): Lift = { + val Unparsed(v, originalName) = value val tpe = c.typecheck(q"import _root_.scala.language.reflectiveCalls; $v").tpe OptionalTypecheck(c)(q"implicitly[${c.prefix}.Encoder[$tpe]]") match { - case Some(enc) => ScalarValueLift(v.toString, v, enc, inferQuat(tpe)) + case Some(enc) => ScalarValueLift(v.toString, External.Source.UnparsedProperty(originalName), v, enc, inferQuat(tpe)) case None => tpe.baseType(c.symbolOf[Product]) match { case NoType => c.fail(s"Can't find an encoder for the lifted case class property '$v'") - case _ => CaseClassValueLift(v.toString, v, inferQuat(tpe)) + case _ => CaseClassValueLift(v.toString, originalName, v, inferQuat(tpe)) } } } @@ -118,10 +134,10 @@ trait ReifyLiftings extends QuatMaking with TranspileConfigSummoning { val nested = q"$ref.$liftings.${encode(lift.name)}" lift match { - case ScalarValueLift(name, value, encoder, quat) => - ScalarValueLift(s"$ref.$name", q"$nested.value", q"$nested.encoder", quat) - case CaseClassValueLift(name, value, quat) => - CaseClassValueLift(s"$ref.$name", q"$nested.value", quat) + case ScalarValueLift(name, source, value, encoder, quat) => + ScalarValueLift(s"$ref.$name", source, q"$nested.value", q"$nested.encoder", quat) + case CaseClassValueLift(name, simpleName, value, quat) => + CaseClassValueLift(s"$ref.$name", simpleName, q"$nested.value", quat) case ScalarQueryLift(name, value, encoder, quat) => ScalarQueryLift(s"$ref.$name", q"$nested.value", q"$nested.encoder", quat) case CaseClassQueryLift(name, value, quat) => diff --git a/quill-core/src/main/scala/io/getquill/quotation/TranspileConfigSummoning.scala b/quill-core/src/main/scala/io/getquill/quotation/TranspileConfigSummoning.scala index 3f3c926a4a..0ce74cc8f2 100644 --- a/quill-core/src/main/scala/io/getquill/quotation/TranspileConfigSummoning.scala +++ b/quill-core/src/main/scala/io/getquill/quotation/TranspileConfigSummoning.scala @@ -1,5 +1,8 @@ package io.getquill.quotation +import io.getquill.{ IdiomContext } +import io.getquill.IdiomContext.QueryType +import io.getquill.IdiomContext.QueryType.{ Batch, Regular } import io.getquill.norm.{ OptionalPhase, TranspileConfig } import io.getquill.util.Messages.TraceType import io.getquill.util.TraceConfig @@ -109,5 +112,26 @@ trait TranspileConfigSummoning { implicit val transpileConfigLiftable: Liftable[TranspileConfig] = Liftable[TranspileConfig] { case TranspileConfig(disablePhases, traceConfig) => q"io.getquill.norm.TranspileConfig(${disablePhases}, ${traceConfig})" } + + implicit val queryTypeRegularLiftable: Liftable[Regular] = Liftable[Regular] { + case QueryType.Select => q"io.getquill.IdiomContext.QueryType.Select" + case QueryType.Insert => q"io.getquill.IdiomContext.QueryType.Insert" + case QueryType.Update => q"io.getquill.IdiomContext.QueryType.Update" + case QueryType.Delete => q"io.getquill.IdiomContext.QueryType.Delete" + } + + implicit val queryTypeBatchLiftable: Liftable[Batch] = Liftable[Batch] { + case QueryType.BatchInsert(foreachAlias) => q"io.getquill.IdiomContext.QueryType.BatchInsert($foreachAlias)" + case QueryType.BatchUpdate(foreachAlias) => q"io.getquill.IdiomContext.QueryType.BatchUpdate($foreachAlias)" + } + + implicit val queryTypeLiftable: Liftable[QueryType] = Liftable[QueryType] { + case v: Regular => queryTypeRegularLiftable(v) + case v: Batch => queryTypeBatchLiftable(v) + } + + implicit val transpileContextLiftable: Liftable[IdiomContext] = Liftable[IdiomContext] { + case IdiomContext(transpileConfig, queryType) => q"io.getquill.IdiomContext(${transpileConfig}, ${queryType})" + } } } \ No newline at end of file diff --git a/quill-core/src/main/scala/io/getquill/quotation/Unliftables.scala b/quill-core/src/main/scala/io/getquill/quotation/Unliftables.scala index 4b223f9dac..5d46104f1f 100644 --- a/quill-core/src/main/scala/io/getquill/quotation/Unliftables.scala +++ b/quill-core/src/main/scala/io/getquill/quotation/Unliftables.scala @@ -8,9 +8,14 @@ trait Unliftables extends QuatUnliftable { val mctx: Context import mctx.universe.{ Ident => _, Constant => _, Function => _, If => _, _ } + implicit val stringOptionUnliftable: Unliftable[Option[String]] = Unliftable[Option[String]] { + case q"scala.None" => None + case q"scala.Some[String](${ value: String })" => Some(value) + } + implicit val astUnliftable: Unliftable[Ast] = Unliftable[Ast] { case liftUnliftable(ast) => ast - case q"ScalarTag(${ uid: String })" => ScalarTag(uid) + case tagUnliftable(ast) => ast case queryUnliftable(ast) => ast case actionUnliftable(ast) => ast case valueUnliftable(ast) => ast @@ -206,13 +211,17 @@ trait Unliftables extends QuatUnliftable { } implicit val liftUnliftable: Unliftable[Lift] = Unliftable[Lift] { - case q"$pack.ScalarValueLift.apply(${ a: String }, $b, $c, ${ quat: Quat })" => ScalarValueLift(a, b, c, quat) - case q"$pack.CaseClassValueLift.apply(${ a: String }, $b, ${ quat: Quat })" => CaseClassValueLift(a, b, quat) - case q"$pack.ScalarQueryLift.apply(${ a: String }, $b, $c, ${ quat: Quat })" => ScalarQueryLift(a, b, c, quat) - case q"$pack.CaseClassQueryLift.apply(${ a: String }, $b, ${ quat: Quat })" => CaseClassQueryLift(a, b, quat) + case q"$pack.ScalarValueLift.apply(${ a: String }, ${ a1: External.Source }, $b, $c, ${ quat: Quat })" => ScalarValueLift(a, a1, b, c, quat) + case q"$pack.CaseClassValueLift.apply(${ a: String }, ${ a1: String }, $b, ${ quat: Quat })" => CaseClassValueLift(a, a1, b, quat) + case q"$pack.ScalarQueryLift.apply(${ a: String }, $b, $c, ${ quat: Quat })" => ScalarQueryLift(a, b, c, quat) + case q"$pack.CaseClassQueryLift.apply(${ a: String }, $b, ${ quat: Quat })" => CaseClassQueryLift(a, b, quat) + } + implicit val sourceUnliftable: Unliftable[External.Source] = Unliftable[External.Source] { + case q"$pack.External.Source.Parser" => External.Source.Parser + case q"$pack.External.Source.UnparsedProperty.apply(${ prop: String })" => External.Source.UnparsedProperty(prop) } implicit val tagUnliftable: Unliftable[Tag] = Unliftable[Tag] { - case q"$pack.ScalarTag.apply(${ uid: String })" => ScalarTag(uid) - case q"$pack.QuotationTag.apply(${ uid: String })" => QuotationTag(uid) + case q"$pack.ScalarTag.apply(${ uid: String }, ${ source: External.Source })" => ScalarTag(uid, source) + case q"$pack.QuotationTag.apply(${ uid: String })" => QuotationTag(uid) } } diff --git a/quill-core/src/test/scala/io/getquill/norm/ExpandReturningSpec.scala b/quill-core/src/test/scala/io/getquill/norm/ExpandReturningSpec.scala index 9f3e83b235..05d4522f7e 100644 --- a/quill-core/src/test/scala/io/getquill/norm/ExpandReturningSpec.scala +++ b/quill-core/src/test/scala/io/getquill/norm/ExpandReturningSpec.scala @@ -25,7 +25,7 @@ class ExpandReturningSpec extends Spec { query[Person].insertValue(lift(Person("Joe", 123))).returning(p => (p.name, p.age)) } val list = - ExpandReturning.apply(q.ast.asInstanceOf[Returning])(MirrorIdiom, Literal, TranspileConfig.Empty) + ExpandReturning.apply(q.ast.asInstanceOf[Returning])(MirrorIdiom, Literal, IdiomContext.Empty) list must matchPattern { case List((Property(ExternalIdent("p", `quat`), "name"), _), (Property(ExternalIdent("p", `quat`), "age"), _)) => } @@ -36,7 +36,7 @@ class ExpandReturningSpec extends Spec { query[Person].insertValue(lift(Person("Joe", 123))).returning(p => Foo(p.name, p.age)) } val list = - ExpandReturning.apply(q.ast.asInstanceOf[Returning])(MirrorIdiom, Literal, TranspileConfig.Empty) + ExpandReturning.apply(q.ast.asInstanceOf[Returning])(MirrorIdiom, Literal, IdiomContext.Empty) list must matchPattern { case List((Property(ExternalIdent("p", `quat`), "name"), _), (Property(ExternalIdent("p", `quat`), "age"), _)) => } @@ -52,7 +52,7 @@ class ExpandReturningSpec extends Spec { query[Person].insertValue(lift(Person("Joe", 123))).returning(p => (p.name, p.age)) } val list = - ExpandReturning.apply(q.ast.asInstanceOf[Returning], Some("OTHER"))(MirrorIdiom, SnakeCase, TranspileConfig.Empty) + ExpandReturning.apply(q.ast.asInstanceOf[Returning], Some("OTHER"))(MirrorIdiom, SnakeCase, IdiomContext.Empty) list must matchPattern { case List((Property(ExternalIdent("OTHER", `quat`), "name"), _), (Property(ExternalIdent("OTHER", `quat`), "age"), _)) => } @@ -63,7 +63,7 @@ class ExpandReturningSpec extends Spec { query[Person].insertValue(lift(Person("Joe", 123))).returning(p => Foo(p.name, p.age)) } val list = - ExpandReturning.apply(q.ast.asInstanceOf[Returning], Some("OTHER"))(MirrorIdiom, SnakeCase, TranspileConfig.Empty) + ExpandReturning.apply(q.ast.asInstanceOf[Returning], Some("OTHER"))(MirrorIdiom, SnakeCase, IdiomContext.Empty) list must matchPattern { case List((Property(ExternalIdent("OTHER", `quat`), "name"), _), (Property(ExternalIdent("OTHER", `quat`), "age"), _)) => } @@ -81,7 +81,7 @@ class ExpandReturningSpec extends Spec { val ret = ExpandReturning.applyMap(qi.ast.asInstanceOf[Returning]) { case (ast, stmt) => fail("Should not use this method for the returning clause") - }(mi, Literal, TranspileConfig.Empty) + }(mi, Literal, IdiomContext.Empty) ret mustBe ReturnRecord } @@ -90,7 +90,7 @@ class ExpandReturningSpec extends Spec { val ret = ExpandReturning.applyMap(qi.ast.asInstanceOf[Returning]) { case (ast, stmt) => fail("Should not use this method for the returning clause") - }(mi, Literal, TranspileConfig.Empty) + }(mi, Literal, IdiomContext.Empty) ret mustBe ReturnRecord } @@ -99,7 +99,7 @@ class ExpandReturningSpec extends Spec { val ret = ExpandReturning.applyMap(qi.ast.asInstanceOf[Returning]) { case (ast, stmt) => fail("Should not use this method for the returning clause") - }(mi, Literal, TranspileConfig.Empty) + }(mi, Literal, IdiomContext.Empty) ret mustBe ReturnRecord } @@ -116,7 +116,7 @@ class ExpandReturningSpec extends Spec { val ret = ExpandReturning.applyMap(qi.ast.asInstanceOf[Returning]) { case (ast, stmt) => Expand(ctx, ast, stmt, mi, Literal, ExecutionType.Unknown).string - }(mi, Literal, TranspileConfig.Empty) + }(mi, Literal, IdiomContext.Empty) ret mustBe ReturnColumns(List("name", "age")) } "should expand case classes" in { @@ -124,7 +124,7 @@ class ExpandReturningSpec extends Spec { val ret = ExpandReturning.applyMap(qi.ast.asInstanceOf[Returning]) { case (ast, stmt) => Expand(ctx, ast, stmt, mi, Literal, ExecutionType.Unknown).string - }(mi, Literal, TranspileConfig.Empty) + }(mi, Literal, IdiomContext.Empty) ret mustBe ReturnColumns(List("name", "age")) } "should expand case classes (converted to tuple in parser)" in { @@ -132,7 +132,7 @@ class ExpandReturningSpec extends Spec { val ret = ExpandReturning.applyMap(qi.ast.asInstanceOf[Returning]) { case (ast, stmt) => Expand(ctx, ast, stmt, mi, Literal, ExecutionType.Unknown).string - }(mi, Literal, TranspileConfig.Empty) + }(mi, Literal, IdiomContext.Empty) ret mustBe ReturnColumns(List("name", "age")) } } @@ -162,14 +162,14 @@ class ExpandReturningSpec extends Spec { assertThrows[IllegalArgumentException] { ExpandReturning.applyMap(retMulti) { case (ast, stmt) => Expand(ctx, ast, stmt, mi, Literal, ExecutionType.Unknown).string - }(mi, Literal, TranspileConfig.Empty) + }(mi, Literal, IdiomContext.Empty) } } "should succeed if single field encountered" in { val ret = ExpandReturning.applyMap(retSingle) { case (ast, stmt) => Expand(ctx, ast, stmt, mi, Literal, ExecutionType.Unknown).string - }(mi, Literal, TranspileConfig.Empty) + }(mi, Literal, IdiomContext.Empty) ret mustBe ReturnColumns(List("name")) } } @@ -181,14 +181,14 @@ class ExpandReturningSpec extends Spec { assertThrows[IllegalArgumentException] { ExpandReturning.applyMap(retMulti) { case (ast, stmt) => Expand(ctx, ast, stmt, mi, Literal, ExecutionType.Unknown).string - }(mi, Literal, TranspileConfig.Empty) + }(mi, Literal, IdiomContext.Empty) } } "should fail if single field encountered" in { assertThrows[IllegalArgumentException] { ExpandReturning.applyMap(retSingle) { case (ast, stmt) => Expand(ctx, ast, stmt, mi, Literal, ExecutionType.Unknown).string - }(mi, Literal, TranspileConfig.Empty) + }(mi, Literal, IdiomContext.Empty) } } } diff --git a/quill-core/src/test/scala/io/getquill/norm/StablizeLiftsSpec.scala b/quill-core/src/test/scala/io/getquill/norm/StablizeLiftsSpec.scala index 7a452dae7c..7efd0f418a 100644 --- a/quill-core/src/test/scala/io/getquill/norm/StablizeLiftsSpec.scala +++ b/quill-core/src/test/scala/io/getquill/norm/StablizeLiftsSpec.scala @@ -17,7 +17,7 @@ class StablizeLiftsSpec extends Spec { val astQuat = quatOf[Int] val (stablized, state) = StablizeLifts.stablize(ast) stablized must matchPattern { - case ScalarValueLift("scalarValue", StablizeLifts.Token(0), _, `astQuat`) => + case ScalarValueLift("scalarValue", External.Source.Parser, StablizeLifts.Token(0), _, `astQuat`) => } state.replaceTable mustEqual (IMap(StablizeLifts.Token(0) -> scalarValue)) StablizeLifts.revert(stablized, state) mustEqual (ast) @@ -39,7 +39,7 @@ class StablizeLiftsSpec extends Spec { val astQuat = quatOf[Foo] val (stablized, state) = StablizeLifts.stablize(ast) stablized must matchPattern { - case CaseClassValueLift("caseClass", StablizeLifts.Token(0), `astQuat`) => + case CaseClassValueLift("caseClass", "caseClass", StablizeLifts.Token(0), `astQuat`) => } state.replaceTable mustEqual (IMap(StablizeLifts.Token(0) -> caseClass)) StablizeLifts.revert(stablized, state) mustEqual (ast) @@ -66,8 +66,8 @@ class StablizeLiftsSpec extends Spec { val (stablized, state) = StablizeLifts.stablize(ast) stablized must matchPattern { case BinaryOperation( - ScalarValueLift("a", StablizeLifts.Token(0), _, `quatA`), - StringOperator.`+`, ScalarValueLift("b", StablizeLifts.Token(1), _, `quatB`)) => + ScalarValueLift("a", External.Source.Parser, StablizeLifts.Token(0), _, `quatA`), + StringOperator.`+`, ScalarValueLift("b", External.Source.Parser, StablizeLifts.Token(1), _, `quatB`)) => } val expectedTable = IMap(StablizeLifts.Token(0) -> a, StablizeLifts.Token(1) -> b) state.replaceTable must contain theSameElementsAs (expectedTable) diff --git a/quill-engine/src/main/scala/io/getquill/AstPrinter.scala b/quill-engine/src/main/scala/io/getquill/AstPrinter.scala index 0d503ca26e..7c255a1c31 100644 --- a/quill-engine/src/main/scala/io/getquill/AstPrinter.scala +++ b/quill-engine/src/main/scala/io/getquill/AstPrinter.scala @@ -111,7 +111,7 @@ class AstPrinter(traceOpinions: Boolean, traceAstSimple: Boolean, traceQuats: Qu case q: Quat => Tree.Literal(q.shortString) - case s: ScalarValueLift => Tree.Apply("ScalarValueLift", treemake("..." + s.name.reverse.take(15).reverse).withQuat(s.bestQuat).make) + case s: ScalarValueLift => Tree.Apply("ScalarValueLift", treemake(s.name, s.source).withQuat(s.bestQuat).make) case p: Property if (traceOpinions) => TreeApplyList("Property", l(treeify(p.ast)) ++ l(treeify(p.name)) ++ diff --git a/quill-engine/src/main/scala/io/getquill/H2Dialect.scala b/quill-engine/src/main/scala/io/getquill/H2Dialect.scala index f30a098271..9909d6bb15 100644 --- a/quill-engine/src/main/scala/io/getquill/H2Dialect.scala +++ b/quill-engine/src/main/scala/io/getquill/H2Dialect.scala @@ -8,7 +8,6 @@ import io.getquill.context.{ CanInsertReturningWithMultiValues, CanInsertWithMul import io.getquill.context.sql.idiom.PositionalBindVariables import io.getquill.context.sql.idiom.SqlIdiom import io.getquill.context.sql.idiom.ConcatSupport -import io.getquill.norm.TranspileConfig import io.getquill.util.Messages.fail trait H2Dialect @@ -24,13 +23,13 @@ trait H2Dialect override def prepareForProbing(string: String) = s"PREPARE p${preparedStatementId.incrementAndGet.toString.token} AS $string}" - override def astTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, transpileConfig: TranspileConfig): Tokenizer[Ast] = + override def astTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, idiomContext: IdiomContext): Tokenizer[Ast] = Tokenizer[Ast] { case c: OnConflict => c.token case ast => super.astTokenizer.token(ast) } - implicit def conflictTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, transpileConfig: TranspileConfig): Tokenizer[OnConflict] = { + implicit def conflictTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, idiomContext: IdiomContext): Tokenizer[OnConflict] = { import OnConflict._ def tokenizer(implicit astTokenizer: Tokenizer[Ast]) = Tokenizer[OnConflict] { diff --git a/quill-engine/src/main/scala/io/getquill/IdiomContext.scala b/quill-engine/src/main/scala/io/getquill/IdiomContext.scala new file mode 100644 index 0000000000..ce1d47d836 --- /dev/null +++ b/quill-engine/src/main/scala/io/getquill/IdiomContext.scala @@ -0,0 +1,64 @@ +package io.getquill + +import io.getquill.ast.{ Ast, CollectAst } +import io.getquill.norm.TranspileConfig +import io.getquill.ast + +case class IdiomContext(config: TranspileConfig, queryType: IdiomContext.QueryType) { + def traceConfig = config.traceConfig +} + +object IdiomContext { + def Empty = IdiomContext(TranspileConfig.Empty, QueryType.Insert) + sealed trait QueryType { + def isBatch: Boolean + def batchAlias: Option[String] + } + object QueryType { + case object Select extends Regular + case object Insert extends Regular + case object Update extends Regular + case object Delete extends Regular + + case class BatchInsert(foreachAlias: String) extends Batch { val batchAlias = Some(foreachAlias) } + case class BatchUpdate(foreachAlias: String) extends Batch { val batchAlias = Some(foreachAlias) } + + sealed trait Regular extends QueryType { val isBatch = false; val batchAlias = None } + sealed trait Batch extends QueryType { val isBatch = true } + + object Regular { + def unapply(qt: QueryType): Boolean = + qt match { + case r: Regular => true + case _ => false + } + } + + object Batch { + def unapply(qt: QueryType) = + qt match { + case BatchInsert(foreachAlias) => Some(foreachAlias) + case BatchUpdate(foreachAlias) => Some(foreachAlias) + case _ => None + } + } + + def discoverFromAst(theAst: Ast, batchAlias: Option[String]): QueryType = { + val actions = + CollectAst(theAst) { + case _: ast.Insert => QueryType.Insert + case _: ast.Update => QueryType.Update + case _: ast.Delete => QueryType.Delete + } + if (actions.length > 1) println(s"[WARN] Found more then one type of Query: ${actions}. Using 1st one!") + // if we have not found it to specifically be an action, it must just be a regular select query + val resultType: QueryType.Regular = actions.headOption.getOrElse(QueryType.Select) + resultType match { + case QueryType.Insert => batchAlias.map(QueryType.BatchInsert(_)) getOrElse QueryType.Insert + case QueryType.Update => batchAlias.map(QueryType.BatchUpdate(_)) getOrElse QueryType.Update + case QueryType.Delete => Delete + case QueryType.Select => Select + } + } + } +} \ No newline at end of file diff --git a/quill-engine/src/main/scala/io/getquill/MirrorIdiom.scala b/quill-engine/src/main/scala/io/getquill/MirrorIdiom.scala index 8aa6819913..d0e04f09a0 100644 --- a/quill-engine/src/main/scala/io/getquill/MirrorIdiom.scala +++ b/quill-engine/src/main/scala/io/getquill/MirrorIdiom.scala @@ -6,9 +6,10 @@ import io.getquill.ast.{ Action => AstAction, Query => AstQuery, _ } import io.getquill.context.{ CanReturnClause, ExecutionType } import io.getquill.idiom.{ Idiom, SetContainsToken, Statement } import io.getquill.idiom.StatementInterpolator._ -import io.getquill.norm.{ Normalize, NormalizeCaching, TranspileConfig } +import io.getquill.norm.{ Normalize, NormalizeCaching } import io.getquill.quat.Quat import io.getquill.util.Interleave +import io.getquill.IdiomContext object MirrorIdiom extends MirrorIdiom class MirrorIdiom extends MirrorIdiomBase with CanReturnClause @@ -25,14 +26,14 @@ trait MirrorIdiomBase extends Idiom { override def liftingPlaceholder(index: Int): String = "?" - override def translateCached(ast: Ast, topLevelQuat: Quat, executionType: ExecutionType, transpileConfig: TranspileConfig)(implicit naming: NamingStrategy): (Ast, Statement, ExecutionType) = { - val normalize = new Normalize(transpileConfig) + override def translateCached(ast: Ast, topLevelQuat: Quat, executionType: ExecutionType, idiomContext: IdiomContext)(implicit naming: NamingStrategy): (Ast, Statement, ExecutionType) = { + val normalize = new Normalize(idiomContext.config) val normalizedAst = NormalizeCaching(normalize.apply)(ast) (normalizedAst, stmt"${normalizedAst.token}", executionType) } - override def translate(ast: Ast, topLevelQuat: Quat, executionType: ExecutionType, transpileConfig: TranspileConfig)(implicit naming: NamingStrategy): (Ast, Statement, ExecutionType) = { - val normalize = new Normalize(transpileConfig) + override def translate(ast: Ast, topLevelQuat: Quat, executionType: ExecutionType, idiomContext: IdiomContext)(implicit naming: NamingStrategy): (Ast, Statement, ExecutionType) = { + val normalize = new Normalize(idiomContext.config) val normalizedAst = normalize(ast) (normalizedAst, stmt"${normalizedAst.token}", executionType) } diff --git a/quill-engine/src/main/scala/io/getquill/MySQLDialect.scala b/quill-engine/src/main/scala/io/getquill/MySQLDialect.scala index 47740d45e3..1b5e76dc17 100644 --- a/quill-engine/src/main/scala/io/getquill/MySQLDialect.scala +++ b/quill-engine/src/main/scala/io/getquill/MySQLDialect.scala @@ -7,7 +7,6 @@ import io.getquill.context.sql.idiom.SqlIdiom.ActionTableAliasBehavior import io.getquill.context.sql.idiom.{ NoConcatSupport, QuestionMarkBindVariables, SqlIdiom } import io.getquill.idiom.StatementInterpolator._ import io.getquill.idiom.{ Statement, Token } -import io.getquill.norm.TranspileConfig import io.getquill.util.Messages.fail trait MySQLDialect @@ -27,13 +26,13 @@ trait MySQLDialect override def defaultAutoGeneratedToken(field: Token) = stmt"($field) VALUES (DEFAULT)" - override def astTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, transpileConfig: TranspileConfig): Tokenizer[Ast] = + override def astTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, idiomContext: IdiomContext): Tokenizer[Ast] = Tokenizer[Ast] { case c: OnConflict => c.token case ast => super.astTokenizer.token(ast) } - implicit def conflictTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, transpileConfig: TranspileConfig): Tokenizer[OnConflict] = { + implicit def conflictTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, idiomContext: IdiomContext): Tokenizer[OnConflict] = { import OnConflict._ lazy val insertIgnoreTokenizer = @@ -54,7 +53,7 @@ trait MySQLDialect stmt"${i.token} ON DUPLICATE KEY UPDATE $assignments" case OnConflict(i: io.getquill.ast.Action, NoTarget, Ignore) => - actionTokenizer(insertIgnoreTokenizer)(actionAstTokenizer, strategy, transpileConfig).token(i) + actionTokenizer(insertIgnoreTokenizer)(actionAstTokenizer, strategy, idiomContext).token(i) case _ => fail("This upsert construct is not supported in MySQL. Please refer documentation for details.") @@ -62,7 +61,7 @@ trait MySQLDialect // TODO Are there situations where you could have invisible properties here? val customAstTokenizer = - Tokenizer.withFallback[Ast](MySQLDialect.this.astTokenizer(_, strategy, transpileConfig)) { + Tokenizer.withFallback[Ast](MySQLDialect.this.astTokenizer(_, strategy, idiomContext)) { case Property.Opinionated(Excluded(_), name, renameable, _) => renameable.fixedOr(name.token)(stmt"VALUES(${strategy.column(name).token})") diff --git a/quill-engine/src/main/scala/io/getquill/OracleDialect.scala b/quill-engine/src/main/scala/io/getquill/OracleDialect.scala index 529ee17736..ac83eb5dfb 100644 --- a/quill-engine/src/main/scala/io/getquill/OracleDialect.scala +++ b/quill-engine/src/main/scala/io/getquill/OracleDialect.scala @@ -8,7 +8,7 @@ import io.getquill.context.sql.idiom._ import io.getquill.idiom.StatementInterpolator._ import io.getquill.idiom.{ Statement, StringToken, Token } import io.getquill.norm.ConcatBehavior.NonAnsiConcat -import io.getquill.norm.{ ConcatBehavior, TranspileConfig } +import io.getquill.norm.ConcatBehavior import io.getquill.sql.idiom.BooleanLiteralSupport trait OracleDialect @@ -22,8 +22,8 @@ trait OracleDialect override def useActionTableAliasAs: ActionTableAliasBehavior = ActionTableAliasBehavior.Hide - class OracleFlattenSqlQueryTokenizerHelper(q: FlattenSqlQuery)(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, transpileConfig: TranspileConfig) - extends FlattenSqlQueryTokenizerHelper(q)(astTokenizer, strategy, transpileConfig) { + class OracleFlattenSqlQueryTokenizerHelper(q: FlattenSqlQuery)(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, idiomContext: IdiomContext) + extends FlattenSqlQueryTokenizerHelper(q)(astTokenizer, strategy, idiomContext) { import q._ override def withFrom: Statement = from match { @@ -34,7 +34,7 @@ trait OracleDialect } } - override implicit def sqlQueryTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, transpileConfig: TranspileConfig): Tokenizer[SqlQuery] = Tokenizer[SqlQuery] { + override implicit def sqlQueryTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, idiomContext: IdiomContext): Tokenizer[SqlQuery] = Tokenizer[SqlQuery] { case q: FlattenSqlQuery => new OracleFlattenSqlQueryTokenizerHelper(q).apply case other => diff --git a/quill-engine/src/main/scala/io/getquill/PostgresDialect.scala b/quill-engine/src/main/scala/io/getquill/PostgresDialect.scala index affa3b2932..6206cb7edf 100644 --- a/quill-engine/src/main/scala/io/getquill/PostgresDialect.scala +++ b/quill-engine/src/main/scala/io/getquill/PostgresDialect.scala @@ -1,11 +1,21 @@ package io.getquill import java.util.concurrent.atomic.AtomicInteger -import io.getquill.ast._ +import io.getquill.ast.{ Action, Query, _ } +import io.getquill.ast +import io.getquill.context.sql.idiom +import io.getquill.context.sql.idiom.SqlIdiom.{ InsertUpdateStmt, copyIdiom } import io.getquill.context.{ CanInsertReturningWithMultiValues, CanInsertWithMultiValues, CanReturnClause } import io.getquill.context.sql.idiom._ +import io.getquill.idiom.{ ScalarTagToken, Statement, Token, ValuesClauseToken } import io.getquill.idiom.StatementInterpolator._ -import io.getquill.norm.{ ProductAggregationToken, TranspileConfig } +import io.getquill.norm.{ BetaReduction, ExpandReturning, ProductAggregationToken } +import io.getquill.quat.Quat +import io.getquill.sql.norm.NormalizeFilteredActionAliases +import io.getquill.util.Messages.fail + +import scala.annotation.tailrec +import scala.collection.immutable.{ ListMap, ListSet, Queue } trait PostgresDialect extends SqlIdiom @@ -18,7 +28,7 @@ trait PostgresDialect override protected def productAggregationToken: ProductAggregationToken = ProductAggregationToken.VariableDotStar - override def astTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, transpileConfig: TranspileConfig): Tokenizer[Ast] = + override def astTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, idiomContext: IdiomContext): Tokenizer[Ast] = Tokenizer[Ast] { case ListContains(ast, body) => stmt"${body.token} = ANY(${ast.token})" case c: OnConflict => conflictTokenizer.token(c) @@ -42,6 +52,184 @@ trait PostgresDialect }) s"PREPARE p${preparedStatementId.incrementAndGet.toString.token} AS $query" } + + private[getquill] case class ReplaceReturningAlias(batchAlias: String) extends StatelessTransformer { + override def apply(e: ast.Action): ast.Action = + e match { + case Returning(action, alias, property) => + val newAlias = alias.copy(name = batchAlias) + val newProperty = BetaReduction(property, alias -> newAlias) + Returning(action, newAlias, newProperty) + case ReturningGenerated(action, alias, property) => + val newAlias = alias.copy(name = batchAlias) + val newProperty = BetaReduction(property, alias -> newAlias) + ReturningGenerated(action, newAlias, newProperty) + case _ => super.apply(e) + } + } + + override protected def actionTokenizer(insertEntityTokenizer: Tokenizer[Entity])(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, idiomContext: IdiomContext): Tokenizer[ast.Action] = + Tokenizer[ast.Action] { + // Don't need to check if this is supported, we know it is since it's postgres. + // Also, only do it for updates, for inserts we don't want the Returning Alias to be the returning-clause otherwise + // it would be something like + // INSERT ... RETURNING {batchAlias}.property + // which we don't want. We want it to just be: + // INSERT ... RETURNING property + + case returning @ ReturningAction(action: ast.Update, alias, prop) if (idiomContext.queryType.isBatch) => + val batchAlias = + idiomContext.queryType.batchAlias.getOrElse { + throw new IllegalArgumentException(s"Batch alias not found in the action: ${idiomContext.queryType} but it is a batch context. This should not be possible.") + } + val returningNew = ReplaceReturningAlias(batchAlias)(returning).asInstanceOf[ReturningAction] + stmt"${(action: Ast).token} RETURNING ${tokenizeReturningClause(returningNew, Some(returningNew.alias.name))}" + + case ConcatableBatchUpdate(output) => + output + + case other => + super.actionTokenizer(insertEntityTokenizer).token(other) + } + + protected def specialPropertyTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, idiomContext: IdiomContext) = + Tokenizer.withFallback[Ast](this.astTokenizer(_, strategy, idiomContext)) { + case p: Property => this.propertyTokenizer.token(p) + } + + object ConcatableBatchUpdate { + + private[getquill] def columnsAndValuesTogether(assignments: List[Assignment])(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, idiomContext: IdiomContext) = { + assignments.map(assignment => + assignment.property match { + case Property.Opinionated(_, key, renameable, visibility) => + ( + tokenizeColumn(strategy, key, renameable).token, + specialPropertyTokenizer.token(assignment.value) + ) + case _ => fail(s"Invalid assignment value of ${assignment}. Must be a Property object.") + }) + } + + //case class UpdateWithValues(action: Statement, where: Statement) + def unapply(action: ast.Update)(implicit actionAstTokenizer: Tokenizer[Ast], strategy: NamingStrategy, idiomContext: IdiomContext): Option[Statement] = { + + // Typical Postgres batch update syntax + // UPDATE people AS p SET id = p.id, name = p.name, age = p.age + // FROM (values (1, 'Joe', 111), (2, 'Jack', 222)) + // AS c(id, name, age) + // WHERE c.id = p.id + + // Uses the `alias` passed in as `actionAlias` since that is now assigned to the copied SqlIdiom + (action, idiomContext.queryType) match { + case (clause @ Update(Filter(table: Entity, origTableAlias, _), _), IdiomContext.QueryType.Batch(batchAlias)) => + + // Original Query looks like: + // liftQuery(people).foreach(ps => query[Person].filter(p => p.id == ps.id).update(_.name -> ps.name)) + // This has already been transpiled to (foreach part has been removed): + // query[Person].filter(p => p.id == STag(A)).update(_.name -> STag(B)) + // SQL Needs to look like: + // UPDATE person AS p SET name = ps.name FROM (VALUES ('Joe', 123)) AS ps(name, id) WHERE ps.id = p.id + // I.e. + // UPDATE person AS p SET name = ps.name FROM (VALUES (STag(B), STag(A))) AS ps(name, id) WHERE ps.id = p.id + // Conceptually, that means the query needs to look like: + // query[Person].filter(p => p.id == ps.id).update(_.name -> ps.id) with VALUES (STag(B), STag(A)) + // We don't actually change it to this, we yield the SQL directly but it is a good conceptual model + + // Let's consider this odd case for all examples. There could have the same id-column name in multiple places. + // (NOTE: STag := ScalarTag, the UUIDs are random so I am just assigning numbers to them for reference. Also when the query is tokenize then turn into `?`) + // (Also [stuff] is short for List(stuff) syntax) + // Need to work around how that happens + // liftQuery(people).foreach(ps => query[Person].filter(p => p.id == ps.id).update(_.name -> ps.name, _.id -> ps.id) + // This has already been transpiled to (foreach part has been removed): + // query[Person].filter(p => p.id == STag(uid:3)).update(_.name -> STag(uid:1), _.id -> STag(uid:2)) + // For now, blindly shove the name into the aliases section and dedupe + // UPDATE person AS p SET name = ps.name, id = ps.id FROM (VALUES ('Joe', 123, 123)) AS ps(name, id, id1) WHERE ps.id = p.id1 + // This should actually be + // UPDATE person AS p SET name = ps.name, id = ps.id FROM (VALUES (STag(uid:1), STag(uid:2), STag(uid:3))) AS ps(name, id, id1) WHERE ps.id = p.id1 + // (note `ps` is the batchAlias var) + + // replacedWhere: + // All the lifts in the WHERE clause that we need to put into the actual VALUES clause instead + // Originally was `WHERE ps.id = STag(uid:3)` + // (replacedWhere: `WHERE ps.id = p.id1`, additionalColumns: [id] /*and any other column names of STags in WHERE*/, additionalLifts: [STag(uid:3)]) + val (Update(Filter(table: Entity, tableAlias, replacedWhere), assignments), valuesColumns, valuesLifts) = { + ReplaceLiftings.of(clause)(batchAlias, List()) + } + + // The SET columns/values i.e. ([name, id], [STag(uid:1), STag(uid:2)] + val columnsAndValues = columnsAndValuesTogether(assignments) + // the `ps` + val colsId = batchAlias + // The columns that go in the SET clause i.e. `SET name = ps.name, id = ps.id` + val setColumns = columnsAndValues.map { case (column, value) => stmt"$column = ${value}" }.mkStmt(", ") + // The columns that go inside ps(name, id, id1) i.e. stmt"name, id, id1" + val asColumns = valuesColumns.toList.mkStmt(", ") + val output = stmt"UPDATE ${table.token} AS ${tableAlias.token} SET $setColumns FROM (VALUES ${ValuesClauseToken(stmt"(${valuesLifts.toList.map(v => v: External).mkStmt(", ")})")}) AS ${colsId.token}($asColumns) WHERE ${specialPropertyTokenizer.token(replacedWhere)}" + Some(output) + + case (clause @ Update(_: Entity, _), IdiomContext.QueryType.Batch(batchAlias)) => + val (Update(table: Entity, assignments), valuesColumns, valuesLifts) = + ReplaceLiftings.of(clause)(batchAlias, List()) + // Choose table alias based on how assignments clauses were realized. Batch-Alias should mean the same thing as when NormalizeFilteredActionAliases was run in Idiom should the + // value should be the same thing as the cluases that were realiased. + val tableAlias = NormalizeFilteredActionAliases.chooseAlias(table.name, Some(batchAlias)) + val colsId = batchAlias + val columnsAndValues = columnsAndValuesTogether(assignments) + val setColumns = columnsAndValues.map { case (column, value) => stmt"$column = ${value}" }.mkStmt(", ") + val asColumns = valuesColumns.toList.mkStmt(", ") + val output = stmt"UPDATE ${table.token} AS ${tableAlias.token} SET $setColumns FROM (VALUES ${ValuesClauseToken(stmt"(${valuesLifts.toList.map(v => v: External).mkStmt(", ")})")}) AS ${colsId.token}($asColumns)" + Some(output) + + case _ => + None + } + } + } } object PostgresDialect extends PostgresDialect + +case class ReplaceAssignmentAliases(newAlias: Ident) extends StatelessTransformer { + override def apply(e: Assignment): Assignment = + Assignment(newAlias, BetaReduction(e.property, e.alias -> newAlias), BetaReduction(e.value, e.alias -> newAlias)) +} + +case class ReplaceLiftings(foreachIdentName: String, existingColumnNames: List[String], state: ListMap[String, ScalarTag]) extends StatefulTransformer[ListMap[String, ScalarTag]] { + + private def columnExists(col: String) = + existingColumnNames.contains(col) || state.keySet.contains(col) + + def freshIdent(newCol: String) = { + @tailrec + def loop(id: String, n: Int): String = { + val fresh = s"${id}${n}" + if (!columnExists(fresh)) + fresh + else + loop(id, n + 1) + } + if (!columnExists(newCol)) + newCol + else + loop(newCol, 1) + } + + private def parseName(name: String) = + name.replace(".", "_") + + override def apply(e: Ast): (Ast, StatefulTransformer[ListMap[String, ScalarTag]]) = + e match { + case lift @ ScalarTag(_, External.Source.UnparsedProperty(propNameRaw)) => + val id = Ident(foreachIdentName, lift.quat) + val propName = freshIdent(propNameRaw) + (Property(id, propName), ReplaceLiftings(foreachIdentName, existingColumnNames, state + (propName -> lift))) + case _ => super.apply(e) + } +} +object ReplaceLiftings { + def of(ast: Ast)(foreachIdent: String, existingColumnNames: List[String]) = { + val (newAst, transform) = new ReplaceLiftings(foreachIdent, existingColumnNames, ListMap()).apply(ast) + (newAst, transform.state.map(_._1), transform.state.map(_._2)) + } +} diff --git a/quill-engine/src/main/scala/io/getquill/SQLServerDialect.scala b/quill-engine/src/main/scala/io/getquill/SQLServerDialect.scala index cc48fcfc5d..49c659538c 100644 --- a/quill-engine/src/main/scala/io/getquill/SQLServerDialect.scala +++ b/quill-engine/src/main/scala/io/getquill/SQLServerDialect.scala @@ -9,7 +9,7 @@ import io.getquill.context.sql.{ FlattenSqlQuery, SqlQuery, SqlQueryApply } import io.getquill.idiom.StatementInterpolator._ import io.getquill.idiom.{ Statement, StringToken, Token } import io.getquill.norm.EqualityBehavior.NonAnsiEquality -import io.getquill.norm.{ EqualityBehavior, TranspileConfig } +import io.getquill.norm.EqualityBehavior import io.getquill.sql.idiom.BooleanLiteralSupport import io.getquill.util.Messages.fail import io.getquill.util.TraceConfig @@ -25,7 +25,7 @@ trait SQLServerDialect override def useActionTableAliasAs: ActionTableAliasBehavior = ActionTableAliasBehavior.Hide - override def querifyAst(ast: Ast, transpileConfig: TraceConfig) = AddDropToNestedOrderBy(new SqlQueryApply(transpileConfig)(ast)) + override def querifyAst(ast: Ast, idiomContext: TraceConfig) = AddDropToNestedOrderBy(new SqlQueryApply(idiomContext)(ast)) override def emptySetContainsToken(field: Token) = StringToken("1 <> 1") @@ -43,7 +43,7 @@ trait SQLServerDialect case other => super.limitOffsetToken(query).token(other) } - override implicit def sqlQueryTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, transpileConfig: TranspileConfig): Tokenizer[SqlQuery] = + override implicit def sqlQueryTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, idiomContext: IdiomContext): Tokenizer[SqlQuery] = Tokenizer[SqlQuery] { case flatten: FlattenSqlQuery if flatten.orderBy.isEmpty && flatten.offset.nonEmpty => fail(s"SQLServer does not support OFFSET without ORDER BY") @@ -56,7 +56,7 @@ trait SQLServerDialect case other => super.operationTokenizer.token(other) } - override protected def actionTokenizer(insertEntityTokenizer: Tokenizer[Entity])(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, transpileConfig: TranspileConfig): Tokenizer[ast.Action] = + override protected def actionTokenizer(insertEntityTokenizer: Tokenizer[Entity])(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, idiomContext: IdiomContext): Tokenizer[ast.Action] = Tokenizer[ast.Action] { // Update(Filter(...)) and Delete(Filter(...)) usually cause a table alias i.e. `UPDATE People SET ... WHERE ...` or `DELETE FROM People WHERE ...` // since the alias is used in the WHERE clause. This functionality removes that because SQLServer doesn't support aliasing in actions. diff --git a/quill-engine/src/main/scala/io/getquill/SqliteDialect.scala b/quill-engine/src/main/scala/io/getquill/SqliteDialect.scala index 16d2deae49..2be8f86b59 100644 --- a/quill-engine/src/main/scala/io/getquill/SqliteDialect.scala +++ b/quill-engine/src/main/scala/io/getquill/SqliteDialect.scala @@ -7,7 +7,6 @@ import io.getquill.idiom.{ StringToken, Token } import io.getquill.ast._ import io.getquill.context.{ CanInsertReturningWithSingleValue, CanInsertWithMultiValues, CanReturnField } import io.getquill.context.sql.OrderByCriteria -import io.getquill.norm.TranspileConfig trait SqliteDialect extends SqlIdiom @@ -22,7 +21,7 @@ trait SqliteDialect override def prepareForProbing(string: String) = s"sqlite3_prepare_v2($string)" - override def astTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, transpileConfig: TranspileConfig): Tokenizer[Ast] = + override def astTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, idiomContext: IdiomContext): Tokenizer[Ast] = Tokenizer[Ast] { case c: OnConflict => conflictTokenizer.token(c) case ast => super.astTokenizer.token(ast) diff --git a/quill-engine/src/main/scala/io/getquill/ast/Ast.scala b/quill-engine/src/main/scala/io/getquill/ast/Ast.scala index 11a8ac7b01..14065873ec 100644 --- a/quill-engine/src/main/scala/io/getquill/ast/Ast.scala +++ b/quill-engine/src/main/scala/io/getquill/ast/Ast.scala @@ -718,6 +718,13 @@ object Dynamic { case class QuotedReference(tree: Any, ast: Ast) extends Ast { def quat = ast.quat; def bestQuat = ast.bestQuat; } sealed trait External extends Ast +object External { + sealed trait Source + object Source { + case class UnparsedProperty(name: String) extends Source + case object Parser extends Source + } +} /***********************************************************************/ /* Only Quill 2 */ @@ -732,26 +739,26 @@ sealed trait ScalarLift extends Lift with Terminal { val encoder: Any } -final class ScalarValueLift(val name: String, val value: Any, val encoder: Any)(theQuat: => Quat) +final class ScalarValueLift(val name: String, val source: External.Source, val value: Any, val encoder: Any)(theQuat: => Quat) extends ScalarLift { def quat: Quat = theQuat def bestQuat = quat override def withQuat(quat: => Quat) = this.copy(quat = quat) - private val id = ScalarValueLift.Id(name, value, encoder) + private val id = ScalarValueLift.Id(name, source, value, encoder) override def hashCode(): Int = id.hashCode() override def equals(obj: Any): Boolean = obj match { case e: ScalarValueLift => e.id == this.id case _ => false } - def copy(name: String = this.name, value: Any = this.value, encoder: Any = this.encoder, quat: => Quat = this.quat) = - ScalarValueLift(name, value, encoder, quat) + def copy(name: String = this.name, source: External.Source = this.source, value: Any = this.value, encoder: Any = this.encoder, quat: => Quat = this.quat) = + ScalarValueLift(name, source, value, encoder, quat) } object ScalarValueLift { - private case class Id(name: String, value: Any, encoder: Any) - def apply(name: String, value: Any, encoder: Any, quat: => Quat): ScalarValueLift = new ScalarValueLift(name, value, encoder)(quat) - def unapply(svl: ScalarValueLift) = Some((svl.name, svl.value, svl.encoder, svl.quat)) + private case class Id(name: String, source: External.Source, value: Any, encoder: Any) + def apply(name: String, source: External.Source, value: Any, encoder: Any, quat: => Quat): ScalarValueLift = new ScalarValueLift(name, source, value, encoder)(quat) + def unapply(svl: ScalarValueLift) = Some((svl.name, svl.source, svl.value, svl.encoder, svl.quat)) } final class ScalarQueryLift(val name: String, val value: Any, val encoder: Any)(theQuat: => Quat) @@ -779,25 +786,25 @@ object ScalarQueryLift { sealed trait CaseClassLift extends Lift -final class CaseClassValueLift(val name: String, val value: Any)(theQuat: => Quat) extends CaseClassLift { +final class CaseClassValueLift(val name: String, val simpleName: String, val value: Any)(theQuat: => Quat) extends CaseClassLift { def quat: Quat = theQuat def bestQuat = quat override def withQuat(quat: => Quat) = this.copy(quat = quat) - private val id = CaseClassValueLift.Id(name, value) + private val id = CaseClassValueLift.Id(name, simpleName, value) override def hashCode(): Int = id.hashCode() override def equals(obj: Any): Boolean = obj match { case e: CaseClassValueLift => e.id == this.id case _ => false } - def copy(name: String = this.name, value: Any = this.value, quat: => Quat = this.quat) = - CaseClassValueLift(name, value, quat) + def copy(name: String = this.name, simpleName: String = this.simpleName, value: Any = this.value, quat: => Quat = this.quat) = + CaseClassValueLift(name, simpleName, value, quat) } object CaseClassValueLift { - private case class Id(name: String, value: Any) - def apply(name: String, value: Any, quat: => Quat): CaseClassValueLift = new CaseClassValueLift(name, value)(quat) - def unapply(l: CaseClassValueLift) = Some((l.name, l.value, l.quat)) + private case class Id(name: String, simpleName: String, value: Any) + def apply(name: String, simpleName: String, value: Any, quat: => Quat): CaseClassValueLift = new CaseClassValueLift(name, simpleName, value)(quat) + def unapply(l: CaseClassValueLift) = Some((l.name, l.simpleName, l.value, l.quat)) } final class CaseClassQueryLift(val name: String, val value: Any)(theQuat: => Quat) extends CaseClassLift { @@ -832,7 +839,7 @@ sealed trait Tag extends External { } case class ScalarTagId(uid: String) -case class ScalarTag(uid: String) extends Tag { +case class ScalarTag(uid: String, source: External.Source) extends Tag { def quat = Quat.Value def bestQuat = quat diff --git a/quill-engine/src/main/scala/io/getquill/context/cassandra/CqlIdiom.scala b/quill-engine/src/main/scala/io/getquill/context/cassandra/CqlIdiom.scala index d609d0ab02..97c446ef06 100644 --- a/quill-engine/src/main/scala/io/getquill/context/cassandra/CqlIdiom.scala +++ b/quill-engine/src/main/scala/io/getquill/context/cassandra/CqlIdiom.scala @@ -9,9 +9,10 @@ import io.getquill.idiom.StatementInterpolator._ import io.getquill.idiom.Statement import io.getquill.idiom.SetContainsToken import io.getquill.idiom.Token -import io.getquill.norm.{ NormalizeCaching, TranspileConfig } +import io.getquill.norm.NormalizeCaching import io.getquill.quat.Quat import io.getquill.util.Interleave +import io.getquill.IdiomContext object CqlIdiom extends CqlIdiom with CannotReturn @@ -23,14 +24,14 @@ trait CqlIdiom extends Idiom { override def prepareForProbing(string: String) = string - override def translate(ast: Ast, topLevelQuat: Quat, executionType: ExecutionType, transpileConfig: TranspileConfig)(implicit naming: NamingStrategy) = { - val cqlNormalize = new CqlNormalize(transpileConfig) + override def translate(ast: Ast, topLevelQuat: Quat, executionType: ExecutionType, idiomContext: IdiomContext)(implicit naming: NamingStrategy) = { + val cqlNormalize = new CqlNormalize(idiomContext.config) val normalizedAst = cqlNormalize(ast) (normalizedAst, stmt"${normalizedAst.token}", executionType) } - override def translateCached(ast: Ast, topLevelQuat: Quat, executionType: ExecutionType, transpileConfig: TranspileConfig)(implicit naming: NamingStrategy) = { - val cqlNormalize = new CqlNormalize(transpileConfig) + override def translateCached(ast: Ast, topLevelQuat: Quat, executionType: ExecutionType, idiomContext: IdiomContext)(implicit naming: NamingStrategy) = { + val cqlNormalize = new CqlNormalize(idiomContext.config) val normalizedAst = NormalizeCaching(cqlNormalize.apply)(ast) (normalizedAst, stmt"${normalizedAst.token}", executionType) } diff --git a/quill-engine/src/main/scala/io/getquill/idiom/Idiom.scala b/quill-engine/src/main/scala/io/getquill/idiom/Idiom.scala index 70b19ed03e..b628c91023 100644 --- a/quill-engine/src/main/scala/io/getquill/idiom/Idiom.scala +++ b/quill-engine/src/main/scala/io/getquill/idiom/Idiom.scala @@ -2,8 +2,8 @@ package io.getquill.idiom import io.getquill.ast._ import io.getquill.NamingStrategy -import io.getquill.context.{ IdiomReturningCapability, ExecutionType } -import io.getquill.norm.TranspileConfig +import io.getquill.IdiomContext +import io.getquill.context.{ ExecutionType, IdiomReturningCapability } import io.getquill.quat.Quat trait Idiom extends IdiomReturningCapability { @@ -14,9 +14,9 @@ trait Idiom extends IdiomReturningCapability { def liftingPlaceholder(index: Int): String - def translate(ast: Ast, topLevelQuat: Quat, executionType: ExecutionType, transpileConfig: TranspileConfig)(implicit naming: NamingStrategy): (Ast, Statement, ExecutionType) + def translate(ast: Ast, topLevelQuat: Quat, executionType: ExecutionType, transpileConfig: IdiomContext)(implicit naming: NamingStrategy): (Ast, Statement, ExecutionType) - def translateCached(ast: Ast, topLevelQuat: Quat, executionType: ExecutionType, transpileConfig: TranspileConfig)(implicit naming: NamingStrategy): (Ast, Statement, ExecutionType) + def translateCached(ast: Ast, topLevelQuat: Quat, executionType: ExecutionType, transpileConfig: IdiomContext)(implicit naming: NamingStrategy): (Ast, Statement, ExecutionType) def format(queryString: String): String = queryString diff --git a/quill-engine/src/main/scala/io/getquill/idiom/ReifyStatement.scala b/quill-engine/src/main/scala/io/getquill/idiom/ReifyStatement.scala index 31f8d7fca6..666ff72388 100644 --- a/quill-engine/src/main/scala/io/getquill/idiom/ReifyStatement.scala +++ b/quill-engine/src/main/scala/io/getquill/idiom/ReifyStatement.scala @@ -54,7 +54,7 @@ object ReifyStatement { lift.value.asInstanceOf[Iterable[Any]].toList match { case Nil => tokens :+ emptySetContainsToken(a) case values => - val liftings = values.map(v => ScalarLiftToken(ScalarValueLift(lift.name, v, lift.encoder, lift.quat))) + val liftings = values.map(v => ScalarLiftToken(ScalarValueLift(lift.name, External.Source.Parser, v, lift.encoder, lift.quat))) val separators = List.fill(liftings.size - 1)(StringToken(", ")) (tokens :+ stmt"$a $op (") ++ Interleave(liftings, separators) :+ StringToken(")") } @@ -164,7 +164,7 @@ object ReifyStatementWithInjectables { lift.value.asInstanceOf[Iterable[Any]].toList match { case Nil => tokens :+ emptySetContainsToken(a) case values => - val liftings = values.map(v => ScalarLiftToken(ScalarValueLift(lift.name, v, lift.encoder, lift.quat))) + val liftings = values.map(v => ScalarLiftToken(ScalarValueLift(lift.name, External.Source.Parser, v, lift.encoder, lift.quat))) val separators = List.fill(liftings.size - 1)(StringToken(", ")) (tokens :+ stmt"$a $op (") ++ Interleave(liftings, separators) :+ StringToken(")") } diff --git a/quill-engine/src/main/scala/io/getquill/norm/ExpandReturning.scala b/quill-engine/src/main/scala/io/getquill/norm/ExpandReturning.scala index 6f32534924..7776937798 100644 --- a/quill-engine/src/main/scala/io/getquill/norm/ExpandReturning.scala +++ b/quill-engine/src/main/scala/io/getquill/norm/ExpandReturning.scala @@ -5,7 +5,7 @@ import io.getquill.ast.Renameable.Fixed import io.getquill.ast._ import io.getquill.context._ import io.getquill.idiom.{ Idiom, Statement } -import io.getquill.{ NamingStrategy, ReturnAction } +import io.getquill.{ NamingStrategy, ReturnAction, IdiomContext } /** * Take the `.returning` part in a query that contains it and return the array of columns @@ -13,15 +13,15 @@ import io.getquill.{ NamingStrategy, ReturnAction } */ object ExpandReturning { - def applyMap(returning: ReturningAction)(f: (Ast, Statement) => String)(idiom: Idiom, naming: NamingStrategy, transpileConfig: TranspileConfig) = { + def applyMap(returning: ReturningAction)(f: (Ast, Statement) => String)(idiom: Idiom, naming: NamingStrategy, idiomContext: IdiomContext) = { idiom.idiomReturningCapability match { case ReturningClauseSupported | OutputClauseSupported => ReturnAction.ReturnRecord case ReturningMultipleFieldSupported => - val initialExpand = ExpandReturning(returning)(idiom, naming, transpileConfig) + val initialExpand = ExpandReturning(returning)(idiom, naming, idiomContext) ReturnColumns(initialExpand.map { case (ast, statement) => f(ast, statement) }) case ReturningSingleFieldSupported => - val initialExpand = ExpandReturning(returning)(idiom, naming, transpileConfig) + val initialExpand = ExpandReturning(returning)(idiom, naming, idiomContext) if (initialExpand.length == 1) ReturnColumns(initialExpand.map { case (ast, statement) => f(ast, statement) }) else @@ -31,7 +31,7 @@ object ExpandReturning { } } - def apply(returning: ReturningAction, renameAlias: Option[String] = None)(idiom: Idiom, naming: NamingStrategy, transpileConfig: TranspileConfig): List[(Ast, Statement)] = { + def apply(returning: ReturningAction, renameAlias: Option[String] = None)(idiom: Idiom, naming: NamingStrategy, idiomContext: IdiomContext): List[(Ast, Statement)] = { val ReturningAction(_, alias, properties) = returning // Ident("j"), Tuple(List(Property(Ident("j"), "name"), BinaryOperation(Property(Ident("j"), "age"), +, Constant(1)))) @@ -53,7 +53,7 @@ object ExpandReturning { implicit val namingStrategy: NamingStrategy = naming // TODO Should propagate ExecutionType from caller of this method. Need to trace - val outputs = deTuplified.map(v => idiom.translate(v, dePropertized.quat, ExecutionType.Unknown, transpileConfig)) + val outputs = deTuplified.map(v => idiom.translate(v, dePropertized.quat, ExecutionType.Unknown, idiomContext)) outputs.map { case (a, b, _) => (a, b) } diff --git a/quill-engine/src/main/scala/io/getquill/sql/SqlQuery.scala b/quill-engine/src/main/scala/io/getquill/sql/SqlQuery.scala index c147ed7d54..d5436fa7fe 100644 --- a/quill-engine/src/main/scala/io/getquill/sql/SqlQuery.scala +++ b/quill-engine/src/main/scala/io/getquill/sql/SqlQuery.scala @@ -2,11 +2,11 @@ package io.getquill.context.sql import io.getquill.ast._ import io.getquill.context.sql.norm.{ ExpandSelection, FlattenGroupByAggregation } -import io.getquill.norm.{ BetaReduction, TranspileConfig } +import io.getquill.norm.BetaReduction import io.getquill.quat.Quat import io.getquill.util.{ Interpolator, TraceConfig } import io.getquill.util.Messages.{ TraceType, fail } -import io.getquill.{ Literal, PseudoAst } +import io.getquill.{ Literal, PseudoAst, IdiomContext } import io.getquill.sql.Common.ContainsImpurities case class OrderByCriteria(ast: Ast, ordering: PropertyOrdering) @@ -24,7 +24,7 @@ sealed trait SqlQuery { override def toString = { import io.getquill.MirrorSqlDialect._ import io.getquill.idiom.StatementInterpolator._ - implicit val transpileConfig = TranspileConfig.Empty + implicit val idiomContext = IdiomContext.Empty implicit val naming: Literal = Literal implicit val tokenizer: Tokenizer[Ast] = defaultTokenizer this.token.toString diff --git a/quill-engine/src/main/scala/io/getquill/sql/idiom/BooleanLiteralSupport.scala b/quill-engine/src/main/scala/io/getquill/sql/idiom/BooleanLiteralSupport.scala index 51d2303e02..0b2dfd84f2 100644 --- a/quill-engine/src/main/scala/io/getquill/sql/idiom/BooleanLiteralSupport.scala +++ b/quill-engine/src/main/scala/io/getquill/sql/idiom/BooleanLiteralSupport.scala @@ -6,15 +6,16 @@ import io.getquill.context.sql.idiom.SqlIdiom import io.getquill.context.sql.norm.SqlNormalize import io.getquill.idiom.StatementInterpolator._ import io.getquill.idiom.StringToken -import io.getquill.norm.{ ConcatBehavior, EqualityBehavior, TranspileConfig } +import io.getquill.norm.{ ConcatBehavior, EqualityBehavior } import io.getquill.quat.Quat import io.getquill.sql.norm.VendorizeBooleans import io.getquill.util.Messages +import io.getquill.IdiomContext trait BooleanLiteralSupport extends SqlIdiom { - override def normalizeAst(ast: Ast, concatBehavior: ConcatBehavior, equalityBehavior: EqualityBehavior, transpileConfig: TranspileConfig) = { - val norm = SqlNormalize(ast, transpileConfig, concatBehavior, equalityBehavior) + override def normalizeAst(ast: Ast, concatBehavior: ConcatBehavior, equalityBehavior: EqualityBehavior, idiomContext: IdiomContext) = { + val norm = SqlNormalize(ast, idiomContext.config, concatBehavior, equalityBehavior) if (Messages.smartBooleans) VendorizeBooleans(norm) else diff --git a/quill-engine/src/main/scala/io/getquill/sql/idiom/OnConflictSupport.scala b/quill-engine/src/main/scala/io/getquill/sql/idiom/OnConflictSupport.scala index 7106719884..d9664afa48 100644 --- a/quill-engine/src/main/scala/io/getquill/sql/idiom/OnConflictSupport.scala +++ b/quill-engine/src/main/scala/io/getquill/sql/idiom/OnConflictSupport.scala @@ -4,23 +4,23 @@ import io.getquill.ast._ import io.getquill.idiom.StatementInterpolator._ import io.getquill.idiom.Token import io.getquill.NamingStrategy -import io.getquill.norm.TranspileConfig +import io.getquill.IdiomContext import io.getquill.util.Messages.fail trait OnConflictSupport { self: SqlIdiom => - implicit def conflictTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, transpileConfig: TranspileConfig): Tokenizer[OnConflict] = { + implicit def conflictTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, idiomContext: IdiomContext): Tokenizer[OnConflict] = { val customEntityTokenizer = Tokenizer[Entity] { case Entity.Opinionated(name, _, _, renameable) => stmt"INTO ${renameable.fixedOr(name.token)(strategy.table(name).token)} AS t" } val customAstTokenizer = - Tokenizer.withFallback[Ast](self.astTokenizer(_, strategy, transpileConfig)) { + Tokenizer.withFallback[Ast](self.astTokenizer(_, strategy, idiomContext)) { case _: OnConflict.Excluded => stmt"EXCLUDED" case OnConflict.Existing(a) => stmt"${a.token}" - case a: Action => self.actionTokenizer(customEntityTokenizer)(actionAstTokenizer, strategy, transpileConfig).token(a) + case a: Action => self.actionTokenizer(customEntityTokenizer)(actionAstTokenizer, strategy, idiomContext).token(a) } import OnConflict._ diff --git a/quill-engine/src/main/scala/io/getquill/sql/idiom/SqlIdiom.scala b/quill-engine/src/main/scala/io/getquill/sql/idiom/SqlIdiom.scala index 9200f78109..a7876dacad 100644 --- a/quill-engine/src/main/scala/io/getquill/sql/idiom/SqlIdiom.scala +++ b/quill-engine/src/main/scala/io/getquill/sql/idiom/SqlIdiom.scala @@ -1,7 +1,7 @@ package io.getquill.context.sql.idiom import com.github.takayahilton.sqlformatter._ -import io.getquill.NamingStrategy +import io.getquill.{ NamingStrategy, IdiomContext } import io.getquill.ast.BooleanOperator._ import io.getquill.ast.Renameable.Fixed import io.getquill.ast.Visibility.Hidden @@ -15,7 +15,7 @@ import io.getquill.idiom.StatementInterpolator._ import io.getquill.idiom._ import io.getquill.norm.ConcatBehavior.AnsiConcat import io.getquill.norm.EqualityBehavior.AnsiEquality -import io.getquill.norm.{ ConcatBehavior, EqualityBehavior, ExpandReturning, NormalizeCaching, ProductAggregationToken, TranspileConfig } +import io.getquill.norm.{ ConcatBehavior, EqualityBehavior, ExpandReturning, NormalizeCaching, ProductAggregationToken } import io.getquill.quat.Quat import io.getquill.sql.norm.{ HideTopLevelFilterAlias, NormalizeFilteredActionAliases, RemoveExtraAlias, RemoveUnusedSelects } import io.getquill.util.{ Interleave, Messages, TraceConfig } @@ -36,37 +36,37 @@ trait SqlIdiom extends Idiom { override def format(queryString: String): String = SqlFormatter.format(queryString) - def normalizeAst(ast: Ast, concatBehavior: ConcatBehavior, equalityBehavior: EqualityBehavior, transpileConfig: TranspileConfig) = - SqlNormalize(ast, transpileConfig, concatBehavior, equalityBehavior) + def normalizeAst(ast: Ast, concatBehavior: ConcatBehavior, equalityBehavior: EqualityBehavior, idiomContext: IdiomContext) = + SqlNormalize(ast, idiomContext.config, concatBehavior, equalityBehavior) def querifyAst(ast: Ast, traceConfig: TraceConfig) = new SqlQueryApply(traceConfig)(ast) // See HideTopLevelFilterAlias for more detail on how this works - def querifyAction(ast: Action) = { - val norm = NormalizeFilteredActionAliases(ast) + def querifyAction(ast: Action, batchAlias: Option[String]) = { + val norm = new NormalizeFilteredActionAliases(batchAlias)(ast) useActionTableAliasAs match { case ActionTableAliasBehavior.Hide => HideTopLevelFilterAlias(norm) case _ => norm } } - private def doTranslate(ast: Ast, cached: Boolean, topLevelQuat: Quat, executionType: ExecutionType, transpileConfig: TranspileConfig)(implicit naming: NamingStrategy): (Ast, Statement, ExecutionType) = { + private def doTranslate(ast: Ast, cached: Boolean, topLevelQuat: Quat, executionType: ExecutionType, idiomContext: IdiomContext)(implicit naming: NamingStrategy): (Ast, Statement, ExecutionType) = { val normalizedAst = { if (cached) { - NormalizeCaching { (a: Ast) => normalizeAst(a, concatBehavior, equalityBehavior, transpileConfig) }(ast) + NormalizeCaching { (a: Ast) => normalizeAst(a, concatBehavior, equalityBehavior, idiomContext) }(ast) } else { - normalizeAst(ast, concatBehavior, equalityBehavior, transpileConfig) + normalizeAst(ast, concatBehavior, equalityBehavior, idiomContext) } } - implicit val transpileConfigImplicit: TranspileConfig = transpileConfig + implicit val transpileContextImplicit: IdiomContext = idiomContext implicit val tokernizer: Tokenizer[Ast] = defaultTokenizer val token = normalizedAst match { case q: Query => - val sql = querifyAst(q, transpileConfig.traceConfig) + val sql = querifyAst(q, idiomContext.traceConfig) trace("sql")(sql) VerifySqlQuery(sql).map(fail) val expanded = ExpandNestedQueries(sql, topLevelQuat) @@ -80,7 +80,7 @@ trait SqlIdiom extends Idiom { tokenized case a: Action => // Mostly we don't use the alias in SQL set-queries but if we do, make sure they are right - val sql = querifyAction(a) + val sql = querifyAction(a, idiomContext.queryType.batchAlias) trace("action sql")(sql) // Run the tokenization, make sure that we're running tokenization from the top-level (i.e. from the Ast-tokenizer, don't go directly to the action tokenizer) (sql: Ast).token @@ -91,22 +91,22 @@ trait SqlIdiom extends Idiom { (normalizedAst, stmt"$token", executionType) } - override def translate(ast: Ast, topLevelQuat: Quat, executionType: ExecutionType, transpileConfig: TranspileConfig)(implicit naming: NamingStrategy): (Ast, Statement, ExecutionType) = { - doTranslate(ast, false, topLevelQuat, executionType, transpileConfig) + override def translate(ast: Ast, topLevelQuat: Quat, executionType: ExecutionType, idiomContext: IdiomContext)(implicit naming: NamingStrategy): (Ast, Statement, ExecutionType) = { + doTranslate(ast, false, topLevelQuat, executionType, idiomContext) } - override def translateCached(ast: Ast, topLevelQuat: Quat, executionType: ExecutionType, transpileConfig: TranspileConfig)(implicit naming: NamingStrategy): (Ast, Statement, ExecutionType) = { - doTranslate(ast, true, topLevelQuat, executionType, transpileConfig) + override def translateCached(ast: Ast, topLevelQuat: Quat, executionType: ExecutionType, idiomContext: IdiomContext)(implicit naming: NamingStrategy): (Ast, Statement, ExecutionType) = { + doTranslate(ast, true, topLevelQuat, executionType, idiomContext) } - def defaultTokenizer(implicit naming: NamingStrategy, transpileConfig: TranspileConfig): Tokenizer[Ast] = + def defaultTokenizer(implicit naming: NamingStrategy, idiomContext: IdiomContext): Tokenizer[Ast] = new Tokenizer[Ast] { - private val stableTokenizer = astTokenizer(this, naming, transpileConfig) + private val stableTokenizer = astTokenizer(this, naming, idiomContext) def token(v: Ast) = stableTokenizer.token(v) } - def astTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, transpileConfig: TranspileConfig): Tokenizer[Ast] = + def astTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, idiomContext: IdiomContext): Tokenizer[Ast] = Tokenizer[Ast] { case a: Query => // This case typically happens when you have a select inside of an insert @@ -116,10 +116,10 @@ trait SqlIdiom extends Idiom { // Right now we are not removing extra select clauses here (via RemoveUnusedSelects) since I am not sure what // kind of impact that could have on selects. Can try to do that in the future. if (Messages.querySubexpand) { - val nestedExpanded = ExpandNestedQueries(new SqlQueryApply(transpileConfig.traceConfig)(a)) + val nestedExpanded = ExpandNestedQueries(new SqlQueryApply(idiomContext.traceConfig)(a)) RemoveExtraAlias(strategy)(nestedExpanded).token } else - new SqlQueryApply(transpileConfig.traceConfig)(a).token + new SqlQueryApply(idiomContext.traceConfig)(a).token case a: Operation => a.token case a: Infix => a.token @@ -164,7 +164,7 @@ trait SqlIdiom extends Idiom { protected def tokenizeGroupBy(values: Ast)(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy): Token = values.token - protected class FlattenSqlQueryTokenizerHelper(q: FlattenSqlQuery)(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, transpileConfig: TranspileConfig) { + protected class FlattenSqlQueryTokenizerHelper(q: FlattenSqlQuery)(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, idiomContext: IdiomContext) { import q._ @@ -221,7 +221,7 @@ trait SqlIdiom extends Idiom { def apply = stmt"SELECT $withLimitOffset" } - implicit def sqlQueryTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, transpileConfig: TranspileConfig): Tokenizer[SqlQuery] = Tokenizer[SqlQuery] { + implicit def sqlQueryTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, idiomContext: IdiomContext): Tokenizer[SqlQuery] = Tokenizer[SqlQuery] { case q: FlattenSqlQuery => new FlattenSqlQueryTokenizerHelper(q).apply case SetOperationSqlQuery(a, op, b) => @@ -256,7 +256,7 @@ trait SqlIdiom extends Idiom { protected def tokenizeIdentName(strategy: NamingStrategy, name: String): String = name - implicit def selectValueTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, transpileConfig: TranspileConfig): Tokenizer[SelectValue] = { + implicit def selectValueTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, idiomContext: IdiomContext): Tokenizer[SelectValue] = { def tokenizer(implicit astTokenizer: Tokenizer[Ast]) = Tokenizer[SelectValue] { @@ -287,7 +287,7 @@ trait SqlIdiom extends Idiom { } val customAstTokenizer = - Tokenizer.withFallback[Ast](SqlIdiom.this.astTokenizer(_, strategy, transpileConfig)) { + Tokenizer.withFallback[Ast](SqlIdiom.this.astTokenizer(_, strategy, idiomContext)) { case Aggregation(op, Ident(id, _: Quat.Product)) => stmt"${op.token}(${makeProductAggregationToken(id)})" // Not too many cases of this. Can happen if doing a leaf-level infix inside of a select clause. For example in postgres: @@ -353,7 +353,7 @@ trait SqlIdiom extends Idiom { protected def tokenOrderBy(criterias: List[OrderByCriteria])(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy) = stmt"ORDER BY ${criterias.token}" - implicit def sourceTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, transpileConfig: TranspileConfig): Tokenizer[FromContext] = Tokenizer[FromContext] { + implicit def sourceTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, idiomContext: IdiomContext): Tokenizer[FromContext] = Tokenizer[FromContext] { case TableContext(name, alias) => stmt"${name.token} ${tokenizeTableAlias(strategy, alias).token}" case QueryContext(query, alias) => stmt"(${query.token})${` AS`} ${tokenizeTableAlias(strategy, alias).token}" case InfixContext(infix, alias) => stmt"(${(infix: Ast).token})${` AS`} ${tokenizeTableAlias(strategy, alias).token}" @@ -530,15 +530,15 @@ trait SqlIdiom extends Idiom { stmt"${prop.token} = ${scopedTokenizer(value)}" } - implicit def defaultAstTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, transpileConfig: TranspileConfig): Tokenizer[Action] = { + implicit def defaultAstTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, idiomContext: IdiomContext): Tokenizer[Action] = { val insertEntityTokenizer = Tokenizer[Entity] { case Entity.Opinionated(name, _, _, renameable) => stmt"INTO ${tokenizeTable(strategy, name, renameable).token}" } - actionTokenizer(insertEntityTokenizer)(actionAstTokenizer, strategy, transpileConfig) + actionTokenizer(insertEntityTokenizer)(actionAstTokenizer, strategy, idiomContext) } - protected def actionAstTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, transpileConfig: TranspileConfig) = - Tokenizer.withFallback[Ast](SqlIdiom.this.astTokenizer(_, strategy, transpileConfig)) { + protected def actionAstTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, idiomContext: IdiomContext) = + Tokenizer.withFallback[Ast](SqlIdiom.this.astTokenizer(_, strategy, idiomContext)) { case q: Query => astTokenizer.token(q) case Property(Property.Opinionated(_, name, renameable, _), "isEmpty") => stmt"${renameable.fixedOr(name)(tokenizeColumn(strategy, name, renameable)).token} IS NULL" case Property(Property.Opinionated(_, name, renameable, _), "isDefined") => stmt"${renameable.fixedOr(name)(tokenizeColumn(strategy, name, renameable)).token} IS NOT NULL" @@ -547,9 +547,9 @@ trait SqlIdiom extends Idiom { case Property.Opinionated(_, name, renameable, _) => renameable.fixedOr(name.token)(tokenizeColumn(strategy, name, renameable).token) } - def returnListTokenizer(implicit tokenizer: Tokenizer[Ast], strategy: NamingStrategy, transpileConfig: TranspileConfig): Tokenizer[List[Ast]] = { + def returnListTokenizer(implicit tokenizer: Tokenizer[Ast], strategy: NamingStrategy, idiomContext: IdiomContext): Tokenizer[List[Ast]] = { val customAstTokenizer = - Tokenizer.withFallback[Ast](SqlIdiom.this.astTokenizer(_, strategy, transpileConfig)) { + Tokenizer.withFallback[Ast](SqlIdiom.this.astTokenizer(_, strategy, idiomContext)) { case sq: Query => stmt"(${tokenizer.token(sq)})" } @@ -569,7 +569,7 @@ trait SqlIdiom extends Idiom { private[getquill] def returningEnabled = !Messages.disableReturning - protected def actionTokenizer(insertEntityTokenizer: Tokenizer[Entity])(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, transpileConfig: TranspileConfig): Tokenizer[Action] = + protected def actionTokenizer(insertEntityTokenizer: Tokenizer[Entity])(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, idiomContext: IdiomContext): Tokenizer[Action] = Tokenizer[Action] { case action @ Update(Filter(_: Entity, alias, _), _) => @@ -615,7 +615,7 @@ trait SqlIdiom extends Idiom { action match { case Insert(entity: Entity, assignments) => val (table, columns, values) = insertInfo(insertEntityTokenizer, entity, assignments) - stmt"INSERT $table${` AS [table]`} (${columns.mkStmt(",")}) OUTPUT ${returnListTokenizer.token(ExpandReturning(r, Some("INSERTED"))(this, strategy, transpileConfig).map(_._1))} VALUES ${ValuesClauseToken(stmt"(${values.mkStmt(", ")})")}" + stmt"INSERT $table${` AS [table]`} (${columns.mkStmt(",")}) OUTPUT ${returnListTokenizer.token(ExpandReturning(r, Some("INSERTED"))(this, strategy, idiomContext).map(_._1))} VALUES ${ValuesClauseToken(stmt"(${values.mkStmt(", ")})")}" // query[Person].filter(...).update/updateValue(...) case action @ Update(Filter(_: Entity, alias, _), _) => @@ -643,12 +643,17 @@ trait SqlIdiom extends Idiom { fail(s"Action ast can't be translated to sql: '$other'") } - def tokenizeReturningClause(r: ReturningAction, alias: Option[String] = None)(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, transpileConfig: TranspileConfig) = { - returnListTokenizer.token(ExpandReturning(r, alias)(this, strategy, transpileConfig).map(_._1)) + def tokenizeReturningClause(r: ReturningAction, alias: Option[String] = None)(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, idiomContext: IdiomContext) = { + returnListTokenizer.token(ExpandReturning(r, alias)(this, strategy, idiomContext).map(_._1)) } private def insertInfo(insertEntityTokenizer: Tokenizer[Entity], entity: Entity, assignments: List[Assignment])(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy) = { val table = insertEntityTokenizer.token(entity) + val (columns, values) = columnsAndValues(assignments) + (table, columns, values) + } + + private[getquill] def columnsAndValues(assignments: List[Assignment])(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy) = { val columns = assignments.map(assignment => assignment.property match { @@ -656,7 +661,7 @@ trait SqlIdiom extends Idiom { case _ => fail(s"Invalid assignment value of ${assignment}. Must be a Property object.") }) val values = assignments.map(assignment => scopedTokenizer(assignment.value)) - (table, columns, values) + (columns, values) } implicit def entityTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy): Tokenizer[Entity] = Tokenizer[Entity] { @@ -697,13 +702,13 @@ object SqlIdiom { } case class InsertUpdateStmt(action: Statement, where: Statement) - private[getquill] def withActionAlias(parentIdiom: SqlIdiom, action: Action, alias: Ident)(implicit strategy: NamingStrategy, transpileConfig: TranspileConfig): InsertUpdateStmt = { + private[getquill] def withActionAlias(parentIdiom: SqlIdiom, action: Action, alias: Ident)(implicit strategy: NamingStrategy, idiomContext: IdiomContext): InsertUpdateStmt = { val idiom = copyIdiom(parentIdiom, Some(alias)) import idiom._ implicit val stableTokenizer = idiom.astTokenizer(new Tokenizer[Ast] { - override def token(v: Ast): Token = astTokenizer(this, strategy, transpileConfig).token(v) - }, strategy, transpileConfig) + override def token(v: Ast): Token = astTokenizer(this, strategy, idiomContext).token(v) + }, strategy, idiomContext) action match { case Update(Filter(table: Entity, x, where), assignments) => @@ -728,13 +733,13 @@ object SqlIdiom { * (i.e. insert, and update) will be rendered with the specified alias. This is needed for RETURNING clauses that have * queries inside. See #1509 for details. */ - private[getquill] def withActionAlias(parentIdiom: SqlIdiom, query: ReturningAction)(implicit strategy: NamingStrategy, transpileConfig: TranspileConfig) = { + private[getquill] def withActionAlias(parentIdiom: SqlIdiom, query: ReturningAction)(implicit strategy: NamingStrategy, idiomContext: IdiomContext) = { val idiom = copyIdiom(parentIdiom, Some(query.alias)) import idiom._ implicit val stableTokenizer = idiom.astTokenizer(new Tokenizer[Ast] { - override def token(v: Ast): Token = astTokenizer(this, strategy, transpileConfig).token(v) - }, strategy, transpileConfig) + override def token(v: Ast): Token = astTokenizer(this, strategy, idiomContext).token(v) + }, strategy, idiomContext) def ` AS [alias]`(alias: Ident) = useActionTableAliasAs match { diff --git a/quill-engine/src/main/scala/io/getquill/sql/norm/NormalizeActionAliases.scala b/quill-engine/src/main/scala/io/getquill/sql/norm/NormalizeActionAliases.scala index d3ae1df233..d329b3cc64 100644 --- a/quill-engine/src/main/scala/io/getquill/sql/norm/NormalizeActionAliases.scala +++ b/quill-engine/src/main/scala/io/getquill/sql/norm/NormalizeActionAliases.scala @@ -3,7 +3,25 @@ package io.getquill.sql.norm import io.getquill.ast._ import io.getquill.norm.BetaReduction -object NormalizeFilteredActionAliases extends StatelessTransformer { +object NormalizeFilteredActionAliases { + private[getquill] def chooseAlias(entityName: String, batchAlias: Option[String]) = { + val lowerEntityName = entityName.toLowerCase + val possibleEntityNameChar = + if (lowerEntityName.length > 0 && lowerEntityName.take(1).matches("[a-z]")) + Some(lowerEntityName.take(1)) + else + None + (possibleEntityNameChar, batchAlias) match { + case (Some(t), Some(b)) if (b != t) => t + case (Some(t), Some(b)) if (b == t) => s"${t}Tbl" + case (Some(t), None) => t + case (None, Some(b)) => s"${b}Tbl" + case (None, None) => "x" + } + } +} + +case class NormalizeFilteredActionAliases(batchAlias: Option[String]) extends StatelessTransformer { override def apply(e: Action): Action = e match { @@ -29,6 +47,11 @@ object NormalizeFilteredActionAliases extends StatelessTransformer { // (since we don't tokenize the identifier of the SET-clauses) Update(apply(query), assignments.map(a => realiasAssignment(a, alias))) + + case Update(e: Entity, assignments) => + val alias = NormalizeFilteredActionAliases.chooseAlias(e.name, batchAlias) + Update(e, assignments.map(a => realiasAssignment(a, alias))) + case _ => super.apply(e) } @@ -36,6 +59,16 @@ object NormalizeFilteredActionAliases extends StatelessTransformer { a match { case Assignment(alias, prop, value) => val newProp = BetaReduction(prop, alias -> newAlias) - Assignment(newAlias, newProp, value) + val newVal = BetaReduction(value, alias -> newAlias) + Assignment(newAlias, newProp, newVal) + } + + private def realiasAssignment(a: Assignment, newAliasName: String) = + a match { + case Assignment(alias, prop, value) => + val newAlias = alias.copy(name = newAliasName) + val newProp = BetaReduction(prop, alias -> newAlias) + val newVal = BetaReduction(value, alias -> newAlias) + Assignment(newAlias, newProp, newVal) } } diff --git a/quill-jasync-zio-postgres/src/test/resources/logback.xml b/quill-jasync-zio-postgres/src/test/resources/logback.xml new file mode 100644 index 0000000000..3938c39719 --- /dev/null +++ b/quill-jasync-zio-postgres/src/test/resources/logback.xml @@ -0,0 +1,16 @@ + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n%ex + + + + + + + + + + + + diff --git a/quill-jdbc/src/test/resources/application.conf b/quill-jdbc/src/test/resources/application.conf index 891e986663..87a9afa4b1 100644 --- a/quill-jdbc/src/test/resources/application.conf +++ b/quill-jdbc/src/test/resources/application.conf @@ -6,6 +6,7 @@ testMysqlDB.dataSource.cachePrepStmts=true testMysqlDB.dataSource.prepStmtCacheSize=250 testMysqlDB.dataSource.prepStmtCacheSqlLimit=2048 testMysqlDB.maximumPoolSize=1 +#testMysqlDB.dataSource.rewriteBatchedStatements=true testPostgresDB.dataSourceClassName=org.postgresql.ds.PGSimpleDataSource testPostgresDB.dataSource.user=postgres @@ -13,6 +14,7 @@ testPostgresDB.dataSource.password=${?POSTGRES_PASSWORD} testPostgresDB.dataSource.databaseName=quill_test testPostgresDB.dataSource.portNumber=${?POSTGRES_PORT} testPostgresDB.dataSource.serverName=${?POSTGRES_HOST} +#testPostgresDB.dataSource.reWriteBatchedInserts=true testH2DB.dataSourceClassName=org.h2.jdbcx.JdbcDataSource testH2DB.dataSource.url="jdbc:h2:mem:test;DB_CLOSE_DELAY=-1;INIT=RUNSCRIPT FROM 'classpath:sql/h2-schema.sql'" diff --git a/quill-jdbc/src/test/resources/logback.xml b/quill-jdbc/src/test/resources/logback.xml new file mode 100644 index 0000000000..fd0a6cae23 --- /dev/null +++ b/quill-jdbc/src/test/resources/logback.xml @@ -0,0 +1,16 @@ + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n%ex + + + + + + + + + + + + diff --git a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/h2/BatchValuesJdbcSpec.scala b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/h2/BatchValuesJdbcSpec.scala index ec60a8550d..ad5c9f3efe 100644 --- a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/h2/BatchValuesJdbcSpec.scala +++ b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/h2/BatchValuesJdbcSpec.scala @@ -25,4 +25,10 @@ class BatchValuesJdbcSpec extends BatchValuesSpec { ids mustEqual productsOriginal.map(_.id) testContext.run(get) mustEqual productsOriginal } + + "Ex 3 - Batch Insert Mixed" in { + import `Ex 3 - Batch Insert Mixed`._ + testContext.run(op, batchSize) + testContext.run(get).toSet mustEqual result.toSet + } } diff --git a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/mysql/BatchValuesJdbcSpec.scala b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/mysql/BatchValuesJdbcSpec.scala index 6e2d1a3a32..4b3b8bc537 100644 --- a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/mysql/BatchValuesJdbcSpec.scala +++ b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/mysql/BatchValuesJdbcSpec.scala @@ -26,4 +26,10 @@ class BatchValuesJdbcSpec extends BatchValuesSpec { ids.toSet mustEqual productsOriginal.map(_.id).toSet testContext.run(get).toSet mustEqual productsOriginal.toSet } + + "Ex 3 - Batch Insert Mixed" in { + import `Ex 3 - Batch Insert Mixed`._ + testContext.run(op, batchSize) + testContext.run(get).toSet mustEqual result.toSet + } } \ No newline at end of file diff --git a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/oracle/BatchValuesJdbcSpec.scala b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/oracle/BatchValuesJdbcSpec.scala index 11c3641662..0571730422 100644 --- a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/oracle/BatchValuesJdbcSpec.scala +++ b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/oracle/BatchValuesJdbcSpec.scala @@ -26,4 +26,13 @@ class BatchValuesJdbcSpec extends BatchValuesSpec { ids.toSet mustEqual productsOriginal.map(_.id).toSet testContext.run(get).toSet mustEqual productsOriginal.toSet } + + "Ex 3 - Batch Insert Mixed" in { + import `Ex 3 - Batch Insert Mixed`._ + def op = quote { + liftQuery(products).foreach(p => query[Product].insert(_.description -> lift("BlahBlah"), _.sku -> p.sku)) + } + testContext.run(op, batchSize) + testContext.run(get).toSet mustEqual result.toSet + } } diff --git a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/postgres/BatchUpdateJdbcSpec.scala b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/postgres/BatchUpdateJdbcSpec.scala new file mode 100644 index 0000000000..b7dab4c3b3 --- /dev/null +++ b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/postgres/BatchUpdateJdbcSpec.scala @@ -0,0 +1,91 @@ +package io.getquill.context.jdbc.postgres + +import io.getquill.context.sql.base.BatchUpdateValuesSpec +import io.getquill.norm.EnableTrace +import io.getquill.util.Messages.TraceType + +class BatchUpdateValuesJdbcSpec extends BatchUpdateValuesSpec { // + + val context = testContext + import testContext._ + + override def beforeEach(): Unit = { + val schema = quote(querySchema[ContactBase]("Contact")) + testContext.run(schema.delete) + super.beforeEach() + } + + "Ex 1 - Simple Contact" in { + import `Ex 1 - Simple Contact`._ + context.run(insert) + context.run(update, 2) + context.run(get).toSet mustEqual (expect.toSet) + } + + "Ex 1.1 - Simple Contact With Lift" in { + import `Ex 1.1 - Simple Contact With Lift`._ + context.run(insert) + context.run(update, 2) + context.run(get).toSet mustEqual (expect.toSet) + } + + "Ex 1.2 - Simple Contact Mixed Lifts" in { + import `Ex 1.2 - Simple Contact Mixed Lifts`._ + context.run(insert) + context.run(update, 2) + context.run(get).toSet mustEqual (expect.toSet) + } + + "Ex 1.3 - Simple Contact with Multi-Lift-Kinds" in { + import `Ex 1.3 - Simple Contact with Multi-Lift-Kinds`._ + context.run(insert) + context.run(update, 2) + context.run(get).toSet mustEqual (expect.toSet) + } + + "Ex 2 - Optional Embedded with Renames" in { + import `Ex 2 - Optional Embedded with Renames`._ + context.run(insert) + context.run(update, 2) + context.run(get).toSet mustEqual (expect.toSet) + } + + "Ex 3 - Deep Embedded Optional" in { + import `Ex 3 - Deep Embedded Optional`._ + context.run(insert) + context.run(update, 2) + context.run(get).toSet mustEqual (expect.toSet) + } + + "Ex 4 - Returning" in { + import `Ex 4 - Returning`._ + context.run(insert) + val agesReturned = context.run(update, 2) + agesReturned mustEqual expectedReturn + context.run(get).toSet mustEqual (expect.toSet) + } + + "Ex 4 - Returning Multiple" in { + import `Ex 4 - Returning Multiple`._ + context.run(insert) + val agesReturned = context.run(update, 2) + agesReturned mustEqual expectedReturn + context.run(get).toSet mustEqual (expect.toSet) + } + + "Ex 5 - Append Data" in { + System.setProperty("quill.binds.log", "true") + io.getquill.util.Messages.resetCache() + import `Ex 5 - Append Data`._ + context.run(insert) + context.run(update, 2) + context.run(get).toSet mustEqual (expectSpecific.toSet) + } + + "Ex 6 - Append Data No Condition" in { + import `Ex 6 - Append Data No Condition`._ + context.run(insert) + context.run(update, 2) + context.run(get).toSet mustEqual (expectSpecific.toSet) + } +} diff --git a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/postgres/BatchValuesJdbcSpec.scala b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/postgres/BatchValuesJdbcSpec.scala index 0847e66d50..208d6b19ec 100644 --- a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/postgres/BatchValuesJdbcSpec.scala +++ b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/postgres/BatchValuesJdbcSpec.scala @@ -25,4 +25,10 @@ class BatchValuesJdbcSpec extends BatchValuesSpec { ids mustEqual expectedIds testContext.run(get) mustEqual result } + + "Ex 3 - Batch Insert Mixed" in { + import `Ex 3 - Batch Insert Mixed`._ + testContext.run(op, batchSize) + testContext.run(get) mustEqual result + } } diff --git a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/sqlite/BatchValuesJdbcSpec.scala b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/sqlite/BatchValuesJdbcSpec.scala index d832f5487a..118a157cbc 100644 --- a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/sqlite/BatchValuesJdbcSpec.scala +++ b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/sqlite/BatchValuesJdbcSpec.scala @@ -34,4 +34,13 @@ class BatchValuesJdbcSpec extends BatchValuesSpec { ids mustEqual productsOriginal.map(_.id) testContext.run(get) mustEqual productsOriginal } + + "Ex 3 - Batch Insert Mixed" in { + import `Ex 3 - Batch Insert Mixed`._ + def op = quote { + liftQuery(products).foreach(p => query[Product].insert(_.description -> lift("BlahBlah"), _.sku -> p.sku)) + } + testContext.run(op, batchSize) + testContext.run(get).toSet mustEqual result.toSet + } } diff --git a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/sqlserver/BatchValuesJdbcSpec.scala b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/sqlserver/BatchValuesJdbcSpec.scala index 6ae32f2699..a31fea9fb9 100644 --- a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/sqlserver/BatchValuesJdbcSpec.scala +++ b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/sqlserver/BatchValuesJdbcSpec.scala @@ -1,6 +1,6 @@ package io.getquill.context.jdbc.sqlserver -import io.getquill.Delete +import io.getquill.{ Delete, Insert } import io.getquill.context.sql.base.BatchValuesSpec class BatchValuesJdbcSpec extends BatchValuesSpec { @@ -36,4 +36,13 @@ class BatchValuesJdbcSpec extends BatchValuesSpec { ids mustEqual productsOriginal testContext.run(get) mustEqual productsOriginal } + + "Ex 3 - Batch Insert Mixed" in { + import `Ex 3 - Batch Insert Mixed`._ + def splicedOp = quote { + opExt(insert => sql"SET IDENTITY_INSERT Product ON; ${insert}".as[Insert[Product]]) + } + testContext.run(splicedOp, batchSize) + testContext.run(get) mustEqual result + } } \ No newline at end of file diff --git a/quill-ndbc-postgres/src/test/resources/logback.xml b/quill-ndbc-postgres/src/test/resources/logback.xml index 657c117594..f7387fb8bc 100644 --- a/quill-ndbc-postgres/src/test/resources/logback.xml +++ b/quill-ndbc-postgres/src/test/resources/logback.xml @@ -9,6 +9,7 @@ + diff --git a/quill-orientdb/src/main/scala/io/getquill/context/orientdb/OrientDBIdiom.scala b/quill-orientdb/src/main/scala/io/getquill/context/orientdb/OrientDBIdiom.scala index 8eb7465a2f..db786ee105 100644 --- a/quill-orientdb/src/main/scala/io/getquill/context/orientdb/OrientDBIdiom.scala +++ b/quill-orientdb/src/main/scala/io/getquill/context/orientdb/OrientDBIdiom.scala @@ -5,7 +5,7 @@ import io.getquill.idiom.StatementInterpolator._ import io.getquill.context.sql.norm._ import io.getquill.ast.{ AggregationOperator, External, _ } import io.getquill.context.sql._ -import io.getquill.NamingStrategy +import io.getquill.{ IdiomContext, NamingStrategy } import io.getquill.context.{ CannotReturn, ExecutionType } import io.getquill.util.Messages.{ fail, trace } import io.getquill.idiom._ @@ -25,25 +25,25 @@ trait OrientDBIdiom extends Idiom { override def prepareForProbing(string: String): String = string - override def translate(ast: Ast, topLevelQuat: Quat, executionType: ExecutionType, transpileConfig: TranspileConfig)(implicit naming: NamingStrategy): (Ast, Statement, ExecutionType) = { - doTranslate(ast, false, executionType, transpileConfig) + override def translate(ast: Ast, topLevelQuat: Quat, executionType: ExecutionType, idiomContext: IdiomContext)(implicit naming: NamingStrategy): (Ast, Statement, ExecutionType) = { + doTranslate(ast, false, executionType, idiomContext) } - override def translateCached(ast: Ast, topLevelQuat: Quat, executionType: ExecutionType, transpileConfig: TranspileConfig)(implicit naming: NamingStrategy): (Ast, Statement, ExecutionType) = { - doTranslate(ast, true, executionType, transpileConfig) + override def translateCached(ast: Ast, topLevelQuat: Quat, executionType: ExecutionType, idiomContext: IdiomContext)(implicit naming: NamingStrategy): (Ast, Statement, ExecutionType) = { + doTranslate(ast, true, executionType, idiomContext) } - private def doTranslate(ast: Ast, cached: Boolean, executionType: ExecutionType, transpileConfig: TranspileConfig)(implicit naming: NamingStrategy): (Ast, Statement, ExecutionType) = { - implicit val transpileConfigImplicit: TranspileConfig = transpileConfig + private def doTranslate(ast: Ast, cached: Boolean, executionType: ExecutionType, idiomContext: IdiomContext)(implicit naming: NamingStrategy): (Ast, Statement, ExecutionType) = { + implicit val implcitIdiomContext: IdiomContext = idiomContext val normalizedAst = { if (cached) - NormalizeCaching { ast: Ast => SqlNormalize(ast, transpileConfig) }(ast) + NormalizeCaching { ast: Ast => SqlNormalize(ast, idiomContext.config) }(ast) else SqlNormalize(ast, TranspileConfig.Empty) } val token = normalizedAst match { case q: Query => - val sql = new SqlQueryApply(transpileConfig.traceConfig)(q) + val sql = new SqlQueryApply(idiomContext.traceConfig)(q) VerifySqlQuery(sql).map(fail) val expanded = ExpandNestedQueries(sql) trace("expanded sql")(expanded) @@ -61,10 +61,10 @@ trait OrientDBIdiom extends Idiom { (normalizedAst, stmt"$token", executionType) } - implicit def astTokenizer(implicit strategy: NamingStrategy, queryTokenizer: Tokenizer[Query], transpileConfig: TranspileConfig): Tokenizer[Ast] = { + implicit def astTokenizer(implicit strategy: NamingStrategy, queryTokenizer: Tokenizer[Query], idiomContext: IdiomContext): Tokenizer[Ast] = { Tokenizer[Ast] { case a: Query => - new SqlQueryApply(transpileConfig.traceConfig)(a).token + new SqlQueryApply(idiomContext.traceConfig)(a).token case a: Operation => a.token case a: Infix => @@ -95,7 +95,7 @@ trait OrientDBIdiom extends Idiom { } } - implicit def ifTokenizer(implicit strategy: NamingStrategy, transpileConfig: TranspileConfig): Tokenizer[If] = Tokenizer[If] { + implicit def ifTokenizer(implicit strategy: NamingStrategy, idiomContext: IdiomContext): Tokenizer[If] = Tokenizer[If] { case ast: If => def flatten(ast: Ast): (List[(Ast, Ast)], Ast) = ast match { @@ -114,11 +114,11 @@ trait OrientDBIdiom extends Idiom { conditions.head } - implicit def queryTokenizer(implicit strategy: NamingStrategy, transpileConfig: TranspileConfig): Tokenizer[Query] = Tokenizer[Query] { - case q => new SqlQueryApply(transpileConfig.traceConfig)(q).token + implicit def queryTokenizer(implicit strategy: NamingStrategy, idiomContext: IdiomContext): Tokenizer[Query] = Tokenizer[Query] { + case q => new SqlQueryApply(idiomContext.traceConfig)(q).token } - implicit def orientDBQueryTokenizer(implicit strategy: NamingStrategy, transpileConfig: TranspileConfig): Tokenizer[SqlQuery] = Tokenizer[SqlQuery] { + implicit def orientDBQueryTokenizer(implicit strategy: NamingStrategy, idiomContext: IdiomContext): Tokenizer[SqlQuery] = Tokenizer[SqlQuery] { case FlattenSqlQuery(from, where, groupBy, orderBy, limit, offset, select, distinct) => val distinctTokenizer = (if (distinct == DistinctKind.Distinct) "DISTINCT" else "").token @@ -176,7 +176,7 @@ trait OrientDBIdiom extends Idiom { fail("Other operators are not supported yet. Please raise a ticket to support more operations") } - implicit def operationTokenizer(implicit propertyTokenizer: Tokenizer[Property], strategy: NamingStrategy, transpileConfig: TranspileConfig): Tokenizer[Operation] = Tokenizer[Operation] { + implicit def operationTokenizer(implicit propertyTokenizer: Tokenizer[Property], strategy: NamingStrategy, idiomContext: IdiomContext): Tokenizer[Operation] = Tokenizer[Operation] { case UnaryOperation(op, ast) => stmt"${op.token} (${ast.token})" case BinaryOperation(a, EqualityOperator.`_==`, NullValue) => stmt"${scopedTokenizer(a)} IS NULL" case BinaryOperation(NullValue, EqualityOperator.`_==`, b) => stmt"${scopedTokenizer(b)} IS NULL" @@ -193,17 +193,17 @@ trait OrientDBIdiom extends Idiom { case UnionAllOperation => stmt"UNION ALL" } - protected def tokenOrderBy(criterias: List[OrderByCriteria])(implicit strategy: NamingStrategy, transpileConfig: TranspileConfig) = + protected def tokenOrderBy(criterias: List[OrderByCriteria])(implicit strategy: NamingStrategy, idiomContext: IdiomContext) = stmt"ORDER BY ${criterias.token}" - implicit def sourceTokenizer(implicit strategy: NamingStrategy, transpileConfig: TranspileConfig): Tokenizer[FromContext] = Tokenizer[FromContext] { + implicit def sourceTokenizer(implicit strategy: NamingStrategy, idiomContext: IdiomContext): Tokenizer[FromContext] = Tokenizer[FromContext] { case TableContext(name, alias) => stmt"${name.token}" case QueryContext(query, alias) => stmt"(${query.token})" case InfixContext(infix, alias) => stmt"(${(infix: Ast).token})" case _ => fail("OrientDB sql doesn't support joins") } - implicit def orderByCriteriaTokenizer(implicit strategy: NamingStrategy, transpileConfig: TranspileConfig): Tokenizer[OrderByCriteria] = Tokenizer[OrderByCriteria] { + implicit def orderByCriteriaTokenizer(implicit strategy: NamingStrategy, idiomContext: IdiomContext): Tokenizer[OrderByCriteria] = Tokenizer[OrderByCriteria] { case OrderByCriteria(ast, Asc | AscNullsFirst | AscNullsLast) => stmt"${scopedTokenizer(ast)} ASC" case OrderByCriteria(ast, Desc | DescNullsFirst | DescNullsLast) => stmt"${scopedTokenizer(ast)} DESC" } @@ -238,7 +238,7 @@ trait OrientDBIdiom extends Idiom { case other => fail(s"OrientDB QL doesn't support the '$other' operator.") } - implicit def selectValueTokenizer(implicit strategy: NamingStrategy, transpileConfig: TranspileConfig): Tokenizer[SelectValue] = { + implicit def selectValueTokenizer(implicit strategy: NamingStrategy, idiomContext: IdiomContext): Tokenizer[SelectValue] = { def tokenValue(ast: Ast) = ast match { case Aggregation(op, Ident(_, _)) => stmt"${op.token}(*)" @@ -255,7 +255,7 @@ trait OrientDBIdiom extends Idiom { } } - implicit def propertyTokenizer(implicit valueTokenizer: Tokenizer[Value], identTokenizer: Tokenizer[Ident], strategy: NamingStrategy, transpileConfig: TranspileConfig): Tokenizer[Property] = { + implicit def propertyTokenizer(implicit valueTokenizer: Tokenizer[Value], identTokenizer: Tokenizer[Ident], strategy: NamingStrategy, idiomContext: IdiomContext): Tokenizer[Property] = { Tokenizer[Property] { case Property(ast, "isEmpty") => stmt"${ast.token} IS NULL" case Property(ast, "nonEmpty") => stmt"${ast.token} IS NOT NULL" @@ -265,7 +265,7 @@ trait OrientDBIdiom extends Idiom { } } - implicit def valueTokenizer(implicit strategy: NamingStrategy, transpileConfig: TranspileConfig): Tokenizer[Value] = Tokenizer[Value] { + implicit def valueTokenizer(implicit strategy: NamingStrategy, idiomContext: IdiomContext): Tokenizer[Value] = Tokenizer[Value] { case Constant(v: String, _) => stmt"'${v.token}'" case Constant((), _) => stmt"1" case Constant(v, _) => stmt"${v.toString.token}" @@ -274,7 +274,7 @@ trait OrientDBIdiom extends Idiom { case CaseClass(values) => stmt"${values.map(_._2).token}" } - implicit def infixTokenizer(implicit propertyTokenizer: Tokenizer[Property], strategy: NamingStrategy, transpileConfig: TranspileConfig): Tokenizer[Infix] = Tokenizer[Infix] { + implicit def infixTokenizer(implicit propertyTokenizer: Tokenizer[Property], strategy: NamingStrategy, idiomContext: IdiomContext): Tokenizer[Infix] = Tokenizer[Infix] { case Infix(parts, params, _, _, _) => val pt = parts.map(_.token) val pr = params.map(_.token) @@ -287,17 +287,17 @@ trait OrientDBIdiom extends Idiom { implicit def externalIdentTokenizer(implicit strategy: NamingStrategy): Tokenizer[ExternalIdent] = Tokenizer[ExternalIdent](e => strategy.default(e.name).token) - implicit def assignmentTokenizer(implicit propertyTokenizer: Tokenizer[Property], strategy: NamingStrategy, transpileConfig: TranspileConfig): Tokenizer[Assignment] = Tokenizer[Assignment] { + implicit def assignmentTokenizer(implicit propertyTokenizer: Tokenizer[Property], strategy: NamingStrategy, idiomContext: IdiomContext): Tokenizer[Assignment] = Tokenizer[Assignment] { case Assignment(alias, prop, value) => stmt"${prop.token} = ${scopedTokenizer(value)}" } - implicit def assignmentDualTokenizer(implicit propertyTokenizer: Tokenizer[Property], strategy: NamingStrategy, transpileConfig: TranspileConfig): Tokenizer[AssignmentDual] = Tokenizer[AssignmentDual] { + implicit def assignmentDualTokenizer(implicit propertyTokenizer: Tokenizer[Property], strategy: NamingStrategy, idiomContext: IdiomContext): Tokenizer[AssignmentDual] = Tokenizer[AssignmentDual] { case AssignmentDual(alias1, alias2, prop, value) => stmt"${prop.token} = ${scopedTokenizer(value)}" } - implicit def actionTokenizer(implicit strategy: NamingStrategy, transpileConfig: TranspileConfig): Tokenizer[Action] = { + implicit def actionTokenizer(implicit strategy: NamingStrategy, idiomContext: IdiomContext): Tokenizer[Action] = { implicit def propertyTokenizer: Tokenizer[Property] = Tokenizer[Property] { case Property(Property.Opinionated(_, name, renameable, _), "isEmpty") => stmt"${renameable.fixedOr(name.token)(strategy.column(name).token)} IS NULL" diff --git a/quill-orientdb/src/test/scala/io/getquill/context/orientdb/OrientDBQuerySpec.scala b/quill-orientdb/src/test/scala/io/getquill/context/orientdb/OrientDBQuerySpec.scala index a5f7a9c500..854f3cc39d 100644 --- a/quill-orientdb/src/test/scala/io/getquill/context/orientdb/OrientDBQuerySpec.scala +++ b/quill-orientdb/src/test/scala/io/getquill/context/orientdb/OrientDBQuerySpec.scala @@ -4,19 +4,18 @@ import io.getquill.ast.{ Action => AstAction, Query => AstQuery, _ } import io.getquill.context.sql._ import io.getquill.idiom.StatementInterpolator._ import io.getquill.idiom.StringToken -import io.getquill.Literal -import io.getquill.Ord +import io.getquill.{ IdiomContext, Literal, Ord } import io.getquill.base.Spec -import io.getquill.norm.TranspileConfig import io.getquill.quat.Quat import io.getquill.util.TraceConfig +import io.getquill.IdiomContext class OrientDBQuerySpec extends Spec { val mirrorContext = orientdb.mirrorContext import mirrorContext._ - implicit val transpileConfig = TranspileConfig.Empty + implicit val idicomContext = IdiomContext.Empty "map" - { "property" in { diff --git a/quill-spark/src/main/scala/io/getquill/context/spark/SparkDialect.scala b/quill-spark/src/main/scala/io/getquill/context/spark/SparkDialect.scala index abf6bf6fda..ac3faa937b 100644 --- a/quill-spark/src/main/scala/io/getquill/context/spark/SparkDialect.scala +++ b/quill-spark/src/main/scala/io/getquill/context/spark/SparkDialect.scala @@ -1,6 +1,6 @@ package io.getquill.context.spark -import io.getquill.NamingStrategy +import io.getquill.{ IdiomContext, NamingStrategy } import io.getquill.ast.{ Ast, BinaryOperation, CaseClass, Constant, ExternalIdent, Ident, Operation, Property, Query, StringOperator, Tuple, Value } import io.getquill.context.spark.norm.EscapeQuestionMarks import io.getquill.context.sql.{ FlattenSqlQuery, SelectValue, SetOperationSqlQuery, SqlQuery, SqlQueryApply, UnaryOperationSqlQuery } @@ -10,14 +10,13 @@ import io.getquill.idiom.StatementInterpolator._ import io.getquill.idiom.Token import io.getquill.util.Messages.trace import io.getquill.context.{ CannotReturn, ExecutionType } -import io.getquill.norm.TranspileConfig import io.getquill.quat.Quat class SparkDialect extends SparkIdiom trait SparkIdiom extends SqlIdiom with CannotReturn { self => - def parentTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, transpileConfig: TranspileConfig) = super.sqlQueryTokenizer + def parentTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, idiomContext: IdiomContext) = super.sqlQueryTokenizer def liftingPlaceholder(index: Int): String = "?" @@ -25,16 +24,16 @@ trait SparkIdiom extends SqlIdiom with CannotReturn { self => override implicit def externalIdentTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy): Tokenizer[ExternalIdent] = super.externalIdentTokenizer - override def translate(ast: Ast, topLevelQuat: Quat, executionType: ExecutionType, transpileConfig: TranspileConfig)(implicit naming: NamingStrategy) = { - val normalizedAst = EscapeQuestionMarks(SqlNormalize(ast, transpileConfig)) + override def translate(ast: Ast, topLevelQuat: Quat, executionType: ExecutionType, idiomContext: IdiomContext)(implicit naming: NamingStrategy) = { + val normalizedAst = EscapeQuestionMarks(SqlNormalize(ast, idiomContext.config)) - implicit val transpileConfigImplicit: TranspileConfig = transpileConfig + implicit val implicitIdiomContext: IdiomContext = idiomContext implicit val tokernizer = defaultTokenizer val token = normalizedAst match { case q: Query => - val sql = new SqlQueryApply(transpileConfig.traceConfig)(q) + val sql = new SqlQueryApply(idiomContext.config.traceConfig)(q) trace("sql")(sql) val expanded = SimpleNestedExpansion(sql) trace("expanded sql")(expanded) @@ -63,8 +62,8 @@ trait SparkIdiom extends SqlIdiom with CannotReturn { self => stmt"${name.token}" } - class SparkFlattenSqlQueryTokenizerHelper(q: FlattenSqlQuery)(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, transpileConfig: TranspileConfig) - extends FlattenSqlQueryTokenizerHelper(q)(astTokenizer, strategy, transpileConfig) { + class SparkFlattenSqlQueryTokenizerHelper(q: FlattenSqlQuery)(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, idiomContext: IdiomContext) + extends FlattenSqlQueryTokenizerHelper(q)(astTokenizer, strategy, idiomContext) { override def selectTokenizer: Token = { // Note that by the time we have reached this point, all Idents representing case classes/tuples in selection have @@ -108,7 +107,7 @@ trait SparkIdiom extends SqlIdiom with CannotReturn { self => } - override implicit def sqlQueryTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, transpileConfig: TranspileConfig): Tokenizer[SqlQuery] = Tokenizer[SqlQuery] { + override implicit def sqlQueryTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy, idiomContext: IdiomContext): Tokenizer[SqlQuery] = Tokenizer[SqlQuery] { case q: FlattenSqlQuery => new SparkFlattenSqlQueryTokenizerHelper(q).apply case SetOperationSqlQuery(a, op, b) => diff --git a/quill-spark/src/test/scala/io/getquill/context/spark/SparkDialectSpec.scala b/quill-spark/src/test/scala/io/getquill/context/spark/SparkDialectSpec.scala index 76799df09e..e580e25c76 100644 --- a/quill-spark/src/test/scala/io/getquill/context/spark/SparkDialectSpec.scala +++ b/quill-spark/src/test/scala/io/getquill/context/spark/SparkDialectSpec.scala @@ -1,9 +1,9 @@ package io.getquill.context.spark -import io.getquill.Literal +import io.getquill.{ IdiomContext, Literal } import io.getquill.base.Spec import io.getquill.context.ExecutionType -import io.getquill.norm.{ SheathLeafClauses, SheathLeafClausesApply, TranspileConfig } +import io.getquill.norm.SheathLeafClausesApply import io.getquill.quat.Quat import io.getquill.util.TraceConfig @@ -23,13 +23,13 @@ class SparkDialectSpec extends Spec { "translate" - { "query" in { val ast = query[Test].ast - val (norm, stmt, _) = SparkDialect.translate(ast, Quat.Unknown, ExecutionType.Unknown, TranspileConfig.Empty)(Literal) + val (norm, stmt, _) = SparkDialect.translate(ast, Quat.Unknown, ExecutionType.Unknown, IdiomContext.Empty)(Literal) norm mustEqual ast stmt.toString mustEqual "SELECT x.i AS i, x.j AS j, x.s AS s FROM Test x" } "non-query" in { val ast = sql"SELECT 1".ast - val (norm, stmt, _) = SparkDialect.translate(ast, Quat.Unknown, ExecutionType.Unknown, TranspileConfig.Empty)(Literal) + val (norm, stmt, _) = SparkDialect.translate(ast, Quat.Unknown, ExecutionType.Unknown, IdiomContext.Empty)(Literal) norm mustEqual ast stmt.toString mustEqual "SELECT 1" } @@ -37,7 +37,7 @@ class SparkDialectSpec extends Spec { "escapes ' " in { val ast = query[Test].map(t => "test'").ast - val (norm, stmt, _) = SparkDialect.translate(ast, Quat.Unknown, ExecutionType.Unknown, TranspileConfig.Empty)(Literal) + val (norm, stmt, _) = SparkDialect.translate(ast, Quat.Unknown, ExecutionType.Unknown, IdiomContext.Empty)(Literal) norm mustEqual ast stmt.toString mustEqual "SELECT 'test\\'' AS x FROM Test t" } @@ -47,7 +47,7 @@ class SparkDialectSpec extends Spec { case class Inner(i: Int) case class Outer(inner: Inner) val ast = query[Outer].filter(t => t.inner.i == 1).ast - val (norm, stmt, _) = SparkDialect.translate(ast, Quat.Unknown, ExecutionType.Unknown, TranspileConfig.Empty)(Literal) + val (norm, stmt, _) = SparkDialect.translate(ast, Quat.Unknown, ExecutionType.Unknown, IdiomContext.Empty)(Literal) norm mustEqual ast stmt.toString mustEqual "SELECT t.inner AS inner FROM Outer t WHERE t.inner.i = 1" } @@ -55,14 +55,14 @@ class SparkDialectSpec extends Spec { // More comprehensive test in MiscQueriesSpec "nested tuple" in { val ast = query[Test].map(t => ((t.i, t.j), t.i + 1)).ast - val (norm, stmt, _) = SparkDialect.translate(ast, Quat.Unknown, ExecutionType.Unknown, TranspileConfig.Empty)(Literal) + val (norm, stmt, _) = SparkDialect.translate(ast, Quat.Unknown, ExecutionType.Unknown, IdiomContext.Empty)(Literal) norm mustEqual ast stmt.toString mustEqual "SELECT struct(t.i AS _1, t.j AS _2) AS _1, t.i + 1 AS _2 FROM Test t" } "concatMap" in { val ast = query[Test].concatMap(t => t.s.split(" ")).ast - val (norm, stmt, _) = SparkDialect.translate(ast, Quat.Unknown, ExecutionType.Unknown, TranspileConfig.Empty)(Literal) + val (norm, stmt, _) = SparkDialect.translate(ast, Quat.Unknown, ExecutionType.Unknown, IdiomContext.Empty)(Literal) norm mustEqual ast stmt.toString mustEqual "SELECT explode(SPLIT(t.s, ' ')) AS x FROM Test t" } @@ -70,21 +70,21 @@ class SparkDialectSpec extends Spec { // More comprehensive test in MiscQueriesSpec "concatMap with filter" in { val ast = query[Test].concatMap(t => t.s.split(" ")).filter(s => s == "s").ast - val (norm, stmt, _) = SparkDialect.translate(ast, Quat.Unknown, ExecutionType.Unknown, TranspileConfig.Empty)(Literal) + val (norm, stmt, _) = SparkDialect.translate(ast, Quat.Unknown, ExecutionType.Unknown, IdiomContext.Empty)(Literal) norm mustEqual new SheathLeafClausesApply(TraceConfig.Empty)(ast) stmt.toString mustEqual "SELECT s.x AS x FROM (SELECT explode(SPLIT(t.s, ' ')) AS x FROM Test t) AS s WHERE s.x = 's'" } "concat string" in { val ast = query[Test].map(t => t.s + " ").ast - val (norm, stmt, _) = SparkDialect.translate(ast, Quat.Unknown, ExecutionType.Unknown, TranspileConfig.Empty)(Literal) + val (norm, stmt, _) = SparkDialect.translate(ast, Quat.Unknown, ExecutionType.Unknown, IdiomContext.Empty)(Literal) norm mustEqual ast stmt.toString mustEqual "SELECT concat(t.s, ' ') AS x FROM Test t" } "groupBy with multiple columns" in { val ast = query[Test].groupBy(t => (t.i, t.j)).map(t => t._2).ast - val (norm, stmt, _) = SparkDialect.translate(ast, Quat.Unknown, ExecutionType.Unknown, TranspileConfig.Empty)(Literal) + val (norm, stmt, _) = SparkDialect.translate(ast, Quat.Unknown, ExecutionType.Unknown, IdiomContext.Empty)(Literal) norm mustEqual ast stmt.toString mustEqual "SELECT t.i AS i, t.j AS j, t.s AS s FROM Test t GROUP BY t.i, t.j" } diff --git a/quill-sql/src/test/scala/io/getquill/context/sql/BatchActionMultiTest.scala b/quill-sql/src/test/scala/io/getquill/context/sql/BatchActionMultiTest.scala index 2cfe5ce9b5..ba113cef77 100644 --- a/quill-sql/src/test/scala/io/getquill/context/sql/BatchActionMultiTest.scala +++ b/quill-sql/src/test/scala/io/getquill/context/sql/BatchActionMultiTest.scala @@ -90,12 +90,14 @@ class BatchActionMultiTest extends Spec { } } - "fallback for non-insert query" - { + "fallback for non-insert query (in a context that doesn't support update)" - { + val ctx: MirrorContext[MySQLDialect, Literal] = new MirrorContext[MySQLDialect, Literal](MySQLDialect, Literal) + import ctx._ val people = List(Person(1, "A", 111), Person(2, "B", 222), Person(3, "C", 333), Person(4, "D", 444), Person(5, "E", 555)) def expect(executionType: ExecutionType) = List( ( - "UPDATE Person AS pt SET id = ?, name = ?, age = ? WHERE pt.id = ?", + "UPDATE Person pt SET id = ?, name = ?, age = ? WHERE pt.id = ?", List(List(1, "A", 111, 1), List(2, "B", 222, 2), List(3, "C", 333, 3), List(4, "D", 444, 4), List(5, "E", 555, 5)), executionType ) @@ -107,6 +109,27 @@ class BatchActionMultiTest extends Spec { } } + "update query" - { + val people = List(Person(1, "A", 111), Person(2, "B", 222), Person(3, "C", 333), Person(4, "D", 444), Person(5, "E", 555)) + def expect(executionType: ExecutionType) = + List( + ( + "UPDATE Person AS pt SET id = p.id1, name = p.name, age = p.age FROM (VALUES (?, ?, ?, ?), (?, ?, ?, ?)) AS p(id, id1, name, age) WHERE pt.id = p.id", + List(List(1, 1, "A", 111, 2, 2, "B", 222), List(3, 3, "C", 333, 4, 4, "D", 444)), + executionType + ), ( + "UPDATE Person AS pt SET id = p.id1, name = p.name, age = p.age FROM (VALUES (?, ?, ?, ?)) AS p(id, id1, name, age) WHERE pt.id = p.id", + List(List(5, 5, "E", 555)), + executionType + ) + ) + + "static" in { + val static = ctx.run(quote(liftQuery(people).foreach(p => updatePeopleById(p))), 2) + static.tripleBatchMulti mustEqual expect(ExecutionType.Unknown) + } + } + "supported contexts" - { val people = List(Person(1, "A", 111), Person(2, "B", 222), Person(3, "C", 333), Person(4, "D", 444), Person(5, "E", 555)) def makeRow(executionType: ExecutionType)(queryA: String, queryB: String) = @@ -157,7 +180,7 @@ class BatchActionMultiTest extends Spec { def expectPostgresReturning(executionType: ExecutionType) = makeRow(executionType)( - "INSERT INTO Person (id,name,age) VALUES (?, ?, ?), (?, ?, ?) RETURNING id", + "INSERT INTO Person (id,name,age) VALUES (?, ?, ?), (?, ?, ?) RETURNING id", // "INSERT INTO Person (id,name,age) VALUES (?, ?, ?) RETURNING id" ) diff --git a/quill-sql/src/test/scala/io/getquill/context/sql/BatchUpdateValuesMirrorSpec.scala b/quill-sql/src/test/scala/io/getquill/context/sql/BatchUpdateValuesMirrorSpec.scala new file mode 100644 index 0000000000..81a67dcf13 --- /dev/null +++ b/quill-sql/src/test/scala/io/getquill/context/sql/BatchUpdateValuesMirrorSpec.scala @@ -0,0 +1,130 @@ +package io.getquill.context.sql + +import io.getquill.context.ExecutionType.Unknown +import io.getquill.context.sql.base.BatchUpdateValuesSpec +import io.getquill.{ Literal, PostgresDialect, SqlMirrorContext } +import io.getquill.context.sql.util.StringOps._ + +class BatchUpdateValuesMirrorSpec extends BatchUpdateValuesSpec { + + val context = new SqlMirrorContext(PostgresDialect, Literal) + import context._ + + "Ex 1 - Simple Contact" in { + import `Ex 1 - Simple Contact`._ + context.run(update, 2).tripleBatchMulti mustEqual List( + ( + "UPDATE Contact AS p SET firstName = ps.firstName1, lastName = ps.lastName, age = ps.age FROM (VALUES (?, ?, ?, ?), (?, ?, ?, ?)) AS ps(firstName, firstName1, lastName, age) WHERE p.firstName = ps.firstName", + List( + List("Joe", "Joe", "BloggsU", 22, "Jan", "Jan", "RoggsU", 33), + List("James", "James", "JonesU", 44, "Dale", "Dale", "DomesU", 55) + ), + Unknown + ), ( + "UPDATE Contact AS p SET firstName = ps.firstName1, lastName = ps.lastName, age = ps.age FROM (VALUES (?, ?, ?, ?)) AS ps(firstName, firstName1, lastName, age) WHERE p.firstName = ps.firstName", + List( + List("Caboose", "Caboose", "CastleU", 66) + ), + Unknown + ) + ) + } + + "Ex 1.1 - Simple Contact With Lift" in { + import `Ex 1.1 - Simple Contact With Lift`._ + context.run(update, 2).tripleBatchMulti mustEqual List( + ( + "UPDATE Contact AS p SET firstName = ps.firstName1, lastName = ps.lastName, age = ps.age FROM (VALUES (?, ?, ?, ?), (?, ?, ?, ?)) AS ps(firstName, firstName1, lastName, age) WHERE p.firstName = ps.firstName AND p.firstName = ?", + List( + List("Joe", "Joe", "BloggsU", 22, "Jan", "Jan", "RoggsU", 33, "Joe"), + List("James", "James", "JonesU", 44, "Dale", "Dale", "DomesU", 55, "Joe") + ), Unknown + ), ( + "UPDATE Contact AS p SET firstName = ps.firstName1, lastName = ps.lastName, age = ps.age FROM (VALUES (?, ?, ?, ?)) AS ps(firstName, firstName1, lastName, age) WHERE p.firstName = ps.firstName AND p.firstName = ?", + List( + List("Caboose", "Caboose", "CastleU", 66, "Joe") + ), Unknown + ) + ) + } + + "Ex 1.2 - Simple Contact With 2 Lifts" in { + import `Ex 1.2 - Simple Contact Mixed Lifts`._ + context.run(update, 2).tripleBatchMulti mustEqual List( + ( + "UPDATE Contact AS p SET lastName = ps.lastName || ? FROM (VALUES (?, ?), (?, ?)) AS ps(firstName, lastName) WHERE p.firstName = ps.firstName AND (p.firstName = ? OR p.firstName = ?)", + List(List(" Jr.", "Joe", "BloggsU", "Jan", "RoggsU", "Joe", "Jan"), List(" Jr.", "James", "JonesU", "Dale", "DomesU", "Joe", "Jan")), Unknown + ), ( + "UPDATE Contact AS p SET lastName = ps.lastName || ? FROM (VALUES (?, ?)) AS ps(firstName, lastName) WHERE p.firstName = ps.firstName AND (p.firstName = ? OR p.firstName = ?)", + List(List(" Jr.", "Caboose", "CastleU", "Joe", "Jan")), Unknown + ) + ) + } + + "Ex 1.3 - Simple Contact With 2 Lifts and Multi-Lift" in { + import `Ex 1.3 - Simple Contact with Multi-Lift-Kinds`._ + context.run(update, 2).tripleBatchMulti mustEqual List( + ( + "UPDATE Contact AS p SET firstName = ps.firstName1, lastName = ps.lastName, age = ps.age FROM (VALUES (?, ?, ?, ?), (?, ?, ?, ?)) AS ps(firstName, firstName1, lastName, age) WHERE p.firstName = ps.firstName AND (p.firstName = ? OR p.firstName IN (?, ?))", + List( + List("Joe", "Joe", "BloggsU", 22, "Jan", "Jan", "RoggsU", 33, "Joe", "Dale", "Caboose"), + List("James", "James", "JonesU", 44, "Dale", "Dale", "DomesU", 55, "Joe", "Dale", "Caboose") + ), Unknown + ), ( + "UPDATE Contact AS p SET firstName = ps.firstName1, lastName = ps.lastName, age = ps.age FROM (VALUES (?, ?, ?, ?)) AS ps(firstName, firstName1, lastName, age) WHERE p.firstName = ps.firstName AND (p.firstName = ? OR p.firstName IN (?, ?))", + List( + List("Caboose", "Caboose", "CastleU", 66, "Joe", "Dale", "Caboose") + ), Unknown + ) + ) + } + + "Ex 2 - Optional Embedded with Renames" in { + import `Ex 2 - Optional Embedded with Renames`._ + context.run(update, 2).tripleBatchMulti.map(_._1.collapseSpace) mustEqual List( + """UPDATE Contact AS p SET lastName = ps.name_last + |FROM (VALUES (?, ?, ?, ?), (?, ?, ?, ?)) AS ps(name_first, name_first1, name_first2, name_last) + |WHERE + | p.firstName IS NULL AND ps.name_first IS NULL OR + | p.firstName IS NOT NULL AND + | ps.name_first1 IS NOT NULL AND + | p.firstName = ps.name_first2 + |""".collapseSpace, + """UPDATE Contact AS p SET lastName = ps.name_last + |FROM (VALUES (?, ?, ?, ?)) AS ps(name_first, name_first1, name_first2, name_last) + |WHERE + | p.firstName IS NULL AND ps.name_first IS NULL OR + | p.firstName IS NOT NULL AND + | ps.name_first1 IS NOT NULL AND + |p.firstName = ps.name_first2 + |""".collapseSpace + ) + } + + "Ex 4 - Returning" in { + import `Ex 4 - Returning`._ + context.run(update, 2).tripleBatchMulti.map(_._1.collapseSpace) mustEqual List( + """UPDATE Contact AS p SET firstName = ps.firstName1, lastName = ps.lastName, age = ps.age + |FROM (VALUES (?, ?, ?, ?), (?, ?, ?, ?)) AS ps(firstName, firstName1, lastName, age) + |WHERE p.firstName = ps.firstName + |RETURNING ps.age + |""".collapseSpace, + """UPDATE Contact AS p SET firstName = ps.firstName1, lastName = ps.lastName, age = ps.age + |FROM (VALUES (?, ?, ?, ?)) AS ps(firstName, firstName1, lastName, age) + |WHERE p.firstName = ps.firstName + |RETURNING ps.age + |""".collapseSpace + ) + } + + "Ex 5 - Append Data" in { + import `Ex 5 - Append Data`._ + context.run(update, 2).tripleBatchMulti mustEqual List( + ( + "UPDATE Contact AS p SET firstName = p.firstName || ps.firstName, lastName = p.lastName || ps.lastName FROM (VALUES (?, ?), (?, ?)) AS ps(firstName, lastName) WHERE p.firstName IN (?, ?, ?, ?, ?)", + List(List("_A", "_B", "_AA", "_BB", "Joe", "Jan", "James", "Dale", "Caboose")), + Unknown + ) + ) + } +} diff --git a/quill-sql/src/test/scala/io/getquill/context/sql/base/BatchUpdateValuesSpec.scala b/quill-sql/src/test/scala/io/getquill/context/sql/base/BatchUpdateValuesSpec.scala new file mode 100644 index 0000000000..afbad82145 --- /dev/null +++ b/quill-sql/src/test/scala/io/getquill/context/sql/base/BatchUpdateValuesSpec.scala @@ -0,0 +1,237 @@ +package io.getquill.context.sql.base + +import io.getquill.base.Spec +import io.getquill.context.sql.SqlContext +import org.scalatest.BeforeAndAfterEach + +trait BatchUpdateValuesSpec extends Spec with BeforeAndAfterEach { + + val context: SqlContext[_, _] + import context._ + + case class ContactBase(firstName: String, lastName: String, age: Int) + val dataBase: List[ContactBase] = List( + ContactBase("Joe", "Bloggs", 22), + ContactBase("A", "A", 111), + ContactBase("B", "B", 111), + ContactBase("Jan", "Roggs", 33), + ContactBase("C", "C", 111), + ContactBase("D", "D", 111), + ContactBase("James", "Jones", 44), + ContactBase("Dale", "Domes", 55), + ContactBase("Caboose", "Castle", 66), + ContactBase("E", "E", 111) + ) + val updatePeople = List("Joe", "Jan", "James", "Dale", "Caboose") + def includeInUpdate(name: String): Boolean = updatePeople.contains(name) + def includeInUpdate(c: ContactBase): Boolean = includeInUpdate(c.firstName) + val updateBase = + dataBase.filter(includeInUpdate(_)).map(r => r.copy(lastName = r.lastName + "U")) + val expectBase = dataBase.map { r => + if (includeInUpdate(r)) r.copy(lastName = r.lastName + "U") else r + } + + trait Adaptable { + type Row + def makeData(c: ContactBase): Row + implicit class AdaptOps(list: List[ContactBase]) { + def adapt: List[Row] = list.map(makeData(_)) + } + lazy val updateData = updateBase.adapt + lazy val expect = expectBase.adapt + lazy val data = dataBase.adapt + } + + object `Ex 1 - Simple Contact` extends Adaptable { + case class Contact(firstName: String, lastName: String, age: Int) + type Row = Contact + override def makeData(c: ContactBase): Contact = Contact(c.firstName, c.lastName, c.age) + + val insert = quote { + liftQuery(data: List[Contact]).foreach(ps => query[Contact].insertValue(ps)) + } + val update = quote { + liftQuery(updateData: List[Contact]).foreach(ps => + query[Contact].filter(p => p.firstName == ps.firstName).updateValue(ps)) + } + val get = quote(query[Contact]) + } + + object `Ex 1.1 - Simple Contact With Lift` extends Adaptable { + case class Contact(firstName: String, lastName: String, age: Int) + type Row = Contact + override def makeData(c: ContactBase): Contact = Contact(c.firstName, c.lastName, c.age) + + val insert = quote { + liftQuery(data: List[Contact]).foreach(ps => query[Contact].insertValue(ps)) + } + val update = quote { + liftQuery(updateData: List[Contact]).foreach(ps => + query[Contact].filter(p => p.firstName == ps.firstName && p.firstName == lift("Joe")).updateValue(ps)) + } + val get = quote(query[Contact]) + override lazy val expect = data.map(p => if (p.firstName == "Joe") p.copy(lastName = p.lastName + "U") else p) + } + + object `Ex 1.2 - Simple Contact Mixed Lifts` extends Adaptable { + case class Contact(firstName: String, lastName: String, age: Int) + type Row = Contact + override def makeData(c: ContactBase): Contact = Contact(c.firstName, c.lastName, c.age) + + val insert = quote { + liftQuery(data: List[Contact]).foreach(ps => query[Contact].insertValue(ps)) + } + val update = quote { + liftQuery(updateData: List[Contact]).foreach(ps => + query[Contact] + .filter(p => p.firstName == ps.firstName && (p.firstName == lift("Joe") || p.firstName == lift("Jan"))) + .update(_.lastName -> (ps.lastName + lift(" Jr.")))) + } + val get = quote(query[Contact]) + override lazy val expect = data.map(p => if (p.firstName == "Joe" || p.firstName == "Jan") p.copy(lastName = p.lastName + "U Jr.") else p) + } + + object `Ex 1.3 - Simple Contact with Multi-Lift-Kinds` extends Adaptable { + case class Contact(firstName: String, lastName: String, age: Int) + type Row = Contact + override def makeData(c: ContactBase): Contact = Contact(c.firstName, c.lastName, c.age) + + val insert = quote { + liftQuery(data: List[Contact]).foreach(ps => query[Contact].insertValue(ps)) + } + val update = quote { + liftQuery(updateData: List[Contact]).foreach(ps => + query[Contact] + .filter(p => p.firstName == ps.firstName && (p.firstName == lift("Joe") || liftQuery(List("Dale", "Caboose")).contains(p.firstName))) + .updateValue(ps)) + } + val get = quote(query[Contact]) + override lazy val expect = data.map(p => if (p.firstName == "Joe" || p.firstName == "Dale" || p.firstName == "Caboose") p.copy(lastName = p.lastName + "U") else p) + } + + object `Ex 2 - Optional Embedded with Renames` extends Adaptable { + case class Name(first: String, last: String) extends Embedded + case class ContactTable(name: Option[Name], age: Int) + type Row = ContactTable + override def makeData(c: ContactBase): ContactTable = ContactTable(Some(Name(c.firstName, c.lastName)), c.age) + + val contacts = quote { + querySchema[ContactTable]("Contact", _.name.map(_.first) -> "firstName", _.name.map(_.last) -> "lastName") + } + + val insert = quote { + liftQuery(data: List[ContactTable]).foreach(ps => contacts.insertValue(ps)) + } + val update = quote { + liftQuery(updateData: List[ContactTable]).foreach(ps => + contacts + .filter(p => p.name.map(_.first) == ps.name.map(_.first)) + .update(_.name.map(_.last) -> ps.name.map(_.last))) + } + val get = quote(contacts) + } + + object `Ex 3 - Deep Embedded Optional` extends Adaptable { + case class FirstName(firstName: Option[String]) extends Embedded + case class LastName(lastName: Option[String]) extends Embedded + case class Name(first: FirstName, last: LastName) extends Embedded + case class Contact(name: Option[Name], age: Int) + type Row = Contact + override def makeData(c: ContactBase): Contact = Contact(Some(Name(FirstName(Option(c.firstName)), LastName(Option(c.lastName)))), c.age) + + val insert = quote { + liftQuery(data: List[Contact]).foreach(ps => query[Contact].insertValue(ps)) + } + val update = quote { + liftQuery(updateData: List[Contact]).foreach(ps => + query[Contact] + .filter(p => p.name.map(_.first.firstName) == ps.name.map(_.first.firstName)) + .update(_.name.map(_.last.lastName) -> ps.name.map(_.last.lastName))) + } + val get = quote(query[Contact]) + } + + object `Ex 4 - Returning` extends Adaptable { + case class Contact(firstName: String, lastName: String, age: Int) + type Row = Contact + override def makeData(c: ContactBase): Contact = Contact(c.firstName, c.lastName, c.age) + + val insert = quote { + liftQuery(data: List[Contact]).foreach(ps => query[Contact].insertValue(ps)) + } + val update = quote { + liftQuery(updateData: List[Contact]).foreach(ps => + query[Contact].filter(p => p.firstName == ps.firstName).updateValue(ps).returning(_.age)) + } + val expectedReturn = updateData.map(_.age) + val get = quote(query[Contact]) + } + + object `Ex 4 - Returning Multiple` extends Adaptable { + case class Contact(firstName: String, lastName: String, age: Int) + type Row = Contact + override def makeData(c: ContactBase): Contact = Contact(c.firstName, c.lastName, c.age) + + val insert = quote { + liftQuery(data: List[Contact]).foreach(ps => query[Contact].insertValue(ps)) + } + val update = quote { + liftQuery(updateData: List[Contact]).foreach(ps => + query[Contact].filter(p => p.firstName == ps.firstName).updateValue(ps).returning(r => (r.lastName, r.age))) + } + val expectedReturn = updateData.map(r => (r.lastName, r.age)) + val get = quote(query[Contact]) + } + + object `Ex 5 - Append Data` extends Adaptable { + case class Contact(firstName: String, lastName: String, age: Int) + type Row = Contact + override def makeData(c: ContactBase): Contact = Contact(c.firstName, c.lastName, c.age) + val insert = quote { + liftQuery(data: List[Contact]).foreach(ps => query[Contact].insertValue(ps)) + } + val updateDataSpecific = List( + Contact("_A", "_B", 22), + Contact("_AA", "_BB", 22) + ) + val update = quote { + liftQuery(updateDataSpecific: List[Contact]).foreach(ps => + query[Contact] + .filter(p => liftQuery(updatePeople).contains(p.firstName)) + .update(pa => pa.firstName -> (pa.firstName + ps.firstName), pb => pb.lastName -> (pb.lastName + ps.lastName))) + } + + val expectSpecific = (data: List[Contact]) + .map(r => { + if (includeInUpdate(r.firstName)) { + // Not sure why the 1nd part i.e. _AA, _BB is not tacked on yet. Something odd about how postgres processes updates + // Note that this happens even with a batch-group-size of 1 + r.copy(firstName = s"${r.firstName}_A", lastName = s"${r.lastName}_B") + } else + r + }) + val get = quote(query[Contact]) + } + + object `Ex 6 - Append Data No Condition` extends Adaptable { + case class Contact(firstName: String, lastName: String, age: Int) + type Row = Contact + override def makeData(c: ContactBase): Contact = Contact(c.firstName, c.lastName, c.age) + val insert = quote { + liftQuery(data: List[Contact]).foreach(ps => query[Contact].insertValue(ps)) + } + val updateDataSpecific = List( + Contact("_A", "_B", 22), + Contact("_AA", "_BB", 22) + ) + val update = quote { + liftQuery(updateDataSpecific: List[Contact]).foreach(ps => + query[Contact].update(pa => pa.firstName -> (pa.firstName + ps.firstName), pb => pb.lastName -> (pb.lastName + ps.lastName))) + } + + // Not sure why the 1nd part i.e. _AA, _BB is not tacked on yet. Something odd about how postgres processes updates + // Note that this happens even with a batch-group-size of 1 + val expectSpecific = (data: List[Contact]).map(r => r.copy(firstName = s"${r.firstName}_A", lastName = s"${r.lastName}_B")) + val get = quote(query[Contact]) + } +} \ No newline at end of file diff --git a/quill-sql/src/test/scala/io/getquill/context/sql/base/BatchValuesSpec.scala b/quill-sql/src/test/scala/io/getquill/context/sql/base/BatchValuesSpec.scala index 2eb2144b9d..a3f8707473 100644 --- a/quill-sql/src/test/scala/io/getquill/context/sql/base/BatchValuesSpec.scala +++ b/quill-sql/src/test/scala/io/getquill/context/sql/base/BatchValuesSpec.scala @@ -45,4 +45,18 @@ trait BatchValuesSpec extends Spec with BeforeAndAfterEach { def get = quote { query[Product] } def result = productsOriginal } + + object `Ex 3 - Batch Insert Mixed` { + val products = makeProducts(20) + val batchSize = 40 + def op = quote { + liftQuery(products).foreach(p => query[Product].insert(_.id -> p.id, _.description -> lift("BlahBlah"), _.sku -> p.sku)) + } + def opExt = quote { + (transform: Insert[Product] => Insert[Product]) => + liftQuery(products).foreach(p => transform(query[Product].insert(_.id -> p.id, _.description -> lift("BlahBlah"), _.sku -> p.sku))) + } + def get = quote { query[Product] } + def result = products.map(_.copy(description = "BlahBlah")) + } }