Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for --sse, --sse-kms-key-id flags under a feature flag #715

Merged
merged 3 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions mountpoint-s3/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ use regex::Regex;

use crate::build_info;
use crate::data_cache::{CacheLimit, DiskDataCache, DiskDataCacheConfig, ManagedCacheDir};
#[cfg(feature = "sse_kms")]
use crate::fs::ServerSideEncryption;
use crate::fs::{CacheConfig, S3FilesystemConfig, S3Personality};
use crate::fuse::session::FuseSession;
use crate::fuse::S3FuseFilesystem;
Expand Down Expand Up @@ -268,6 +270,23 @@ pub struct CliArgs {
help_heading = ADVANCED_OPTIONS_HEADER,
)]
pub user_agent_prefix: Option<String>,

#[cfg(feature = "sse_kms")]
#[clap(
long,
help = "Server-side encryption algorithm to use when uploading new objects",
help_heading = BUCKET_OPTIONS_HEADER,
value_parser = clap::builder::PossibleValuesParser::new(["aws:kms", "aws:kms:dsse"]))]
pub sse: Option<String>,

#[cfg(feature = "sse_kms")]
#[clap(
long,
help = "AWS Key Management Service (KMS) key ID to use with KMS server-side encryption when uploading new objects",
help_heading = BUCKET_OPTIONS_HEADER,
requires = "sse",
)]
pub sse_kms_key_id: Option<String>,
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -608,6 +627,10 @@ where
filesystem_config.allow_delete = args.allow_delete;
filesystem_config.allow_overwrite = args.allow_overwrite;
filesystem_config.s3_personality = s3_personality;
#[cfg(feature = "sse_kms")]
{
filesystem_config.server_side_encryption = ServerSideEncryption::new(args.sse, args.sse_kms_key_id);
}

let prefetcher_config = Default::default();

Expand Down
37 changes: 36 additions & 1 deletion mountpoint-s3/src/fs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,8 @@ pub struct S3FilesystemConfig {
pub storage_class: Option<String>,
/// S3 personality (for different S3 semantics)
pub s3_personality: S3Personality,
/// Server side encryption configuration to be used when creating new S3 object
pub server_side_encryption: ServerSideEncryption,
}

impl Default for S3FilesystemConfig {
Expand All @@ -396,6 +398,7 @@ impl Default for S3FilesystemConfig {
allow_overwrite: false,
storage_class: None,
s3_personality: S3Personality::Standard,
server_side_encryption: Default::default(),
}
}
}
Expand All @@ -421,6 +424,34 @@ impl S3Personality {
}
}

/// Server-side encryption configuration for newly created objects
#[derive(Debug, Clone, Default)]
pub struct ServerSideEncryption {
sse_type: Option<String>,
sse_kms_key_id: Option<String>,
}

impl ServerSideEncryption {
/// Construct SSE settings from raw values provided via CLI
pub fn new(sse_type: Option<String>, sse_kms_key_id: Option<String>) -> Self {
// TODO: compute checksum
Self {
sse_type,
sse_kms_key_id,
}
}

/// String representation of the SSE type as it is expected by S3 API
pub fn sse_type(&self) -> Option<String> {
self.sse_type.clone()
}

/// AWS KMS Key ID, if provided
pub fn key_id(&self) -> Option<String> {
self.sse_kms_key_id.clone()
}
}

#[derive(Debug)]
pub struct S3Filesystem<Client, Prefetcher>
where
Expand Down Expand Up @@ -462,7 +493,11 @@ where

let client = Arc::new(client);

let uploader = Uploader::new(client.clone(), config.storage_class.to_owned());
let uploader = Uploader::new(
client.clone(),
config.storage_class.to_owned(),
config.server_side_encryption.clone(),
);

