diff --git a/Cargo.toml b/Cargo.toml index bedff9a..d81c8f5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,14 +19,15 @@ laplace_rs = {version = "0.2.0", git = "https://github.com/samply/laplace-rs.git # Logging tracing = { version = "0.1.37", default_features = false } -tracing-subscriber = { version = "0.3.11", default_features = false, features = ["env-filter", "fmt"] } +tracing-subscriber = { version = "0.3.11", default_features = false, features = ["env-filter", "ansi"] } # Global variables once_cell = "1.18" # Command Line Interface -clap = { version = "4.0", default_features = false, features = ["std", "env", "derive", "help"] } +clap = { version = "4", default_features = false, features = ["std", "env", "derive", "help", "color"] } rand = { default-features = false, version = "0.8.5" } +futures-util = { version = "0.3", default-features = false, features = ["std"] } [dev-dependencies] tokio-test = "0.4.2" diff --git a/Dockerfile b/Dockerfile index 2dbc13a..62987bb 100644 --- a/Dockerfile +++ b/Dockerfile @@ -16,4 +16,3 @@ FROM gcr.io/distroless/cc-debian12 ARG COMPONENT COPY --from=chmodder /app/$COMPONENT /usr/local/bin/samply ENTRYPOINT [ "/usr/local/bin/focus" ] - diff --git a/dev/docker-compose.yml b/dev/docker-compose.yml index 4396d18..b4a73c7 100644 --- a/dev/docker-compose.yml +++ b/dev/docker-compose.yml @@ -86,6 +86,7 @@ services: BEAM_PROXY_URL: http://proxy1:8081 RETRY_COUNT: 30 OBFUSCATE: "no" + RUST_LOG: "debug,hyper=info" blaze: image: samply/blaze volumes: diff --git a/dev/focusdev b/dev/focusdev index 618541b..bd09418 100755 --- a/dev/focusdev +++ b/dev/focusdev @@ -86,7 +86,7 @@ function build() { function build_docker() { BACK2=$(pwd) cd $SD - docker compose build --build-arg TARGETARCH=$ARCH + docker compose build --build-arg TARGETARCH=$ARCH --build-arg COMPONENT=focus cd $BACK2 } diff --git a/resources/cql/ITCC_STRAT_AGE_CLASS_STRATIFIER b/resources/cql/ITCC_STRAT_AGE_CLASS_STRATIFIER new file mode 100644 index 0000000..3ab0641 --- /dev/null +++ b/resources/cql/ITCC_STRAT_AGE_CLASS_STRATIFIER @@ -0,0 +1,2 @@ +define function DiagnosisAge(condition FHIR.Condition): +condition.onset.value \ No newline at end of file diff --git a/resources/cql/PRISM_STRAT_AGE_STRATIFIER_BBMRI b/resources/cql/PRISM_STRAT_AGE_STRATIFIER_BBMRI new file mode 100644 index 0000000..7c9c6c2 --- /dev/null +++ b/resources/cql/PRISM_STRAT_AGE_STRATIFIER_BBMRI @@ -0,0 +1,2 @@ +define AgeClass: + (AgeInYears()) diff --git a/src/main.rs b/src/main.rs index f00382e..f190233 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,17 +12,21 @@ mod intermediate_rep; mod task_processing; mod util; +use base64::engine::general_purpose; use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; -use beam_lib::{MsgId, TaskRequest, TaskResult}; +use beam_lib::{TaskRequest, TaskResult}; +use futures_util::future::BoxFuture; +use futures_util::FutureExt; use laplace_rs::ObfCache; -use task_processing::TaskQueue; use tokio::sync::Mutex; +use crate::blaze::parse_blaze_query; +use crate::config::EndpointType; use crate::util::{is_cql_tampered_with, obfuscate_counts_mr}; use crate::{config::CONFIG, errors::FocusError}; use blaze::CqlQuery; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::ops::DerefMut; use std::process::ExitCode; use std::str; @@ -103,64 +107,110 @@ pub async fn main() -> ExitCode { } async fn main_loop() -> ExitCode { - // TODO: The report cache init should be an fn on the cache - let report_cache: ReportCache = ReportCache::new(); - - let mut seen_tasks = Default::default(); - let mut task_queue = task_processing::spawn_task_workers(report_cache); + let endpoint_service_available: fn() -> BoxFuture<'static, bool> = match CONFIG.endpoint_type { + EndpointType::Blaze => || blaze::check_availability().boxed(), + EndpointType::Omop => || async { true }.boxed(), // TODO health check + }; let mut failures = 0; - while failures < CONFIG.retry_count { - if failures > 0 { - warn!( - "Retrying connection (attempt {}/{})", - failures + 1, + while !(beam::check_availability().await && endpoint_service_available().await) { + failures += 1; + if failures >= CONFIG.retry_count { + error!( + "Encountered too many errors -- exiting after {} attempts.", CONFIG.retry_count ); - tokio::time::sleep(Duration::from_secs(2)).await; + return ExitCode::from(22); } - if !(beam::check_availability().await) { - failures += 1; - } - if CONFIG.endpoint_type == config::EndpointType::Blaze { - if !(blaze::check_availability().await) { - failures += 1; - } - } else if CONFIG.endpoint_type == config::EndpointType::Omop { + tokio::time::sleep(Duration::from_secs(2)).await; + warn!( + "Retrying connection (attempt {}/{})", + failures, + CONFIG.retry_count + ); + }; + let report_cache = Arc::new(Mutex::new(ReportCache::new())); + let obf_cache = Arc::new(Mutex::new(ObfCache { + cache: Default::default(), + })); + task_processing::process_tasks(move |task| { + let obf_cache = obf_cache.clone(); + let report_cache = report_cache.clone(); + process_task(task, obf_cache, report_cache).boxed_local() + }).await; + ExitCode::FAILURE +} - //TODO health check - } +async fn process_task( + task: &BeamTask, + obf_cache: Arc>, + report_cache: Arc>, +) -> Result { + debug!("Processing task {}", task.id); - if let Err(e) = process_tasks(&mut task_queue, &mut seen_tasks).await { - warn!("Encountered the following error, while processing tasks: {e}"); - failures += 1; - } else { - failures = 0; - } + let metadata: Metadata = serde_json::from_value(task.metadata.clone()).unwrap_or(Metadata { + project: "default_obfuscation".to_string(), + execute: true, + }); + + if metadata.project == "focus-healthcheck" { + return Ok(beam::beam_result::succeeded( + CONFIG.beam_app_id_long.clone(), + vec![task.from.clone()], + task.id, + "healthy".into() + )); } - error!( - "Encountered too many errors -- exiting after {} attempts.", - CONFIG.retry_count - ); - ExitCode::from(22) -} -async fn process_tasks( - task_queue: &mut TaskQueue, - seen: &mut HashSet, -) -> Result<(), FocusError> { - debug!("Start processing tasks..."); - let tasks = beam::retrieve_tasks().await?; - for task in tasks { - if seen.contains(&task.id) { - continue; + if metadata.project == "exporter" { + let body = &task.body; + return Ok(run_exporter_query(task, body, metadata.execute).await)?; + } + + if CONFIG.endpoint_type == EndpointType::Blaze { + let query = parse_blaze_query(task)?; + if query.lang == "cql" { + // TODO: Change query.lang to an enum + + Ok(run_cql_query(task, &query, obf_cache, report_cache, metadata.project).await)? + } else { + warn!("Can't run queries with language {} in Blaze", query.lang); + Ok(beam::beam_result::perm_failed( + CONFIG.beam_app_id_long.clone(), + vec![task.from.clone()], + task.id, + format!( + "Can't run queries with language {} and/or endpoint type {}", + query.lang, CONFIG.endpoint_type + ), + )) } - seen.insert(task.id); - task_queue - .send(task) - .await - .expect("Receiver is never dropped"); + } else if CONFIG.endpoint_type == EndpointType::Omop { + let decoded = util::base64_decode(&task.body)?; + let intermediate_rep_query: intermediate_rep::IntermediateRepQuery = + serde_json::from_slice(&decoded).map_err(|e| FocusError::ParsingError(e.to_string()))?; + //TODO check that the language is ast + let query_decoded = general_purpose::STANDARD + .decode(intermediate_rep_query.query) + .map_err(FocusError::DecodeError)?; + let ast: ast::Ast = + serde_json::from_slice(&query_decoded).map_err(|e| FocusError::ParsingError(e.to_string()))?; + + Ok(run_intermediate_rep_query(task, ast).await)? + } else { + warn!( + "Can't run queries with endpoint type {}", + CONFIG.endpoint_type + ); + Ok(beam::beam_result::perm_failed( + CONFIG.beam_app_id_long.clone(), + vec![task.from.clone()], + task.id, + format!( + "Can't run queries with endpoint type {}", + CONFIG.endpoint_type + ), + )) } - Ok(()) } async fn run_cql_query( diff --git a/src/task_processing.rs b/src/task_processing.rs index fc5f703..cb7afaf 100644 --- a/src/task_processing.rs +++ b/src/task_processing.rs @@ -1,79 +1,86 @@ -use std::{sync::Arc, collections::HashMap, time::Duration}; +use std::{rc::Rc, time::Duration}; -use base64::{engine::general_purpose, Engine as _}; -use laplace_rs::ObfCache; -use tokio::sync::{mpsc, Semaphore, Mutex}; -use tracing::{error, warn, debug, info, Instrument, info_span}; +use futures_util::{future::LocalBoxFuture, FutureExt, StreamExt}; +use tracing::{debug, error, info_span, warn, Instrument}; -use crate::{ReportCache, errors::FocusError, beam, BeamTask, BeamResult, run_exporter_query, config::{EndpointType, CONFIG}, run_cql_query, intermediate_rep, ast, run_intermediate_rep_query, Metadata, blaze::parse_blaze_query, util}; +use crate::{beam, errors::FocusError, BeamResult, BeamTask}; const NUM_WORKERS: usize = 3; -const WORKER_BUFFER: usize = 32; -pub type TaskQueue = mpsc::Sender; - -pub fn spawn_task_workers(report_cache: ReportCache) -> TaskQueue { - let (tx, mut rx) = mpsc::channel::(WORKER_BUFFER); - - let obf_cache = Arc::new(Mutex::new(ObfCache { - cache: HashMap::new(), - })); - - let report_cache: Arc> = Arc::new(Mutex::new(report_cache)); - - tokio::spawn(async move { - let semaphore = Arc::new(Semaphore::new(NUM_WORKERS)); - while let Some(task) = rx.recv().await { - let permit = semaphore.clone().acquire_owned().await.unwrap(); - let local_report_cache = report_cache.clone(); - let local_obf_cache = obf_cache.clone(); - tokio::spawn(async move { - let span = info_span!("task handling", %task.id); - handle_beam_task(task, local_obf_cache, local_report_cache).instrument(span).await; - drop(permit) - }); +pub async fn process_tasks(task_hanlder: F) +where + F: Fn(&BeamTask) -> LocalBoxFuture<'_, Result> + Clone + 'static, +{ + let on_task_claimed = |res: &Result| { + if let Err(e) = res { + warn!("Failed to claim task: {e}"); + } else { + debug!("Successfully claimed task"); } - }); - - tx + }; + futures_util::stream::repeat_with(beam::retrieve_tasks) + .filter_map(|v| async { + match v.await { + Ok(mut ts) => ts.pop(), + Err(e) => { + warn!("Failed to get tasks from beam: {e}"); + tokio::time::sleep(Duration::from_secs(10)).await; + None + } + } + }) + .then(move |t| { + let id = t.id; + let span = info_span!("task", %id); + let span_for_handler = span.clone(); + let on_task = task_hanlder.clone(); + let task = Rc::new(t); + let t1 = Rc::clone(&task); + let t2 = Rc::clone(&task); + #[allow(clippy::async_yields_async)] + async move { + let mut task_claiming = std::pin::pin!(beam::claim_task(&t1)); + let mut task_processing = async move { on_task(&t2).await }.boxed_local(); + tokio::select! { + task_processed = &mut task_processing => { + debug!("Proccessed task before it was claimed"); + answer_task_result(&task, task_processed).await; + futures_util::future::ready(()).boxed_local() + }, + task_claimed = &mut task_claiming => { + on_task_claimed(&task_claimed); + task_processing + .then(move |res| async move { answer_task_result(&task, res).await }) + .instrument(span_for_handler) + .boxed_local() + } + } + } + .instrument(span) + }) + .buffer_unordered(NUM_WORKERS) + .for_each(|_| async {}) + .await } -async fn handle_beam_task(task: BeamTask, local_obf_cache: Arc>, local_report_cache: Arc>) { - let task_claiming = beam::claim_task(&task); - let mut task_processing = std::pin::pin!(process_task(&task, local_obf_cache, local_report_cache)); - let task_result = tokio::select! { - // If task task processing happens before claiming is done drop the task claiming future - task_processed = &mut task_processing => { - task_processed - }, - task_claimed = task_claiming => { - if let Err(e) = task_claimed { - warn!("Failed to claim task: {e}"); - } else { - debug!("Successfully claimed task"); - }; - task_processing.await - } - }; +async fn answer_task_result(task: &BeamTask, task_result: Result) { let result = match task_result { Ok(res) => res, Err(e) => { warn!("Failed to execute query: {e}"); - if let Err(e) = beam::fail_task(&task, e.user_facing_error()).await { + if let Err(e) = beam::fail_task(task, e.user_facing_error()).await { warn!("Failed to report failure to beam: {e}"); } return; } }; - const MAX_TRIES: u32 = 3600; + const MAX_TRIES: u32 = 150; for attempt in 0..MAX_TRIES { match beam::answer_task(&result).await { Ok(_) => break, Err(FocusError::ConfigurationError(s)) => { - error!( - "FATAL: Unable to report back to Beam due to a configuration issue: {s}" - ); + error!("FATAL: Unable to report back to Beam due to a configuration issue: {s}"); } Err(FocusError::UnableToAnswerTask(e)) => { warn!("Unable to report task result to Beam: {e}. Retrying (attempt {attempt}/{MAX_TRIES})."); @@ -86,75 +93,3 @@ async fn handle_beam_task(task: BeamTask, local_obf_cache: Arc>, } } -async fn process_task( - task: &BeamTask, - obf_cache: Arc>, - report_cache: Arc>, -) -> Result { - info!("Processing task {}", task.id); - - let metadata: Metadata = serde_json::from_value(task.metadata.clone()).unwrap_or(Metadata { - project: "default_obfuscation".to_string(), - execute: true, - }); - - if metadata.project == "focus-healthcheck" { - return Ok(beam::beam_result::succeeded( - CONFIG.beam_app_id_long.clone(), - vec![task.from.clone()], - task.id, - "healthy".into() - )); - } - - if metadata.project == "exporter" { - let body = &task.body; - return Ok(run_exporter_query(task, body, metadata.execute).await)?; - } - - if CONFIG.endpoint_type == EndpointType::Blaze { - let query = parse_blaze_query(task)?; - if query.lang == "cql" { - // TODO: Change query.lang to an enum - - Ok(run_cql_query(task, &query, obf_cache, report_cache, metadata.project).await)? - } else { - warn!("Can't run queries with language {} in Blaze", query.lang); - Ok(beam::beam_result::perm_failed( - CONFIG.beam_app_id_long.clone(), - vec![task.from.clone()], - task.id, - format!( - "Can't run queries with language {} and/or endpoint type {}", - query.lang, CONFIG.endpoint_type - ), - )) - } - } else if CONFIG.endpoint_type == EndpointType::Omop { - let decoded = util::base64_decode(&task.body)?; - let intermediate_rep_query: intermediate_rep::IntermediateRepQuery = - serde_json::from_slice(&decoded).map_err(|e| FocusError::ParsingError(e.to_string()))?; - //TODO check that the language is ast - let query_decoded = general_purpose::STANDARD - .decode(intermediate_rep_query.query) - .map_err(FocusError::DecodeError)?; - let ast: ast::Ast = - serde_json::from_slice(&query_decoded).map_err(|e| FocusError::ParsingError(e.to_string()))?; - - Ok(run_intermediate_rep_query(task, ast).await)? - } else { - warn!( - "Can't run queries with endpoint type {}", - CONFIG.endpoint_type - ); - Ok(beam::beam_result::perm_failed( - CONFIG.beam_app_id_long.clone(), - vec![task.from.clone()], - task.id, - format!( - "Can't run queries with endpoint type {}", - CONFIG.endpoint_type - ), - )) - } -}