From f3dd8cb2904a6f9b5b8724553292339372365bf7 Mon Sep 17 00:00:00 2001 From: Alexander Ioffe Date: Mon, 9 May 2022 18:38:56 -0400 Subject: [PATCH] Integrating Doobie Support (#98) --- build.sbt | 16 +- build/setup_bigdata.sh | 2 +- build/setup_databases.sh | 8 +- build/setup_db_scripts.sh | 62 +++-- build/setup_local.sh | 13 +- build/setup_mysql_postgres_databases.sh | 4 +- .../io/getquill/doobie/DoobieContext.scala | 32 +++ .../getquill/doobie/DoobieContextBase.scala | 237 ++++++++++++++++ .../doobie/PostgresDoobieContextSuite.scala | 106 ++++++++ .../io/getquill/doobie/issue/Issue1067.scala | 35 +++ .../context/BatchQueryExecution.scala | 253 +++++++++++++++--- .../scala/io/getquill/context/LiftMacro.scala | 6 + .../io/getquill/metaprog/ExprModel.scala | 3 +- .../scala/io/getquill/BatchActionTest.scala | 34 +++ .../src/test/sql/postgres-doobie-schema.sql | 87 ++++++ 15 files changed, 815 insertions(+), 83 deletions(-) create mode 100644 quill-doobie/src/main/scala/io/getquill/doobie/DoobieContext.scala create mode 100644 quill-doobie/src/main/scala/io/getquill/doobie/DoobieContextBase.scala create mode 100644 quill-doobie/src/test/scala/io/getquill/doobie/PostgresDoobieContextSuite.scala create mode 100644 quill-doobie/src/test/scala/io/getquill/doobie/issue/Issue1067.scala create mode 100644 quill-sql/src/test/sql/postgres-doobie-schema.sql diff --git a/build.sbt b/build.sbt index f4074eb20..be2be1be8 100644 --- a/build.sbt +++ b/build.sbt @@ -49,7 +49,7 @@ lazy val sqlTestModules = Seq[sbt.ClasspathDep[sbt.ProjectReference]]( ) lazy val dbModules = Seq[sbt.ClasspathDep[sbt.ProjectReference]]( - `quill-jdbc`, `quill-zio`, `quill-jdbc-zio`, `quill-caliban` + `quill-jdbc`, `quill-doobie`, `quill-zio`, `quill-jdbc-zio`, `quill-caliban` ) lazy val jasyncModules = Seq[sbt.ClasspathDep[sbt.ProjectReference]]( @@ -169,6 +169,20 @@ lazy val `quill-jdbc` = .settings(jdbcTestingSettings: _*) .dependsOn(`quill-sql` % "compile->compile;test->test") +ThisBuild / libraryDependencySchemes += "org.typelevel" %% "cats-effect" % "always" +lazy val `quill-doobie` = + (project in file("quill-doobie")) + .settings(commonSettings: _*) + .settings(releaseSettings: _*) + .settings(jdbcTestingSettings: _*) + .settings( + libraryDependencies ++= Seq( + "org.tpolecat" %% "doobie-core" % "1.0.0-RC2", + "org.tpolecat" %% "doobie-postgres" % "1.0.0-RC2" % Test + ) + ) + .dependsOn(`quill-jdbc` % "compile->compile;test->test") + lazy val `quill-jasync` = (project in file("quill-jasync")) .settings(commonSettings: _*) diff --git a/build/setup_bigdata.sh b/build/setup_bigdata.sh index 4aba2e435..ff6bdf7b5 100755 --- a/build/setup_bigdata.sh +++ b/build/setup_bigdata.sh @@ -10,6 +10,6 @@ time docker-compose up -d cassandra orientdb # setup cassandra in docker send_script cassandra $CASSANDRA_SCRIPT cassandra-schema.cql send_script cassandra ./build/setup_db_scripts.sh setup_db_scripts.sh -time docker-compose exec -T cassandra bash -c ". setup_db_scripts.sh && setup_cassandra cassandra-schema.cql 127.0.0.1" +time docker-compose exec -T cassandra bash -c ". setup_db_scripts.sh && setup_cassandra 127.0.0.1 cassandra-schema.cql" echo "Databases are ready!" \ No newline at end of file diff --git a/build/setup_databases.sh b/build/setup_databases.sh index 5a2fb94e2..68408b2df 100755 --- a/build/setup_databases.sh +++ b/build/setup_databases.sh @@ -18,17 +18,17 @@ echo "### Sourcing DB Scripts ###" # run setup scripts for local databases echo "### Running Setup for sqlite ###" -time setup_sqlite $SQLITE_SCRIPT 127.0.0.1 +time setup_sqlite 127.0.0.1 echo "### Running Setup for mysql ###" -time setup_mysql $MYSQL_SCRIPT 127.0.0.1 13306 +time setup_mysql 127.0.0.1 13306 echo "### Running Setup for postgres ###" -time setup_postgres $POSTGRES_SCRIPT 127.0.0.1 15432 +time setup_postgres 127.0.0.1 15432 echo "### Running Setup for sqlserver ###" # setup sqlserver in docker send_script sqlserver $SQL_SERVER_SCRIPT sqlserver-schema.sql send_script sqlserver ./build/setup_db_scripts.sh setup_db_scripts.sh -time docker-compose exec -T sqlserver bash -c ". setup_db_scripts.sh && setup_sqlserver sqlserver-schema.sql 127.0.0.1" +time docker-compose exec -T sqlserver bash -c ". setup_db_scripts.sh && setup_sqlserver 127.0.0.1 sqlserver-schema.sql" # Can't do absolute paths here so need to do relative mkdir sqlline/ diff --git a/build/setup_db_scripts.sh b/build/setup_db_scripts.sh index fbd29a157..092c50794 100755 --- a/build/setup_db_scripts.sh +++ b/build/setup_db_scripts.sh @@ -3,6 +3,7 @@ export SQLITE_SCRIPT=quill-jdbc/src/test/resources/sql/sqlite-schema.sql export MYSQL_SCRIPT=quill-sql/src/test/sql/mysql-schema.sql export POSTGRES_SCRIPT=quill-sql/src/test/sql/postgres-schema.sql +export POSTGRES_DOOBIE_SCRIPT=quill-sql/src/test/sql/postgres-doobie-schema.sql export SQL_SERVER_SCRIPT=quill-sql/src/test/sql/sqlserver-schema.sql export ORACLE_SCRIPT=quill-sql/src/test/sql/oracle-schema.sql export CASSANDRA_SCRIPT=quill-cassandra/src/test/cql/cassandra-schema.cql @@ -24,22 +25,22 @@ function setup_sqlite() { echo "Removing Previous sqlite DB File (if any)" rm -f $DB_FILE echo "Creating sqlite DB File" - echo "(with the $1 script)" - sqlite3 $DB_FILE < $1 + echo "(with the $SQLITE_SCRIPT script)" + sqlite3 $DB_FILE < $SQLITE_SCRIPT echo "Setting permissions on sqlite DB File" chmod a+rw $DB_FILE - # # DB File in quill-jdbc-monix - # DB_FILE=quill-jdbc-monix/quill_test.db - # rm -f $DB_FILE - # sqlite3 $DB_FILE < $1 - # chmod a+rw $DB_FILE + # # DB File in quill-jdbc-monix + # DB_FILE=quill-jdbc-monix/quill_test.db + # rm -f $DB_FILE + # sqlite3 $DB_FILE < $SQLITE_SCRIPT + # chmod a+rw $DB_FILE echo "Sqlite ready!" } function setup_mysql() { - port=$3 + port=$2 password='' if [ -z "$port" ]; then echo "MySQL Port not defined. Setting to default: 3306 " @@ -48,7 +49,7 @@ function setup_mysql() { echo "MySQL Port specified as $port" fi - connection=$2 + connection=$1 MYSQL_ROOT_PASSWORD=root echo "Waiting for MySql" @@ -68,7 +69,7 @@ function setup_mysql() { echo "MySql: Create quill_test" mysql --protocol=tcp --host=$connection --password="$MYSQL_ROOT_PASSWORD" --port=$port -u root -e "CREATE DATABASE quill_test;" echo "MySql: Write Schema to quill_test" - mysql --protocol=tcp --host=$connection --password="$MYSQL_ROOT_PASSWORD" --port=$port -u root quill_test < $1 + mysql --protocol=tcp --host=$connection --password="$MYSQL_ROOT_PASSWORD" --port=$port -u root quill_test < $MYSQL_SCRIPT echo "MySql: Create finagle user" mysql --protocol=tcp --host=$connection --password="$MYSQL_ROOT_PASSWORD" --port=$port -u root -e "CREATE USER 'finagle'@'%' IDENTIFIED BY 'finagle';" echo "MySql: Grant finagle user" @@ -78,7 +79,8 @@ function setup_mysql() { } function setup_postgres() { - port=$3 + port=$2 + host=$1 if [ -z "$port" ]; then echo "Postgres Port not defined. Setting to default: 5432" port="5432" @@ -86,45 +88,49 @@ function setup_postgres() { echo "Postgres Port specified as $port" fi echo "Waiting for Postgres" - until psql --host $2 --port $port --username postgres -c "select 1" &> /dev/null; do - echo "## Tapping Postgres Connection> psql --host $2 --port $port --username postgres -c 'select 1'" - psql --host $2 --port $port --username postgres -c "select 1" || true + until psql --host $host --port $port --username postgres -c "select 1" &> /dev/null; do + echo "## Tapping Postgres Connection> psql --host $host --port $port --username postgres -c 'select 1'" + psql --host $host --port $port --username postgres -c "select 1" || true sleep 5; done echo "Connected to Postgres" echo "Postgres: Create codegen_test" - psql --host $2 --port $port -U postgres -c "CREATE DATABASE codegen_test" + psql --host $host --port $port -U postgres -c "CREATE DATABASE codegen_test" echo "Postgres: Create quill_test" - psql --host $2 --port $port -U postgres -c "CREATE DATABASE quill_test" + psql --host $host --port $port -U postgres -c "CREATE DATABASE quill_test" echo "Postgres: Write Schema to quill_test" - psql --host $2 --port $port -U postgres -d quill_test -a -q -f $1 + psql --host $host --port $port -U postgres -d quill_test -a -q -f $POSTGRES_SCRIPT + echo "Postgres: Create doobie_test" + psql --host $host --port $port -U postgres -c "CREATE DATABASE doobie_test" + echo "Postgres: Write Schema to doobie_test" + psql --host $host --port $port -U postgres -d doobie_test -a -q -f $POSTGRES_DOOBIE_SCRIPT } function setup_cassandra() { - host=$(get_host $2) + host=$(get_host $1) echo "Waiting for Cassandra" - until cqlsh $2 -e "describe cluster" &> /dev/null; do + until cqlsh $1 -e "describe cluster" &> /dev/null; do sleep 5; done echo "Connected to Cassandra" - cqlsh $2 -f $1 + cqlsh $1 -f $2 } function setup_sqlserver() { - host=$(get_host $2) + host=$(get_host $1) echo "Waiting for SqlServer" - until /opt/mssql-tools/bin/sqlcmd -S $2 -U SA -P "QuillRocks!" -Q "select 1" &> /dev/null; do + until /opt/mssql-tools/bin/sqlcmd -S $1 -U SA -P "QuillRocks!" -Q "select 1" &> /dev/null; do sleep 5; done echo "Connected to SqlServer" - /opt/mssql-tools/bin/sqlcmd -S $2 -U SA -P "QuillRocks!" -Q "CREATE DATABASE codegen_test" - /opt/mssql-tools/bin/sqlcmd -S $2 -U SA -P "QuillRocks!" -Q "CREATE DATABASE alpha" - /opt/mssql-tools/bin/sqlcmd -S $2 -U SA -P "QuillRocks!" -Q "CREATE DATABASE bravo" - /opt/mssql-tools/bin/sqlcmd -S $2 -U SA -P "QuillRocks!" -Q "CREATE DATABASE quill_test" - /opt/mssql-tools/bin/sqlcmd -S $2 -U SA -P "QuillRocks!" -d quill_test -i $1 + /opt/mssql-tools/bin/sqlcmd -S $1 -U SA -P "QuillRocks!" -Q "CREATE DATABASE codegen_test" + /opt/mssql-tools/bin/sqlcmd -S $1 -U SA -P "QuillRocks!" -Q "CREATE DATABASE alpha" + /opt/mssql-tools/bin/sqlcmd -S $1 -U SA -P "QuillRocks!" -Q "CREATE DATABASE bravo" + /opt/mssql-tools/bin/sqlcmd -S $1 -U SA -P "QuillRocks!" -Q "CREATE DATABASE quill_test" + /opt/mssql-tools/bin/sqlcmd -S $1 -U SA -P "QuillRocks!" -d quill_test -i $2 } # Do a simple necat poll to make sure the oracle database is ready. @@ -132,7 +138,7 @@ function setup_sqlserver() { # by the container and docker-compose steps. function setup_oracle() { - while ! nc -z $2 1521; do + while ! nc -z $1 1521; do echo "Waiting for Oracle" sleep 2; done; diff --git a/build/setup_local.sh b/build/setup_local.sh index 67c07c4e4..d37a517df 100755 --- a/build/setup_local.sh +++ b/build/setup_local.sh @@ -6,13 +6,14 @@ set -e . /app/build/setup_db_scripts.sh -time setup_mysql $MYSQL_SCRIPT mysql -time setup_postgres $POSTGRES_SCRIPT postgres -time setup_cassandra $CASSANDRA_SCRIPT cassandra -time setup_sqlserver $SQL_SERVER_SCRIPT sqlserver -time setup_oracle $ORACLE_SCRIPT oracle +time setup_mysql mysql +time setup_postgres postgres +time setup_cassandra cassandra $CASSANDRA_SCRIPT +# SQL Server needs to be passed different script paths based on environment. Therefore it has a 2nd arg. +time setup_sqlserver sqlserver $SQL_SERVER_SCRIPT +time setup_oracle oracle # TODO Move this back up to the top. This is failing for now but want mysql to succeed -time setup_sqlite $SQLITE_SCRIPT +time setup_sqlite echo "Databases are ready!" \ No newline at end of file diff --git a/build/setup_mysql_postgres_databases.sh b/build/setup_mysql_postgres_databases.sh index 75e164f1c..cfe2b4c22 100755 --- a/build/setup_mysql_postgres_databases.sh +++ b/build/setup_mysql_postgres_databases.sh @@ -15,7 +15,7 @@ docker ps . build/setup_db_scripts.sh # run setup scripts for local databases -time setup_mysql $MYSQL_SCRIPT 127.0.0.1 13306 -time setup_postgres $POSTGRES_SCRIPT 127.0.0.1 15432 +time setup_mysql 127.0.0.1 13306 +time setup_postgres 127.0.0.1 15432 echo "Postgres and MySQL Databases are ready!" \ No newline at end of file diff --git a/quill-doobie/src/main/scala/io/getquill/doobie/DoobieContext.scala b/quill-doobie/src/main/scala/io/getquill/doobie/DoobieContext.scala new file mode 100644 index 000000000..50c363def --- /dev/null +++ b/quill-doobie/src/main/scala/io/getquill/doobie/DoobieContext.scala @@ -0,0 +1,32 @@ +package io.getquill.doobie + +import io.getquill._ +import io.getquill.context.jdbc._ + +object DoobieContext { + + class H2[N <: NamingStrategy](val naming: N) + extends DoobieContextBase[H2Dialect, N] + with H2JdbcTypes[N] + + class MySQL[N <: NamingStrategy](val naming: N) + extends DoobieContextBase[MySQLDialect, N] + with MysqlJdbcTypes[N] + + class Oracle[N <: NamingStrategy](val naming: N) + extends DoobieContextBase[OracleDialect, N] + with OracleJdbcTypes[N] + + class Postgres[N <: NamingStrategy](val naming: N) + extends DoobieContextBase[PostgresDialect, N] + with PostgresJdbcTypes[N] + + class SQLite[N <: NamingStrategy](val naming: N) + extends DoobieContextBase[SqliteDialect, N] + with SqliteJdbcTypes[N] + + class SQLServer[N <: NamingStrategy](val naming: N) + extends DoobieContextBase[SQLServerDialect, N] + with SqlServerJdbcTypes[N] + +} diff --git a/quill-doobie/src/main/scala/io/getquill/doobie/DoobieContextBase.scala b/quill-doobie/src/main/scala/io/getquill/doobie/DoobieContextBase.scala new file mode 100644 index 000000000..6aa72a05b --- /dev/null +++ b/quill-doobie/src/main/scala/io/getquill/doobie/DoobieContextBase.scala @@ -0,0 +1,237 @@ +package io.getquill.doobie + +import cats.data.Nested +import cats.syntax.all._ +import cats.free.Free +import doobie.free.connection.ConnectionOp +import doobie.{Query => DQuery, _} +import doobie.implicits._ +import doobie.util.query.DefaultChunkSize +import fs2.Stream +import io.getquill.context.sql.idiom.SqlIdiom +import io.getquill.context.ContextVerbStream +import io.getquill.context.ExecutionInfo +import java.sql.Connection +import scala.util.Success +import scala.util.Try +import doobie.enumerated.AutoGeneratedKeys +import io.getquill.ReturnAction.ReturnColumns +import io.getquill.ReturnAction.ReturnNothing +import io.getquill.ReturnAction.ReturnRecord +import io.getquill._ +import io.getquill.context.jdbc.JdbcContextBase +import io.getquill.util.ContextLogger +import scala.language.implicitConversions +import io.getquill.context.jdbc.JdbcContextTypes +import io.getquill.context.ProtoContext +import scala.annotation.targetName + +/** Base trait from which vendor-specific variants are derived. */ +trait DoobieContextBase[Dialect <: SqlIdiom, Naming <: NamingStrategy] + extends JdbcContextTypes[Dialect, Naming] + with ProtoContext[Dialect, Naming] + with ContextVerbStream[Dialect, Naming] { + + override type Result[A] = ConnectionIO[A] + override type RunQueryResult[A] = List[A] + override type RunQuerySingleResult[A] = A + override type StreamResult[A] = Stream[ConnectionIO, A] + override type RunActionResult = Long + override type RunActionReturningResult[A] = A + override type RunBatchActionResult = List[Long] + override type RunBatchActionReturningResult[A] = List[A] + + override type Runner = Unit + override protected def context: Runner = () + + @targetName("runQueryDefault") + inline def run[T](inline quoted: Quoted[Query[T]]): ConnectionIO[List[T]] = InternalApi.runQueryDefault(quoted) + @targetName("runQuery") + inline def run[T](inline quoted: Quoted[Query[T]], inline wrap: OuterSelectWrap): ConnectionIO[List[T]] = InternalApi.runQuery(quoted, wrap) + @targetName("runQuerySingle") + inline def run[T](inline quoted: Quoted[T]): ConnectionIO[T] = InternalApi.runQuerySingle(quoted) + @targetName("runAction") + inline def run[E](inline quoted: Quoted[Action[E]]): ConnectionIO[Long] = InternalApi.runAction(quoted) + @targetName("runActionReturning") + inline def run[E, T](inline quoted: Quoted[ActionReturning[E, T]]): ConnectionIO[T] = InternalApi.runActionReturning[E, T](quoted) + @targetName("runBatchAction") + inline def run[I, A <: Action[I] & QAC[I, Nothing]](inline quoted: Quoted[BatchAction[A]]): ConnectionIO[List[Long]] = InternalApi.runBatchAction(quoted) + @targetName("runBatchActionReturning") + inline def run[I, T, A <: Action[I] & QAC[I, T]](inline quoted: Quoted[BatchAction[A]]): ConnectionIO[List[T]] = InternalApi.runBatchActionReturning(quoted) + + // Logging behavior should be identical to JdbcContextBase.scala, which includes a couple calls + // to log.underlying below. + private val log: ContextLogger = new ContextLogger("DoobieContext") + + private def useConnection[A](f: Connection => PreparedStatementIO[A]): PreparedStatementIO[A] = + FPS.getConnection.flatMap(f) + + private def prepareAndLog( + sql: String, + p: Prepare, + )( + implicit connection: Connection + ): PreparedStatementIO[Unit] = FPS.raw(p(_, connection)).flatMap { case (params, _) => + FPS.delay(log.logQuery(sql, params)) + } + + override def executeQuery[A]( + sql: String, + prepare: Prepare = identityPrepare, + extractor: Extractor[A] = identityExtractor, + )( + info: ExecutionInfo, + dc: Runner + ): ConnectionIO[List[A]] = + HC.prepareStatement(sql) { + useConnection { implicit connection => + prepareAndLog(sql, prepare) *> + HPS.executeQuery { + HRS.list(extractor) + } + } + } + + override def executeQuerySingle[A]( + sql: String, + prepare: Prepare = identityPrepare, + extractor: Extractor[A] = identityExtractor, + )( + info: ExecutionInfo, + dc: Runner + ): ConnectionIO[A] = + HC.prepareStatement(sql) { + useConnection { implicit connection => + prepareAndLog(sql, prepare) *> + HPS.executeQuery { + HRS.getUnique(extractor) + } + } + } + + def streamQuery[A]( + fetchSize: Option[Int], + sql: String, + prepare: Prepare = identityPrepare, + extractor: Extractor[A] = identityExtractor, + )( + info: ExecutionInfo, + dc: Runner + ): Stream[ConnectionIO, A] = + for { + connection <- Stream.eval(FC.raw(identity)) + result <- + HC.stream( + sql, + prepareAndLog(sql, prepare)(connection), + fetchSize.getOrElse(DefaultChunkSize), + )(extractorToRead(extractor)(connection)) + } yield result + + override def executeAction( + sql: String, + prepare: Prepare = identityPrepare, + )(info: ExecutionInfo, dc: Runner): ConnectionIO[Long] = + HC.prepareStatement(sql) { + useConnection { implicit connection => + prepareAndLog(sql, prepare) *> + HPS.executeUpdate.map(_.toLong) + } + } + + private def prepareConnections[A](returningBehavior: ReturnAction) = + returningBehavior match { + case ReturnColumns(columns) => (sql: String) => HC.prepareStatementS[A](sql, columns)(_) + case ReturnRecord => + (sql: String) => HC.prepareStatement[A](sql, AutoGeneratedKeys.ReturnGeneratedKeys)(_) + case ReturnNothing => (sql: String) => HC.prepareStatement[A](sql)(_) + } + + override def executeActionReturning[A]( + sql: String, + prepare: Prepare = identityPrepare, + extractor: Extractor[A], + returningBehavior: ReturnAction, + )( + info: ExecutionInfo, + dc: Runner, + ): ConnectionIO[A] = + prepareConnections[A](returningBehavior)(sql) { + useConnection { implicit connection => + prepareAndLog(sql, prepare) *> + FPS.executeUpdate *> + HPS.getGeneratedKeys(HRS.getUnique(extractor)) + } + } + + private def prepareBatchAndLog( + sql: String, + p: Prepare, + )( + implicit connection: Connection + ): PreparedStatementIO[Unit] = + FPS.raw(p(_, connection)) flatMap { case (params, _) => + FPS.delay(log.logBatchItem(sql, params)) + } + + override def executeBatchAction( + groups: List[BatchGroup] + )( + info: ExecutionInfo, + dc: Runner + ): ConnectionIO[List[Long]] = groups.flatTraverse { case BatchGroup(sql, preps) => + HC.prepareStatement(sql) { + useConnection { implicit connection => + for { + _ <- FPS.delay(log.underlying.debug("Batch: {}", sql)) + _ <- preps.traverse(prepareBatchAndLog(sql, _) *> FPS.addBatch) + r <- Nested(HPS.executeBatch).value.map(_.map(_.toLong)) + } yield r + } + } + } + + override def executeBatchActionReturning[A]( + groups: List[BatchGroupReturning], + extractor: Extractor[A], + )( + info: ExecutionInfo, + dc: Runner + ): ConnectionIO[List[A]] = groups.flatTraverse { + case BatchGroupReturning(sql, returningBehavior, preps) => + prepareConnections(returningBehavior)(sql) { + + useConnection { implicit connection => + for { + _ <- FPS.delay(log.underlying.debug("Batch: {}", sql)) + _ <- preps.traverse(prepareBatchAndLog(sql, _) *> FPS.addBatch) + _ <- HPS.executeBatch + r <- HPS.getGeneratedKeys(HRS.list(extractor)) + } yield r + } + } + } + + // Turn an extractor into a `Read` so we can use the existing resultset. + private implicit def extractorToRead[A]( + ex: Extractor[A] + )( + implicit connection: Connection + ): Read[A] = new Read[A](Nil, (rs, _) => ex(rs, connection)) + + // Nothing to do here. + override def close(): Unit = () + + // Dotty Quill does not support probing yet. + // override def probe(statement: String): Try[_] = Success(()) + + // Don't need this for our particular override + // override protected def withConnection[A](f: Connection => ConnectionIO[A]): ConnectionIO[A] = ??? + + protected val effect = null + + def wrap[T](t: => T): Free[ConnectionOp, T] = Free.pure(t) + def push[A, B](result: Free[ConnectionOp, A])(f: A => B): Free[ConnectionOp, B] = result.map(f(_)) + def seq[A](list: List[Free[ConnectionOp, A]]): Free[ConnectionOp, List[A]] = + list.sequence[[L] =>> Free[ConnectionOp, L], A] +} diff --git a/quill-doobie/src/test/scala/io/getquill/doobie/PostgresDoobieContextSuite.scala b/quill-doobie/src/test/scala/io/getquill/doobie/PostgresDoobieContextSuite.scala new file mode 100644 index 000000000..853ab6f34 --- /dev/null +++ b/quill-doobie/src/test/scala/io/getquill/doobie/PostgresDoobieContextSuite.scala @@ -0,0 +1,106 @@ +package io.getquill.doobie + +import cats.effect._ +import cats.syntax.all._ +import doobie._ +import doobie.implicits._ +import io.getquill._ +import org.scalatest.freespec.AnyFreeSpec +import org.scalatest.matchers.must.Matchers + +class PostgresDoobieContextSuite extends AnyFreeSpec with Matchers { + + // Logging should appear in test output + sys.props.put("quill.binds.log", "true") + sys.props.put("org.slf4j.simpleLogger.defaultLogLevel", "debug") + + import cats.effect.unsafe.implicits.global + + // A transactor that always rolls back. + lazy val xa = Transactor + .after + .set( + Transactor.fromDriverManager[IO]( + "org.postgresql.Driver", + s"jdbc:postgresql://${System.getenv("POSTGRES_HOST")}:${System.getenv("POSTGRES_PORT")}/doobie_test", + "postgres", + System.getenv("POSTGRES_PASSWORD") + ), + HC.rollback, + ) + + val dc = new DoobieContext.Postgres[Literal](Literal) + + import dc._ + + case class Country(code: String, name: String, population: Int) + + "executeQuery should correctly select a country" in { + inline def stmt = quote(query[Country].filter(_.code == "GBR")) + val actual = dc.run(stmt).transact(xa).unsafeRunSync() + val expected = List(Country("GBR", "United Kingdom", 59623400)) + actual mustEqual expected + } + + "executeQuerySingle should correctly select a constant" in { + inline def stmt = quote(42) + val actual = dc.run(stmt).transact(xa).unsafeRunSync() + val expected = 42 + actual mustEqual expected + } + + "streamQuery should correctly stream a bunch of countries" in { + inline def stmt = quote(query[Country]) + val actual = dc.stream(stmt, 2).transact(xa).as(1).compile.foldMonoid.unsafeRunSync() + val expected = 4 // this many countries total + actual mustEqual expected + } + + "executeAction should correctly update a bunch of countries" in { + inline def stmt = quote(query[Country].filter(_.name like "U%").update(_.name -> "foo")) + val actual = dc.run(stmt).transact(xa).unsafeRunSync() + val expected = 2L // this many countries start with 'U' + actual mustEqual expected + } + + "executeBatchAction should correctly do multiple updates" in { + val list = List("U%", "I%") + inline def stmt = quote { + liftQuery(list).foreach { pat => + query[Country].filter(_.name like pat).update(_.name -> "foo") + } + } + val actual = dc.run(stmt).transact(xa).unsafeRunSync() + val expected = List(2L, 1L) + actual mustEqual expected + } + + // For these last two we need a new table with an auto-generated id, so we'll do a temp table. + val create: ConnectionIO[Unit] = + sql""" + CREATE TEMPORARY TABLE QuillTest ( + id SERIAL, + value VARCHAR(42) + ) ON COMMIT DROP + """.update.run.void + + case class QuillTest(id: Int, value: String) + + "executeActionReturning should correctly retrieve a generated key" in { + inline def stmt = quote(query[QuillTest].insertValue(lift(QuillTest(0, "Joe"))).returningGenerated(_.id)) + val actual = (create *> dc.run(stmt)).transact(xa).unsafeRunSync() + val expected = 1 + actual mustEqual expected + } + + "executeBatchActionReturning should correctly retrieve a list of generated keys" in { + val values = List(QuillTest(0, "Foo"), QuillTest(0, "Bar"), QuillTest(0, "Baz")) + inline def stmt = quote { + liftQuery(values).foreach(a => query[QuillTest].insertValue(a).returningGenerated(_.id)) + } + val actual = (create *> dc.run(stmt)).transact(xa).unsafeRunSync() + val expected = List(1, 2, 3) + actual mustEqual expected + } + +} \ No newline at end of file diff --git a/quill-doobie/src/test/scala/io/getquill/doobie/issue/Issue1067.scala b/quill-doobie/src/test/scala/io/getquill/doobie/issue/Issue1067.scala new file mode 100644 index 000000000..243ac240d --- /dev/null +++ b/quill-doobie/src/test/scala/io/getquill/doobie/issue/Issue1067.scala @@ -0,0 +1,35 @@ +package io.getquill.doobie.issue + +import cats.effect._ +import doobie._ +import doobie.implicits._ +import io.getquill._ +import org.scalatest.freespec.AnyFreeSpec +import org.scalatest.matchers.must.Matchers +import io.getquill.doobie.DoobieContext + +// https://github.com/tpolecat/doobie/issues/1067 +class Issue1067 extends AnyFreeSpec with Matchers { + + import cats.effect.unsafe.implicits.global + + lazy val xa = Transactor.fromDriverManager[IO]( + "org.postgresql.Driver", + s"jdbc:postgresql://${System.getenv("POSTGRES_HOST")}:${System.getenv("POSTGRES_PORT")}/doobie_test", + "postgres", + System.getenv("POSTGRES_PASSWORD") + ) + + val dc = new DoobieContext.Postgres(Literal) + import dc._ + + case class Country(name: String, indepYear: Option[Short]) + + "Issue1067 - correctly select many countries, with a null in last position" in { + val stmt = quote(query[Country]) + val actual = dc.run(stmt).transact(xa).unsafeRunSync() + actual.count(_.indepYear.isDefined) mustEqual 3 + actual.count(_.indepYear.isEmpty) mustEqual 1 + } + +} \ No newline at end of file diff --git a/quill-sql/src/main/scala/io/getquill/context/BatchQueryExecution.scala b/quill-sql/src/main/scala/io/getquill/context/BatchQueryExecution.scala index 7da1958b9..6743515c2 100644 --- a/quill-sql/src/main/scala/io/getquill/context/BatchQueryExecution.scala +++ b/quill-sql/src/main/scala/io/getquill/context/BatchQueryExecution.scala @@ -54,11 +54,16 @@ import _root_.io.getquill.norm.BetaReduction import io.getquill.context.Execution.ElaborationBehavior import io.getquill.quat.Quat import io.getquill.quat.QuatMaking +import io.getquill.metaprog.EagerListPlanterExpr +import io.getquill.metaprog.EagerPlanterExpr // trait BatchContextOperation[I, T, A <: QAC[I, T] with Action[I], D <: Idiom, N <: NamingStrategy, PrepareRow, ResultRow, Session, Res](val idiom: D, val naming: N) { // def execute(sql: String, prepare: List[(PrepareRow, Session) => (List[Any], PrepareRow)], extractor: Extraction[ResultRow, Session, T], executionInfo: ExecutionInfo): Res // } +case class Shard[T](members: List[T]): + def add(t: T) = Shard(members :+ t) + private[getquill] enum BatchActionType: case Insert case Update @@ -76,7 +81,7 @@ object PrepareBatchComponents: import Execution._ import BatchQueryExecutionModel._ - def apply[I, PrepareRow](unliftedAst: Ast, caseClassAst: ast.CaseClass, extractionBehavior: BatchExtractBehavior): Either[String, (Ast, BatchActionType)] = { + def apply[I, PrepareRow](unliftedAst: Ast, foreachIdentAst: ast.Ast, extractionBehavior: BatchExtractBehavior): Either[String, (Ast, BatchActionType)] = { // putting this in a block since I don't want to externally import these packages import io.getquill.ast._ val componentsOrError = @@ -102,6 +107,12 @@ object PrepareBatchComponents: case other => Left(s"Malformed batch entity: ${other}. Batch insertion entities must have the form Returning/ReturningGenerated(Insert(Entity, Nil: List[Assignment]), _, _)") + // TODO NOT NECESSARILY A CASE CLASS ANOT NOT NECESSARILY THE TYPE BEING INSERTED. NEED TO TEST WITH DIFF + // TYPES BEING ENCODED AND INSERTED TOO + // e.g. liftQuery(Vip, Vip).foreach(v => query[Person].insertValue(Person(v.name, v.age))) + // OR WITH SCALARS + // e.g. liftQuery(1, 2).foreach(i => query[Person].insertValue(Person("Joe", i))) + // (continue to beta-reduce out the foreach-ident if an error has not happened) componentsOrError.map { (foreachIdent, actionQueryAstRaw, bType) => // The primary idea that drives batch query execution is the realization that you @@ -115,7 +126,7 @@ object PrepareBatchComponents: // actionQueryAstRaw: liftQuery(people).foreach(p => query[Person].filter(pf => pf.id == ScalarTag(A)).update(_.name == ScalarTag(B))) // this will ultimately yield a query that looks like: UPDATE Person SET name = ? WHERE id = ? and for each person entity // the corresponding values will be plugged in - val actionQueryAst = BetaReduction(actionQueryAstRaw, foreachIdent -> caseClassAst) + val actionQueryAst = BetaReduction(actionQueryAstRaw, foreachIdent -> foreachIdentAst) // println(s"==== Reduced AST: ${io.getquill.util.Messages.qprint(actionQueryAst)}") (actionQueryAst, bType) } @@ -248,9 +259,11 @@ object BatchQueryExecution: def prepareLifts(): (ast.CaseClass, List[Expr[InjectableEagerPlanter[_, PrepareRow, Session]]]) = // Use some custom functionality in the lift macro to prepare the case class an injectable lifts // e.g. if T is Person(name: String, age: Int) and we do liftQuery(people:List[Person]).foreach(p => query[Person].insertValue(p)) - // ast = CaseClass(name -> lift(UUID1), age -> lift(UUID2)) - // lifts = List(InjectableEagerLift(p.name, UUID1), InjectableEagerLift(p.age, UUID2)) + // Then: + // ast = CaseClass(name -> lift(UUID1), age -> lift(UUID2)) + // lifts = List(InjectableEagerLift(p.name, UUID1), InjectableEagerLift(p.age, UUID2)) val (caseClassAst, perRowLifts) = LiftMacro.liftInjectedProduct[I, PrepareRow, Session] + // println(s"Case class AST: ${io.getquill.util.Messages.qprint(caseClassAst)}") // println("========= CaseClass =========\n" + io.getquill.util.Messages.qprint(caseClassAst)) // Assuming that all lifts of the batch query are injectable @@ -305,53 +318,103 @@ object BatchQueryExecution: end applyDynamic + enum ExpansionType: + case Entities(entities: Expr[Iterable[_]]) + case Values(values: Expr[List[Any]], encoder: Expr[GenericEncoder[Any, PrepareRow, Session]]) + def apply(): Expr[Res] = UntypeExpr(quoted) match case QuotedExpr.UprootableWithLifts(QuotedExpr(quoteAst, _, _), planters) => - // isolate the list that went into the liftQuery i.e. the liftQuery(entities) - val entities = - planters match - case List(EagerEntitiesPlanterExpr(_, entities)) => entities - case _ => report.throwError(s"Invalid liftQuery clause: ${planters}. Must be a single EagerEntitiesPlanter", quoted) - val unliftedAst = Unlifter(quoteAst) - // for Person(name, age) it would be (CaseClass(name->lift(A), age->lift(B), List(InjectableEagerLift(A), InjectableEagerLift(B)))) - val (caseClass, perRowLifts) = prepareLifts() - val (actionQueryAst, batchActionType) = - PrepareBatchComponents[I, PrepareRow](unliftedAst, caseClass, extractionBehavior) match - case Right(value) => value - case Left(error) => - report.throwError(error) - val expandedQuotation = expandQuotation(Lifter(actionQueryAst), batchActionType, perRowLifts) + val comps = BatchStatic[I, PrepareRow, Session](unliftedAst, planters, extractionBehavior) + val expandedQuotation = expandQuotation(Lifter(comps.actionQueryAst), comps.batchActionType, comps.perRowLifts) + + def expandLiftQueryMembers(filteredPerRowLifts: List[PlanterExpr[?, ?, ?]], entities: Expr[Iterable[?]]) = + '{ + $entities.map(entity => + ${ + // Since things like returningGenerated can exclude lifts... + // For example: + // query[Person].insert(_.id -> lift(1), _.name -> lift("Joe")).returningGenerated(_.id)) + // becomes something like Quoted(query[Person].insert(_.id -> lift(A), _.name -> lift(B)).returningGenerated(_.id)), lifts: List(ScalarTag(A, 1), ScalarTag(B, "Joe"))) + // but since we are excluding the person.id column (this is done in the transformation phase NormalizeReturning which is in SqlNormalization in the quill-sql-portable module) + // actually we only want only the ScalarTag(B) so we need to get the list of lift tags (in tokens) once the Dialect has serialized the query + // which correctly order the list of lifts. A similar issue happens with insertMeta and updateMeta. + // we need a pre-filtered, and ordered list of lifts. The StaticTranslationMacro interanally has done that so we can take the lifts from there although they need to be casted. + // This is safe because they are just the lifts taht we have already had from the `injectableLifts` list + // TODO If all the lists are not InjectableEagerPlanterExpr, then we need to find out which ones are not and not inject them + val injectedLifts = filteredPerRowLifts.asInstanceOf[List[InjectableEagerPlanterExpr[_, _, _]]].map(lift => lift.inject('entity)) + val injectedLiftsExpr = Expr.ofList(injectedLifts) + // val prepare = '{ (row: PrepareRow, session: Session) => LiftsExtractor.apply[PrepareRow, Session]($injectedLiftsExpr, row, session) } + // prepare + injectedLiftsExpr + } + ) + } StaticTranslationMacro[I, T, D, N](expandedQuotation, ElaborationBehavior.Skip, topLevelQuat) match case Some(state @ StaticState(query, filteredPerRowLifts, _, _)) => // create an extractor for returning actions val extractor = MakeExtractor[ResultRow, Session, T, T].static(state, identityConverter, extractionBehavior) - val prepares = - '{ - $entities.map(entity => - ${ - // Since things like returningGenerated can exclude lifts... - // For example: - // query[Person].insert(_.id -> lift(1), _.name -> lift("Joe")).returningGenerated(_.id)) - // becomes something like Quoted(query[Person].insert(_.id -> lift(A), _.name -> lift(B)).returningGenerated(_.id)), lifts: List(ScalarTag(A, 1), ScalarTag(B, "Joe"))) - // but since we are excluding the person.id column (this is done in the transformation phase NormalizeReturning which is in SqlNormalization in the quill-sql-portable module) - // actually we only want only the ScalarTag(B) so we need to get the list of lift tags (in tokens) once the Dialect has serialized the query - // which correctly order the list of lifts. A similar issue happens with insertMeta and updateMeta. - // we need a pre-filtered, and ordered list of lifts. The StaticTranslationMacro interanally has done that so we can take the lifts from there although they need to be casted. - // This is safe because they are just the lifts taht we have already had from the `injectableLifts` list - // TODO If all the lists are not InjectableEagerPlanterExpr, then we need to find out which ones are not and not inject them - val injectedLifts = filteredPerRowLifts.asInstanceOf[List[InjectableEagerPlanterExpr[_, _, _]]].map(lift => lift.inject('entity)) - val injectedLiftsExpr = Expr.ofList(injectedLifts) - val prepare = '{ (row: PrepareRow, session: Session) => LiftsExtractor.apply[PrepareRow, Session]($injectedLiftsExpr, row, session) } - prepare - } - ) - } - '{ $batchContextOperation.execute(ContextOperation.Argument(${ Expr(query.basicQuery) }, $prepares.toArray, $extractor, ExecutionInfo(ExecutionType.Static, ${ Lifter(state.ast) }, ${ Lifter.quat(topLevelQuat) }), None)) } + // In an expression we could have a whole bunch of different lifts + // liftQuery([Person1, Person2 <- these are EagerEntitiesPlanterExpr]) + // .filter(p => p.id == lift(somethingElse) <- another lift expression) + // etc... + // So we need to go through all of them and expand + // For example, say that we have: + // liftQuery([Joe, Jim]).foreach(p => query[Person].filter(p => p.id == lift(somethingElse))) + // That means our lifts need to be: + // lift(Joe.name), lift(Joe.age), lift(somethingElse) + // lift(Jim.name), lift(Jim.age), lift(somethingElse) + // + // So first we expland the primary planter list into a list-of lists. The add all additional lifts + // into each list. We are assuming that the primary planter (i.e. the liftQuery thing) is the 1st in the in the batch query + val primaryPlanterLifts = + comps.primaryPlanter match + case BatchStatic.PlanterKind.PrimaryEntitiesList(entitiesPlanter) => + val exp = expandLiftQueryMembers(filteredPerRowLifts, entitiesPlanter.expr) + '{ $exp.toList } + + case BatchStatic.PlanterKind.PrimaryScalarList(scalarsPlanter) => + val exp = expandLiftQueryMembers(filteredPerRowLifts, scalarsPlanter.expr) + '{ $exp.toList } + + // At this point here is waht the lifts look like: + // List( + // List(lift(Joe.name), lift(Joe.age)) + // List(lift(Jim.name), lift(Jim.age)) + // ) + // We need to make them into: + // List( + // List(lift(Joe.name), lift(Joe.age)), lift(somethingElse) <- per-entity lifts of Joe + // List(lift(Jim.name), lift(Jim.age)), lift(somethingElse) <- per-entity lifts of Jim + // ) + val otherPlanters = + Expr.ofList( + comps.categorizedPlanters.drop(1).map { + case BatchStatic.PlanterKind.Other(planter: EagerListPlanterExpr[_, _, _]) => + planter.asInstanceOf[EagerListPlanterExpr[?, ?, ?]].plant + case other => + report.throwError(s"Invalid planter: ${other}") + } + ) + val combinedPlanters = + '{ $primaryPlanterLifts.map(perEntityPlanters => perEntityPlanters ++ $otherPlanters) } + + // println(s"====================== PreparesList: ${Format.Expr(combinedPlanters)} =================") + + val prepares = '{ + $combinedPlanters.map(perRowList => + (row: PrepareRow, session: Session) => + LiftsExtractor.apply[PrepareRow, Session](perRowList, row, session) + ) + } + + '{ + $batchContextOperation.execute(ContextOperation.Argument(${ Expr(query.basicQuery) }, $prepares.toArray, $extractor, ExecutionInfo(ExecutionType.Static, ${ Lifter(state.ast) }, ${ Lifter.quat(topLevelQuat) }), None)) + } case None => // TODO report via trace debug @@ -395,3 +458,113 @@ object BatchQueryExecution: new RunQuery[I, T, A, ResultRow, PrepareRow, Session, D, N, Ctx, Res](quoted, ctx).apply() end BatchQueryExecution + +object BatchStatic: + case class Components[PrepareRow, Session]( + actionQueryAst: Ast, + batchActionType: BatchActionType, + perRowLifts: List[Expr[InjectableEagerPlanter[?, PrepareRow, Session]]], + categorizedPlanters: List[PlanterKind], + primaryPlanter: PlanterKind.PrimaryEntitiesList | PlanterKind.PrimaryScalarList + ) + + sealed trait PlanterKind + object PlanterKind: + case class PrimaryEntitiesList(planter: EagerEntitiesPlanterExpr[?, ?, ?]) extends PlanterKind + case class PrimaryScalarList(planter: EagerListPlanterExpr[?, ?, ?]) extends PlanterKind + case class Other(planter: PlanterExpr[?, ?, ?]) extends PlanterKind + + // Given: Person(name, age) + // For the query: + // liftQuery(List(Person("Joe", 123))).foreach(p => query[Person].insertValue(p)) + // it would be (CaseClass(name->lift(A), age->lift(B)), BatchActionType.Insert, List(InjectableEagerLift(A), InjectableEagerLift(B)))) + // Same thing regardless of what kind of object is in the insert: + // liftQuery(List("foo")).foreach(name => query[Person].update(_.name -> name)) + // it would be (CaseClass(name->lift(A), age->lift(B)), BatchActionType.Update, List(InjectableEagerLift(A), InjectableEagerLift(B)))) + // + // That is why it is important to find the actual EagerEntitiesPlanterExpr (i.e. the part defined by `query[Person]`). That + // way we know the actual entity that needs to be lifted. + def apply[I: Type, PrepareRow: Type, Session: Type](ast: Ast, planters: List[PlanterExpr[?, ?, ?]], extractionBehavior: BatchQueryExecutionModel.BatchExtractBehavior)(using Quotes) = + import quotes.reflect._ + val (primaryPlanter, categorizedPlanters) = + planters.foldLeft((Option.empty[PlanterKind.PrimaryEntitiesList | PlanterKind.PrimaryScalarList], List.empty[PlanterKind])) { + case ((None, list), planter: EagerEntitiesPlanterExpr[?, ?, ?]) => + val planterKind = PlanterKind.PrimaryEntitiesList(planter) + (Some(planterKind), list :+ planterKind) + case ((None, list), planter: EagerListPlanterExpr[?, ?, ?]) => + val planterKind = PlanterKind.PrimaryScalarList(planter) + (Some(planterKind), list :+ planterKind) + case ((primary @ Some(_), list), planter) => + (primary, list :+ PlanterKind.Other(planter)) + // this means we haven't found the primary planter yet (don't think this can happen because nothing can be before liftQuery), keep going + case ((primary @ None, list), planter) => + report.throwError("Invalid planter traversal") + } match { + case (Some(primary), categorizedPlanters) => (primary, categorizedPlanters) + case (None, _) => report.throwError(s"Could not find an entities list-lift (i.e. liftQuery(entities/scalars) in liftQuery(...).foreach()) in lifts: ${planters.map(p => Format.Expr(p.plant))}") + } + + // TODO check that there are no EagerEntitiesPlanterExpr other than in the primary planter + val (actionQueryAst, batchActionType, perRowLifts) = + primaryPlanter match { + // In the case of liftQuery(entities) + case PlanterKind.PrimaryEntitiesList(planter) => + planter.tpe match + case '[t] => PrepareLiftQueryComponents[t, I, PrepareRow, Session](ast, extractionBehavior) + + // In the case of liftQuery(scalars) + // Note, we could have potential other liftQuery(scalars) later in the query for example: + // liftQuery(List("Joe","Jack","Jill")).foreach(query[Person].filter(name => liftQuery(1,2,3 /*ids of Joe,Jack,Jill respectively*/).contains(p.id)).update(_.name -> name)) + // Therefore we cannot assume that there is only one + case PlanterKind.PrimaryScalarList(planter) => + planter.tpe match + case '[t] => PrepareLiftQueryComponents[t, I, PrepareRow, Session](ast, extractionBehavior) + } + Components[PrepareRow, Session](actionQueryAst, batchActionType, perRowLifts, categorizedPlanters, primaryPlanter) + end apply + +end BatchStatic + +object PrepareLiftQueryComponents: + def apply[LiftQueryT: Type, I: Type, PrepareRow: Type, Session: Type](ast: Ast, extractionBehavior: BatchQueryExecutionModel.BatchExtractBehavior)(using Quotes) = + import quotes.reflect._ + val (foreachReplacementAst, perRowLifts) = prepareLiftQueryLifts[LiftQueryT, PrepareRow, Session] + val (actionQueryAst, batchActionType) = + PrepareBatchComponents[I, PrepareRow](ast, foreachReplacementAst, extractionBehavior) match + case Right(value) => value + case Left(error) => + report.throwError(error) + (actionQueryAst, batchActionType, perRowLifts) + + private def prepareLiftQueryLifts[LiftQueryT: Type, PrepareRow: Type, Session: Type](using Quotes): (ast.Ast, List[Expr[InjectableEagerPlanter[_, PrepareRow, Session]]]) = + import quotes.reflect._ + // Use some custom functionality in the lift macro to prepare the case class an injectable lifts + // e.g. if T is Person(name: String, age: Int) and we do liftQuery(people:List[Person]).foreach(p => query[Person].insertValue(p)) + // Then: + // ast = CaseClass(name -> lift(UUID1), age -> lift(UUID2)) // NOTE: lift in the AST means a ScalarTag + // lifts = List(InjectableEagerLift(p.name, UUID1), InjectableEagerLift(p.age, UUID2)) + // e.g. if T is String and we do liftQuery(people:List[String]).foreach(p => query[Person].insertValue(Person(p, 123))) + // Then: + // ast = lift(UUID1) // I.e. ScalarTag(UUID1) since lift in the AST means a ScalarTag + // lifts = List(InjectableEagerLift(p, UUID1)) + val (ast, fieldLifts) = + QuatMaking.ofType[LiftQueryT] match + case _: Quat.Product => + LiftMacro.liftInjectedProduct[LiftQueryT, PrepareRow, Session] + case other => + val (ast, lift) = LiftMacro.liftInjectedScalar[LiftQueryT, PrepareRow, Session] + (ast, List(lift)) + + // println(s"Case class AST: ${io.getquill.util.Messages.qprint(caseClassAst)}") + // println("========= CaseClass =========\n" + io.getquill.util.Messages.qprint(caseClassAst)) + // Assuming that all lifts of the batch query are injectable + + fieldLifts.foreach { + case PlanterExpr.Uprootable(expr @ InjectableEagerPlanterExpr(_, _, _)) => expr + case PlanterExpr.Uprootable(expr) => + report.throwError(s"wrong kind of uprootable ${(expr)}") + case other => report.throwError(s"The lift expression ${Format(Printer.TreeStructure.show(other.asTerm))} is not valid for batch queries because it is not injectable") + } + (ast, fieldLifts) + end prepareLiftQueryLifts +end PrepareLiftQueryComponents diff --git a/quill-sql/src/main/scala/io/getquill/context/LiftMacro.scala b/quill-sql/src/main/scala/io/getquill/context/LiftMacro.scala index ec5db9956..72ebfb466 100644 --- a/quill-sql/src/main/scala/io/getquill/context/LiftMacro.scala +++ b/quill-sql/src/main/scala/io/getquill/context/LiftMacro.scala @@ -80,6 +80,12 @@ object LiftMacro { } } + private[getquill] def liftInjectedScalar[T, PrepareRow, Session](using qctx: Quotes, tpe: Type[T], prepareRowTpe: Type[PrepareRow], sessionTpe: Type[Session]): (ScalarTag, Expr[InjectableEagerPlanter[_, PrepareRow, Session]]) = { + import qctx.reflect._ + val uuid = java.util.UUID.randomUUID.toString + (ScalarTag(uuid), injectableLiftValue[T, PrepareRow, Session]('{ (t: T) => t }, uuid)) + } + // TODO Injected => Injectable private[getquill] def liftInjectedProduct[T, PrepareRow, Session](using qctx: Quotes, tpe: Type[T], prepareRowTpe: Type[PrepareRow], sessionTpe: Type[Session]): (CaseClass, List[Expr[InjectableEagerPlanter[_, PrepareRow, Session]]]) = { import qctx.reflect._ diff --git a/quill-sql/src/main/scala/io/getquill/metaprog/ExprModel.scala b/quill-sql/src/main/scala/io/getquill/metaprog/ExprModel.scala index e590421fc..ac0a68b29 100644 --- a/quill-sql/src/main/scala/io/getquill/metaprog/ExprModel.scala +++ b/quill-sql/src/main/scala/io/getquill/metaprog/ExprModel.scala @@ -43,7 +43,8 @@ sealed trait PlanterExpr[T: scala.quoted.Type, PrepareRow: scala.quoted.Type, Se def plant(using Quotes): Expr[Planter[T, PrepareRow, Session]] // TODO Change to 'replant' ? def nestInline(using Quotes)(call: Option[quotes.reflect.Tree], bindings: List[quotes.reflect.Definition]): PlanterExpr[T, PrepareRow, Session] -case class EagerListPlanterExpr[T: Type, PrepareRow: Type, Session: Type](uid: String, expr: Expr[List[T]], encoder: Expr[GenericEncoder[T, PrepareRow, Session]])(using Type[Query[T]]) extends PlanterExpr[Query[T], PrepareRow, Session]: +case class EagerListPlanterExpr[T, PrepareRow: Type, Session: Type](uid: String, expr: Expr[List[T]], encoder: Expr[GenericEncoder[T, PrepareRow, Session]])(using val tpe: Type[T], queryTpe: Type[Query[T]]) + extends PlanterExpr[Query[T], PrepareRow, Session]: def plant(using Quotes): Expr[EagerListPlanter[T, PrepareRow, Session]] = '{ EagerListPlanter[T, PrepareRow, Session]($expr, $encoder, ${ Expr(uid) }) } def nestInline(using Quotes)(call: Option[quotes.reflect.Tree], bindings: List[quotes.reflect.Definition]) = diff --git a/quill-sql/src/test/scala/io/getquill/BatchActionTest.scala b/quill-sql/src/test/scala/io/getquill/BatchActionTest.scala index 8ae13f47a..ada37bd10 100644 --- a/quill-sql/src/test/scala/io/getquill/BatchActionTest.scala +++ b/quill-sql/src/test/scala/io/getquill/BatchActionTest.scala @@ -80,6 +80,40 @@ class BatchActionTest extends Spec with Inside with SuperContext[PostgresDialect mirror.triple mustEqual ("INSERT INTO Person (id,name,age) VALUES (?, ?, ?)", List(List(1, "Joe", 123), List(2, "Jill", 456)), Static) } + case class Vip(vipId: Int, vipName: String, vipAge: Int, other: String) + "insert - different-objects" in { + val vips = List(Vip(1, "Joe", 123, "Something"), Vip(2, "Jill", 456, "Something")) + val mirror = ctx.run { liftQuery(vips).foreach(v => query[Person].insertValue(Person(v.vipId, v.vipName, v.vipAge))) } + mirror.triple mustEqual ("INSERT INTO Person (id,name,age) VALUES (?, ?, ?)", List(List(1, "Joe", 123), List(2, "Jill", 456)), Static) + } + + "update - liftQuery scalars" in { + val mirror = ctx.run { liftQuery(List(1, 2, 3)).foreach(i => query[Person].filter(p => p.id == i).update(_.age -> 111)) } + mirror.triple mustEqual ("UPDATE Person SET age = 111 WHERE id = ?", List(List(1), List(2), List(3)), Static) + } + + // NOTE: Not going to be supported in the near term because the type from liftQuery is very hard to reconstruct + // should have a warning about not allowing different entity in liftQuery([Ent]).foreach(query[Ent]) though + // "update - liftQuery scalars - dynamic" in { + // val updateDynamic = quote { + // (i: Int) => query[Person].filter(p => p.id == i).update(_.age -> 111) + // } + // val mirror = ctx.run { liftQuery(List(1, 2, 3)).foreach(i => updateDynamic(i)) } + // mirror.triple mustEqual ("UPDATE Person SET age = 111 WHERE id = ?", List(List(1), List(2), List(3)), Static) + // } + + // NOTE: This kind of AST expansion is not supported yet + // "update - extra lift" in { + // val mirror = ctx.run { liftQuery(people).foreach(p => query[Person].filter(p => p.id == lift(123) /*see if p symbol override works*/ ).insertValue(p)) } + // mirror.triple mustEqual ("INSERT INTO Person (id,name,age) VALUES (?, ?, ?)", List(List(1, "Joe", 123), List(2, "Jill", 456)), Static) + // } + + // NOTE: This kind of AST expansion is not supported yet + // "update - extra list lift" in { + // val mirror = ctx.run { liftQuery(people).foreach(p => query[Person].filter(p => liftQuery(List(1,2,3)).contains(p.id)).insertValue(p)) } + // mirror.triple mustEqual ("INSERT INTO Person (id,name,age) VALUES (?, ?, ?)", List(List(1, "Joe", 123), List(2, "Jill", 456)), Static) + // } + "insert with function splice" in { val mirror = ctx.run { liftQuery(people).foreach(p => insertPeople(p)) } mirror.triple mustEqual ("INSERT INTO Person (id,name,age) VALUES (?, ?, ?)", List(List(1, "Joe", 123), List(2, "Jill", 456)), Static) diff --git a/quill-sql/src/test/sql/postgres-doobie-schema.sql b/quill-sql/src/test/sql/postgres-doobie-schema.sql new file mode 100644 index 000000000..3573fe9fb --- /dev/null +++ b/quill-sql/src/test/sql/postgres-doobie-schema.sql @@ -0,0 +1,87 @@ +-- +-- The sample data used in the world database is Copyright Statistics +-- Finland, http://www.stat.fi/worldinfigures. +-- + +CREATE TABLE IF NOT EXISTS city ( + id integer NOT NULL, + name varchar NOT NULL, + countrycode character(3) NOT NULL, + district varchar NOT NULL, + population integer NOT NULL +); + +CREATE TABLE IF NOT EXISTS country ( + code character(3) NOT NULL, + name varchar NOT NULL, + continent varchar NOT NULL, + region varchar NOT NULL, + surfacearea real NOT NULL, + indepyear smallint, + population integer NOT NULL, + lifeexpectancy real, + gnp numeric(10,2), + gnpold numeric(10,2), + localname varchar NOT NULL, + governmentform varchar NOT NULL, + headofstate varchar, + capital integer, + code2 character(2) NOT NULL --, + -- TODO: we can do this with CREATE DOMAIN + -- CONSTRAINT country_continent_check CHECK ((((((((continent = 'Asia'::text) OR (continent = 'Europe'::text)) OR (continent = 'North America'::text)) OR (continent = 'Africa'::text)) OR (continent = 'Oceania'::text)) OR (continent = 'Antarctica'::text)) OR (continent = 'South America'::text))); +); + +CREATE TABLE IF NOT EXISTS countrylanguage ( + countrycode character(3) NOT NULL, + language varchar NOT NULL, + isofficial boolean NOT NULL, + percentage real NOT NULL +); + +-- COPY city (id, name, countrycode, district, population) FROM stdin; +INSERT INTO city VALUES (206, 'São Paulo', 'BRA', 'São Paulo', 9968485); +INSERT INTO city VALUES (207, 'Rio de Janeiro', 'BRA', 'Rio de Janeiro', 5598953); +INSERT INTO city VALUES (208, 'Salvador', 'BRA', 'Bahia', 2302832); +INSERT INTO city VALUES (456, 'London', 'GBR', 'England', 7285000); +INSERT INTO city VALUES (457, 'Birmingham', 'GBR', 'England', 1013000); +INSERT INTO city VALUES (458, 'Glasgow', 'GBR', 'Scotland', 619680); +INSERT INTO city VALUES (3793, 'New York', 'USA', 'New York', 8008278); +INSERT INTO city VALUES (3794, 'Los Angeles', 'USA', 'California', 3694820); +INSERT INTO city VALUES (3795, 'Chicago', 'USA', 'Illinois', 2896016); +INSERT INTO city VALUES (3796, 'Houston', 'USA', 'Texas', 1953631); +INSERT INTO city VALUES (3797, 'Philadelphia', 'USA', 'Pennsylvania', 1517550); +INSERT INTO city VALUES (3798, 'Phoenix', 'USA', 'Arizona', 1321045); +INSERT INTO city VALUES (1450, 'Jerusalem', 'ISR', 'Jerusalem', 633700); +INSERT INTO city VALUES (1451, 'Tel Aviv-Jaffa', 'ISR', 'Tel Aviv', 348100); +INSERT INTO city VALUES (1454, 'Beerseba', 'ISR', 'Ha Darom', 163700); + + +-- +-- Data for Name: country; Type: TABLE DATA; Schema: public; Owner: chriskl +-- +INSERT INTO country VALUES ('GBR', 'United Kingdom', 'Europe', 'British Islands', 242900, null, 59623400, 77.699997, 1378330.00, 1296830.00, 'United Kingdom', 'Constitutional Monarchy', 'Elisabeth II', 456, 'GB'); +INSERT INTO country VALUES ('USA', 'United States', 'North America', 'North America', 9363520, 1776, 278357000, 77.099998, 8510700.00, 8110900.00, 'United States', 'Federal Republic', 'George W. Bush', 3813, 'US'); +INSERT INTO country VALUES ('ISR', 'Israel', 'Asia', 'Middle East', 21056, 1948, 6217000, 78.599998, 97477.00, 98577.00, 'Yisra’el/Isra’il', 'Republic', 'Moshe Katzav', 1450, 'IL'); +INSERT INTO country VALUES ('BRA', 'Brazil', 'South America', 'South America', 8547403, 1822, 170115000, 62.900002, 776739.00, 804108.00, 'Brasil', 'Federal Republic', 'Fernando Henrique Cardoso', 211, 'BR'); +-- +-- Data for Name: countrylanguage; Type: TABLE DATA; Schema: public; Owner: chriskl +-- +INSERT INTO countrylanguage VALUES ('GBR', 'English', true, 97.300003); +INSERT INTO countrylanguage VALUES ('ISR', 'Hebrew', true, 63.099998); +INSERT INTO countrylanguage VALUES ('USA', 'English', true, 86.199997); +INSERT INTO countrylanguage VALUES ('BRA', 'Portuguese', true, 97.5); + +ALTER TABLE city + ADD CONSTRAINT city_pkey PRIMARY KEY (id); + +ALTER TABLE country + ADD CONSTRAINT country_pkey PRIMARY KEY (code); + +ALTER TABLE countrylanguage + ADD CONSTRAINT countrylanguage_pkey PRIMARY KEY (countrycode, language); + +ALTER TABLE country + ADD CONSTRAINT country_capital_fkey FOREIGN KEY (capital) REFERENCES city(id); + +ALTER TABLE countrylanguage + ADD CONSTRAINT countrylanguage_countrycode_fkey FOREIGN KEY (countrycode) REFERENCES country(code);