Self {
config,
Expand Down
2 changes: 1 addition & 1 deletion mountpoint-s3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub mod prefix;
mod sync;
mod upload;

pub use fs::{S3Filesystem, S3FilesystemConfig};
pub use fs::{S3Filesystem, S3FilesystemConfig, ServerSideEncryption};

/// Enable tracing and CRT logging when running unit tests.
#[cfg(test)]
Expand Down
28 changes: 22 additions & 6 deletions mountpoint-s3/src/upload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use thiserror::Error;
use tracing::error;

use crate::checksums::combine_checksums;
use crate::fs::ServerSideEncryption;

type PutRequestError<Client> = ObjectClientError<PutObjectError, <Client as ObjectClient>::ClientError>;

Expand All @@ -25,12 +26,21 @@ pub struct Uploader<Client> {
struct UploaderInner<Client> {
client: Arc<Client>,
storage_class: Option<String>,
server_side_encryption: ServerSideEncryption,
}

impl<Client: ObjectClient> Uploader<Client> {
/// Create a new [Uploader] that will make requests to the given client.
pub fn new(client: Arc<Client>, storage_class: Option<String>) -> Self {
let inner = UploaderInner { client, storage_class };
pub fn new(
client: Arc<Client>,
storage_class: Option<String>,
server_side_encryption: ServerSideEncryption,
) -> Self {
let inner = UploaderInner {
client,
storage_class,
server_side_encryption,
};
Self { inner: Arc::new(inner) }
}

Expand Down Expand Up @@ -79,6 +89,8 @@ impl<Client: ObjectClient> UploadRequest<Client> {
if let Some(storage_class) = &inner.storage_class {
params = params.storage_class(storage_class.clone());
}
params = params.server_side_encryption(inner.server_side_encryption.sse_type());
params = params.ssekms_key_id(inner.server_side_encryption.key_id());

let request = inner.client.put_object(bucket, key, &params).await?;
let maximum_upload_size = inner.client.part_size().map(|ps| ps * MAX_S3_MULTIPART_UPLOAD_PARTS);
Expand Down Expand Up @@ -204,7 +216,7 @@ mod tests {
part_size: 32,
..Default::default()
}));
let uploader = Uploader::new(client.clone(), None);
let uploader = Uploader::new(client.clone(), None, ServerSideEncryption::default());
let request = uploader.put(bucket, key).await.unwrap();

assert!(!client.contains_key(key));
Expand All @@ -228,7 +240,11 @@ mod tests {
part_size: 32,
..Default::default()
}));
let uploader = Uploader::new(client.clone(), Some(storage_class.to_owned()));
let uploader = Uploader::new(
client.clone(),
Some(storage_class.to_owned()),
ServerSideEncryption::default(),
);

let mut request = uploader.put(bucket, key).await.unwrap();

Expand Down Expand Up @@ -277,7 +293,7 @@ mod tests {
put_failures,
));

let uploader = Uploader::new(failure_client.clone(), None);
let uploader = Uploader::new(failure_client.clone(), None, ServerSideEncryption::default());

// First request fails on first write.
{
Expand Down Expand Up @@ -318,7 +334,7 @@ mod tests {
part_size: PART_SIZE,
..Default::default()
}));
let uploader = Uploader::new(client.clone(), None);
let uploader = Uploader::new(client.clone(), None, ServerSideEncryption::default());
let mut request = uploader.put(bucket, key).await.unwrap();

let successful_writes = PART_SIZE * MAX_S3_MULTIPART_UPLOAD_PARTS / write_size;
Expand Down
21 changes: 20 additions & 1 deletion mountpoint-s3/tests/common/fuse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@ use mountpoint_s3::fuse::S3FuseFilesystem;
use mountpoint_s3::prefetch::{Prefetch, PrefetcherConfig};
use mountpoint_s3::prefix::Prefix;
use mountpoint_s3::S3FilesystemConfig;
use mountpoint_s3_client::config::S3ClientAuthConfig;
use mountpoint_s3_client::types::PutObjectParams;
use mountpoint_s3_client::ObjectClient;
use mountpoint_s3_crt::auth::credentials::{CredentialsProvider, CredentialsProviderStaticOptions};
use mountpoint_s3_crt::common::allocator::Allocator;
use tempfile::TempDir;

pub trait TestClient: Send {
Expand Down Expand Up @@ -46,6 +49,7 @@ pub struct TestSessionConfig {
pub part_size: usize,
pub filesystem_config: S3FilesystemConfig,
pub prefetcher_config: PrefetcherConfig,
pub auth_config: S3ClientAuthConfig,
}

impl Default for TestSessionConfig {
Expand All @@ -54,10 +58,24 @@ impl Default for TestSessionConfig {
part_size: 8 * 1024 * 1024,
filesystem_config: Default::default(),
prefetcher_config: Default::default(),
auth_config: Default::default(),
}
}
}

