Skip to content

Commit

Permalink
fix: more robust validation of db directory
Browse files Browse the repository at this point in the history
  • Loading branch information
mbhall88 committed Jul 22, 2024
1 parent 87a36d3 commit 0377cd6
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 5 deletions.
27 changes: 25 additions & 2 deletions src/download.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,19 @@ mod tests {
use std::path::PathBuf;
use tempfile::TempDir;

pub fn check_internet_connection(timeout: std::time::Duration) -> bool {
use std::net::{TcpStream, SocketAddr};

let addr = "8.8.8.8:53".parse::<SocketAddr>().unwrap();
TcpStream::connect_timeout(&addr, timeout).is_ok()
}

#[test]
fn test_download_and_extract_tarball() {
// Skip the test if there is no internet connection
if !check_internet_connection(std::time::Duration::from_secs(2)) {
return;
}
// Create a temporary directory to store the extracted files
let temp_dir = TempDir::new().unwrap();
let output_path = temp_dir.path().join("output");
Expand All @@ -173,6 +184,10 @@ mod tests {

#[test]
fn test_download_and_extract_tarball_md5_mismatch() {
// Skip the test if there is no internet connection
if !check_internet_connection(std::time::Duration::from_secs(2)) {
return;
}
// Create a temporary directory to store the extracted files
let temp_dir = TempDir::new().unwrap();
let output_path = temp_dir.path().join("output");
Expand All @@ -195,6 +210,10 @@ mod tests {

#[test]
fn test_download_failure() {
// Skip the test if there is no internet connection
if !check_internet_connection(std::time::Duration::from_secs(2)) {
return;
}
// Create a temporary directory to store the downloaded files
let temp_dir = TempDir::new().unwrap();
let output_path = temp_dir.path().join("output");
Expand All @@ -217,13 +236,17 @@ mod tests {

#[test]
fn test_extraction_failure() {
// Skip the test if there is no internet connection
if !check_internet_connection(std::time::Duration::from_secs(2)) {
return;
}
// Create a temporary directory to store the downloaded files
let temp_dir = TempDir::new().unwrap();
let output_path = temp_dir.path().join("output");

// Download and extract a tarball with invalid format
let url = "https://raw.githubusercontent.com/mbhall88/rasusa/main/Cargo.toml";
let md5 = "77c811c1264306e607aff057420cf354";
let url = "https://raw.githubusercontent.com/mbhall88/rasusa/fa7e87b843419151cc4716c670adbb28544979b1/Cargo.toml";
let md5 = "95143b02c21cc9ce1980645d2db69937";
let result = download_and_extract_tarball(url, &output_path, md5);

// Assert that the function returns an ExtractionFailed error
Expand Down
38 changes: 37 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ pub mod download;
use serde::Deserialize;
use std::ffi::OsStr;
use std::io::{self, Write};
use std::path::PathBuf;
use std::path::{Path, PathBuf};
use std::process::Command;

#[derive(Deserialize)]
Expand Down Expand Up @@ -63,6 +63,42 @@ pub fn check_path_exists<S: AsRef<OsStr> + ?Sized>(s: &S) -> Result<PathBuf, Str
}
}

/// Checks if the specified path is a directory and contains the required kraken2 db files.
/// If not found, checks inside a 'db' subdirectory.
///
/// # Arguments
///
/// * `path` - A path to check for the required kraken2 db files.
///
/// # Returns
///
/// * `Result<PathBuf, String>` - Ok with the valid path if the files are found, Err otherwise.
pub fn validate_db_directory(path: &Path) -> Result<PathBuf, String> {
let required_files = ["hash.k2d", "opts.k2d", "taxo.k2d"];
let files_str = required_files.join(", ");

// Check if the path is a directory and contains the required files
if path.is_dir() && required_files.iter().all(|file| path.join(file).exists()) {
return Ok(path.to_path_buf());
}

// Check inside a 'db' subdirectory
let db_path = path.join("db");
if db_path.is_dir()
&& required_files
.iter()
.all(|file| db_path.join(file).exists())
{
return Ok(db_path);
}

Err(format!(
"Required files ({}) not found in {:?} or its 'db' subdirectory",
files_str, path
))
}


#[cfg(test)]
mod tests {
use super::*;
Expand Down
9 changes: 7 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ use clap::Parser;
use env_logger::Builder;
use lazy_static::lazy_static;
use log::{debug, error, info, warn, LevelFilter};
use nohuman::{check_path_exists, download::download_database, CommandRunner};
use nohuman::{
check_path_exists, download::download_database, validate_db_directory, CommandRunner,
};

lazy_static! {
static ref DEFAULT_DB_LOCATION: String = {
Expand Down Expand Up @@ -124,7 +126,10 @@ fn main() -> Result<()> {
let temp_kraken_output =
tempfile::NamedTempFile::new().context("Failed to create temporary kraken output file")?;
let threads = args.threads.to_string();
let db = args.database.join("db").to_string_lossy().to_string();
let db = validate_db_directory(&args.database)
.map_err(|e| anyhow::anyhow!(e))?
.to_string_lossy()
.to_string();
let mut kraken_cmd = vec![
"--threads",
&threads,
Expand Down

0 comments on commit 0377cd6

Please sign in to comment.