From 0377cd612e190f651c276145531eb285cf7927ba Mon Sep 17 00:00:00 2001 From: Michael Hall Date: Mon, 22 Jul 2024 15:56:01 +1000 Subject: [PATCH] fix: more robust validation of db directory --- src/download.rs | 27 +++++++++++++++++++++++++-- src/lib.rs | 38 +++++++++++++++++++++++++++++++++++++- src/main.rs | 9 +++++++-- 3 files changed, 69 insertions(+), 5 deletions(-) diff --git a/src/download.rs b/src/download.rs index e73c29c..9ee0e17 100644 --- a/src/download.rs +++ b/src/download.rs @@ -145,8 +145,19 @@ mod tests { use std::path::PathBuf; use tempfile::TempDir; + pub fn check_internet_connection(timeout: std::time::Duration) -> bool { + use std::net::{TcpStream, SocketAddr}; + + let addr = "8.8.8.8:53".parse::().unwrap(); + TcpStream::connect_timeout(&addr, timeout).is_ok() + } + #[test] fn test_download_and_extract_tarball() { + // Skip the test if there is no internet connection + if !check_internet_connection(std::time::Duration::from_secs(2)) { + return; + } // Create a temporary directory to store the extracted files let temp_dir = TempDir::new().unwrap(); let output_path = temp_dir.path().join("output"); @@ -173,6 +184,10 @@ mod tests { #[test] fn test_download_and_extract_tarball_md5_mismatch() { + // Skip the test if there is no internet connection + if !check_internet_connection(std::time::Duration::from_secs(2)) { + return; + } // Create a temporary directory to store the extracted files let temp_dir = TempDir::new().unwrap(); let output_path = temp_dir.path().join("output"); @@ -195,6 +210,10 @@ mod tests { #[test] fn test_download_failure() { + // Skip the test if there is no internet connection + if !check_internet_connection(std::time::Duration::from_secs(2)) { + return; + } // Create a temporary directory to store the downloaded files let temp_dir = TempDir::new().unwrap(); let output_path = temp_dir.path().join("output"); @@ -217,13 +236,17 @@ mod tests { #[test] fn test_extraction_failure() { + // Skip the test if there is no internet connection + if !check_internet_connection(std::time::Duration::from_secs(2)) { + return; + } // Create a temporary directory to store the downloaded files let temp_dir = TempDir::new().unwrap(); let output_path = temp_dir.path().join("output"); // Download and extract a tarball with invalid format - let url = "https://raw.githubusercontent.com/mbhall88/rasusa/main/Cargo.toml"; - let md5 = "77c811c1264306e607aff057420cf354"; + let url = "https://raw.githubusercontent.com/mbhall88/rasusa/fa7e87b843419151cc4716c670adbb28544979b1/Cargo.toml"; + let md5 = "95143b02c21cc9ce1980645d2db69937"; let result = download_and_extract_tarball(url, &output_path, md5); // Assert that the function returns an ExtractionFailed error diff --git a/src/lib.rs b/src/lib.rs index d23b845..d8157ea 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,7 +3,7 @@ pub mod download; use serde::Deserialize; use std::ffi::OsStr; use std::io::{self, Write}; -use std::path::PathBuf; +use std::path::{Path, PathBuf}; use std::process::Command; #[derive(Deserialize)] @@ -63,6 +63,42 @@ pub fn check_path_exists + ?Sized>(s: &S) -> Result` - Ok with the valid path if the files are found, Err otherwise. +pub fn validate_db_directory(path: &Path) -> Result { + let required_files = ["hash.k2d", "opts.k2d", "taxo.k2d"]; + let files_str = required_files.join(", "); + + // Check if the path is a directory and contains the required files + if path.is_dir() && required_files.iter().all(|file| path.join(file).exists()) { + return Ok(path.to_path_buf()); + } + + // Check inside a 'db' subdirectory + let db_path = path.join("db"); + if db_path.is_dir() + && required_files + .iter() + .all(|file| db_path.join(file).exists()) + { + return Ok(db_path); + } + + Err(format!( + "Required files ({}) not found in {:?} or its 'db' subdirectory", + files_str, path + )) +} + + #[cfg(test)] mod tests { use super::*; diff --git a/src/main.rs b/src/main.rs index 5b5bf31..14c1ed8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,7 +5,9 @@ use clap::Parser; use env_logger::Builder; use lazy_static::lazy_static; use log::{debug, error, info, warn, LevelFilter}; -use nohuman::{check_path_exists, download::download_database, CommandRunner}; +use nohuman::{ + check_path_exists, download::download_database, validate_db_directory, CommandRunner, +}; lazy_static! { static ref DEFAULT_DB_LOCATION: String = { @@ -124,7 +126,10 @@ fn main() -> Result<()> { let temp_kraken_output = tempfile::NamedTempFile::new().context("Failed to create temporary kraken output file")?; let threads = args.threads.to_string(); - let db = args.database.join("db").to_string_lossy().to_string(); + let db = validate_db_directory(&args.database) + .map_err(|e| anyhow::anyhow!(e))? + .to_string_lossy() + .to_string(); let mut kraken_cmd = vec![ "--threads", &threads,