Skip to content

Commit

Permalink
support the SSE-C encryption
Browse files Browse the repository at this point in the history
  • Loading branch information
Hang Zheng committed Aug 9, 2024
1 parent 2e1b245 commit 1134b07
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 7 deletions.
87 changes: 80 additions & 7 deletions cpp/src/arrow/filesystem/s3fs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@
#include <aws/s3/model/PutObjectRequest.h>
#include <aws/s3/model/PutObjectResult.h>
#include <aws/s3/model/UploadPartRequest.h>
#include <aws/core/utils/HashingUtils.h>
#include <aws/core/utils/base64/Base64.h>

// AWS_SDK_VERSION_{MAJOR,MINOR,PATCH} are available since 1.9.7.
#if defined(AWS_SDK_VERSION_MAJOR) && defined(AWS_SDK_VERSION_MINOR) && \
Expand Down Expand Up @@ -168,9 +170,50 @@ static constexpr const char kAwsEndpointUrlEnvVar[] = "AWS_ENDPOINT_URL";
static constexpr const char kAwsEndpointUrlS3EnvVar[] = "AWS_ENDPOINT_URL_S3";
static constexpr const char kAwsDirectoryContentType[] = "application/x-directory";

template <typename S3RequestType>
void SetSSECustomerKey(S3RequestType& request, const std::string& sse_customer_algorithm,
const std::string& sse_customer_key,
const std::string& sse_customer_key_md5) {
if (!sse_customer_algorithm.empty()) {
request.SetSSECustomerAlgorithm(sse_customer_algorithm);
}
if (!sse_customer_key.empty()) {
request.SetSSECustomerKey(sse_customer_key);
}
if (!sse_customer_key_md5.empty()) {
request.SetSSECustomerKeyMD5(sse_customer_key_md5);
}
}

// -----------------------------------------------------------------------
// S3ProxyOptions implementation

std::string ComputeMD5Base64(const std::string& base64EncodedKey) {
// Decode the Base64-encoded key to get the raw binary key
Aws::Utils::ByteBuffer rawKey = Aws::Utils::HashingUtils::Base64Decode(base64EncodedKey);

// Convert the raw binary key to an Aws::String
Aws::String rawKeyStr(reinterpret_cast<const char*>(rawKey.GetUnderlyingData()),
rawKey.GetLength());

// Compute the MD5 hash of the raw binary key
Aws::Utils::ByteBuffer md5Hash = Aws::Utils::HashingUtils::CalculateMD5(rawKeyStr);

// Base64-encode the MD5 hash
Aws::String awsEncodedHash = Aws::Utils::HashingUtils::Base64Encode(md5Hash);

// Return the Base64-encoded MD5 hash as a std::string
return std::string(awsEncodedHash.begin(), awsEncodedHash.end());
}

/// Set the SSE-C customized key.
void S3Options::SetSSECKey(const std::string& c_key)
{
sse_customer_algorithm = "AES256";
sse_customer_key = c_key;
sse_customer_key_md5 = ComputeMD5Base64(c_key);
}

