Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(db): Allow creating owned Postgres connections #2654

Merged
merged 3 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 20 additions & 15 deletions core/lib/db_connection/src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use std::{
collections::HashMap,
fmt, io,
marker::PhantomData,
panic::Location,
sync::{
atomic::{AtomicUsize, Ordering},
Mutex,
Arc, Mutex, Weak,
},
time::{Instant, SystemTime},
};
Expand Down Expand Up @@ -98,14 +99,14 @@ impl TracedConnections {
}
}

struct PooledConnection<'a> {
struct PooledConnection {
connection: PoolConnection<Postgres>,
tags: Option<ConnectionTags>,
created_at: Instant,
traced: Option<(&'a TracedConnections, usize)>,
traced: (Weak<TracedConnections>, usize),
}

impl fmt::Debug for PooledConnection<'_> {
impl fmt::Debug for PooledConnection {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("PooledConnection")
Expand All @@ -115,7 +116,7 @@ impl fmt::Debug for PooledConnection<'_> {
}
}

impl Drop for PooledConnection<'_> {
impl Drop for PooledConnection {
fn drop(&mut self) {
if let Some(tags) = &self.tags {
let lifetime = self.created_at.elapsed();
Expand All @@ -132,15 +133,17 @@ impl Drop for PooledConnection<'_> {
);
}
}
if let Some((connections, id)) = self.traced {
connections.mark_as_dropped(id);

let (traced_connections, id) = &self.traced;
if let Some(connections) = traced_connections.upgrade() {
connections.mark_as_dropped(*id);
}
}
}

