diff --git a/Cargo.toml b/Cargo.toml index 8fd5b53..9df319d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,31 +15,16 @@ identifier = "com.netease.nmc.demo" copyright = "Copyright (c) Jane Doe 2016. All rights reserved." [dependencies] -aws-config = { version = "1.1.1", features = ["behavior-version-latest"] } -aws-credential-types = "1.1.1" -aws-sdk-s3 = { version = "1.10.0", features = ["behavior-version-latest", "rt-tokio"], default-features = false } -aws-smithy-runtime-api = { version = "1.1.1", features = ["client"] } -aws-smithy-runtime = { version = "1.1.1", features = ["tls-rustls"] } -aws-smithy-types = { version = "1.1.1", features = ["http-body-0-4-x"] } -rustls = "0.21.9" -bytes = "1.5.0" -chrono = "0.4.31" -http = "0.2.8" -http-body = "0.4.6" -pin-project = "1.1.3" +rust-s3 = { version = "0.34.0-rc4", features = ["no-verify-ssl"] } serde = { version = "1.0.193", features = ["derive"] } serde_json = "1.0.108" tokio = { version = "1.35.0", features = ["full"] } -rustls-pemfile = "1.0.3" log = "0.4.20" flexi_logger = "0.27.3" sysinfo = "0.30.1" clap = { version = "4.4.11", features = ["derive"] } urlencoding = "2.1.3" - -[dependencies.hyper-rustls] -version = "0.24.2" -features = ["http2"] +tokio-stream = "0.1.14" [lib] path = "src/lib.rs" diff --git a/src/basic.rs b/src/basic.rs index a5634eb..c931669 100644 --- a/src/basic.rs +++ b/src/basic.rs @@ -1,20 +1,7 @@ -use aws_config::{retry::RetryConfig, Region}; -use aws_credential_types::provider::{ProvideCredentials, SharedCredentialsProvider}; -use aws_sdk_s3::{config::Credentials, Client}; -use aws_smithy_runtime::client::http::hyper_014::HyperClientBuilder; -use bytes::Bytes; -use http_body::{Body, SizeHint}; -use log::info; -use rustls::{Certificate, RootCertStore}; -use rustls_pemfile::certs; use serde::{Deserialize, Serialize}; -use std::{ - fs::File, - io::BufReader, - pin::Pin, - sync::{Arc, Mutex}, - task::{Context, Poll}, -}; +use std::sync::{Arc, Mutex}; +use s3::Bucket; +use s3::creds::Credentials; /// result callback /// - `success` - true if success, false if failed @@ -25,139 +12,6 @@ pub type ResultCallback = Box; /// - `progress` - progress of upload or download, in percentage pub type ProgressCallback = Arc>; -#[derive(Debug)] -pub struct NeS3Credential { - access_key_id: String, - secret_access_key: String, - session_token: String, -} - -impl NeS3Credential { - pub fn new(access_key_id: String, secret_access_key: String, session_token: String) -> Self { - Self { - access_key_id, - secret_access_key, - session_token, - } - } - - async fn load_credentials(&self) -> aws_credential_types::provider::Result { - Ok(Credentials::new( - self.access_key_id.clone(), - self.secret_access_key.clone(), - Some(self.session_token.clone()), - None, - "NeS3Credential", - )) - } -} - -impl ProvideCredentials for NeS3Credential { - fn provide_credentials<'a>( - &'a self, - ) -> aws_credential_types::provider::future::ProvideCredentials<'a> - where - Self: 'a, - { - aws_credential_types::provider::future::ProvideCredentials::new(self.load_credentials()) - } -} - -// ProgressTracker prints information as the upload progresses. -struct ProgressTracker { - bytes_written: Arc>, - content_length: u64, - progress_callback: ProgressCallback, - last_callback_time: std::time::Instant, -} -impl ProgressTracker { - fn track(&mut self, len: u64) { - let mut bytes_written = self.bytes_written.lock().unwrap(); - *bytes_written += len; - let progress = *bytes_written as f32 / self.content_length as f32 * 100.0; - let progress_callback = self.progress_callback.lock().unwrap(); - if std::time::Instant::now() - self.last_callback_time - < std::time::Duration::from_millis(500) - && progress < 100.0 - { - return; - } - self.last_callback_time = std::time::Instant::now(); - progress_callback(progress); - } -} - -// snippet-start:[s3.rust.put-object-progress-body] -// A ProgressBody to wrap any http::Body with upload progress information. -#[pin_project::pin_project] -pub struct ProgressBody { - #[pin] - inner: InnerBody, - // prograss_tracker is a separate field so it can be accessed as &mut. - progress_tracker: ProgressTracker, -} - -impl ProgressBody -where - InnerBody: Body, -{ - pub fn new( - body: InnerBody, - bytes_written: Arc>, - content_length: u64, - progress_callback: ProgressCallback, - ) -> Self { - Self { - inner: body, - progress_tracker: ProgressTracker { - bytes_written, - content_length, - progress_callback, - last_callback_time: std::time::Instant::now(), - }, - } - } -} - -impl Body for ProgressBody -where - InnerBody: Body, -{ - type Data = Bytes; - - type Error = aws_smithy_types::body::Error; - - // Our poll_data delegates to the inner poll_data, but needs a project() to - // get there. When the poll has data, it updates the progress_tracker. - fn poll_data( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - let this = self.project(); - match this.inner.poll_data(cx) { - Poll::Ready(Some(Ok(data))) => { - this.progress_tracker.track(data.len() as u64); - Poll::Ready(Some(Ok(data))) - } - Poll::Ready(None) => Poll::Ready(None), - Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), - Poll::Pending => Poll::Pending, - } - } - - // Delegate utilities to inner and progress_tracker. - fn poll_trailers( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - self.project().inner.poll_trailers(cx) - } - - fn size_hint(&self) -> http_body::SizeHint { - SizeHint::with_exact(self.progress_tracker.content_length) - } -} - #[derive(Debug, Serialize, Deserialize)] pub struct S3Params { pub(crate) bucket: String, @@ -173,56 +27,29 @@ pub struct S3Params { pub(crate) ca_cert_path: Option, } -fn load_ca_cert(path: &String) -> Result> { - let file = File::open(path)?; - let mut reader = BufReader::new(file); - let mut root_store = RootCertStore::empty(); - for cert in certs(&mut reader)? { - root_store.add(&Certificate(cert))?; - } - Ok(root_store) -} - -pub fn create_s3_client(params: &S3Params) -> Result> { - let mut region = Region::new("ap-southeast-1"); - if let Some(region_str) = ¶ms.region { - region = Region::new(region_str.clone()); - } - let credential = NeS3Credential::new( - params.access_key_id.clone(), - params.secret_access_key.clone(), - params.session_token.clone(), - ); - let mut builder = aws_config::SdkConfig::builder() - .region(region) - .credentials_provider(SharedCredentialsProvider::new(credential)) - // Set max attempts. - // If tries is 1, there are no retries. - .retry_config(RetryConfig::standard().with_max_attempts(params.tries.unwrap_or(1))); - if params.ca_cert_path.is_some() { - info!( - "use custom ca certs, path: {}", - params.ca_cert_path.as_ref().unwrap() - ); - let root_store = load_ca_cert(params.ca_cert_path.as_ref().unwrap())?; - let config = rustls::ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(root_store) - .with_no_client_auth(); - let tls_connector = hyper_rustls::HttpsConnectorBuilder::new() - .with_tls_config(config) - .https_only() - .enable_http1() - .enable_http2() - .build(); - let hyper_client = HyperClientBuilder::new().build(tls_connector); - builder.set_http_client(Some(hyper_client)); - } - if params.endpoint.as_ref().is_some_and(|url| !url.is_empty()) { - let endpoint = params.endpoint.as_ref().unwrap().clone(); - builder.set_endpoint_url(Some(endpoint)); - } - let shared_config = builder.build(); - // Construct an S3 client with customized retry configuration. - Ok(Client::new(&shared_config)) +pub fn get_s3_bucket(params: &S3Params) -> Result> { + let mut region: s3::Region = match params.region { + Some(ref r) => r.parse()?, + None => s3::Region::UsEast1, + }; + if params.endpoint.as_ref().is_some_and(|e| e.len() > 0) { + region = s3::Region::Custom { + region: params.region.as_ref().unwrap().to_string(), + endpoint: params.endpoint.as_ref().unwrap().to_string(), + }; + } + let mut bucket = Bucket::new( + params.bucket.as_str(), + region, + // Credentials are collected from environment, config, profile or instance metadata + Credentials::new( + Some(params.access_key_id.as_str()), + Some(params.secret_access_key.as_str()), + None, + Some(params.session_token.as_str()), + None, + )?, + )?; + bucket.add_header("x-amz-meta-token", params.security_token.as_str()); + Ok(bucket) } diff --git a/src/download.rs b/src/download.rs index b7cecca..558c452 100644 --- a/src/download.rs +++ b/src/download.rs @@ -1,20 +1,10 @@ use crate::basic; -use std::{fs::File, io::Write}; +use log::info; pub async fn get_object(params: &basic::S3Params) -> Result> { - let client = basic::create_s3_client(¶ms)?; - let mut file = File::create(¶ms.file_path)?; - let mut resp = client - .get_object() - .bucket(¶ms.bucket) - .key(¶ms.object) - .send() - .await?; - let mut byte_count = 0_u64; - while let Some(bytes) = resp.body.try_next().await? { - let bytes_len = bytes.len(); - file.write_all(&bytes)?; - byte_count += bytes_len as u64; - } - Ok(byte_count) + let bucket = basic::get_s3_bucket(params)?; + let mut async_output_file = tokio::fs::File::create(¶ms.file_path).await?; + let status_code = bucket.get_object_to_writer(¶ms.object, &mut async_output_file).await?; + info!("get_object status code: {}", status_code); + Ok(0) } diff --git a/src/upload.rs b/src/upload.rs index 4bbca27..a7f747e 100644 --- a/src/upload.rs +++ b/src/upload.rs @@ -1,167 +1,13 @@ use crate::basic; -use aws_sdk_s3::{ - operation::create_multipart_upload::CreateMultipartUploadOutput, - primitives::{ByteStream, SdkBody}, - types::{CompletedMultipartUpload, CompletedPart}, -}; -use aws_smithy_runtime_api::http::Request; -use aws_smithy_types::byte_stream::Length; use log::info; -use std::{ - convert::Infallible, - path::Path, - sync::{Arc, Mutex}, -}; - -const DEFAULT_CHUNK_SIZE: u64 = 1024 * 1024 * 5; -const MAX_CHUNK_COUNT: u64 = 10000; - -#[derive(Debug)] -struct MultiUploadChunkInfo { - chunk_count: u64, - chunk_size: u64, - size_of_last_chunk: u64, -} - -fn get_multiupload_chunk_info( - file_size: u64, - default_chunk_size: u64, - max_chunk_count: u64, -) -> MultiUploadChunkInfo { - let mut chunk_size = default_chunk_size; - if file_size > default_chunk_size * max_chunk_count { - chunk_size = file_size / (max_chunk_count - 1).max(1); - } - let mut size_of_last_chunk = file_size % chunk_size; - let mut chunk_count = file_size / chunk_size + 1; - if size_of_last_chunk == 0 { - size_of_last_chunk = chunk_size; - chunk_count = (chunk_count - 1).max(1) - } - MultiUploadChunkInfo { - chunk_count, - chunk_size, - size_of_last_chunk, - } -} pub async fn put_object( params: &basic::S3Params, proress_callback: basic::ProgressCallback, -) -> Result> { - let client = basic::create_s3_client(¶ms)?; - let mut metadata_header = std::collections::HashMap::new(); - metadata_header.insert( - "x-amz-meta-token".to_string(), - params.security_token.clone(), - ); - let multipart_upload_res: CreateMultipartUploadOutput = client - .create_multipart_upload() - .bucket(¶ms.bucket) - .key(¶ms.object) - .set_metadata(Some(metadata_header)) - .send() - .await?; - let upload_id = multipart_upload_res.upload_id(); - if upload_id.is_none() { - return Err("upload_id is none".into()); - } - let file_path = Path::new(¶ms.file_path); - let upload_id = upload_id.unwrap(); - let file_size = tokio::fs::metadata(file_path).await?.len(); - if file_size == 0 { - return Err("file size is 0".into()); - } - let upload_chunk_info = - get_multiupload_chunk_info(file_size, DEFAULT_CHUNK_SIZE, MAX_CHUNK_COUNT); - info!("upload_chunk_info: {:?}", upload_chunk_info); - let mut upload_parts = Vec::new(); - let uploaded_size = Arc::new(Mutex::new(0_u64)); - for chunk_index in 0..upload_chunk_info.chunk_count { - let this_chunk = if upload_chunk_info.chunk_count - 1 == chunk_index { - upload_chunk_info.size_of_last_chunk - } else { - upload_chunk_info.chunk_size - }; - let stream = ByteStream::read_from() - .path(file_path) - .offset(chunk_index * upload_chunk_info.chunk_size) - .length(Length::Exact(this_chunk)) - .build() - .await?; - // chunk index needs to start at 0, but part numbers start at 1. - // chunk_count is less than MAX_CHUNKS, so this can't overflow. - let part_number = (chunk_index as i32) + 1; - let uploaded_size = uploaded_size.clone(); - let proress_callback = proress_callback.clone(); - let customized = client - .upload_part() - .key(¶ms.object) - .bucket(¶ms.bucket) - .upload_id(upload_id) - .body(stream) - .part_number(part_number) - .customize() - .map_request( - move |value: Request| -> Result, Infallible> { - let uploaded_size = uploaded_size.clone(); - let proress_callback = proress_callback.clone(); - let value = value.map(move |body| { - let body = basic::ProgressBody::new( - body, - uploaded_size.clone(), - file_size, - proress_callback.clone(), - ); - SdkBody::from_body_0_4(body) - }); - Ok(value) - }, - ); - let upload_part = tokio::task::spawn(async move { customized.send().await }); - upload_parts.push((upload_part, part_number)); - } - let mut upload_part_res_vec: Vec = Vec::new(); - for (handle, part_number) in upload_parts { - let upload_part_res = handle.await??; - upload_part_res_vec.push( - CompletedPart::builder() - .e_tag(upload_part_res.e_tag.unwrap_or_default()) - .part_number(part_number) - .build(), - ); - } - info!("upload_parts finished"); - let completed_multipart_upload: CompletedMultipartUpload = CompletedMultipartUpload::builder() - .set_parts(Some(upload_part_res_vec)) - .build(); - client - .complete_multipart_upload() - .bucket(¶ms.bucket) - .key(¶ms.object) - .multipart_upload(completed_multipart_upload) - .upload_id(upload_id) - .send() - .await?; - Ok(file_size) -} - -#[cfg(test)] -mod tests { - #[test] - fn test_get_multiupload_chunk_info() { - use super::get_multiupload_chunk_info; - let mut chunk_info = get_multiupload_chunk_info(10, 2, 5); - assert_eq!(chunk_info.chunk_count, 5); - assert_eq!(chunk_info.chunk_size, 2); - assert_eq!(chunk_info.size_of_last_chunk, 2); - chunk_info = get_multiupload_chunk_info(10, 2, 4); - assert_eq!(chunk_info.chunk_count, 4); - assert_eq!(chunk_info.chunk_size, 3); - assert_eq!(chunk_info.size_of_last_chunk, 1); - chunk_info = get_multiupload_chunk_info(10, 2, 1); - assert_eq!(chunk_info.chunk_count, 1); - assert_eq!(chunk_info.chunk_size, 10); - assert_eq!(chunk_info.size_of_last_chunk, 10); - } +) -> Result<(), Box> { + let bucket = basic::get_s3_bucket(params)?; + let mut async_output_file = tokio::fs::File::open(¶ms.file_path).await?; + let response = bucket.put_object_stream(&mut async_output_file, ¶ms.object).await?; + info!("put_object status code: {}", response.status_code()); + Ok(()) }