Skip to content

Commit

Permalink
use mapped mutex guard (#177)
Browse files Browse the repository at this point in the history
  • Loading branch information
maxcountryman authored Mar 17, 2024
1 parent e27ecce commit bc0d0f9
Showing 1 changed file with 43 additions and 51 deletions.
94 changes: 43 additions & 51 deletions tower-sessions-core/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use base64::{engine::general_purpose::URL_SAFE_NO_PAD, DecodeError, Engine as _}
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::Value;
use time::{Duration, OffsetDateTime};
use tokio::sync::{Mutex, MutexGuard};
use tokio::sync::{MappedMutexGuard, Mutex, MutexGuard};

use crate::{session_store, SessionStore};

Expand Down Expand Up @@ -97,21 +97,21 @@ impl Session {
}

#[tracing::instrument(skip(self), err)]
async fn get_record(&self) -> Result<MutexGuard<Option<Record>>> {
async fn get_record(&self) -> Result<MappedMutexGuard<Record>> {
let mut record_guard = self.record.lock().await;
let session_id = *self.session_id.lock();

// Lazily load the record.
// Lazily load the record since `None` here indicates we have no yet loaded it.
if record_guard.is_none() {
tracing::trace!("record not loaded from store; loading");

*record_guard = if let Some(session_id) = session_id {
match self.store.load(&session_id).await.map_err(Error::Store)? {
Some(mut loaded_record) => {
let session_id = *self.session_id.lock();
*record_guard = Some(if let Some(session_id) = session_id {
match self.store.load(&session_id).await? {
Some(loaded_record) => {
tracing::trace!("record found in store");
loaded_record.expiry_date = self.expiry_date();
Some(loaded_record)
loaded_record
}

None => {
// A well-behaved user agent should not send session cookies after
// expiration. Even so it's possible for an expired session to be removed
Expand All @@ -120,16 +120,19 @@ impl Session {
// malicious behavior.
tracing::warn!("possibly suspicious activity: record not found in store");
*self.session_id.lock() = None;
Some(self.create_record())
self.create_record()
}
}
} else {
tracing::trace!("session id not found");
Some(self.create_record())
}
self.create_record()
})
}

Ok(record_guard)
Ok(MutexGuard::map(record_guard, |opt| {
opt.as_mut()
.expect("Record should always be `Option::Some` at this point")
}))
}

/// Inserts a `impl Serialize` value into the session.
Expand Down Expand Up @@ -207,14 +210,13 @@ impl Session {
/// - If the session has not been hydrated and loading from the store fails,
/// we fail with [`Error::Store`].
pub async fn insert_value(&self, key: &str, value: Value) -> Result<Option<Value>> {
Ok(self.get_record().await?.as_mut().and_then(|record| {
if record.data.get(key) != Some(&value) {
self.is_modified.store(true, atomic::Ordering::Release);
record.data.insert(key.to_string(), value)
} else {
None
}
}))
let mut record_guard = self.get_record().await?;
Ok(if record_guard.data.get(key) != Some(&value) {
self.is_modified.store(true, atomic::Ordering::Release);
record_guard.data.insert(key.to_string(), value)
} else {
None
})
}

/// Gets a value from the store.
Expand Down Expand Up @@ -275,11 +277,8 @@ impl Session {
/// - If the session has not been hydrated and loading from the store fails,
/// we fail with [`Error::Store`].
pub async fn get_value(&self, key: &str) -> Result<Option<Value>> {
Ok(self
.get_record()
.await?
.as_ref()
.and_then(|record| record.data.get(key).cloned()))
let record_guard = self.get_record().await?;
Ok(record_guard.data.get(key).cloned())
}

/// Removes a value from the store, retuning the value of the key if it was
Expand Down Expand Up @@ -346,10 +345,9 @@ impl Session {
/// - If the session has not been hydrated and loading from the store fails,
/// we fail with [`Error::Store`].
pub async fn remove_value(&self, key: &str) -> Result<Option<Value>> {
Ok(self.get_record().await?.as_mut().and_then(|record| {
self.is_modified.store(true, atomic::Ordering::Release);
record.data.remove(key)
}))
let mut record_guard = self.get_record().await?;
self.is_modified.store(true, atomic::Ordering::Release);
Ok(record_guard.data.remove(key))
}

/// Clears the session of all data but does not delete it from the store.
Expand Down Expand Up @@ -649,24 +647,21 @@ impl Session {
/// - If saving to the store fails, we fail with [`Error::Store`].
#[tracing::instrument(skip(self), err)]
pub async fn save(&self) -> Result<()> {
// N.B.: `get_record` will create a new record if one isn't found in the store.
if let Some(record) = self.get_record().await?.as_mut() {
record.expiry_date = self.expiry_date();

{
let mut session_id_guard = self.session_id.lock();
if session_id_guard.is_none() {
// Generate a new ID here since e.g. flush may have been called, which will
// not directly update the record ID.
let id = Id::default();
*session_id_guard = Some(id);
record.id = id;
}
let mut record_guard = self.get_record().await?;
record_guard.expiry_date = self.expiry_date();
{
let mut session_id_guard = self.session_id.lock();
if session_id_guard.is_none() {
// Generate a new ID here since e.g. flush may have been called, which will
// not directly update the record ID.
let id = Id::default();
*session_id_guard = Some(id);
record_guard.id = id;
}

self.store.save(record).await.map_err(Error::Store)?;
}

self.store.save(&record_guard).await.map_err(Error::Store)?;

Ok(())
}

Expand Down Expand Up @@ -829,13 +824,10 @@ impl Session {
/// with [`Error::Store`].
pub async fn cycle_id(&self) -> Result<()> {
let mut record_guard = self.get_record().await?;
let Some(record) = record_guard.as_mut() else {
return Ok(());
};

let old_session_id = record.id;
record.id = Id::default();
*self.session_id.lock() = Some(record.id);
let old_session_id = record_guard.id;
record_guard.id = Id::default();
*self.session_id.lock() = Some(record_guard.id);

self.store
.delete(&old_session_id)
Expand Down

0 comments on commit bc0d0f9

Please sign in to comment.