diff --git a/src/main.rs b/src/main.rs index 1077d74..9c7037b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -45,10 +45,15 @@ async fn upload_file( } let mut buffer = Vec::new(); - let _ = file.read_to_end(&mut buffer); + if let Ok(mut f) = File::open(file_name.clone()) { + let _ = f.read_to_end(&mut buffer); + } else { + return HttpResponse::new(StatusCode::INTERNAL_SERVER_ERROR); + } // Verify HMAC if let Err(response) = verify_hmac(&req, &buffer, secret_key) { + let _ = std::fs::remove_file(file_name); return response; } @@ -104,10 +109,11 @@ fn verify_hmac( #[actix_web::main] async fn main() -> std::io::Result<()> { std::fs::create_dir_all("uploads").unwrap(); + let secret_key = std::env::var("SECRET_KEY").unwrap(); let app_state = web::Data::new(AppState { file_map: Mutex::new(HashMap::new()), - secret_key: "super-long-super-secret-unique-key-goes-here".to_string(), + secret_key, }); HttpServer::new(move || { @@ -125,57 +131,12 @@ async fn main() -> std::io::Result<()> { mod tests { use super::*; use actix_web::dev::ServiceResponse; - use actix_web::{body::to_bytes, test, web, App}; + use actix_web::{test, web, App}; use hex::encode; use sha2::Sha256; - use std::io::Write; - use tempfile::NamedTempFile; - - #[actix_rt::test] - #[ignore] - async fn test_file_upload_and_retrieval() { - let app = test::init_service( - App::new() - .app_data(web::Data::new(AppState { - file_map: Mutex::new(HashMap::new()), - secret_key: "test-test-super-long-super-secret-unique-key-goes-here" - .to_string(), - })) - .service(web::resource("/upload").route(web::post().to(upload_file))) - .service(web::resource("/files/{id}").route(web::get().to(get_file))), - ) - .await; - - let mut temp_file = NamedTempFile::new().unwrap(); - writeln!(temp_file, "Not much of a message history bundle yet.").unwrap(); - - // Read the file's contents into a Vec - let file_contents = std::fs::read(temp_file.path()).unwrap(); - let payload = web::Bytes::from(file_contents); - - // Create the request with the file's bytes as the payload - let req = test::TestRequest::post() - .uri("/upload") - .set_payload(payload) - .to_request(); - - let resp = test::call_service(&app, req).await; - assert!(resp.status().is_success()); - - // Extract the file ID from the response - let body = to_bytes(resp.into_body()).await.unwrap(); - let file_id_str = std::str::from_utf8(&body).unwrap(); - let file_id = Uuid::parse_str(file_id_str).unwrap(); - - // Attempt to retrieve the file using its ID - let req = test::TestRequest::get() - .uri(&format!("/files/{}", file_id)) - .to_request(); - let resp = test::call_service(&app, req).await; - - assert!(resp.status().is_success()); - } + const SECRET_KEY: &[u8] = b"TEST_SECRET_KEY"; + // Helper function to create a HMAC signature fn create_hmac_signature(secret_key: &[u8], data: &[u8]) -> String { let mut mac = @@ -187,12 +148,13 @@ mod tests { // Tests the HMAC verification logic #[actix_rt::test] async fn test_hmac_verification() { - let secret_key = b"your-secret-key"; let app = test::init_service(App::new().route( "/test", - web::post().to(|req: HttpRequest, body: web::Bytes| async { + web::post().to(|req: HttpRequest, body: web::Bytes| async move { // Attempt to verify the HMAC - match verify_hmac(&req, &body, &secret_key) { + let req = req.clone(); + let body = body.clone(); + match verify_hmac(&req, &body, SECRET_KEY) { Ok(_) => HttpResponse::Ok().finish(), Err(err) => err, } @@ -204,7 +166,7 @@ mod tests { let incorrect_payload = b"incorrect payload"; // Create a correct HMAC for the correct payload - let correct_hmac = create_hmac_signature(secret_key, correct_payload); + let correct_hmac = create_hmac_signature(SECRET_KEY, correct_payload); // Simulate sending a request with the correct HMAC let req = test::TestRequest::post() @@ -220,7 +182,7 @@ mod tests { ); // Create an incorrect HMAC for the purpose of testing - let incorrect_hmac = create_hmac_signature(secret_key, incorrect_payload); + let incorrect_hmac = create_hmac_signature(SECRET_KEY, incorrect_payload); // Simulate sending a request with the incorrect HMAC let req = test::TestRequest::post()