impl TestSessionConfig {
pub fn with_credentials(mut self, credentials: aws_sdk_s3::config::Credentials) -> Self {
let auth_config = CredentialsProviderStaticOptions {
access_key_id: credentials.access_key_id(),
secret_access_key: credentials.secret_access_key(),
session_token: credentials.session_token(),
};
let credentials_provider = CredentialsProvider::new_static(&Allocator::default(), auth_config).unwrap();
self.auth_config = S3ClientAuthConfig::Provider(credentials_provider);
self
}
}

fn create_fuse_session<Client, Prefetcher>(
client: Client,
prefetcher: Prefetcher,
Expand Down Expand Up @@ -256,7 +274,8 @@ pub mod s3_session {

let client_config = S3ClientConfig::default()
.part_size(test_config.part_size)
.endpoint_config(EndpointConfig::new(&region));
.endpoint_config(EndpointConfig::new(&region))
.auth_config(test_config.auth_config);
let client = S3CrtClient::new(client_config).unwrap();
let runtime = client.event_loop_group();
let prefetcher = default_prefetch(runtime, test_config.prefetcher_config);
Expand Down
54 changes: 54 additions & 0 deletions mountpoint-s3/tests/common/s3.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use aws_config::{BehaviorVersion, Region};
use aws_sdk_s3::primitives::ByteStream;
use aws_sdk_sts::config::Credentials;
use futures::Future;
use rand::RngCore;
use rand_chacha::rand_core::OsRng;
Expand Down Expand Up @@ -83,6 +84,10 @@ pub async fn get_test_sdk_client(region: &str) -> aws_sdk_s3::Client {
aws_sdk_s3::Client::from_conf(s3_config.build())
}

pub fn get_test_kms_key_id() -> String {
std::env::var("KMS_TEST_KEY_ID").expect("Set KMS_TEST_KEY_ID to run integration tests")
}

pub fn create_objects(bucket: &str, prefix: &str, region: &str, key: &str, value: &[u8]) {
let sdk_client = tokio_block_on(get_test_sdk_client(region));
let full_key = format!("{prefix}{key}");
Expand All @@ -105,3 +110,52 @@ pub fn tokio_block_on<F: Future>(future: F) -> F::Output {
.unwrap();
runtime.block_on(future)
}

/// Detect if running on GitHub Actions (GHA) and if so,
/// emit masking string to avoid credentials accidentally being printed.
fn mask_aws_creds_if_on_gha(credentials: &Credentials) {
if std::env::var_os("GITHUB_ACTIONS").is_some() {
// GitHub Actions aren't aware of these credential strings since we're sourcing them inside the tests.
// If we think we're in GitHub Actions environment, register each in stdout.
// https://docs.github.com/en/actions/using-workflows/workflow-commands-for-github-actions#masking-a-value-in-a-log
println!("::add-mask::{}", credentials.access_key_id());
println!("::add-mask::{}", credentials.secret_access_key());
if let Some(token) = credentials.session_token() {
println!("::add-mask::{}", token);
}
}
}

pub async fn get_test_sdk_sts_client() -> aws_sdk_sts::Client {
let config = aws_config::defaults(BehaviorVersion::latest())
.region(Region::new(get_test_region()))
.load()
.await;
aws_sdk_sts::Client::new(&config)
}

pub async fn get_scoped_down_credentials(policy: &str) -> Credentials {
let sts_client = get_test_sdk_sts_client().await;
let nonce = OsRng.next_u64();
let assume_role_response = sts_client
.assume_role()
.role_arn(get_subsession_iam_role())
.role_session_name(format!("mountpoint-s3-tests-{nonce}"))
.policy(policy)
.send()
.await
.expect("assume_role with valid ARN and policy should succeed");
let credentials = assume_role_response
.credentials()
.expect("credentials should be present if assume_role succeeded")
.to_owned();
let credentials = Credentials::new(
credentials.access_key_id(),
credentials.secret_access_key(),
Some(credentials.session_token().to_owned()),
None,
"scoped_down_sts_creds",
);
mask_aws_creds_if_on_gha(&credentials);
credentials
}
Loading
Loading