Skip to content

Commit

Permalink
Migration DB unit testing for PostgreSQL v111 to v116 [DPP-756] (#12517)
Browse files Browse the repository at this point in the history
* Adding minimal dynamic framework of independent DB migration unit testing
* Implementing (relevant parts of) PG 111 and 116 DB schema
* Adding test scenarios to test data migration for string interning and create-filter table population

changelog_begin
changelog_end
  • Loading branch information
nmarton-da authored Jan 21, 2022
1 parent 7218e6f commit 15c0ad7
Show file tree
Hide file tree
Showing 5 changed files with 840 additions and 0 deletions.
1 change: 1 addition & 0 deletions ledger/participant-integration-api/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ da_scala_library(
"@maven//:io_netty_netty_common",
"@maven//:io_netty_netty_handler",
"@maven//:io_netty_netty_transport",
"@maven//:org_flywaydb_flyway_core",
"@maven//:org_mockito_mockito_core",
"@maven//:org_scalatest_scalatest_compatible",
],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// Copyright (c) 2022 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.platform.store.migration

import java.sql.{Connection, ResultSet}

import com.daml.platform.store.DbType
import com.daml.platform.store.FlywayMigrations.locations
import javax.sql.DataSource
import org.flywaydb.core.Flyway

import scala.util.Using

object MigrationTestSupport {
def migrateTo(version: String)(implicit dataSource: DataSource, dbType: DbType): Unit = {
Flyway
.configure()
.locations(locations(dbType): _*)
.dataSource(dataSource)
.target(version)
.load()
.migrate()
()
}

trait DbDataType {
def get(resultSet: ResultSet, index: Int): Any
def put(value: Any): String
def optional: DbDataType = DbDataType.Optional(this)
}

object DbDataType {
case class Optional(delegate: DbDataType) extends DbDataType {
override def get(resultSet: ResultSet, index: Int): Any =
if (resultSet.getObject(index) == null) None
else Some(delegate.get(resultSet, index))

override def put(value: Any): String = value.asInstanceOf[Option[Any]] match {
case Some(someValue) => delegate.put(someValue)
case None => "null"
}
}
}

case class TableSchema(
tableName: String,
orderByColumn: String,
columns: Map[String, DbDataType],
) {
val columnsList: List[String] = columns.keySet.toList

def ++(entries: (String, DbDataType)*): TableSchema =
copy(columns = columns ++ entries)

def --(cols: String*): TableSchema =
copy(columns = columns -- cols)
}

object TableSchema {
def apply(tableName: String, orderByColumn: String)(
entries: (String, DbDataType)*
): TableSchema =
TableSchema(tableName, orderByColumn, entries.toMap)
}

type Row = Map[String, Any]

def row(entries: (String, Any)*): Row = entries.toMap[String, Any]

implicit class RowOps(val r: Row) extends AnyVal {
def updateIn[T](key: String)(f: T => Any): Row = r + (key -> f(r(key).asInstanceOf[T]))
}

implicit class VectorRowOps(val r: Vector[Row]) extends AnyVal {
def updateInAll[T](key: String)(f: T => Any): Vector[Row] = r.map(_.updateIn(key)(f))
}

def insert(tableSchema: TableSchema, rows: Row*)(implicit connection: Connection): Unit =
rows.foreach { row =>
assert(tableSchema.columns.keySet == row.keySet)
val values =
tableSchema.columnsList.map(column => tableSchema.columns(column).put(row(column)))
val insertStatement =
s"""INSERT INTO ${tableSchema.tableName}
|(${tableSchema.columnsList.mkString(", ")})
|VALUES (${values.mkString(", ")})""".stripMargin
Using.resource(connection.createStatement())(_.execute(insertStatement))
()
}

def fetchTable(tableSchema: TableSchema)(implicit connection: Connection): Vector[Row] = {
val query =
s"""SELECT ${tableSchema.columnsList.mkString(", ")}
|FROM ${tableSchema.tableName}
|ORDER BY ${tableSchema.orderByColumn}
|""".stripMargin
Using.resource(connection.createStatement())(statement =>
Using.resource(statement.executeQuery(query)) { resultSet =>
val resultBuilder = Vector.newBuilder[Map[String, Any]]
while (resultSet.next()) {
val row = tableSchema.columnsList.zipWithIndex.map { case (column, i) =>
column -> tableSchema.columns(column).get(resultSet, i + 1)
}.toMap
resultBuilder.addOne(row)
}
resultBuilder.result()
}
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright (c) 2022 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.platform.store.migration.postgres

import java.sql.Connection

import com.daml.logging.LoggingContext
import com.daml.platform.store.DbType
import com.daml.platform.store.backend.StorageBackendFactory
import com.daml.testing.postgresql.PostgresAroundEach
import javax.sql.DataSource
import org.scalatest.Suite

import scala.annotation.tailrec
import scala.util.{Failure, Success, Try}

trait PostgresConnectionSupport extends PostgresAroundEach {
self: Suite =>

implicit var conn: Connection = _
implicit val dbType: DbType = DbType.Postgres
private val dataSourceBackend = StorageBackendFactory.of(dbType).createDataSourceStorageBackend
implicit var dataSource: DataSource = _

override def beforeEach(): Unit = {
super.beforeEach()
dataSource = dataSourceBackend.createDataSource(postgresDatabase.url)(LoggingContext.ForTesting)
conn = retry(20, 1000) {
val c = dataSource.getConnection
dataSourceBackend.checkDatabaseAvailable(c)
c
}
}

override protected def afterEach(): Unit = {
conn.close()
super.afterEach()
}

@tailrec
private def retry[T](max: Int, sleep: Long)(t: => T): T =
Try(t) match {
case Success(value) => value
case Failure(_) if max > 0 =>
Thread.sleep(sleep)
retry(max - 1, sleep)(t)
case Failure(exception) => throw exception
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright (c) 2022 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.platform.store.migration.postgres

import java.sql.ResultSet

import com.daml.platform.store.migration.MigrationTestSupport.DbDataType

object PostgresDbDataType {
case object Integer extends DbDataType {
override def get(resultSet: ResultSet, index: Int): Any = resultSet.getInt(index)
override def put(value: Any): String = value.asInstanceOf[Int].toString
}

case object BigInt extends DbDataType {
override def get(resultSet: ResultSet, index: Int): Any = resultSet.getLong(index)
override def put(value: Any): String = value.asInstanceOf[Long].toString
}

case object Str extends DbDataType {
override def get(resultSet: ResultSet, index: Int): Any = resultSet.getString(index)
override def put(value: Any): String = s"'${value.asInstanceOf[String]}'"
}

case object Bool extends DbDataType {
override def get(resultSet: ResultSet, index: Int): Any = resultSet.getBoolean(index)
override def put(value: Any): String = value.asInstanceOf[Boolean].toString
}

case object Bytea extends DbDataType {
override def get(resultSet: ResultSet, index: Int): Any = resultSet.getBytes(index).toVector
override def put(value: Any): String = value
.asInstanceOf[Vector[Byte]]
.map(_.toInt.toHexString)
.map {
case hexByte if hexByte.length == 1 => s"0$hexByte"
case hexByte => hexByte
}
.mkString("E'\\\\x", "", "'")
}

case object StringArray extends DbDataType {
override def get(resultSet: ResultSet, index: Int): Any =
resultSet
.getArray(index)
.getArray
.asInstanceOf[Array[String]]
.toVector

override def put(value: Any): String =
value
.asInstanceOf[Vector[String]]
.map(x => s"'$x'")
.mkString("ARRAY[", ", ", "]::TEXT[]")
}

case object IntArray extends DbDataType {
override def get(resultSet: ResultSet, index: Int): Any =
resultSet
.getArray(index)
.getArray
.asInstanceOf[Array[java.lang.Integer]]
.toVector
.map(_.intValue())

override def put(value: Any): String =
value
.asInstanceOf[Vector[Int]]
.map(_.toString)
.mkString("ARRAY[", ", ", "]::INTEGER[]")
}
}
Loading

0 comments on commit 15c0ad7

Please sign in to comment.