From ac210f50fc4374e5cec21234b52cd8794df9bb52 Mon Sep 17 00:00:00 2001 From: Daniel Boline Date: Sat, 23 Nov 2024 19:57:11 -0500 Subject: [PATCH] refactoring --- Cargo.toml | 14 +- .../V11__authorized_users_created_deleted.sql | 2 + src/logged_user.rs | 63 ++++-- src/models.rs | 203 ++++++++++++------ src/security_log_http.rs | 5 +- 5 files changed, 188 insertions(+), 99 deletions(-) create mode 100644 migrations/V11__authorized_users_created_deleted.sql diff --git a/Cargo.toml b/Cargo.toml index 6973aad..24d4505 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "security_log_analysis_rust" -version = "0.11.10" +version = "0.12.0" authors = ["Daniel Boline "] edition = "2018" @@ -14,12 +14,12 @@ Analyze Auth Logs.""" [dependencies] anyhow = "1.0" -authorized_users = { git = "https://github.com/ddboline/auth_server_rust.git", tag="0.11.15"} +authorized_users = { git = "https://github.com/ddboline/auth_server_rust.git", tag="0.12.0"} aws-config = {version="1.0", features=["behavior-version-latest"]} aws-sdk-s3 = "1.1" aws-sdk-ses = "1.1" bytes = "1.0" -cached = {version="0.53", features=["async", "async_tokio_rt_multi_thread"]} +cached = {version="0.54", features=["async", "async_tokio_rt_multi_thread"]} chrono = "0.4" clap = {version="4.0", features=["derive"]} deadpool = {version = "0.12", features=["serde", "rt_tokio_1"]} @@ -40,7 +40,7 @@ itertools = "0.13" log = "0.4" maplit = "1.0" parking_lot = "0.12" -polars = {version="0.43", features=["temporal", "parquet", "lazy"]} +polars = {version="0.44", features=["temporal", "parquet", "lazy"]} postgres_query = {git = "https://github.com/ddboline/rust-postgres-query", tag = "0.3.8", features=["deadpool"]} postgres-types = {version="0.2", features=["with-time-0_3", "with-uuid-1", "with-serde_json-1"]} rand = "0.8" @@ -51,13 +51,13 @@ serde = { version="1.0", features=["derive"]} serde_json = "1.0" serde_yml = "0.0.12" smallvec = "1.6" -stack-string = { git = "https://github.com/ddboline/stack-string-rs.git", features=["postgres_types", "rweb-openapi"], tag="0.9.3" } +stack-string = { git = "https://github.com/ddboline/stack-string-rs.git", features=["postgres_types", "rweb-openapi"], tag="1.0.2" } stdout-channel = "0.6" -thiserror = "1.0" +thiserror = "2.0" time = {version="0.3", features=["serde-human-readable", "macros", "formatting"]} time-tz = {version="2.0", features=["system"]} tokio-postgres = {version="0.7", features=["with-time-0_3", "with-uuid-1", "with-serde_json-1"]} -tokio = {version="1.38", features=["rt", "macros", "rt-multi-thread"]} +tokio = {version="1.41", features=["rt", "macros", "rt-multi-thread"]} rweb = {git = "https://github.com/ddboline/rweb.git", features=["openapi"], default-features=false, tag="0.15.2"} rweb-helper = { git = "https://github.com/ddboline/rweb_helper.git", tag="0.5.3" } uuid = { version = "1.0", features = ["serde", "v4"] } diff --git a/migrations/V11__authorized_users_created_deleted.sql b/migrations/V11__authorized_users_created_deleted.sql new file mode 100644 index 0000000..8c461f1 --- /dev/null +++ b/migrations/V11__authorized_users_created_deleted.sql @@ -0,0 +1,2 @@ +ALTER TABLE authorized_users ADD COLUMN created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now(); +ALTER TABLE authorized_users ADD COLUMN deleted_at TIMESTAMP WITH TIME ZONE; diff --git a/src/logged_user.rs b/src/logged_user.rs index 6913708..d3beedd 100644 --- a/src/logged_user.rs +++ b/src/logged_user.rs @@ -1,10 +1,10 @@ pub use authorized_users::{ - get_random_key, get_secrets, token::Token, AuthorizedUser, AUTHORIZED_USERS, JWT_SECRET, - KEY_LENGTH, LOGIN_HTML, SECRET_KEY, TRIGGER_DB_UPDATE, + get_random_key, get_secrets, token::Token, AuthorizedUser as ExternalUser, AUTHORIZED_USERS, + JWT_SECRET, KEY_LENGTH, LOGIN_HTML, SECRET_KEY, TRIGGER_DB_UPDATE, }; use futures::TryStreamExt; use log::debug; -use maplit::hashset; +use maplit::hashmap; use rweb::{ filters::{cookie::cookie, BoxedFilter}, Filter, FromRequest, Rejection, Schema, @@ -13,10 +13,12 @@ use rweb_helper::UuidWrapper; use serde::{Deserialize, Serialize}; use stack_string::StackString; use std::{ + collections::HashMap, convert::{TryFrom, TryInto}, env::var, str::FromStr, }; +use time::OffsetDateTime; use uuid::Uuid; use crate::{errors::ServiceError as Error, models::AuthorizedUsers, pgpool::PgPool}; @@ -63,8 +65,8 @@ impl FromRequest for LoggedUser { } } -impl From for LoggedUser { - fn from(user: AuthorizedUser) -> Self { +impl From for LoggedUser { + fn from(user: ExternalUser) -> Self { Self { email: user.email, session: user.session.into(), @@ -99,21 +101,46 @@ impl FromStr for LoggedUser { /// # Errors /// Return error if db query fails pub async fn fill_from_db(pool: &PgPool) -> Result<(), Error> { - debug!("{:?}", *TRIGGER_DB_UPDATE); - let users = if TRIGGER_DB_UPDATE.check() { - AuthorizedUsers::get_authorized_users(pool) - .await? - .map_ok(|user| user.email) - .try_collect() - .await? - } else { - AUTHORIZED_USERS.get_users() - }; if let Ok("true") = var("TESTENV").as_ref().map(String::as_str) { - AUTHORIZED_USERS.update_users(hashset! {"user@test".into()}); + AUTHORIZED_USERS.update_users(hashmap! { + "user@test".into() => ExternalUser { + email: "user@test".into(), + session: Uuid::new_v4(), + secret_key: StackString::default(), + created_at: Some(OffsetDateTime::now_utc()) + } + }); + return Ok(()); + } + let (created_at, deleted_at) = AuthorizedUsers::get_most_recent(pool).await?; + let most_recent_user_db = created_at.max(deleted_at); + let existing_users = AUTHORIZED_USERS.get_users(); + let most_recent_user = existing_users.values().map(|i| i.created_at).max(); + debug!("most_recent_user_db {most_recent_user_db:?} most_recent_user {most_recent_user:?}"); + if most_recent_user_db.is_some() + && most_recent_user.is_some() + && most_recent_user_db <= most_recent_user + { + return Ok(()); } - AUTHORIZED_USERS.update_users(users); - debug!("{:?}", *AUTHORIZED_USERS); + let result: Result, _> = AuthorizedUsers::get_authorized_users(pool) + .await? + .map_ok(|u| { + ( + u.email.clone(), + ExternalUser { + email: u.email, + session: Uuid::new_v4(), + secret_key: StackString::default(), + created_at: Some(u.created_at), + }, + ) + }) + .try_collect() + .await; + let users = result?; + AUTHORIZED_USERS.update_users(users); + debug!("AUTHORIZED_USERS {:?}", *AUTHORIZED_USERS); Ok(()) } diff --git a/src/models.rs b/src/models.rs index 9c903f3..ed57426 100644 --- a/src/models.rs +++ b/src/models.rs @@ -199,6 +199,60 @@ pub struct IntrusionLog { pub username: Option, } +#[derive(Clone, Copy)] +struct IntrusionLogFilteredQueryOptions<'a> { + select_str: &'a str, + order_str: &'a str, + service: &'a Option, + server: &'a Option, + min_datetime: &'a Option, + max_datetime: &'a Option, + offset: Option, + limit: Option, +} + +impl Default for IntrusionLogFilteredQueryOptions<'_> { + fn default() -> Self { + Self { + select_str: "", + order_str: "", + service: &None, + server: &None, + min_datetime: &None, + max_datetime: &None, + offset: None, + limit: None, + } + } +} + +#[derive(Clone, Copy)] +struct SystemdMessagesQueryOptions<'a> { + select_str: &'a str, + order_str: &'a str, + log_level: &'a Option, + log_unit: &'a Option, + min_timestamp: &'a Option, + max_timestamp: &'a Option, + offset: Option, + limit: Option, +} + +impl Default for SystemdMessagesQueryOptions<'_> { + fn default() -> Self { + Self { + select_str: "", + order_str: "", + log_level: &None, + log_unit: &None, + min_timestamp: &None, + max_timestamp: &None, + offset: None, + limit: None, + } + } +} + impl IntrusionLog { /// # Errors /// Return error if db query fails @@ -267,31 +321,24 @@ impl IntrusionLog { Ok(result.map(Into::into)) } - fn get_intrusion_log_filtered_query<'a>( - select_str: &'a str, - order_str: &'a str, - service: &'a Option, - server: &'a Option, - min_datetime: &'a Option, - max_datetime: &'a Option, - offset: Option, - limit: Option, - ) -> Result, PgError> { + fn get_intrusion_log_filtered_query( + options: IntrusionLogFilteredQueryOptions, + ) -> Result { let mut bindings = Vec::new(); let mut constraints = Vec::new(); - if let Some(service) = &service { + if let Some(service) = &options.service { constraints.push(format_sstr!("service=$service")); bindings.push(("service", service as Parameter)); } - if let Some(server) = &server { + if let Some(server) = &options.server { constraints.push(format_sstr!("server=$server")); bindings.push(("server", server as Parameter)); } - if let Some(min_datetime) = &min_datetime { + if let Some(min_datetime) = &options.min_datetime { constraints.push(format_sstr!("datetime >= $min_datetime")); bindings.push(("min_datetime", min_datetime as Parameter)); } - if let Some(max_datetime) = &max_datetime { + if let Some(max_datetime) = &options.max_datetime { constraints.push(format_sstr!("datetine <= $max_datetime")); bindings.push(("max_datetime", max_datetime as Parameter)); } @@ -300,6 +347,8 @@ impl IntrusionLog { } else { format_sstr!("WHERE {}", constraints.join(" AND ")) }; + let select_str = options.select_str; + let order_str = options.order_str; let mut query = format_sstr!( r#" SELECT {select_str} FROM intrusion_log @@ -307,10 +356,10 @@ impl IntrusionLog { {order_str} "#, ); - if let Some(offset) = &offset { + if let Some(offset) = &options.offset { query.push_str(&format_sstr!(" OFFSET {offset}")); } - if let Some(limit) = &limit { + if let Some(limit) = &options.limit { query.push_str(&format_sstr!(" LIMIT {limit}")); } bindings.shrink_to_fit(); @@ -332,16 +381,16 @@ impl IntrusionLog { let service = service.map(Service::to_str).map(Into::into); let server = server.map(Host::to_str).map(Into::into); - let query = Self::get_intrusion_log_filtered_query( - "*", - "ORDER BY datetime DESC", - &service, - &server, - &min_datetime, - &max_datetime, + let query = Self::get_intrusion_log_filtered_query(IntrusionLogFilteredQueryOptions { + select_str: "*", + order_str: "ORDER BY datetime DESC", + service: &service, + server: &server, + min_datetime: &min_datetime, + max_datetime: &max_datetime, offset, limit, - )?; + })?; let conn = pool.get().await?; query.fetch_streaming(&conn).await.map_err(Into::into) } @@ -363,16 +412,14 @@ impl IntrusionLog { let service = service.map(Service::to_str).map(Into::into); let server = server.map(Host::to_str).map(Into::into); - let query = Self::get_intrusion_log_filtered_query( - "count(*)", - "", - &service, - &server, - &min_datetime, - &max_datetime, - None, - None, - )?; + let query = Self::get_intrusion_log_filtered_query(IntrusionLogFilteredQueryOptions { + select_str: "count(*)", + service: &service, + server: &server, + min_datetime: &min_datetime, + max_datetime: &max_datetime, + ..IntrusionLogFilteredQueryOptions::default() + })?; let conn = pool.get().await?; let count: Count = query.fetch_one(&conn).await?; @@ -422,6 +469,7 @@ impl IntrusionLog { #[derive(FromSqlRow, Clone, Debug)] pub struct AuthorizedUsers { pub email: StackString, + pub created_at: OffsetDateTime, } impl AuthorizedUsers { @@ -430,10 +478,32 @@ impl AuthorizedUsers { pub async fn get_authorized_users( pool: &PgPool, ) -> Result>, Error> { - let query = query!("SELECT * FROM authorized_users"); + let query = query!("SELECT * FROM authorized_users WHERE deleted_at IS NULL"); let conn = pool.get().await?; query.fetch_streaming(&conn).await.map_err(Into::into) } + + /// # Errors + /// Returns error if db query fails + pub async fn get_most_recent( + pool: &PgPool, + ) -> Result<(Option, Option), Error> { + #[derive(FromSqlRow)] + struct CreatedDeleted { + created_at: Option, + deleted_at: Option, + } + + let query = query!( + "SELECT max(created_at) as created_at, max(deleted_at) as deleted_at FROM users" + ); + let conn = pool.get().await?; + let result: Option = query.fetch_opt(&conn).await?; + match result { + Some(result) => Ok((result.created_at, result.deleted_at)), + None => Ok((None, None)), + } + } } /// # Errors @@ -683,31 +753,22 @@ impl SystemdLogMessages { query.execute(&conn).await.map_err(Into::into) } - fn get_systemd_messages_query<'a>( - select_str: &'a str, - order_str: &'a str, - log_level: &'a Option, - log_unit: &'a Option<&str>, - min_timestamp: &'a Option, - max_timestamp: &'a Option, - offset: Option, - limit: Option, - ) -> Result, PgError> { + fn get_systemd_messages_query(options: SystemdMessagesQueryOptions) -> Result { let mut constraints = Vec::new(); let mut bindings = Vec::new(); - if let Some(log_level) = log_level { + if let Some(log_level) = options.log_level { constraints.push(format_sstr!("log_level=$log_level")); bindings.push(("log_level", log_level as Parameter)); } - if let Some(log_unit) = log_unit { + if let Some(log_unit) = options.log_unit { constraints.push(format_sstr!("log_unit=$log_unit")); bindings.push(("log_unit", log_unit as Parameter)); } - if let Some(min_timestamp) = min_timestamp { + if let Some(min_timestamp) = options.min_timestamp { constraints.push(format_sstr!("log_timestamp > $min_timestamp")); bindings.push(("min_timestamp", min_timestamp as Parameter)); } - if let Some(max_timestamp) = max_timestamp { + if let Some(max_timestamp) = options.max_timestamp { constraints.push(format_sstr!("log_timestamp > $max_timestamp")); bindings.push(("max_timestamp", max_timestamp as Parameter)); } @@ -716,6 +777,8 @@ impl SystemdLogMessages { } else { format_sstr!("WHERE {}", constraints.join(" AND ")) }; + let select_str = options.select_str; + let order_str = options.order_str; let mut query = format_sstr!( r#" SELECT {select_str} FROM systemd_log_messages @@ -723,10 +786,10 @@ impl SystemdLogMessages { {order_str} "#, ); - if let Some(offset) = offset { + if let Some(offset) = options.offset { query.push_str(&format_sstr!(" OFFSET {offset}")); } - if let Some(limit) = limit { + if let Some(limit) = options.limit { query.push_str(&format_sstr!(" LIMIT {limit}")); } bindings.shrink_to_fit(); @@ -739,7 +802,7 @@ impl SystemdLogMessages { pub async fn get_total( pool: &PgPool, log_level: Option, - log_unit: Option<&str>, + log_unit: &Option, min_timestamp: Option, max_timestamp: Option, ) -> Result { @@ -748,16 +811,14 @@ impl SystemdLogMessages { count: i64, } - let query = Self::get_systemd_messages_query( - "count(*)", - "", - &log_level, - &log_unit, - &min_timestamp, - &max_timestamp, - None, - None, - )?; + let query = Self::get_systemd_messages_query(SystemdMessagesQueryOptions { + select_str: "count(*)", + log_level: &log_level, + log_unit, + min_timestamp: &min_timestamp, + max_timestamp: &max_timestamp, + ..SystemdMessagesQueryOptions::default() + })?; let conn = pool.get().await?; let count: Count = query.fetch_one(&conn).await?; @@ -769,22 +830,22 @@ impl SystemdLogMessages { pub async fn get_systemd_messages( pool: &PgPool, log_level: Option, - log_unit: Option<&str>, + log_unit: &Option, min_timestamp: Option, max_timestamp: Option, offset: Option, limit: Option, ) -> Result>, Error> { - let query = Self::get_systemd_messages_query( - "*", - "ORDER BY log_timestamp", - &log_level, - &log_unit, - &min_timestamp, - &max_timestamp, + let query = Self::get_systemd_messages_query(SystemdMessagesQueryOptions { + select_str: "*", + order_str: "ORDER BY log_timestamp", + log_level: &log_level, + log_unit, + min_timestamp: &min_timestamp, + max_timestamp: &max_timestamp, offset, limit, - )?; + })?; let conn = pool.get().await?; query.fetch_streaming(&conn).await.map_err(Into::into) } diff --git a/src/security_log_http.rs b/src/security_log_http.rs index 68a05d1..4bac71e 100644 --- a/src/security_log_http.rs +++ b/src/security_log_http.rs @@ -541,11 +541,10 @@ async fn get_log_messages( let min_date: Option = query.min_date.map(Into::into); let max_date: Option = query.max_date.map(Into::into); let log_level = query.log_level; - let log_unit: Option<&str> = query.log_unit.as_ref().map(Into::into); let total = SystemdLogMessages::get_total( &data.pool, log_level, - log_unit, + &query.log_unit, min_date.map(Into::into), max_date.map(Into::into), ) @@ -562,7 +561,7 @@ async fn get_log_messages( let data: Vec<_> = SystemdLogMessages::get_systemd_messages( &data.pool, log_level, - log_unit, + &query.log_unit, min_date.map(Into::into), max_date.map(Into::into), Some(offset),