Skip to content

Commit

Permalink
[FEAT] Add S3Config.from_env functionality (#2137)
Browse files Browse the repository at this point in the history
Adds a constructor for S3Config: `S3Config.from_env`.

This constructor creates an S3Config from the current environment,
leveraging our current code for auto-discovering things such as: region,
credentials and anonymous mode.

Closes: #2139

---------

Co-authored-by: Jay Chia <[email protected]@users.noreply.github.com>
  • Loading branch information
jaychia and Jay Chia authored Apr 16, 2024
1 parent d153668 commit 897b384
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 4 deletions.
7 changes: 7 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,11 @@ class S3Config:
"""Replaces values if provided, returning a new S3Config"""
...

@staticmethod
def from_env() -> S3Config:
"""Creates an S3Config, retrieving credentials and configurations from the current environtment"""
...

class AzureConfig:
"""
I/O configuration for accessing Azure Blob Storage.
Expand Down Expand Up @@ -530,6 +535,8 @@ class IOConfig:
"""
Recreate an IOConfig from a JSON string.
"""
...

def replace(
self, s3: S3Config | None = None, azure: AzureConfig | None = None, gcs: GCSConfig | None = None
) -> IOConfig:
Expand Down
15 changes: 15 additions & 0 deletions src/common/io-config/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,21 @@ impl S3Config {
}
}

/// Creates an S3Config from the current environment, auto-discovering variables such as
/// credentials, regions and more.
#[staticmethod]
pub fn from_env(py: Python) -> PyResult<Self> {
let io_config_from_env_func = py
.import(pyo3::intern!(py, "daft"))?
.getattr(pyo3::intern!(py, "daft"))?
.getattr(pyo3::intern!(py, "s3_config_from_env"))?;
io_config_from_env_func.call0().map(|pyany| {
pyany
.extract()
.expect("s3_config_from_env function must return S3Config")
})
}

pub fn __repr__(&self) -> PyResult<String> {
Ok(format!("{}", self.config))
}
Expand Down
17 changes: 15 additions & 2 deletions src/daft-io/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ pub use common_io_config::python::{AzureConfig, GCSConfig, IOConfig};
pub use py::register_modules;

mod py {
use crate::{get_io_client, get_runtime, parse_url, stats::IOStatsContext};
use crate::{get_io_client, get_runtime, parse_url, s3_like, stats::IOStatsContext};
use common_error::DaftResult;
use futures::TryStreamExt;
use pyo3::{
Expand Down Expand Up @@ -66,11 +66,24 @@ mod py {
Ok(crate::set_io_pool_num_threads(num_threads as usize))
}

/// Creates an S3Config from the current environment, auto-discovering variables such as
/// credentials, regions and more.
#[pyfunction]
fn s3_config_from_env(py: Python) -> PyResult<common_io_config::python::S3Config> {
let s3_config: DaftResult<common_io_config::S3Config> = py.allow_threads(|| {
let runtime = get_runtime(false)?;
let runtime_handle = runtime.handle();
let _rt_guard = runtime_handle.enter();
runtime_handle.block_on(async { Ok(s3_like::s3_config_from_env().await?) })
});
Ok(common_io_config::python::S3Config { config: s3_config? })
}

pub fn register_modules(py: Python, parent: &PyModule) -> PyResult<()> {
common_io_config::python::register_modules(py, parent)?;
parent.add_function(wrap_pyfunction!(io_glob, parent)?)?;
parent.add_function(wrap_pyfunction!(set_io_pool_num_threads, parent)?)?;

parent.add_function(wrap_pyfunction!(s3_config_from_env, parent)?)?;
Ok(())
}
}
40 changes: 38 additions & 2 deletions src/daft-io/src/s3_like.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,33 @@ impl From<Error> for super::Error {
}
}

/// Retrieves an S3Config from the environment by leveraging the AWS SDK's credentials chain
pub(crate) async fn s3_config_from_env() -> super::Result<S3Config> {
let default_s3_config = S3Config::default();
let (anonymous, s3_conf) = build_s3_conf(&default_s3_config, None).await?;
let creds = s3_conf
.credentials_cache()
.provide_cached_credentials()
.await
.with_context(|_| UnableToLoadCredentialsSnafu {})?;
let key_id = Some(creds.access_key_id().to_string());
let access_key = Some(creds.secret_access_key().to_string());
let session_token = creds.session_token().map(|t| t.to_string());
let region_name = s3_conf.region().map(|r| r.to_string());
Ok(S3Config {
// Do not perform auto-discovery of endpoint_url. This is possible, but requires quite a bit
// of work that our current implementation of `build_s3_conf` does not yet do. See smithy-rs code:
// https://github.com/smithy-lang/smithy-rs/blob/94ecd38c2518583042796b2b45c37947237e31dd/aws/rust-runtime/aws-config/src/lib.rs#L824-L849
endpoint_url: None,
region_name,
key_id,
session_token,
access_key,
anonymous,
..default_s3_config
})
}

/// Helper to parse S3 URLs, returning (scheme, bucket, key)
fn parse_url(uri: &str) -> super::Result<(String, String, String)> {
let parsed = url::Url::parse(uri).with_context(|_| InvalidUrlSnafu { path: uri })?;
Expand Down Expand Up @@ -247,10 +274,10 @@ fn handle_https_client_settings(
Ok(builder)
}

async fn build_s3_client(
async fn build_s3_conf(
config: &S3Config,
credentials_cache: Option<SharedCredentialsCache>,
) -> super::Result<(bool, s3::Client)> {
) -> super::Result<(bool, s3::Config)> {
const DEFAULT_REGION: Region = Region::from_static("us-east-1");

let mut anonymous = config.anonymous;
Expand Down Expand Up @@ -405,6 +432,15 @@ async fn build_s3_client(
} else {
s3_conf
};

Ok((anonymous, s3_conf))
}

async fn build_s3_client(
config: &S3Config,
credentials_cache: Option<SharedCredentialsCache>,
) -> super::Result<(bool, s3::Client)> {
let (anonymous, s3_conf) = build_s3_conf(config, credentials_cache).await?;
Ok((anonymous, s3::Client::from_conf(s3_conf)))
}

Expand Down

0 comments on commit 897b384

Please sign in to comment.