Skip to content

Commit

Permalink
use a single SQL query for any number of json-api query pairs (#10344)
Browse files Browse the repository at this point in the history
* new projection for aggregated matched-queries

We can redo all the template-ID matches (and payload query matches, if
needed) in the SELECT projection clause to emit a list of matchedQueries
indices SQL-side.

CHANGELOG_BEGIN
CHANGELOG_END

* selectContractsMultiTemplate always returns one query

* factoring

* remove multiquery deduplication from ContractDao

* test simplest case of projectedIndex; remove uniqueSets tests

* remove uniqueSets

* add more test cases for the 3 main varieties of potential inputs

* remove uniqueSets tests that were commented for reference

* remove unneeded left-join

* scala 2.12 port

* port Map test order to 2.12

* use SortedMap so the Scala version tests are unified

- suggested by @cocreature; thanks
  • Loading branch information
S11001001 authored Jul 23, 2021
1 parent 6a16684 commit 17709b5
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 118 deletions.
15 changes: 12 additions & 3 deletions ledger-service/db-backend/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,19 @@ load(
"da_scala_library",
"da_scala_test",
"lf_scalacopts",
"silencer_plugin",
)

da_scala_library(
name = "db-backend",
srcs = glob(["src/main/scala/**/*.scala"]),
plugins = [
silencer_plugin,
],
scala_deps = [
"@maven//:com_chuusai_shapeless",
"@maven//:io_spray_spray_json",
"@maven//:org_scala_lang_modules_scala_collection_compat",
"@maven//:org_scalaz_scalaz_core",
"@maven//:org_tpolecat_doobie_core",
"@maven//:org_tpolecat_doobie_free",
Expand Down Expand Up @@ -45,15 +50,19 @@ da_scala_test(
size = "medium",
srcs = glob(["src/test/scala/**/*.scala"]),
scala_deps = [
"@maven//:org_scalacheck_scalacheck",
"@maven//:com_chuusai_shapeless",
"@maven//:org_scalatest_scalatest",
"@maven//:org_scalatestplus_scalacheck_1_14",
"@maven//:org_scalaz_scalaz_core",
"@maven//:org_tpolecat_doobie_core",
"@maven//:org_tpolecat_doobie_free",
"@maven//:org_typelevel_cats_core",
"@maven//:org_typelevel_cats_effect",
"@maven//:org_typelevel_cats_free",
"@maven//:org_typelevel_cats_kernel",
],
scalacopts = lf_scalacopts,
# data = ["//docs:quickstart-model.dar"],
deps = [
":db-backend",
"//libs-scala/scala-utils",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ import nonempty.NonEmptyReturningOps._
import doobie._
import doobie.implicits._
import scala.annotation.nowarn
import scala.collection.immutable.{Iterable, Seq => ISeq}
import scala.collection.compat._
import scala.collection.immutable.{Seq => ISeq, SortedMap}
import scalaz.{@@, Cord, Functor, OneAnd, Tag, \/, -\/, \/-}
import scalaz.Digit._0
import scalaz.Id.Id
import scalaz.syntax.foldable._
import scalaz.syntax.functor._
import scalaz.syntax.std.option._
import scalaz.std.stream.unfold
import scalaz.syntax.std.string._
import scalaz.std.AllInstances._
import spray.json._
import cats.instances.list._
Expand Down Expand Up @@ -241,22 +241,17 @@ sealed abstract class Queries {
selectContractsMultiTemplate(parties, ISeq((tpid, predicate)), MatchedQueryMarker.Unused)
.map(_ copy (templateId = ()))

/** Make the smallest number of queries from `queries` that still indicates
/** Make a query that may indicate
* which query or queries produced each contract.
*
* A contract cannot be produced more than once from a given resulting query,
* but may be produced more than once from different queries. In each case, the
* `templateId` of the resulting [[DBContract]] is actually the 0-based index
* into the `queries` argument that produced the contract.
*/
private[http] def selectContractsMultiTemplate[T[_], Mark](
private[http] def selectContractsMultiTemplate[Mark](
parties: OneAnd[Set, String],
queries: ISeq[(SurrogateTpId, Fragment)],
trackMatchIndices: MatchedQueryMarker[T, Mark],
trackMatchIndices: MatchedQueryMarker[Mark],
)(implicit
log: LogHandler,
ipol: SqlInterpol,
): T[Query0[DBContract[Mark, JsValue, JsValue, Vector[String]]]]
): Query0[DBContract[Mark, JsValue, JsValue, Vector[String]]]

private[http] final def fetchById(
parties: OneAnd[Set, String],
Expand Down Expand Up @@ -302,6 +297,10 @@ object Queries {
val SurrogateTpId = Tag.of[SurrogateTpIdTag]
type SurrogateTpId = Long @@ SurrogateTpIdTag // matches tpid (BIGINT) above

sealed trait MatchedQueriesTag
val MatchedQueries = Tag.of[MatchedQueriesTag]
type MatchedQueries = NonEmpty[ISeq[Int]] @@ MatchedQueriesTag

// NB: #, order of arguments must match createContractsTable
final case class DBContract[+TpId, +CK, +PL, +Prt](
contractId: String,
Expand Down Expand Up @@ -343,18 +342,14 @@ object Queries {
final case class DoMagicSetup(create: Fragment) extends InitDdl
}

/** Whether selectContractsMultiTemplate computes a matchedQueries marker,
* and whether it may compute >1 query to run.
/** Whether selectContractsMultiTemplate computes a matchedQueries marker.
*
* @tparam T The traversable of queries that result.
* @tparam Mark The "marker" indicating which query matched.
* @tparam Mark The "marker" indicating which queries matched.
*/
private[http] sealed abstract class MatchedQueryMarker[T[_], Mark]
extends Product
with Serializable
private[http] sealed abstract class MatchedQueryMarker[Mark] extends Product with Serializable
private[http] object MatchedQueryMarker {
case object ByInt extends MatchedQueryMarker[Seq, Int]
case object Unused extends MatchedQueryMarker[Id, SurrogateTpId]
case object ByInt extends MatchedQueryMarker[MatchedQueries]
case object Unused extends MatchedQueryMarker[SurrogateTpId]
}

/** Path to a location in a JSON tree. */
Expand Down Expand Up @@ -391,27 +386,8 @@ object Queries {
private[this] def intersperse[A](oaa: OneAnd[Vector, A], a: A): OneAnd[Vector, A] =
OneAnd(oaa.head, oaa.tail.flatMap(Vector(a, _)))

// Like groupBy but split into n maps where n is the longest list under groupBy.
private[dbbackend] def uniqueSets[A, B](iter: Iterable[(A, B)]): Seq[NonEmpty[Map[A, B]]] =
unfold(
iter
.groupBy1(_._1)
.transform((_, i) => i.toList): Map[A, NonEmpty[List[(_, B)]]]
) {
case NonEmpty(m) =>
Some {
val hd = m transform { (_, abs) =>
val (_, b) +-: _ = abs
b
}
val tl = m collect { case (a, _ +-: NonEmpty(tl)) => (a, tl) }
(hd, tl)
}
case _ => None
}

private[dbbackend] def caseLookup[SelEq: Put, Then: Put](
m: Map[SelEq, Then],
private[this] def caseLookupFragment[SelEq: Put](
m: Map[SelEq, Fragment],
selector: Fragment,
): Fragment =
fr"CASE" ++ {
Expand All @@ -422,6 +398,60 @@ object Queries {
concatFragment(OneAnd(when, whens))
} ++ fr"ELSE NULL END"

private[dbbackend] def caseLookup[SelEq: Put, Then: Put](
m: Map[SelEq, Then],
selector: Fragment,
): Fragment =
caseLookupFragment(m transform { (_, e) => fr"$e" }, selector)

// an expression that yields a comma-terminated/separated list of SQL-side
// string conversions of `Ix`es indicating which tpid/query pairs matched
private[dbbackend] def projectedIndex[Ix: Put](
queries: ISeq[((SurrogateTpId, Fragment), Ix)],
tpidSelector: Fragment,
): Fragment = {
import Implicits._
caseLookupFragment(
// SortedMap is only used so the tests are consistent; the SQL semantics
// don't care what order this map is in
SortedMap.from(queries.groupBy1(_._1._1)).transform {
case (_, (_, ix) +-: ISeq()) => fr"${ix: Ix}||''"
case (_, tqixes) =>
concatFragment(
intersperse(
tqixes.toVector.toOneAnd.map { case ((_, q), ix) =>
fr"(CASE WHEN ($q) THEN ${ix: Ix}||',' ELSE '' END)"
},
fr"||",
)
)
},
selector = tpidSelector,
)
}

import doobie.util.invariant.InvalidValue

@throws[InvalidValue[_, _]]
private[this] def assertReadProjectedIndex(from: Option[String]): NonEmpty[ISeq[Int]] = {
def invalid(reason: String) = {
import cats.instances.option._, cats.instances.string._
throw InvalidValue[Option[String], ISeq[Int]](from, reason = reason)
}
from.cata(
{ s =>
val matches = s split ',' collect {
case e if e.nonEmpty => e.parseInt.fold(err => invalid(err.getMessage), identity)
}
matches.to(ISeq)
},
ISeq.empty,
) match {
case NonEmpty(matches) => matches
case _ => invalid("matched row, but no matching index found; this indicates a query bug")
}
}

private[http] val Postgres: Aux[SqlInterpolation.StringArray] = PostgresQueries
private[http] val Oracle: Aux[SqlInterpolation.Unused] = OracleQueries

Expand All @@ -436,6 +466,12 @@ object Queries {

implicit val `SurrogateTpId meta`: Meta[SurrogateTpId] =
SurrogateTpId subst Meta[Long]

implicit val `SurrogateTpId ordering`: Ordering[SurrogateTpId] =
SurrogateTpId subst implicitly[Ordering[Long]]

implicit val `MatchedQueries get`: Read[MatchedQueries] =
MatchedQueries subst (Read[Option[String]] map assertReadProjectedIndex)
}

private[dbbackend] object CompatImplicits {
Expand Down Expand Up @@ -512,30 +548,32 @@ private object PostgresQueries extends Queries {
).updateMany(dbcs)
}

private[http] override def selectContractsMultiTemplate[T[_], Mark](
private[http] override def selectContractsMultiTemplate[Mark](
parties: OneAnd[Set, String],
queries: ISeq[(SurrogateTpId, Fragment)],
trackMatchIndices: MatchedQueryMarker[T, Mark],
trackMatchIndices: MatchedQueryMarker[Mark],
)(implicit
log: LogHandler,
ipol: SqlInterpol,
): T[Query0[DBContract[Mark, JsValue, JsValue, Vector[String]]]] = {
): Query0[DBContract[Mark, JsValue, JsValue, Vector[String]]] = {
val partyVector = parties.toVector
def query(preds: OneAnd[Vector, (SurrogateTpId, Fragment)], findMark: SurrogateTpId => Mark) = {
val assocedPreds = preds.map { case (tpid, predicate) =>
@nowarn("msg=parameter value evidence.* is never used")
def query[Mark0: Read](tpid: Fragment, preds: NonEmpty[Vector[(SurrogateTpId, Fragment)]]) = {
val assocedPreds = preds.toOneAnd.map { case (tpid, predicate) =>
sql"(tpid = $tpid AND (" ++ predicate ++ sql"))"
}
val unionPred = joinFragment(assocedPreds, sql" OR ")
import ipol.{gas, pas}
val q = sql"""SELECT contract_id, tpid, key, payload, signatories, observers, agreement_text
val q =
sql"""SELECT contract_id, $tpid tpid, key, payload, signatories, observers, agreement_text
FROM contract AS c
WHERE (signatories && $partyVector::text[] OR observers && $partyVector::text[])
AND (""" ++ unionPred ++ sql")"
q.query[(String, SurrogateTpId, JsValue, JsValue, Vector[String], Vector[String], String)]
q.query[(String, Mark0, JsValue, JsValue, Vector[String], Vector[String], String)]
.map { case (cid, tpid, key, payload, signatories, observers, agreement) =>
DBContract(
contractId = cid,
templateId = findMark(tpid),
templateId = tpid,
key = key,
payload = payload,
signatories = signatories,
Expand All @@ -545,21 +583,16 @@ private object PostgresQueries extends Queries {
}
}

val NonEmpty(nequeries) = queries.toVector
trackMatchIndices match {
case MatchedQueryMarker.ByInt =>
type Ix = Int
uniqueSets(queries.zipWithIndex map { case ((tpid, pred), ix) => (tpid, (pred, ix)) }).map {
preds: NonEmpty[Map[SurrogateTpId, (Fragment, Ix)]] =>
val predHd +-: predTl = preds.toVector
val predsList = OneAnd(predHd, predTl).map { case (tpid, (predicate, _)) =>
(tpid, predicate)
}
query(predsList, tpid => preds(tpid)._2)
}
query[MatchedQueries](
tpid = projectedIndex(queries.zipWithIndex, tpidSelector = fr"tpid"),
nequeries,
)

case MatchedQueryMarker.Unused =>
val predHd +: predTl = queries.toVector
query(OneAnd(predHd, predTl), identity)
query[SurrogateTpId](tpid = fr"tpid", nequeries)
}
}

Expand Down Expand Up @@ -674,19 +707,19 @@ private object OracleQueries extends Queries {
)
}

private[http] override def selectContractsMultiTemplate[T[_], Mark](
private[http] override def selectContractsMultiTemplate[Mark](
parties: OneAnd[Set, String],
queries: ISeq[(SurrogateTpId, Fragment)],
trackMatchIndices: MatchedQueryMarker[T, Mark],
trackMatchIndices: MatchedQueryMarker[Mark],
)(implicit
log: LogHandler,
ipol: SqlInterpol,
): T[Query0[DBContract[Mark, JsValue, JsValue, Vector[String]]]] = {
): Query0[DBContract[Mark, JsValue, JsValue, Vector[String]]] = {

// we effectively shadow Mark because Scala 2.12 doesn't quite get
// that it should use the GADT type equality otherwise
@nowarn("msg=parameter value evidence.* is never used")
def queryByCondition[Mark0: Get](
def queryByCondition[Mark0: Read](
tpid: Fragment,
queryConditions: NonEmpty[ISeq[(SurrogateTpId, Fragment)]],
): Query0[DBContract[Mark0, JsValue, JsValue, Vector[String]]] = {
Expand All @@ -708,7 +741,7 @@ private object OracleQueries extends Queries {
signatories, observers, agreement_text,
row_number() over (PARTITION BY c.contract_id ORDER BY c.contract_id) AS rownumber
FROM contract c
LEFT JOIN contract_stakeholders cst ON (c.contract_id = cst.contract_id)
JOIN contract_stakeholders cst ON (c.contract_id = cst.contract_id)
WHERE (${Fragments.in(fr"cst.stakeholder", parties)})
AND ($queriesCondition)"""
val q = sql"SELECT $outerSelectList FROM ($dupQ) WHERE rownumber = 1"
Expand All @@ -728,18 +761,12 @@ private object OracleQueries extends Queries {
}
}

val NonEmpty(nequeries) = queries
trackMatchIndices match {
case MatchedQueryMarker.ByInt =>
type Ix = Int
// TODO we may UNION the resulting queries and aggregate the Ixes SQL-side,
// but this will probably necessitate the same PostgreSQL-side
uniqueSets(queries.zipWithIndex.map { case ((tpid, pred), ix) => (tpid, (pred, ix)) }).map {
preds: NonEmpty[Map[SurrogateTpId, (Fragment, Ix)]] =>
val tpid = caseLookup(preds.transform((_, predIx) => predIx._2), fr"cst.tpid")
queryByCondition[Int](tpid, preds.transform((_, predIx) => predIx._1).toVector)
}
val tpid = projectedIndex(queries.zipWithIndex, tpidSelector = fr"cst.tpid")
queryByCondition[MatchedQueries](tpid, nequeries)
case MatchedQueryMarker.Unused =>
val NonEmpty(nequeries) = queries
queryByCondition[SurrogateTpId](fr"cst.tpid", nequeries)
}
}
Expand Down
Loading

0 comments on commit 17709b5

Please sign in to comment.