#[derive(Debug)]
enum ConnectionInner<'a> {
Pooled(PooledConnection<'a>),
Pooled(PooledConnection),
Transaction {
transaction: Transaction<'a, Postgres>,
tags: Option<&'a ConnectionTags>,
Expand All @@ -156,7 +159,7 @@ pub trait DbMarker: 'static + Send + Sync + Clone {}
#[derive(Debug)]
pub struct Connection<'a, DB: DbMarker> {
inner: ConnectionInner<'a>,
_marker: std::marker::PhantomData<DB>,
_marker: PhantomData<DB>,
}

impl<'a, DB: DbMarker> Connection<'a, DB> {
Expand All @@ -166,21 +169,23 @@ impl<'a, DB: DbMarker> Connection<'a, DB> {
pub(crate) fn from_pool(
connection: PoolConnection<Postgres>,
tags: Option<ConnectionTags>,
traced_connections: Option<&'a TracedConnections>,
traced_connections: Option<&Arc<TracedConnections>>,
) -> Self {
let created_at = Instant::now();
let inner = ConnectionInner::Pooled(PooledConnection {
connection,
tags,
created_at,
traced: traced_connections.map(|connections| {
traced: if let Some(connections) = traced_connections {
let id = connections.acquire(tags, created_at);
(connections, id)
}),
(Arc::downgrade(connections), id)
} else {
(Weak::new(), 0)
},
});
Self {
inner,
_marker: Default::default(),
_marker: PhantomData,
}
}

Expand All @@ -196,7 +201,7 @@ impl<'a, DB: DbMarker> Connection<'a, DB> {
};
Ok(Connection {
inner,
_marker: Default::default(),
_marker: PhantomData,
})
}

Expand Down
8 changes: 4 additions & 4 deletions core/lib/db_connection/src/connection_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ impl<DB: DbMarker> ConnectionPool<DB> {
///
/// This method is intended to be used in crucial contexts, where the
/// database access is must-have (e.g. block committer).
pub async fn connection(&self) -> DalResult<Connection<'_, DB>> {
pub async fn connection(&self) -> DalResult<Connection<'static, DB>> {
self.connection_inner(None).await
}

Expand All @@ -361,7 +361,7 @@ impl<DB: DbMarker> ConnectionPool<DB> {
pub fn connection_tagged(
&self,
requester: &'static str,
) -> impl Future<Output = DalResult<Connection<'_, DB>>> + '_ {
) -> impl Future<Output = DalResult<Connection<'static, DB>>> + '_ {
let location = Location::caller();
async move {
let tags = ConnectionTags {
Expand All @@ -375,7 +375,7 @@ impl<DB: DbMarker> ConnectionPool<DB> {
async fn connection_inner(
&self,
tags: Option<ConnectionTags>,
) -> DalResult<Connection<'_, DB>> {
) -> DalResult<Connection<'static, DB>> {
let acquire_latency = CONNECTION_METRICS.acquire.start();
let conn = self.acquire_connection_retried(tags.as_ref()).await?;
let elapsed = acquire_latency.observe();
Expand All @@ -386,7 +386,7 @@ impl<DB: DbMarker> ConnectionPool<DB> {
Ok(Connection::<DB>::from_pool(
conn,
tags,
self.traced_connections.as_deref(),
self.traced_connections.as_ref(),
))
}

Expand Down
3 changes: 1 addition & 2 deletions core/lib/state/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ pub use self::{
},
shadow_storage::ShadowStorage,
storage_factory::{
BatchDiff, OwnedPostgresStorage, OwnedStorage, PgOrRocksdbStorage, ReadStorageFactory,
RocksdbWithMemory,
BatchDiff, OwnedStorage, PgOrRocksdbStorage, ReadStorageFactory, RocksdbWithMemory,
},
};

Expand Down
79 changes: 23 additions & 56 deletions core/lib/state/src/storage_factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ use zksync_vm_interface::storage::ReadStorage;

use crate::{PostgresStorage, RocksdbStorage, RocksdbStorageBuilder, StateKeeperColumnFamily};

/// Storage with a static lifetime that can be sent to Tokio tasks etc.
pub type OwnedStorage = PgOrRocksdbStorage<'static>;

/// Factory that can produce storage instances on demand. The storage type is encapsulated as a type param
/// (mostly for testing purposes); the default is [`OwnedStorage`].
#[async_trait]
Expand All @@ -35,8 +38,9 @@ impl ReadStorageFactory for ConnectionPool<Core> {
_stop_receiver: &watch::Receiver<bool>,
l1_batch_number: L1BatchNumber,
) -> anyhow::Result<Option<OwnedStorage>> {
let storage = OwnedPostgresStorage::new(self.clone(), l1_batch_number);
Ok(Some(storage.into()))
let connection = self.connection().await?;
let storage = OwnedStorage::postgres(connection, l1_batch_number).await?;
Ok(Some(storage))
}
}

Expand All @@ -61,31 +65,29 @@ pub struct RocksdbWithMemory {
pub batch_diffs: Vec<BatchDiff>,
}

/// Owned Postgres-backed VM storage for a certain L1 batch.
/// A [`ReadStorage`] implementation that uses either [`PostgresStorage`] or [`RocksdbStorage`]
/// underneath.
#[derive(Debug)]
pub struct OwnedPostgresStorage {
connection_pool: ConnectionPool<Core>,
l1_batch_number: L1BatchNumber,
pub enum PgOrRocksdbStorage<'a> {
/// Implementation over a Postgres connection.
Postgres(PostgresStorage<'a>),
/// Implementation over a RocksDB cache instance.
Rocksdb(RocksdbStorage),
/// Implementation over a RocksDB cache instance with in-memory DB diffs.
RocksdbWithMemory(RocksdbWithMemory),
}

impl OwnedPostgresStorage {
/// Creates a VM storage for the specified batch number.
pub fn new(connection_pool: ConnectionPool<Core>, l1_batch_number: L1BatchNumber) -> Self {
Self {
connection_pool,
l1_batch_number,
}
}

/// Returns a [`ReadStorage`] implementation backed by Postgres
impl PgOrRocksdbStorage<'static> {
/// Creates a Postgres-based storage. Because of the `'static` lifetime requirement, `connection` must be
/// non-transactional.
///
/// # Errors
///
/// Propagates Postgres errors.
pub async fn borrow(&self) -> anyhow::Result<PgOrRocksdbStorage<'_>> {
let l1_batch_number = self.l1_batch_number;
let mut connection = self.connection_pool.connection().await?;

/// Propagates Postgres I/O errors.
pub async fn postgres(
mut connection: Connection<'static, Core>,
l1_batch_number: L1BatchNumber,
) -> anyhow::Result<Self> {
let l2_block_number = if let Some((_, l2_block_number)) = connection
.blocks_dal()
.get_l2_block_range_of_l1_batch(l1_batch_number)
Expand Down Expand Up @@ -114,42 +116,7 @@ impl OwnedPostgresStorage {
.into(),
)
}
}

/// Owned version of [`PgOrRocksdbStorage`]. It is thus possible to send to blocking tasks for VM execution.
#[derive(Debug)]
pub enum OwnedStorage {
/// Readily initialized storage with a static lifetime.
Static(PgOrRocksdbStorage<'static>),
/// Storage that must be `borrow()`ed from.
Lending(OwnedPostgresStorage),
}

impl From<OwnedPostgresStorage> for OwnedStorage {
fn from(storage: OwnedPostgresStorage) -> Self {
Self::Lending(storage)
}
}

impl From<PgOrRocksdbStorage<'static>> for OwnedStorage {
fn from(storage: PgOrRocksdbStorage<'static>) -> Self {
Self::Static(storage)
}
}

/// A [`ReadStorage`] implementation that uses either [`PostgresStorage`] or [`RocksdbStorage`]
/// underneath.
#[derive(Debug)]
pub enum PgOrRocksdbStorage<'a> {
/// Implementation over a Postgres connection.
Postgres(PostgresStorage<'a>),
/// Implementation over a RocksDB cache instance.
Rocksdb(RocksdbStorage),
/// Implementation over a RocksDB cache instance with in-memory DB diffs.
RocksdbWithMemory(RocksdbWithMemory),
}

impl PgOrRocksdbStorage<'static> {
/// Catches up RocksDB synchronously (i.e. assumes the gap is small) and
/// returns a [`ReadStorage`] implementation backed by caught-up RocksDB.
///
Expand Down
2 changes: 1 addition & 1 deletion core/node/api_server/src/web3/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ impl RpcState {
#[track_caller]
pub(crate) fn acquire_connection(
&self,
) -> impl Future<Output = Result<Connection<'_, Core>, Web3Error>> + '_ {
) -> impl Future<Output = Result<Connection<'static, Core>, Web3Error>> + '_ {
self.connection_pool
.connection_tagged("api")
.map_err(|err| err.generalize().into())
Expand Down
Loading
Loading