Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use a single SQL query for any number of json-api query pairs #10344

Merged
merged 15 commits into from
Jul 23, 2021
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions ledger-service/db-backend/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,19 @@ da_scala_test(
size = "medium",
srcs = glob(["src/test/scala/**/*.scala"]),
scala_deps = [
"@maven//:org_scalacheck_scalacheck",
"@maven//:com_chuusai_shapeless",
"@maven//:org_scalatest_scalatest",
"@maven//:org_scalatestplus_scalacheck_1_14",
"@maven//:org_scalaz_scalaz_core",
"@maven//:org_tpolecat_doobie_core",
"@maven//:org_tpolecat_doobie_free",
"@maven//:org_typelevel_cats_core",
"@maven//:org_typelevel_cats_effect",
"@maven//:org_typelevel_cats_free",
"@maven//:org_typelevel_cats_kernel",
],
scalacopts = lf_scalacopts,
# data = ["//docs:quickstart-model.dar"],
deps = [
":db-backend",
"//libs-scala/scala-utils",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@ import nonempty.NonEmptyReturningOps._
import doobie._
import doobie.implicits._
import scala.annotation.nowarn
import scala.collection.immutable.{Iterable, Seq => ISeq}
import scala.collection.immutable.{Seq => ISeq}
import scalaz.{@@, Cord, Foldable, Functor, OneAnd, Tag, \/, -\/, \/-}
import scalaz.Digit._0
import scalaz.Id.Id
import scalaz.syntax.foldable._
import scalaz.syntax.functor._
import scalaz.syntax.std.option._
import scalaz.std.stream.unfold
import scalaz.syntax.std.string._
import scalaz.std.AllInstances._
import spray.json._
import cats.instances.list._
Expand Down Expand Up @@ -231,22 +230,17 @@ sealed abstract class Queries {
selectContractsMultiTemplate(parties, ISeq((tpid, predicate)), MatchedQueryMarker.Unused)
.map(_ copy (templateId = ()))

/** Make the smallest number of queries from `queries` that still indicates
/** Make a query that may indicate
* which query or queries produced each contract.
*
* A contract cannot be produced more than once from a given resulting query,
* but may be produced more than once from different queries. In each case, the
* `templateId` of the resulting [[DBContract]] is actually the 0-based index
* into the `queries` argument that produced the contract.
*/
private[http] def selectContractsMultiTemplate[T[_], Mark](
private[http] def selectContractsMultiTemplate[Mark](
parties: OneAnd[Set, String],
queries: ISeq[(SurrogateTpId, Fragment)],
trackMatchIndices: MatchedQueryMarker[T, Mark],
trackMatchIndices: MatchedQueryMarker[Mark],
)(implicit
log: LogHandler,
ipol: SqlInterpol,
): T[Query0[DBContract[Mark, JsValue, JsValue, Vector[String]]]]
): Query0[DBContract[Mark, JsValue, JsValue, Vector[String]]]

private[http] final def fetchById(
parties: OneAnd[Set, String],
Expand Down Expand Up @@ -292,6 +286,10 @@ object Queries {
val SurrogateTpId = Tag.of[SurrogateTpIdTag]
type SurrogateTpId = Long @@ SurrogateTpIdTag // matches tpid (BIGINT) above

sealed trait MatchedQueriesTag
val MatchedQueries = Tag.of[MatchedQueriesTag]
type MatchedQueries = NonEmpty[ISeq[Int]] @@ MatchedQueriesTag

// NB: #, order of arguments must match createContractsTable
final case class DBContract[+TpId, +CK, +PL, +Prt](
contractId: String,
Expand Down Expand Up @@ -333,18 +331,14 @@ object Queries {
final case class DoMagicSetup(create: Fragment) extends InitDdl
}

/** Whether selectContractsMultiTemplate computes a matchedQueries marker,
* and whether it may compute >1 query to run.
/** Whether selectContractsMultiTemplate computes a matchedQueries marker.
*
* @tparam T The traversable of queries that result.
* @tparam Mark The "marker" indicating which query matched.
* @tparam Mark The "marker" indicating which queries matched.
*/
private[http] sealed abstract class MatchedQueryMarker[T[_], Mark]
extends Product
with Serializable
private[http] sealed abstract class MatchedQueryMarker[Mark] extends Product with Serializable
private[http] object MatchedQueryMarker {
case object ByInt extends MatchedQueryMarker[Seq, Int]
case object Unused extends MatchedQueryMarker[Id, SurrogateTpId]
case object ByInt extends MatchedQueryMarker[MatchedQueries]
case object Unused extends MatchedQueryMarker[SurrogateTpId]
}

/** Path to a location in a JSON tree. */
Expand Down Expand Up @@ -381,27 +375,8 @@ object Queries {
private[this] def intersperse[A](oaa: OneAnd[Vector, A], a: A): OneAnd[Vector, A] =
OneAnd(oaa.head, oaa.tail.flatMap(Vector(a, _)))

// Like groupBy but split into n maps where n is the longest list under groupBy.
private[dbbackend] def uniqueSets[A, B](iter: Iterable[(A, B)]): Seq[NonEmpty[Map[A, B]]] =
unfold(
iter
.groupBy1(_._1)
.transform((_, i) => i.toList): Map[A, NonEmpty[List[(_, B)]]]
) {
case NonEmpty(m) =>
Some {
val hd = m transform { (_, abs) =>
val (_, b) +-: _ = abs
b
}
val tl = m collect { case (a, _ +-: NonEmpty(tl)) => (a, tl) }
(hd, tl)
}
case _ => None
}

private[dbbackend] def caseLookup[SelEq: Put, Then: Put](
m: Map[SelEq, Then],
private[this] def caseLookupFragment[SelEq: Put](
m: Map[SelEq, Fragment],
selector: Fragment,
): Fragment =
fr"CASE" ++ {
Expand All @@ -412,6 +387,58 @@ object Queries {
concatFragment(OneAnd(when, whens))
} ++ fr"ELSE NULL END"

private[dbbackend] def caseLookup[SelEq: Put, Then: Put](
m: Map[SelEq, Then],
selector: Fragment,
): Fragment =
caseLookupFragment(m transform { (_, e) => fr"$e" }, selector)

// an expression that yields a comma-terminated/separated list of SQL-side
// string conversions of `Ix`es indicating which tpid/query pairs matched
private[dbbackend] def projectedIndex[Ix: Put](
queries: ISeq[((SurrogateTpId, Fragment), Ix)],
tpidSelector: Fragment,
): Fragment = {
import Implicits._
caseLookupFragment(
queries.groupBy1(_._1._1).transform {
case (_, (_, ix) +-: ISeq()) => fr"${ix: Ix}||''"
case (_, tqixes) =>
concatFragment(
intersperse(
tqixes.toVector.toOneAnd.map { case ((_, q), ix) =>
fr"(CASE WHEN ($q) THEN ${ix: Ix}||',' ELSE '' END)"
},
fr"||",
)
)
},
selector = tpidSelector,
)
}

import doobie.util.invariant.InvalidValue

@throws[InvalidValue[_, _]]
private[this] def assertReadProjectedIndex(from: Option[String]): NonEmpty[ISeq[Int]] = {
def invalid(reason: String) = {
import cats.instances.option._, cats.instances.string._
throw InvalidValue[Option[String], ISeq[Int]](from, reason = reason)
}
(from.cata(
{ s =>
val matches = s split ',' collect {
case e if e.nonEmpty => e.parseInt.fold(err => invalid(err.getMessage), identity)
}
matches.toSeq
},
ISeq.empty,
)) match {
case NonEmpty(matches) => matches
case _ => invalid("matched row, but no matching index found; this indicates a query bug")
}
}

private[http] val Postgres: Aux[SqlInterpolation.StringArray] = PostgresQueries
private[http] val Oracle: Aux[SqlInterpolation.Unused] = OracleQueries

Expand All @@ -426,6 +453,9 @@ object Queries {

implicit val `SurrogateTpId meta`: Meta[SurrogateTpId] =
SurrogateTpId subst Meta[Long]

implicit val `MatchedQueries get`: Read[MatchedQueries] =
MatchedQueries subst (Read[Option[String]] map assertReadProjectedIndex)
}

private[dbbackend] object CompatImplicits {
Expand Down Expand Up @@ -500,30 +530,32 @@ private object PostgresQueries extends Queries {
).updateMany(dbcs)
}

private[http] override def selectContractsMultiTemplate[T[_], Mark](
private[http] override def selectContractsMultiTemplate[Mark](
parties: OneAnd[Set, String],
queries: ISeq[(SurrogateTpId, Fragment)],
trackMatchIndices: MatchedQueryMarker[T, Mark],
trackMatchIndices: MatchedQueryMarker[Mark],
)(implicit
log: LogHandler,
ipol: SqlInterpol,
): T[Query0[DBContract[Mark, JsValue, JsValue, Vector[String]]]] = {
): Query0[DBContract[Mark, JsValue, JsValue, Vector[String]]] = {
val partyVector = parties.toVector
def query(preds: OneAnd[Vector, (SurrogateTpId, Fragment)], findMark: SurrogateTpId => Mark) = {
val assocedPreds = preds.map { case (tpid, predicate) =>
@nowarn("msg=parameter value evidence.* is never used")
def query[Mark0: Read](tpid: Fragment, preds: NonEmpty[Vector[(SurrogateTpId, Fragment)]]) = {
val assocedPreds = preds.toOneAnd.map { case (tpid, predicate) =>
sql"(tpid = $tpid AND (" ++ predicate ++ sql"))"
}
val unionPred = joinFragment(assocedPreds, sql" OR ")
import ipol.{gas, pas}
val q = sql"""SELECT contract_id, tpid, key, payload, signatories, observers, agreement_text
val q =
sql"""SELECT contract_id, $tpid tpid, key, payload, signatories, observers, agreement_text
FROM contract AS c
WHERE (signatories && $partyVector::text[] OR observers && $partyVector::text[])
AND (""" ++ unionPred ++ sql")"
q.query[(String, SurrogateTpId, JsValue, JsValue, Vector[String], Vector[String], String)]
q.query[(String, Mark0, JsValue, JsValue, Vector[String], Vector[String], String)]
.map { case (cid, tpid, key, payload, signatories, observers, agreement) =>
DBContract(
contractId = cid,
templateId = findMark(tpid),
templateId = tpid,
key = key,
payload = payload,
signatories = signatories,
Expand All @@ -533,21 +565,16 @@ private object PostgresQueries extends Queries {
}
}

val NonEmpty(nequeries) = queries.toVector
trackMatchIndices match {
case MatchedQueryMarker.ByInt =>
type Ix = Int
uniqueSets(queries.zipWithIndex map { case ((tpid, pred), ix) => (tpid, (pred, ix)) }).map {
preds: NonEmpty[Map[SurrogateTpId, (Fragment, Ix)]] =>
val predHd +-: predTl = preds.toVector
val predsList = OneAnd(predHd, predTl).map { case (tpid, (predicate, _)) =>
(tpid, predicate)
}
query(predsList, tpid => preds(tpid)._2)
}
query[MatchedQueries](
tpid = projectedIndex(queries.zipWithIndex, tpidSelector = fr"tpid"),
nequeries,
)

case MatchedQueryMarker.Unused =>
val predHd +: predTl = queries.toVector
query(OneAnd(predHd, predTl), identity)
query[SurrogateTpId](tpid = fr"tpid", nequeries)
}
}

Expand Down Expand Up @@ -660,19 +687,19 @@ private object OracleQueries extends Queries {
)
}

private[http] override def selectContractsMultiTemplate[T[_], Mark](
private[http] override def selectContractsMultiTemplate[Mark](
parties: OneAnd[Set, String],
queries: ISeq[(SurrogateTpId, Fragment)],
trackMatchIndices: MatchedQueryMarker[T, Mark],
trackMatchIndices: MatchedQueryMarker[Mark],
)(implicit
log: LogHandler,
ipol: SqlInterpol,
): T[Query0[DBContract[Mark, JsValue, JsValue, Vector[String]]]] = {
): Query0[DBContract[Mark, JsValue, JsValue, Vector[String]]] = {

// we effectively shadow Mark because Scala 2.12 doesn't quite get
// that it should use the GADT type equality otherwise
@nowarn("msg=parameter value evidence.* is never used")
def queryByCondition[Mark0: Get](
def queryByCondition[Mark0: Read](
tpid: Fragment,
queryConditions: NonEmpty[ISeq[(SurrogateTpId, Fragment)]],
): Query0[DBContract[Mark0, JsValue, JsValue, Vector[String]]] = {
Expand All @@ -694,7 +721,7 @@ private object OracleQueries extends Queries {
signatories, observers, agreement_text,
row_number() over (PARTITION BY c.contract_id ORDER BY c.contract_id) AS rownumber
FROM contract c
LEFT JOIN contract_stakeholders cst ON (c.contract_id = cst.contract_id)
JOIN contract_stakeholders cst ON (c.contract_id = cst.contract_id)
WHERE (${Fragments.in(fr"cst.stakeholder", parties)})
AND ($queriesCondition)"""
val q = sql"SELECT $outerSelectList FROM ($dupQ) WHERE rownumber = 1"
Expand All @@ -714,18 +741,12 @@ private object OracleQueries extends Queries {
}
}

val NonEmpty(nequeries) = queries
trackMatchIndices match {
case MatchedQueryMarker.ByInt =>
type Ix = Int
// TODO we may UNION the resulting queries and aggregate the Ixes SQL-side,
// but this will probably necessitate the same PostgreSQL-side
uniqueSets(queries.zipWithIndex.map { case ((tpid, pred), ix) => (tpid, (pred, ix)) }).map {
preds: NonEmpty[Map[SurrogateTpId, (Fragment, Ix)]] =>
val tpid = caseLookup(preds.transform((_, predIx) => predIx._2), fr"cst.tpid")
queryByCondition[Int](tpid, preds.transform((_, predIx) => predIx._1).toVector)
}
val tpid = projectedIndex(queries.zipWithIndex, tpidSelector = fr"cst.tpid")
queryByCondition[MatchedQueries](tpid, nequeries)
case MatchedQueryMarker.Unused =>
val NonEmpty(nequeries) = queries
queryByCondition[SurrogateTpId](fr"cst.tpid", nequeries)
}
}
Expand Down
Loading