From e02999d34d7926fddb4833c39c23f56fddb34d3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20G=C3=B6ttsche?= Date: Sat, 26 Oct 2024 16:40:50 +0200 Subject: [PATCH] Avoid ID clashes Currently the the ID for a new paste is randomly generated in the caller of the database insert() function. Then the insert() function tries to insert a new row into the database with that passed ID. There can however already exists a paste in the database with the same ID leading to an insert failure, due to a constraint violation due to the PRIMARY KEY attribute. Checking prior the the INSERT via a SELECT query would open the window for a race condition. A failure to push a new paste is quite severe, since the user might have spent some some to format the input. Generate the ID in a loop inside, until the INSERT succeeds. --- src/db.rs | 85 ++++++++++++++++++++++++++++++++-------------- src/id.rs | 8 ++--- src/routes/form.rs | 32 +++++++---------- src/routes/json.rs | 18 +++------- 4 files changed, 78 insertions(+), 65 deletions(-) diff --git a/src/db.rs b/src/db.rs index 2a7cf08..19eff94 100644 --- a/src/db.rs +++ b/src/db.rs @@ -247,32 +247,68 @@ impl Database { Ok(Self { conn }) } - /// Insert `entry` under `id` into the database and optionally set owner to `uid`. - pub async fn insert(&self, id: Id, entry: write::Entry) -> Result<(), Error> { + /// Insert `entry` with a new generated `id` into the database and optionally set owner to `uid`. + pub async fn insert(&self, entry: write::Entry) -> Result { let conn = self.conn.clone(); - let id = id.as_u32(); let write::DatabaseEntry { entry, data, nonce } = entry.compress().await?.encrypt().await?; - spawn_blocking(move || match entry.expires { - None => conn.lock().execute( - "INSERT INTO entries (id, uid, data, burn_after_reading, nonce) VALUES (?1, ?2, ?3, ?4, ?5)", - params![id, entry.uid, data, entry.burn_after_reading, nonce], - ), - Some(expires) => conn.lock().execute( - "INSERT INTO entries (id, uid, data, burn_after_reading, nonce, expires) VALUES (?1, ?2, ?3, ?4, ?5, datetime('now', ?6))", - params![ - id, - entry.uid, - data, - entry.burn_after_reading, - nonce, - format!("{expires} seconds") - ], - ), + let id = spawn_blocking(move || { + const COUNTER_LIMIT: u32 = 10; + let mut counter = 0; + + let mut rng = rand::thread_rng(); + + loop { + let id: Id = rand::Rng::gen::(&mut rng).into(); + let id_inner = id.as_u32(); + + let result = match entry.expires { + None => conn.lock().execute( + "INSERT INTO entries (id, uid, data, burn_after_reading, nonce) VALUES (?1, ?2, ?3, ?4, ?5)", + params![id_inner, entry.uid, data, entry.burn_after_reading, nonce], + ), + Some(expires) => conn.lock().execute( + "INSERT INTO entries (id, uid, data, burn_after_reading, nonce, expires) VALUES (?1, ?2, ?3, ?4, ?5, datetime('now', ?6))", + params![ + id_inner, + entry.uid, + data, + entry.burn_after_reading, + nonce, + format!("{expires} seconds") + ], + ), + }; + + match result { + Err(rusqlite::Error::SqliteFailure(rusqlite::ffi::Error { code, extended_code }, Some(ref _message))) + if code == rusqlite::ErrorCode::ConstraintViolation && extended_code == rusqlite::ffi::SQLITE_CONSTRAINT_PRIMARYKEY && counter < COUNTER_LIMIT => { + /* Retry if ID is already existent */ + counter += 1; + continue; + }, + Err(err) => { + if counter >= COUNTER_LIMIT { + tracing::error!("Failed to generate ID after {counter} retries"); + } + + break Err(err) + }, + Ok(rows) => { + debug_assert!(rows == 1); + + if counter > 4 { + tracing::warn!("Required {counter} retries to generate new ID"); + } + + break Ok(id) + }, + } + } }) .await??; - Ok(()) + Ok(id) } /// Get entire entry for `id`. @@ -383,8 +419,7 @@ mod tests { ..Default::default() }; - let id = Id::from(1234); - db.insert(id, entry).await?; + let id = db.insert(entry).await?; let entry = db.get(id, None).await?; assert_eq!(entry.text, "hello world"); @@ -406,8 +441,7 @@ mod tests { ..Default::default() }; - let id = Id::from(1234); - db.insert(id, entry).await?; + let id = db.insert(entry).await?; tokio::time::sleep(tokio::time::Duration::from_secs(2)).await; @@ -422,8 +456,7 @@ mod tests { async fn delete() -> Result<(), Box> { let db = new_db()?; - let id = Id::from(1234); - db.insert(id, write::Entry::default()).await?; + let id = db.insert(write::Entry::default()).await?; assert!(db.get(id, None).await.is_ok()); assert!(db.delete(id).await.is_ok()); diff --git a/src/id.rs b/src/id.rs index 313044f..956523e 100644 --- a/src/id.rs +++ b/src/id.rs @@ -1,4 +1,3 @@ -use crate::db::write::Entry; use crate::errors::Error; use std::fmt; use std::str::FromStr; @@ -23,11 +22,8 @@ impl Id { } /// Generate a URL path from the string representation and `entry`'s extension. - pub fn to_url_path(self, entry: &Entry) -> String { - entry - .extension - .as_ref() - .map_or_else(|| format!("{self}"), |ext| format!("{self}.{ext}")) + pub fn to_url_path(self, extension: Option<&str>) -> String { + extension.map_or_else(|| format!("{self}"), |ext| format!("{self}.{ext}")) } } diff --git a/src/routes/form.rs b/src/routes/form.rs index d0bbe87..6d6c02b 100644 --- a/src/routes/form.rs +++ b/src/routes/form.rs @@ -2,12 +2,10 @@ use std::num::NonZeroU32; use crate::db::write; use crate::env::BASE_PATH; -use crate::id::Id; use crate::{pages, AppState, Error}; use axum::extract::{Form, State}; use axum::response::Redirect; use axum_extra::extract::cookie::{Cookie, SignedCookieJar}; -use rand::Rng; use serde::{Deserialize, Serialize}; #[derive(Debug, Serialize, Deserialize)] @@ -44,14 +42,6 @@ pub async fn insert( jar: SignedCookieJar, Form(entry): Form, ) -> Result<(SignedCookieJar, Redirect), pages::ErrorResponse<'static>> { - let id: Id = tokio::task::spawn_blocking(|| { - let mut rng = rand::thread_rng(); - rng.gen::() - }) - .await - .map_err(Error::from)? - .into(); - // Retrieve uid from cookie or generate a new one. let uid = if let Some(cookie) = jar.get("uid") { cookie @@ -65,22 +55,24 @@ pub async fn insert( let mut entry: write::Entry = entry.into(); entry.uid = Some(uid); - let mut url = id.to_url_path(&entry); - - let burn_after_reading = entry.burn_after_reading.unwrap_or(false); - if burn_after_reading { - url = format!("burn/{url}"); - } - - let url_with_base = BASE_PATH.join(&url); - if let Some(max_exp) = state.max_expiration { entry.expires = entry .expires .map_or_else(|| Some(max_exp), |value| Some(value.min(max_exp))); } - state.db.insert(id, entry).await?; + let burn = entry.burn_after_reading.unwrap_or(false); + let extension = entry.extension.clone(); + + let id = state.db.insert(entry).await?; + + let mut url = id.to_url_path(extension.as_deref()); + + if burn { + url = format!("burn/{url}"); + } + + let url_with_base = BASE_PATH.join(&url); let jar = jar.add(Cookie::new("uid", uid.to_string())); Ok((jar, Redirect::to(&url_with_base))) diff --git a/src/routes/json.rs b/src/routes/json.rs index a6962af..575cd6e 100644 --- a/src/routes/json.rs +++ b/src/routes/json.rs @@ -2,12 +2,10 @@ use std::num::NonZeroU32; use crate::db::write; use crate::env::BASE_PATH; -use crate::errors::{Error, JsonErrorResponse}; -use crate::id::Id; +use crate::errors::JsonErrorResponse; use crate::AppState; use axum::extract::State; use axum::Json; -use rand::Rng; use serde::{Deserialize, Serialize}; #[derive(Debug, Serialize, Deserialize)] @@ -41,14 +39,6 @@ pub async fn insert( state: State, Json(entry): Json, ) -> Result, JsonErrorResponse> { - let id: Id = tokio::task::spawn_blocking(|| { - let mut rng = rand::thread_rng(); - rng.gen::() - }) - .await - .map_err(Error::from)? - .into(); - let mut entry: write::Entry = entry.into(); if let Some(max_exp) = state.max_expiration { @@ -57,9 +47,11 @@ pub async fn insert( .map_or_else(|| Some(max_exp), |value| Some(value.min(max_exp))); } - let url = id.to_url_path(&entry); + let extension = entry.extension.clone(); + + let id = state.db.insert(entry).await?; + let url = id.to_url_path(extension.as_deref()); let path = BASE_PATH.join(&url); - state.db.insert(id, entry).await?; Ok(Json::from(RedirectResponse { path })) }