Skip to content

Commit

Permalink
Merge pull request #617 from GDATASoftwareAG/rust/from_buf
Browse files Browse the repository at this point in the history
add for_buf function
  • Loading branch information
secana authored Oct 7, 2024
2 parents fe52d3d + 88ce03c commit a4434cd
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 109 deletions.
92 changes: 40 additions & 52 deletions rust/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@
use crate::error::{Error, VResult};
use crate::message::{
MessageType, UploadUrl, Verdict, VerdictRequest, VerdictRequestForStream, VerdictRequestForUrl,
VerdictResponse,
MessageType, UploadUrl, Verdict, VerdictRequest, VerdictRequestFile, VerdictRequestForStream,
VerdictRequestForUrl, VerdictResponse,
};
use crate::options::Options;
use crate::sha256::Sha256;
use crate::vaas_verdict::VaasVerdict;
use crate::CancellationToken;
use bytes::Bytes;
use futures::future::join_all;
use reqwest::{Body, Url, Version};
use reqwest::{Body, Response, Url, Version};
use serde::Serialize;
use std::convert::TryFrom;
use std::ops::Deref;
use std::path::{Path, PathBuf};
Expand Down Expand Up @@ -81,7 +82,7 @@ impl Connection {
self.options.use_cache,
self.options.use_hash_lookup,
);
let response = Self::for_url_request(
let response = Self::for_request(
request,
self.ws_writer.clone(),
&mut self.result_channel.subscribe(),
Expand Down Expand Up @@ -111,7 +112,7 @@ impl Connection {
sha256: &Sha256,
ct: &CancellationToken,
) -> VResult<VaasVerdict> {
let request = VerdictRequest::new(
let request = VerdictRequestFile::new(
sha256,
self.session_id.clone(),
self.options.use_cache,
Expand Down Expand Up @@ -144,9 +145,9 @@ impl Connection {
self.options.use_cache,
self.options.use_hash_lookup,
);
let guid = request.guid().to_string();
let guid = request.guid.to_string();

let response = Self::for_stream_request(
let response = Self::for_request(
request,
self.ws_writer.clone(),
&mut self.result_channel.subscribe(),
Expand Down Expand Up @@ -189,14 +190,20 @@ impl Connection {

/// Request a verdict for a file.
pub async fn for_file(&self, file: &Path, ct: &CancellationToken) -> VResult<VaasVerdict> {
let sha256 = Sha256::try_from(file)?;
let request = VerdictRequest::new(
let buf = tokio::fs::read(file).await?;
self.for_buf(buf, ct).await
}

/// Request a verdict for a buffer.
pub async fn for_buf(&self, buf: Vec<u8>, ct: &CancellationToken) -> VResult<VaasVerdict> {
let sha256 = Sha256::from(buf.as_slice());
let request = VerdictRequestFile::new(
&sha256,
self.session_id.clone(),
self.options.use_cache,
self.options.use_hash_lookup,
);
let guid = request.guid().to_string();
let guid = request.guid.to_string();

let response = Self::for_request(
request,
Expand All @@ -210,7 +217,7 @@ impl Connection {
match verdict {
Verdict::Unknown { upload_url } => {
Self::handle_unknown(
file,
buf,
&guid,
response,
upload_url,
Expand All @@ -224,7 +231,7 @@ impl Connection {
}

async fn handle_unknown(
file: &Path,
buf: Vec<u8>,
guid: &str,
response: VerdictResponse,
upload_url: UploadUrl,
Expand All @@ -235,17 +242,9 @@ impl Connection {
.upload_token
.as_ref()
.ok_or(Error::MissingAuthToken)?;
let response = upload_file(file, upload_url, auth_token).await?;

if response.status() != 200 {
return Err(Error::FailedUploadFile(
response.status(),
response.text().await.expect("failed to get payload"),
));
}
let response = upload_buf(buf, upload_url, auth_token).await?;

let resp = Self::wait_for_response(guid, result_channel, ct).await?;
VaasVerdict::try_from(resp)
Self::handle_result(guid, response, result_channel, ct).await
}

async fn handle_unknown_stream<S>(
Expand All @@ -268,10 +267,22 @@ impl Connection {
.ok_or(Error::MissingAuthToken)?;
let response = upload_stream(stream, content_length, upload_url, auth_token).await?;

Self::handle_result(guid, response, result_channel, ct).await
}

async fn handle_result(
guid: &str,
response: Response,
result_channel: &mut ResultChannelRx,
ct: &CancellationToken,
) -> Result<VaasVerdict, Error> {
if response.status() != 200 {
return Err(Error::FailedUploadFile(
response.status(),
response.text().await.expect("failed to get payload"),
response
.text()
.await
.unwrap_or("failed to get payload".to_string()),
));
}

Expand All @@ -290,30 +301,8 @@ impl Connection {
join_all(req).await
}

async fn for_request(
request: VerdictRequest,
ws_writer: WebSocketWriter,
result_channel: &mut ResultChannelRx,
ct: &CancellationToken,
) -> VResult<VerdictResponse> {
let guid = request.guid().to_string();
ws_writer.lock().await.send_text(request.to_json()?).await?;
Self::wait_for_response(&guid, result_channel, ct).await
}

async fn for_url_request(
request: VerdictRequestForUrl,
ws_writer: WebSocketWriter,
result_channel: &mut ResultChannelRx,
ct: &CancellationToken,
) -> VResult<VerdictResponse> {
let guid = request.guid().to_string();
ws_writer.lock().await.send_text(request.to_json()?).await?;
Self::wait_for_response(&guid, result_channel, ct).await
}

async fn for_stream_request(
request: VerdictRequestForStream,
async fn for_request<T: VerdictRequest + Serialize>(
request: T,
ws_writer: WebSocketWriter,
result_channel: &mut ResultChannelRx,
ct: &CancellationToken,
Expand Down Expand Up @@ -396,14 +385,13 @@ impl Connection {
}
}

async fn upload_file(
file: &Path,
async fn upload_buf(
buf: Vec<u8>,
upload_url: UploadUrl,
auth_token: &str,
) -> VResult<reqwest::Response> {
let body = tokio::fs::File::open(&file).await?;
let content_length = body.metadata().await?.len() as usize;
upload_internal(body, content_length, upload_url, auth_token).await
let content_length = buf.len();
upload_internal(buf, content_length, upload_url, auth_token).await
}

async fn upload_stream<S>(
Expand Down
12 changes: 2 additions & 10 deletions rust/src/message/message_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,7 @@ mod tests {
.to_string();

let message_type = MessageType::try_from(&msg).unwrap();

let is_correct_type = match message_type {
MessageType::VerdictResponse(_) => true,
_ => false,
};
let is_correct_type = matches!(message_type, MessageType::VerdictResponse(_));

assert!(is_correct_type);
}
Expand All @@ -67,11 +63,7 @@ mod tests {
.to_string();

let message_type = MessageType::try_from(&msg);

let is_correct_type = match message_type {
Err(Error::ErrorResponse(_)) => true,
_ => false,
};
let is_correct_type = matches!(message_type, Err(Error::ErrorResponse(_)));

assert!(is_correct_type);
}
Expand Down
2 changes: 2 additions & 0 deletions rust/src/message/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ mod open_id_connect_token_response;
mod upload_url;
mod verdict;
mod verdict_request;
mod verdict_request_for_file;
mod verdict_request_for_stream;
mod verdict_request_for_url;
mod verdict_response;
Expand All @@ -21,6 +22,7 @@ pub(super) use open_id_connect_token_response::OpenIdConnectTokenResponse;
pub(super) use upload_url::UploadUrl;
pub use verdict::Verdict;
pub(super) use verdict_request::VerdictRequest;
pub(super) use verdict_request_for_file::VerdictRequestFile;
pub(super) use verdict_request_for_stream::VerdictRequestForStream;
pub(super) use verdict_request_for_url::VerdictRequestForUrl;
pub(super) use verdict_response::VerdictResponse;
38 changes: 9 additions & 29 deletions rust/src/message/verdict_request.rs
Original file line number Diff line number Diff line change
@@ -1,33 +1,13 @@
use crate::message::kind::Kind;
use crate::{error::VResult, sha256::Sha256};
use serde::{Deserialize, Serialize};
use crate::error::VResult;
use serde::Serialize;

#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct VerdictRequest {
pub sha256: String,
pub kind: Kind,
pub guid: String,
pub session_id: String,
pub use_hash_lookup: bool,
pub use_cache: bool,
}

impl VerdictRequest {
pub fn new(sha256: &Sha256, session_id: String, use_cache: bool, use_hash_lookup: bool) -> Self {
Self {
guid: uuid::Uuid::new_v4().to_string(),
sha256: sha256.to_string(),
kind: Kind::VerdictRequest,
session_id,
use_cache,
use_hash_lookup,
}
}

pub fn to_json(&self) -> VResult<String> {
pub trait VerdictRequest {
fn to_json(&self) -> VResult<String>
where
Self: Serialize,
{
serde_json::to_string(self).map_err(|e| e.into())
}
pub fn guid(&self) -> &str {
&self.guid
}

fn guid(&self) -> &str;
}
39 changes: 39 additions & 0 deletions rust/src/message/verdict_request_for_file.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
use crate::message::kind::Kind;
use crate::sha256::Sha256;
use serde::{Deserialize, Serialize};

use super::VerdictRequest;

#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct VerdictRequestFile {
pub sha256: String,
pub kind: Kind,
pub guid: String,
pub session_id: String,
pub use_hash_lookup: bool,
pub use_cache: bool,
}

impl VerdictRequestFile {
pub fn new(
sha256: &Sha256,
session_id: String,
use_cache: bool,
use_hash_lookup: bool,
) -> Self {
Self {
guid: uuid::Uuid::new_v4().to_string(),
sha256: sha256.to_string(),
kind: Kind::VerdictRequest,
session_id,
use_cache,
use_hash_lookup,
}
}
}

impl VerdictRequest for VerdictRequestFile {
fn guid(&self) -> &str {
&self.guid
}
}
9 changes: 4 additions & 5 deletions rust/src/message/verdict_request_for_stream.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::error::VResult;
use super::VerdictRequest;
use crate::message::kind::Kind;
use serde::{Deserialize, Serialize};

Expand All @@ -21,11 +21,10 @@ impl VerdictRequestForStream {
use_shed,
}
}
}

pub fn to_json(&self) -> VResult<String> {
serde_json::to_string(self).map_err(|e| e.into())
}
pub fn guid(&self) -> &str {
impl VerdictRequest for VerdictRequestForStream {
fn guid(&self) -> &str {
&self.guid
}
}
9 changes: 4 additions & 5 deletions rust/src/message/verdict_request_for_url.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::error::VResult;
use super::VerdictRequest;
use crate::message::kind::Kind;
use reqwest::Url;
use serde::{Deserialize, Serialize};
Expand All @@ -24,11 +24,10 @@ impl VerdictRequestForUrl {
use_shed,
}
}
}

pub fn to_json(&self) -> VResult<String> {
serde_json::to_string(self).map_err(|e| e.into())
}
pub fn guid(&self) -> &str {
impl VerdictRequest for VerdictRequestForUrl {
fn guid(&self) -> &str {
&self.guid
}
}
26 changes: 18 additions & 8 deletions rust/src/sha256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,23 @@ use std::{convert::TryFrom, fmt, ops::Deref};
#[derive(Debug, PartialEq, Eq, Clone, Hash)]
pub struct Sha256(String);

impl From<&[u8]> for Sha256 {
fn from(value: &[u8]) -> Self {
use sha2::Digest;
use std::fmt::Write;

let mut hasher = sha2::Sha256::new();
hasher.update(value);
let result = hasher.finalize();

let hex_string = result.iter().fold(String::new(), |mut output, b| {
let _ = write!(output, "{b:02x}");
output
});
Self(hex_string)
}
}

impl TryFrom<&str> for Sha256 {
type Error = crate::error::Error;

Expand Down Expand Up @@ -46,15 +63,8 @@ impl TryFrom<&PathBuf> for Sha256 {
type Error = crate::error::Error;

fn try_from(value: &PathBuf) -> Result<Self, Self::Error> {
use sha2::Digest;
let bytes = std::fs::read(value)?;

let mut hasher = sha2::Sha256::new();
hasher.update(&bytes);
let result = hasher.finalize();

let hex_string = result.iter().map(|b| format!("{b:02x}")).collect();
Ok(Self(hex_string))
Ok(Self::from(bytes.as_slice()))
}
}

Expand Down

0 comments on commit a4434cd

Please sign in to comment.