Result<S3ProxyOptions> S3ProxyOptions::FromUri(const Uri& uri) {
S3ProxyOptions options;

Expand Down Expand Up @@ -439,6 +482,9 @@ bool S3Options::Equals(const S3Options& other) const {
background_writes == other.background_writes &&
allow_bucket_creation == other.allow_bucket_creation &&
allow_bucket_deletion == other.allow_bucket_deletion &&
sse_customer_key == other.GetSSECKey() &&
sse_customer_algorithm == other.GetSSECAlgorithm() &&
sse_customer_key_md5 == other.GetSSECKeyMD5() &&
default_metadata_equals && GetAccessKey() == other.GetAccessKey() &&
GetSecretKey() == other.GetSecretKey() &&
GetSessionToken() == other.GetSessionToken());
Expand Down Expand Up @@ -1292,11 +1338,16 @@ Aws::IOStreamFactory AwsWriteableStreamFactory(void* data, int64_t nbytes) {
}

Result<S3Model::GetObjectResult> GetObjectRange(Aws::S3::S3Client* client,
const S3Path& path, int64_t start,
const S3Path& path,
const std::string& sse_customer_key,
const std::string& sse_customer_key_md5,
const std::string& sse_customer_algorithm,
int64_t start,
int64_t length, void* out) {
S3Model::GetObjectRequest req;
req.SetBucket(ToAwsString(path.bucket));
req.SetKey(ToAwsString(path.key));
SetSSECustomerKey(req, sse_customer_algorithm, sse_customer_key, sse_customer_key_md5);
req.SetRange(ToAwsString(FormatRange(start, length)));
req.SetResponseStreamFactory(AwsWriteableStreamFactory(out, length));
return OutcomeToResult("GetObject", client->GetObject(req));
Expand Down Expand Up @@ -1433,11 +1484,14 @@ bool IsDirectory(std::string_view key, const S3Model::HeadObjectResult& result)
class ObjectInputFile final : public io::RandomAccessFile {
public:
ObjectInputFile(std::shared_ptr<S3ClientHolder> holder, const io::IOContext& io_context,
const S3Path& path, int64_t size = kNoSize)
const S3Path& path, int64_t size = kNoSize, const std::string& c_algorithm = "AES256", const std::string& c_key = "", const std::string& c_key_md5 = "")
: holder_(std::move(holder)),
io_context_(io_context),
path_(path),
content_length_(size) {}
content_length_(size),
sse_customer_algorithm(c_algorithm),
sse_customer_key(c_key),
sse_customer_key_md5(c_key_md5) {}

Status Init() {
// Issue a HEAD Object to get the content-length and ensure any
Expand All @@ -1450,6 +1504,7 @@ class ObjectInputFile final : public io::RandomAccessFile {
S3Model::HeadObjectRequest req;
req.SetBucket(ToAwsString(path_.bucket));
req.SetKey(ToAwsString(path_.key));
SetSSECustomerKey(req, sse_customer_algorithm, sse_customer_key, sse_customer_key_md5);

ARROW_ASSIGN_OR_RAISE(auto client_lock, holder_->Lock());
auto outcome = client_lock.Move()->HeadObject(req);
Expand Down Expand Up @@ -1536,7 +1591,7 @@ class ObjectInputFile final : public io::RandomAccessFile {
ARROW_ASSIGN_OR_RAISE(auto client_lock, holder_->Lock());
ARROW_ASSIGN_OR_RAISE(
S3Model::GetObjectResult result,
GetObjectRange(client_lock.get(), path_, position, nbytes, out));
GetObjectRange(client_lock.get(), path_,sse_customer_key, sse_customer_key_md5, sse_customer_algorithm, position, nbytes, out));

auto& stream = result.GetBody();
stream.ignore(nbytes);
Expand Down Expand Up @@ -1584,6 +1639,9 @@ class ObjectInputFile final : public io::RandomAccessFile {
int64_t pos_ = 0;
int64_t content_length_ = kNoSize;
std::shared_ptr<const KeyValueMetadata> metadata_;
std::string sse_customer_algorithm;
std::string sse_customer_key;
std::string sse_customer_key_md5;
};

// Upload size per part. While AWS and Minio support different sizes for each
Expand Down Expand Up @@ -1620,7 +1678,10 @@ class ObjectOutputStream final : public io::OutputStream {
metadata_(metadata),
default_metadata_(options.default_metadata),
background_writes_(options.background_writes),
allow_delayed_open_(options.allow_delayed_open) {}
allow_delayed_open_(options.allow_delayed_open),
sse_customer_algorithm(options.GetSSECAlgorithm()),
sse_customer_key(options.options.GetSSECKey()),
sse_customer_key_md5(options.GetSSECKeyMD5()) {}

~ObjectOutputStream() override {
// For compliance with the rest of the IO stack, Close rather than Abort,
Expand Down Expand Up @@ -1668,6 +1729,8 @@ class ObjectOutputStream final : public io::OutputStream {
S3Model::CreateMultipartUploadRequest req;
req.SetBucket(ToAwsString(path_.bucket));
req.SetKey(ToAwsString(path_.key));
SetSSECustomerKey(req, sse_customer_algorithm, sse_customer_key, sse_customer_key_md5);

RETURN_NOT_OK(SetMetadataInRequest(&req));

auto outcome = client_lock.Move()->CreateMultipartUpload(req);
Expand Down Expand Up @@ -1769,6 +1832,9 @@ class ObjectOutputStream final : public io::OutputStream {
S3Model::CompleteMultipartUploadRequest req;
req.SetBucket(ToAwsString(path_.bucket));
req.SetKey(ToAwsString(path_.key));
SetSSECustomerKey(req, sse_customer_algorithm, sse_customer_key, sse_customer_key_md5);


req.SetUploadId(multipart_upload_id_);
req.SetMultipartUpload(std::move(completed_upload));

Expand Down Expand Up @@ -1950,6 +2016,8 @@ class ObjectOutputStream final : public io::OutputStream {
req.SetKey(ToAwsString(path_.key));
req.SetBody(std::make_shared<StringViewStream>(data, nbytes));
req.SetContentLength(nbytes);
SetSSECustomerKey(req, sse_customer_algorithm, sse_customer_key, sse_customer_key_md5);


if (!background_writes_) {
req.SetBody(std::make_shared<StringViewStream>(data, nbytes));
Expand Down Expand Up @@ -2171,6 +2239,9 @@ class ObjectOutputStream final : public io::OutputStream {
Future<> pending_uploads_completed = Future<>::MakeFinished(Status::OK());
};
std::shared_ptr<UploadState> upload_state_;
std::string sse_customer_algorithm;
std::string sse_customer_key;
std::string sse_customer_key_md5;
};

// This function assumes info->path() is already set
Expand Down Expand Up @@ -2972,7 +3043,8 @@ class S3FileSystem::Impl : public std::enable_shared_from_this<S3FileSystem::Imp

RETURN_NOT_OK(CheckS3Initialized());

auto ptr = std::make_shared<ObjectInputFile>(holder_, fs->io_context(), path);
auto ptr = std::make_shared<ObjectInputFile>(holder_, fs->io_context(), path, kNoSize, fs->options().GetSSECAlgorithm(),
fs->options().GetSSECKey(), fs->options().GetSSECKeyMD5());
RETURN_NOT_OK(ptr->Init());
return ptr;
}
Expand All @@ -2993,7 +3065,8 @@ class S3FileSystem::Impl : public std::enable_shared_from_this<S3FileSystem::Imp
RETURN_NOT_OK(CheckS3Initialized());

auto ptr =
std::make_shared<ObjectInputFile>(holder_, fs->io_context(), path, info.size());
std::make_shared<ObjectInputFile>(holder_, fs->io_context(), path, info.size(), fs->options().GetSSECAlgorithm(),
fs->options().GetSSECKey(), fs->options().GetSSECKeyMD5());
RETURN_NOT_OK(ptr->Init());
return ptr;
}
Expand Down
10 changes: 10 additions & 0 deletions cpp/src/arrow/filesystem/s3fs.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,16 @@ struct ARROW_EXPORT S3Options {
std::string* out_path = NULLPTR);
static Result<S3Options> FromUri(const std::string& uri,
std::string* out_path = NULLPTR);
/// Set the SSE-C customized key.
void SetSSECKey(const std::string& sse_customer_key);
std::string GetSSECKey() const { return sse_customer_key; }
std::string GetSSECAlgorithm() const { return sse_customer_algorithm; }
std::string GetSSECKeyMD5() const { return sse_customer_key_md5; }

private:
std::string sse_customer_algorithm;
std::string sse_customer_key;
std::string sse_customer_key_md5;
};

/// S3-backed FileSystem implementation.
Expand Down

0 comments on commit 1134b07

Please sign in to comment.