Skip to content

Commit

Permalink
move session state into inner struct (#189)
Browse files Browse the repository at this point in the history
This wraps session state into an inner struct, reducing the required
Arcs.
  • Loading branch information
maxcountryman authored Apr 2, 2024
1 parent 2a542d8 commit d7b7e13
Showing 1 changed file with 57 additions and 39 deletions.
96 changes: 57 additions & 39 deletions tower-sessions-core/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,29 +37,33 @@ pub enum Error {
Store(#[from] session_store::Error),
}

/// A session which allows HTTP applications to associate key-value pairs with
/// visitors.
#[derive(Debug, Clone)]
pub struct Session {
#[derive(Debug)]
struct Inner {
// This will be `None` when:
//
// 1. We have not been provided a session cookie or have failed to parse it,
// 2. The store has not found the session.
//
// Sync lock, see: https://docs.rs/tokio/latest/tokio/sync/struct.Mutex.html#which-kind-of-mutex-should-you-use
session_id: Arc<parking_lot::Mutex<Option<Id>>>,

store: Arc<dyn SessionStore>,
session_id: parking_lot::Mutex<Option<Id>>,

// A lazy representation of the session's value, hydrated on a just-in-time basis. A
// `None` value indicates we have not tried to access it yet. After access, it will always
// contain `Some(Record)`.
record: Arc<Mutex<Option<Record>>>,
record: Mutex<Option<Record>>,

// Sync lock, see: https://docs.rs/tokio/latest/tokio/sync/struct.Mutex.html#which-kind-of-mutex-should-you-use
expiry: Arc<parking_lot::Mutex<Option<Expiry>>>,
expiry: parking_lot::Mutex<Option<Expiry>>,

is_modified: Arc<AtomicBool>,
is_modified: AtomicBool,
}

/// A session which allows HTTP applications to associate key-value pairs with
/// visitors.
#[derive(Debug, Clone)]
pub struct Session {
store: Arc<dyn SessionStore>,
inner: Arc<Inner>,
}

impl Session {
Expand All @@ -83,12 +87,16 @@ impl Session {
store: Arc<impl SessionStore>,
expiry: Option<Expiry>,
) -> Self {
let inner = Inner {
session_id: parking_lot::Mutex::new(session_id),
record: Mutex::new(None), // `None` indicates we have not loaded from store.
expiry: parking_lot::Mutex::new(expiry),
is_modified: AtomicBool::new(false),
};

Self {
session_id: Arc::new(parking_lot::Mutex::new(session_id)),
store,
record: Arc::new(Mutex::new(None)), // `None` indicates we have not loaded from store.
expiry: Arc::new(parking_lot::Mutex::new(expiry)),
is_modified: Arc::new(AtomicBool::new(false)),
inner: Arc::new(inner),
}
}

Expand All @@ -98,13 +106,13 @@ impl Session {

#[tracing::instrument(skip(self), err)]
async fn get_record(&self) -> Result<MappedMutexGuard<Record>> {
let mut record_guard = self.record.lock().await;
let mut record_guard = self.inner.record.lock().await;

// Lazily load the record since `None` here indicates we have no yet loaded it.
if record_guard.is_none() {
tracing::trace!("record not loaded from store; loading");

let session_id = *self.session_id.lock();
let session_id = *self.inner.session_id.lock();
*record_guard = Some(if let Some(session_id) = session_id {
match self.store.load(&session_id).await? {
Some(loaded_record) => {
Expand All @@ -119,7 +127,7 @@ impl Session {
// be relatively uncommon and as such entering this branch could indicate
// malicious behavior.
tracing::warn!("possibly suspicious activity: record not found in store");
*self.session_id.lock() = None;
*self.inner.session_id.lock() = None;
self.create_record()
}
}
Expand Down Expand Up @@ -212,7 +220,9 @@ impl Session {
pub async fn insert_value(&self, key: &str, value: Value) -> Result<Option<Value>> {
let mut record_guard = self.get_record().await?;
Ok(if record_guard.data.get(key) != Some(&value) {
self.is_modified.store(true, atomic::Ordering::Release);
self.inner
.is_modified
.store(true, atomic::Ordering::Release);
record_guard.data.insert(key.to_string(), value)
} else {
None
Expand Down Expand Up @@ -346,7 +356,9 @@ impl Session {
/// we fail with [`Error::Store`].
pub async fn remove_value(&self, key: &str) -> Result<Option<Value>> {
let mut record_guard = self.get_record().await?;
self.is_modified.store(true, atomic::Ordering::Release);
self.inner
.is_modified
.store(true, atomic::Ordering::Release);
Ok(record_guard.data.remove(key))
}

Expand Down Expand Up @@ -386,16 +398,18 @@ impl Session {
/// # });
/// ```
pub async fn clear(&self) {
let mut record_guard = self.record.lock().await;
let mut record_guard = self.inner.record.lock().await;
if let Some(record) = record_guard.as_mut() {
record.data.clear();
} else if let Some(session_id) = *self.session_id.lock() {
} else if let Some(session_id) = *self.inner.session_id.lock() {
let mut new_record = self.create_record();
new_record.id = session_id;
*record_guard = Some(new_record);
}

self.is_modified.store(true, atomic::Ordering::Release);
self.inner
.is_modified
.store(true, atomic::Ordering::Release);
}

/// Returns `true` if there is no session ID and the session is empty.
Expand Down Expand Up @@ -450,13 +464,13 @@ impl Session {
/// # });
/// ```
pub async fn is_empty(&self) -> bool {
let record_guard = self.record.lock().await;
let record_guard = self.inner.record.lock().await;

// N.B.: Session IDs are `None` if:
//
// 1. The cookie was not provided or otherwise could not be parsed,
// 2. Or the session could not be loaded from the store.
let session_id = self.session_id.lock();
let session_id = self.inner.session_id.lock();

let Some(record) = record_guard.as_ref() else {
return session_id.is_none();
Expand Down Expand Up @@ -484,7 +498,7 @@ impl Session {
/// assert_eq!(id, session.id());
/// ```
pub fn id(&self) -> Option<Id> {
*self.session_id.lock()
*self.inner.session_id.lock()
}

/// Get the session expiry.
Expand All @@ -502,7 +516,7 @@ impl Session {
/// assert_eq!(session.expiry(), None);
/// ```
pub fn expiry(&self) -> Option<Expiry> {
*self.expiry.lock()
*self.inner.expiry.lock()
}

/// Set `expiry` to the given value.
Expand All @@ -527,8 +541,10 @@ impl Session {
/// assert_eq!(session.expiry(), Some(expiry));
/// ```
pub fn set_expiry(&self, expiry: Option<Expiry>) {
*self.expiry.lock() = expiry;
self.is_modified.store(true, atomic::Ordering::Release);
*self.inner.expiry.lock() = expiry;
self.inner
.is_modified
.store(true, atomic::Ordering::Release);
}

/// Get session expiry as `OffsetDateTime`.
Expand All @@ -551,7 +567,7 @@ impl Session {
/// assert!(session.expiry_date() < expected_expiry.saturating_add(Duration::seconds(1)));
/// ```
pub fn expiry_date(&self) -> OffsetDateTime {
let expiry = self.expiry.lock();
let expiry = self.inner.expiry.lock();
match *expiry {
Some(Expiry::OnInactivity(duration)) => {
OffsetDateTime::now_utc().saturating_add(duration)
Expand Down Expand Up @@ -614,7 +630,7 @@ impl Session {
/// # });
/// ```
pub fn is_modified(&self) -> bool {
self.is_modified.load(atomic::Ordering::Acquire)
self.inner.is_modified.load(atomic::Ordering::Acquire)
}

/// Saves the session record to the store.
Expand Down Expand Up @@ -658,9 +674,9 @@ impl Session {
// In either case, we must create a new session via the store interface.
//
// Potential ID collisions must be handled by session store implementers.
if self.session_id.lock().is_none() {
if self.inner.session_id.lock().is_none() {
self.store.create(&mut record_guard).await?;
*self.session_id.lock() = Some(record_guard.id);
*self.inner.session_id.lock() = Some(record_guard.id);
} else {
self.store.save(&record_guard).await?;
}
Expand Down Expand Up @@ -699,13 +715,13 @@ impl Session {
/// - If loading from the store fails, we fail with [`Error::Store`].
#[tracing::instrument(skip(self), err)]
pub async fn load(&self) -> Result<()> {
let session_id = *self.session_id.lock();
let session_id = *self.inner.session_id.lock();
let Some(ref id) = session_id else {
tracing::warn!("called load with no session id");
return Ok(());
};
let loaded_record = self.store.load(id).await.map_err(Error::Store)?;
let mut record_guard = self.record.lock().await;
let mut record_guard = self.inner.record.lock().await;
*record_guard = loaded_record;
Ok(())
}
Expand Down Expand Up @@ -738,7 +754,7 @@ impl Session {
/// - If deleting from the store fails, we fail with [`Error::Store`].
#[tracing::instrument(skip(self), err)]
pub async fn delete(&self) -> Result<()> {
let session_id = *self.session_id.lock();
let session_id = *self.inner.session_id.lock();
let Some(ref session_id) = session_id else {
tracing::warn!("called delete with no session id");
return Ok(());
Expand Down Expand Up @@ -780,7 +796,7 @@ impl Session {
pub async fn flush(&self) -> Result<()> {
self.clear().await;
self.delete().await?;
*self.session_id.lock() = None;
*self.inner.session_id.lock() = None;
Ok(())
}

Expand Down Expand Up @@ -829,15 +845,17 @@ impl Session {

let old_session_id = record_guard.id;
record_guard.id = Id::default();
*self.session_id.lock() = None; // Setting `None` ensures `save` invokes the store's
// `create` method.
*self.inner.session_id.lock() = None; // Setting `None` ensures `save` invokes the store's
// `create` method.

self.store
.delete(&old_session_id)
.await
.map_err(Error::Store)?;

self.is_modified.store(true, atomic::Ordering::Release);
self.inner
.is_modified
.store(true, atomic::Ordering::Release);

Ok(())
}
Expand Down

0 comments on commit d7b7e13

Please sign in to comment.