diff --git a/rust/src/connection.rs b/rust/src/connection.rs index d82ad2e8..3e97282f 100644 --- a/rust/src/connection.rs +++ b/rust/src/connection.rs @@ -2,8 +2,8 @@ use crate::error::{Error, VResult}; use crate::message::{ - MessageType, UploadUrl, Verdict, VerdictRequest, VerdictRequestForStream, VerdictRequestForUrl, - VerdictResponse, + MessageType, UploadUrl, Verdict, VerdictRequest, VerdictRequestFile, VerdictRequestForStream, + VerdictRequestForUrl, VerdictResponse, }; use crate::options::Options; use crate::sha256::Sha256; @@ -11,7 +11,8 @@ use crate::vaas_verdict::VaasVerdict; use crate::CancellationToken; use bytes::Bytes; use futures::future::join_all; -use reqwest::{Body, Url, Version}; +use reqwest::{Body, Response, Url, Version}; +use serde::Serialize; use std::convert::TryFrom; use std::ops::Deref; use std::path::{Path, PathBuf}; @@ -81,7 +82,7 @@ impl Connection { self.options.use_cache, self.options.use_hash_lookup, ); - let response = Self::for_url_request( + let response = Self::for_request( request, self.ws_writer.clone(), &mut self.result_channel.subscribe(), @@ -111,7 +112,7 @@ impl Connection { sha256: &Sha256, ct: &CancellationToken, ) -> VResult { - let request = VerdictRequest::new( + let request = VerdictRequestFile::new( sha256, self.session_id.clone(), self.options.use_cache, @@ -144,9 +145,9 @@ impl Connection { self.options.use_cache, self.options.use_hash_lookup, ); - let guid = request.guid().to_string(); + let guid = request.guid.to_string(); - let response = Self::for_stream_request( + let response = Self::for_request( request, self.ws_writer.clone(), &mut self.result_channel.subscribe(), @@ -189,14 +190,20 @@ impl Connection { /// Request a verdict for a file. pub async fn for_file(&self, file: &Path, ct: &CancellationToken) -> VResult { - let sha256 = Sha256::try_from(file)?; - let request = VerdictRequest::new( + let buf = tokio::fs::read(file).await?; + self.for_buf(buf, ct).await + } + + /// Request a verdict for a buffer. + pub async fn for_buf(&self, buf: Vec, ct: &CancellationToken) -> VResult { + let sha256 = Sha256::from(buf.as_slice()); + let request = VerdictRequestFile::new( &sha256, self.session_id.clone(), self.options.use_cache, self.options.use_hash_lookup, ); - let guid = request.guid().to_string(); + let guid = request.guid.to_string(); let response = Self::for_request( request, @@ -210,7 +217,7 @@ impl Connection { match verdict { Verdict::Unknown { upload_url } => { Self::handle_unknown( - file, + buf, &guid, response, upload_url, @@ -224,7 +231,7 @@ impl Connection { } async fn handle_unknown( - file: &Path, + buf: Vec, guid: &str, response: VerdictResponse, upload_url: UploadUrl, @@ -235,17 +242,9 @@ impl Connection { .upload_token .as_ref() .ok_or(Error::MissingAuthToken)?; - let response = upload_file(file, upload_url, auth_token).await?; - - if response.status() != 200 { - return Err(Error::FailedUploadFile( - response.status(), - response.text().await.expect("failed to get payload"), - )); - } + let response = upload_buf(buf, upload_url, auth_token).await?; - let resp = Self::wait_for_response(guid, result_channel, ct).await?; - VaasVerdict::try_from(resp) + Self::handle_result(guid, response, result_channel, ct).await } async fn handle_unknown_stream( @@ -268,10 +267,22 @@ impl Connection { .ok_or(Error::MissingAuthToken)?; let response = upload_stream(stream, content_length, upload_url, auth_token).await?; + Self::handle_result(guid, response, result_channel, ct).await + } + + async fn handle_result( + guid: &str, + response: Response, + result_channel: &mut ResultChannelRx, + ct: &CancellationToken, + ) -> Result { if response.status() != 200 { return Err(Error::FailedUploadFile( response.status(), - response.text().await.expect("failed to get payload"), + response + .text() + .await + .unwrap_or("failed to get payload".to_string()), )); } @@ -290,30 +301,8 @@ impl Connection { join_all(req).await } - async fn for_request( - request: VerdictRequest, - ws_writer: WebSocketWriter, - result_channel: &mut ResultChannelRx, - ct: &CancellationToken, - ) -> VResult { - let guid = request.guid().to_string(); - ws_writer.lock().await.send_text(request.to_json()?).await?; - Self::wait_for_response(&guid, result_channel, ct).await - } - - async fn for_url_request( - request: VerdictRequestForUrl, - ws_writer: WebSocketWriter, - result_channel: &mut ResultChannelRx, - ct: &CancellationToken, - ) -> VResult { - let guid = request.guid().to_string(); - ws_writer.lock().await.send_text(request.to_json()?).await?; - Self::wait_for_response(&guid, result_channel, ct).await - } - - async fn for_stream_request( - request: VerdictRequestForStream, + async fn for_request( + request: T, ws_writer: WebSocketWriter, result_channel: &mut ResultChannelRx, ct: &CancellationToken, @@ -396,14 +385,13 @@ impl Connection { } } -async fn upload_file( - file: &Path, +async fn upload_buf( + buf: Vec, upload_url: UploadUrl, auth_token: &str, ) -> VResult { - let body = tokio::fs::File::open(&file).await?; - let content_length = body.metadata().await?.len() as usize; - upload_internal(body, content_length, upload_url, auth_token).await + let content_length = buf.len(); + upload_internal(buf, content_length, upload_url, auth_token).await } async fn upload_stream( diff --git a/rust/src/message/message_type.rs b/rust/src/message/message_type.rs index 8e4cc90f..4f89e9f0 100644 --- a/rust/src/message/message_type.rs +++ b/rust/src/message/message_type.rs @@ -46,11 +46,7 @@ mod tests { .to_string(); let message_type = MessageType::try_from(&msg).unwrap(); - - let is_correct_type = match message_type { - MessageType::VerdictResponse(_) => true, - _ => false, - }; + let is_correct_type = matches!(message_type, MessageType::VerdictResponse(_)); assert!(is_correct_type); } @@ -67,11 +63,7 @@ mod tests { .to_string(); let message_type = MessageType::try_from(&msg); - - let is_correct_type = match message_type { - Err(Error::ErrorResponse(_)) => true, - _ => false, - }; + let is_correct_type = matches!(message_type, Err(Error::ErrorResponse(_))); assert!(is_correct_type); } diff --git a/rust/src/message/mod.rs b/rust/src/message/mod.rs index b834ff2d..5ef4be92 100644 --- a/rust/src/message/mod.rs +++ b/rust/src/message/mod.rs @@ -9,6 +9,7 @@ mod open_id_connect_token_response; mod upload_url; mod verdict; mod verdict_request; +mod verdict_request_for_file; mod verdict_request_for_stream; mod verdict_request_for_url; mod verdict_response; @@ -21,6 +22,7 @@ pub(super) use open_id_connect_token_response::OpenIdConnectTokenResponse; pub(super) use upload_url::UploadUrl; pub use verdict::Verdict; pub(super) use verdict_request::VerdictRequest; +pub(super) use verdict_request_for_file::VerdictRequestFile; pub(super) use verdict_request_for_stream::VerdictRequestForStream; pub(super) use verdict_request_for_url::VerdictRequestForUrl; pub(super) use verdict_response::VerdictResponse; diff --git a/rust/src/message/verdict_request.rs b/rust/src/message/verdict_request.rs index 4b3c65f5..fd277726 100644 --- a/rust/src/message/verdict_request.rs +++ b/rust/src/message/verdict_request.rs @@ -1,33 +1,13 @@ -use crate::message::kind::Kind; -use crate::{error::VResult, sha256::Sha256}; -use serde::{Deserialize, Serialize}; +use crate::error::VResult; +use serde::Serialize; -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct VerdictRequest { - pub sha256: String, - pub kind: Kind, - pub guid: String, - pub session_id: String, - pub use_hash_lookup: bool, - pub use_cache: bool, -} - -impl VerdictRequest { - pub fn new(sha256: &Sha256, session_id: String, use_cache: bool, use_hash_lookup: bool) -> Self { - Self { - guid: uuid::Uuid::new_v4().to_string(), - sha256: sha256.to_string(), - kind: Kind::VerdictRequest, - session_id, - use_cache, - use_hash_lookup, - } - } - - pub fn to_json(&self) -> VResult { +pub trait VerdictRequest { + fn to_json(&self) -> VResult + where + Self: Serialize, + { serde_json::to_string(self).map_err(|e| e.into()) } - pub fn guid(&self) -> &str { - &self.guid - } + + fn guid(&self) -> &str; } diff --git a/rust/src/message/verdict_request_for_file.rs b/rust/src/message/verdict_request_for_file.rs new file mode 100644 index 00000000..0285a156 --- /dev/null +++ b/rust/src/message/verdict_request_for_file.rs @@ -0,0 +1,39 @@ +use crate::message::kind::Kind; +use crate::sha256::Sha256; +use serde::{Deserialize, Serialize}; + +use super::VerdictRequest; + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct VerdictRequestFile { + pub sha256: String, + pub kind: Kind, + pub guid: String, + pub session_id: String, + pub use_hash_lookup: bool, + pub use_cache: bool, +} + +impl VerdictRequestFile { + pub fn new( + sha256: &Sha256, + session_id: String, + use_cache: bool, + use_hash_lookup: bool, + ) -> Self { + Self { + guid: uuid::Uuid::new_v4().to_string(), + sha256: sha256.to_string(), + kind: Kind::VerdictRequest, + session_id, + use_cache, + use_hash_lookup, + } + } +} + +impl VerdictRequest for VerdictRequestFile { + fn guid(&self) -> &str { + &self.guid + } +} diff --git a/rust/src/message/verdict_request_for_stream.rs b/rust/src/message/verdict_request_for_stream.rs index 746505aa..233b40a3 100644 --- a/rust/src/message/verdict_request_for_stream.rs +++ b/rust/src/message/verdict_request_for_stream.rs @@ -1,4 +1,4 @@ -use crate::error::VResult; +use super::VerdictRequest; use crate::message::kind::Kind; use serde::{Deserialize, Serialize}; @@ -21,11 +21,10 @@ impl VerdictRequestForStream { use_shed, } } +} - pub fn to_json(&self) -> VResult { - serde_json::to_string(self).map_err(|e| e.into()) - } - pub fn guid(&self) -> &str { +impl VerdictRequest for VerdictRequestForStream { + fn guid(&self) -> &str { &self.guid } } diff --git a/rust/src/message/verdict_request_for_url.rs b/rust/src/message/verdict_request_for_url.rs index 501102c3..a9f10df6 100644 --- a/rust/src/message/verdict_request_for_url.rs +++ b/rust/src/message/verdict_request_for_url.rs @@ -1,4 +1,4 @@ -use crate::error::VResult; +use super::VerdictRequest; use crate::message::kind::Kind; use reqwest::Url; use serde::{Deserialize, Serialize}; @@ -24,11 +24,10 @@ impl VerdictRequestForUrl { use_shed, } } +} - pub fn to_json(&self) -> VResult { - serde_json::to_string(self).map_err(|e| e.into()) - } - pub fn guid(&self) -> &str { +impl VerdictRequest for VerdictRequestForUrl { + fn guid(&self) -> &str { &self.guid } } diff --git a/rust/src/sha256.rs b/rust/src/sha256.rs index 467f7a99..97c3a142 100644 --- a/rust/src/sha256.rs +++ b/rust/src/sha256.rs @@ -19,6 +19,23 @@ use std::{convert::TryFrom, fmt, ops::Deref}; #[derive(Debug, PartialEq, Eq, Clone, Hash)] pub struct Sha256(String); +impl From<&[u8]> for Sha256 { + fn from(value: &[u8]) -> Self { + use sha2::Digest; + use std::fmt::Write; + + let mut hasher = sha2::Sha256::new(); + hasher.update(value); + let result = hasher.finalize(); + + let hex_string = result.iter().fold(String::new(), |mut output, b| { + let _ = write!(output, "{b:02x}"); + output + }); + Self(hex_string) + } +} + impl TryFrom<&str> for Sha256 { type Error = crate::error::Error; @@ -46,15 +63,8 @@ impl TryFrom<&PathBuf> for Sha256 { type Error = crate::error::Error; fn try_from(value: &PathBuf) -> Result { - use sha2::Digest; let bytes = std::fs::read(value)?; - - let mut hasher = sha2::Sha256::new(); - hasher.update(&bytes); - let result = hasher.finalize(); - - let hex_string = result.iter().map(|b| format!("{b:02x}")).collect(); - Ok(Self(hex_string)) + Ok(Self::from(bytes.as_slice())) } }