Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nest stakeholders in contracts table as JSON arrays #9484

Merged
merged 5 commits into from
Apr 27, 2021
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import spray.json._
import cats.instances.list._
import cats.Applicative
import cats.syntax.applicative._
import cats.syntax.apply._
import cats.syntax.functor._

sealed abstract class Queries {
Expand Down Expand Up @@ -581,7 +580,7 @@ private object PostgresQueries extends Queries {
}

private object OracleQueries extends Queries {
import Queries._, Queries.InitDdl.CreateTable
import Queries._
import Implicits._

type SqlInterpol = Queries.SqlInterpolation.Unused
Expand All @@ -608,34 +607,11 @@ private object OracleQueries extends Queries {
protected[this] override def jsonColumn(name: Fragment) =
name ++ sql" CLOB NOT NULL CONSTRAINT ensure_json_" ++ name ++ sql" CHECK (" ++ name ++ sql" IS JSON)"

protected[this] override def contractsTableSignatoriesObservers = sql""

private val createSignatoriesTable = CreateTable(
"signatories",
protected[this] override def contractsTableSignatoriesObservers =
sql"""
CREATE TABLE
signatories
(contract_id """ ++ contractIdType ++ sql""" NOT NULL REFERENCES contract(contract_id) ON DELETE CASCADE
,party """ ++ partyType ++ sql""" NOT NULL
,CONSTRAINT signatories_cid_party_k UNIQUE (contract_id, party)
)
""",
)

private val createObserversTable = CreateTable(
"observers",
sql"""
CREATE TABLE
observers
(contract_id """ ++ contractIdType ++ sql""" NOT NULL REFERENCES contract(contract_id) ON DELETE CASCADE
,party """ ++ partyType ++ sql""" NOT NULL
,CONSTRAINT observers_cid_party_k UNIQUE (contract_id, party)
)
""",
)

protected[this] override def initDatabaseDdls =
super.initDatabaseDdls ++ Seq(createSignatoriesTable, createObserversTable)
,""" ++ jsonColumn(sql"signatories") ++ sql"""
,""" ++ jsonColumn(sql"observers") ++ sql"""
"""

protected[this] type DBContractKey = JsValue

Expand All @@ -645,41 +621,17 @@ private object OracleQueries extends Queries {
protected[this] override def primInsertContracts[F[_]: cats.Foldable: Functor](
dbcs: F[DBContract[SurrogateTpId, DBContractKey, JsValue, Array[String]]]
)(implicit log: LogHandler, ipol: SqlInterpol): ConnectionIO[Int] = {
val r = Update[(String, SurrogateTpId, JsValue, JsValue, String)](
import spray.json.DefaultJsonProtocol._
Update[DBContract[SurrogateTpId, JsValue, JsValue, JsValue]](
"""
INSERT /*+ ignore_row_on_dupkey_index(contract(contract_id)) */
INTO contract (contract_id, tpid, key, payload, agreement_text)
VALUES (?, ?, ?, ?, ?)
""",
logHandler0 = log,
).updateMany(
dbcs
.map { c =>
(c.contractId, c.templateId, c.key, c.payload, c.agreementText)
}
)
import cats.syntax.foldable._, cats.instances.vector._
val r2 = Update[(String, String)](
"""
INSERT /*+ ignore_row_on_dupkey_index(signatories(contract_id, party)) */
INTO signatories (contract_id, party)
VALUES (?, ?)
INTO contract (contract_id, tpid, key, payload, signatories, observers, agreement_text)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
logHandler0 = log,
).updateMany(
dbcs.foldMap(c => c.signatories.view.map(s => (c.contractId, s)).toVector)
dbcs.map(_.mapKeyPayloadParties(identity, identity, _.toJson))
)
val r3 = Update[(String, String)](
"""
INSERT /*+ ignore_row_on_dupkey_index(observers(contract_id, party)) */
INTO observers (contract_id, party)
VALUES (?, ?)
""",
logHandler0 = log,
).updateMany(
dbcs.foldMap(c => c.observers.view.map(s => (c.contractId, s)).toVector)
)
r *> r2 *> r3
}

private[http] override def selectContractsMultiTemplate[T[_], Mark](
Expand All @@ -690,6 +642,7 @@ private object OracleQueries extends Queries {
log: LogHandler,
ipol: SqlInterpol,
): T[Query0[DBContract[Mark, JsValue, JsValue, Vector[String]]]] = {

// we effectively shadow Mark because Scala 2.12 doesn't quite get
// that it should use the GADT type equality otherwise
def queryByCondition[Mark0: Get](
Expand All @@ -705,40 +658,27 @@ private object OracleQueries extends Queries {
fr" OR ",
)
}
import Queries.CompatImplicits.catsReducibleFromFoldable1
val outerSelectList =
sql"""contract_id, template_id, key, payload, agreement_text,
pctSignatories, pctObservers"""
// % is explicitly reserved by specification as a delimiter
val dupQ =
sql"""SELECT c.contract_id contract_id, $tpid template_id, key, payload, agreement_text,
sd.parties pctSignatories, od.parties pctObservers,
row_number() over (PARTITION BY c.contract_id ORDER BY c.contract_id) AS rownumber
FROM (contract c
LEFT JOIN signatories sm ON (c.contract_id = sm.contract_id)
LEFT JOIN observers om ON (c.contract_id = om.contract_id))
LEFT JOIN (SELECT contract_id, LISTAGG(party, '%') parties
FROM signatories GROUP BY contract_id) sd
ON (c.contract_id = sd.contract_id)
LEFT JOIN (SELECT contract_id, LISTAGG(party, '%') parties
FROM observers GROUP BY contract_id) od
ON (c.contract_id = od.contract_id)
WHERE (${Fragments.in(fr"sm.party", parties)}
OR ${Fragments.in(fr"om.party", parties)})
AND $queriesCondition"""
// see https://github.com/digital-asset/daml/issues/9388#issuecomment-820538688
// for a demonstration
val q = sql"SELECT $outerSelectList FROM ($dupQ) WHERE rownumber = 1"
val quotedParties = parties.toVector.map(p => s""""$p"""").mkString(", ")
val partiesQuery = oracleShortPathEscape(
'$' -: Cord.stringToCord("[*]?(@ in (") :+ quotedParties :+ "))"
stefanobaghino-da marked this conversation as resolved.
Show resolved Hide resolved
)
val q =
sql"""SELECT c.contract_id contract_id, $tpid template_id, key, payload, signatories, observers, agreement_text
FROM contract c
WHERE (JSON_EXISTS(signatories, $partiesQuery)
OR JSON_EXISTS(observers, $partiesQuery))
AND $queriesCondition"""
q.query[
(String, Mark0, JsValue, JsValue, Option[String], Option[String], Option[String])
].map { case (cid, tpid, key, payload, agreement, pctSignatories, pctObservers) =>
(String, Mark0, JsValue, JsValue, JsValue, JsValue, Option[String])
].map { case (cid, tpid, key, payload, signatories, observers, agreement) =>
import spray.json.DefaultJsonProtocol._
DBContract(
contractId = cid,
templateId = tpid,
key = key.asJsObject.fields("key"),
payload = payload,
signatories = unpct(pctSignatories),
observers = unpct(pctObservers),
signatories = signatories.convertTo[Vector[String]],
observers = observers.convertTo[Vector[String]],
agreementText = agreement getOrElse "",
)
}
Expand All @@ -760,9 +700,6 @@ private object OracleQueries extends Queries {
}
}

private[this] def unpct(s: Option[String]) =
s.cata(_.split('%').filter(_.nonEmpty).toVector, Vector.empty)

private[http] override def keyEquality(key: JsValue): Fragment = {
import spray.json.DefaultJsonProtocol.JsValueFormat
sql"JSON_EQUAL(key, ${toDBContractKey(key)})"
Expand Down