diff --git a/sdk/cosmos/azure_data_cosmos/src/clients/container_client.rs b/sdk/cosmos/azure_data_cosmos/src/clients/container_client.rs index 42a49f5583..d580abe448 100644 --- a/sdk/cosmos/azure_data_cosmos/src/clients/container_client.rs +++ b/sdk/cosmos/azure_data_cosmos/src/clients/container_client.rs @@ -5,30 +5,37 @@ use crate::{ constants, models::{ContainerProperties, Item, QueryResults}, options::{QueryOptions, ReadContainerOptions}, - pipeline::{CosmosPipeline, ResourceType}, - utils::AppendPathSegments, + pipeline::CosmosPipeline, + resource_context::{ResourceLink, ResourceType}, DeleteContainerOptions, ItemOptions, PartitionKey, Query, QueryPartitionStrategy, }; use azure_core::{Context, Method, Pager, Request, Response}; use serde::{de::DeserializeOwned, Serialize}; -use typespec_client_core::http::PagerResult; -use url::Url; /// A client for working with a specific container in a Cosmos DB account. /// /// You can get a `Container` by calling [`DatabaseClient::container_client()`](crate::clients::DatabaseClient::container_client()). pub struct ContainerClient { - container_url: Url, + link: ResourceLink, + items_link: ResourceLink, pipeline: CosmosPipeline, } impl ContainerClient { - pub(crate) fn new(pipeline: CosmosPipeline, database_url: &Url, container_name: &str) -> Self { - let container_url = database_url.with_path_segments(["colls", container_name]); + pub(crate) fn new( + pipeline: CosmosPipeline, + database_link: &ResourceLink, + container_id: &str, + ) -> Self { + let link = database_link + .feed(ResourceType::Containers) + .item(container_id); + let items_link = link.feed(ResourceType::Items); Self { - container_url, + link, + items_link, pipeline, } } @@ -58,9 +65,10 @@ impl ContainerClient { // REASON: This is a documented public API so prefixing with '_' is undesirable. options: Option, ) -> azure_core::Result> { - let mut req = Request::new(self.container_url.clone(), Method::Get); + let url = self.pipeline.url(&self.link); + let mut req = Request::new(url, Method::Get); self.pipeline - .send(Context::new(), &mut req, ResourceType::Containers) + .send(Context::new(), &mut req, self.link.clone()) .await } @@ -76,9 +84,10 @@ impl ContainerClient { // REASON: This is a documented public API so prefixing with '_' is undesirable. options: Option, ) -> azure_core::Result { - let mut req = Request::new(self.container_url.clone(), Method::Delete); + let url = self.pipeline.url(&self.link); + let mut req = Request::new(url, Method::Delete); self.pipeline - .send(Context::new(), &mut req, ResourceType::Containers) + .send(Context::new(), &mut req, self.link.clone()) .await } @@ -126,12 +135,12 @@ impl ContainerClient { // REASON: This is a documented public API so prefixing with '_' is undesirable. options: Option, ) -> azure_core::Result>> { - let url = self.container_url.with_path_segments(["docs"]); + let url = self.pipeline.url(&self.items_link); let mut req = Request::new(url, Method::Post); req.insert_headers(&partition_key.into())?; req.set_json(&item)?; self.pipeline - .send(Context::new(), &mut req, ResourceType::Items) + .send(Context::new(), &mut req, self.items_link.clone()) .await } @@ -181,15 +190,12 @@ impl ContainerClient { // REASON: This is a documented public API so prefixing with '_' is undesirable. options: Option, ) -> azure_core::Result>> { - let url = self - .container_url - .with_path_segments(["docs", item_id.as_ref()]); + let link = self.items_link.item(item_id); + let url = self.pipeline.url(&link); let mut req = Request::new(url, Method::Put); req.insert_headers(&partition_key.into())?; req.set_json(&item)?; - self.pipeline - .send(Context::new(), &mut req, ResourceType::Items) - .await + self.pipeline.send(Context::new(), &mut req, link).await } /// Creates or replaces an item in the container. @@ -239,13 +245,13 @@ impl ContainerClient { // REASON: This is a documented public API so prefixing with '_' is undesirable. options: Option, ) -> azure_core::Result>> { - let url = self.container_url.with_path_segments(["docs"]); + let url = self.pipeline.url(&self.items_link); let mut req = Request::new(url, Method::Post); req.insert_header(constants::IS_UPSERT, "true"); req.insert_headers(&partition_key.into())?; req.set_json(&item)?; self.pipeline - .send(Context::new(), &mut req, ResourceType::Items) + .send(Context::new(), &mut req, self.items_link.clone()) .await } @@ -288,14 +294,11 @@ impl ContainerClient { // REASON: This is a documented public API so prefixing with '_' is undesirable. options: Option, ) -> azure_core::Result>> { - let url = self - .container_url - .with_path_segments(["docs", item_id.as_ref()]); + let link = self.items_link.item(item_id); + let url = self.pipeline.url(&link); let mut req = Request::new(url, Method::Get); req.insert_headers(&partition_key.into())?; - self.pipeline - .send(Context::new(), &mut req, ResourceType::Items) - .await + self.pipeline.send(Context::new(), &mut req, link).await } /// Deletes an item from the container. @@ -326,14 +329,11 @@ impl ContainerClient { // REASON: This is a documented public API so prefixing with '_' is undesirable. options: Option, ) -> azure_core::Result { - let url = self - .container_url - .with_path_segments(["docs", item_id.as_ref()]); + let link = self.items_link.item(item_id); + let url = self.pipeline.url(&link); let mut req = Request::new(url, Method::Delete); req.insert_headers(&partition_key.into())?; - self.pipeline - .send(Context::new(), &mut req, ResourceType::Items) - .await + self.pipeline.send(Context::new(), &mut req, link).await } /// Executes a single-partition query against items in the container. @@ -399,40 +399,12 @@ impl ContainerClient { // REASON: This is a documented public API so prefixing with '_' is undesirable. options: Option, ) -> azure_core::Result>> { - let mut url = self.container_url.clone(); - url.append_path_segments(["docs"]); - let mut base_req = Request::new(url, Method::Post); - - base_req.insert_header(constants::QUERY, "True"); - base_req.add_mandatory_header(&constants::QUERY_CONTENT_TYPE); - + let url = self.pipeline.url(&self.items_link); + let mut base_request = Request::new(url, Method::Post); let QueryPartitionStrategy::SinglePartition(partition_key) = partition_key.into(); - base_req.insert_headers(&partition_key)?; + base_request.insert_headers(&partition_key)?; - base_req.set_json(&query.into())?; - - // We have to double-clone here. - // First we clone the pipeline to pass it in to the closure - let pipeline = self.pipeline.clone(); - Ok(Pager::from_callback(move |continuation| { - // Then we have to clone it again to pass it in to the async block. - // This is because Pageable can't borrow any data, it has to own it all. - // That's probably good, because it means a Pageable can outlive the client that produced it, but it requires some extra cloning. - let pipeline = pipeline.clone(); - let mut req = base_req.clone(); - async move { - if let Some(continuation) = continuation { - req.insert_header(constants::CONTINUATION, continuation); - } - - let response = pipeline - .send(Context::new(), &mut req, ResourceType::Items) - .await?; - Ok(PagerResult::from_response_header( - response, - &constants::CONTINUATION, - )) - } - })) + self.pipeline + .send_query_request(query.into(), base_request, self.items_link.clone()) } } diff --git a/sdk/cosmos/azure_data_cosmos/src/clients/cosmos_client.rs b/sdk/cosmos/azure_data_cosmos/src/clients/cosmos_client.rs index 2673b34fb7..da2d982e39 100644 --- a/sdk/cosmos/azure_data_cosmos/src/clients/cosmos_client.rs +++ b/sdk/cosmos/azure_data_cosmos/src/clients/cosmos_client.rs @@ -4,12 +4,11 @@ use crate::{ clients::DatabaseClient, models::{DatabaseProperties, DatabaseQueryResults, Item}, - pipeline::{AuthorizationPolicy, CosmosPipeline, ResourceType}, - utils::AppendPathSegments, + pipeline::{AuthorizationPolicy, CosmosPipeline}, + resource_context::{ResourceLink, ResourceType}, CosmosClientOptions, CreateDatabaseOptions, Query, QueryDatabasesOptions, }; use azure_core::{credentials::TokenCredential, Context, Method, Request, Response, Url}; - use serde::Serialize; use std::sync::Arc; @@ -19,8 +18,8 @@ use azure_core::credentials::Secret; /// Client for Azure Cosmos DB. #[derive(Debug, Clone)] pub struct CosmosClient { - endpoint: Url, - pub(crate) pipeline: CosmosPipeline, + databases_link: ResourceLink, + pipeline: CosmosPipeline, #[allow(dead_code)] options: CosmosClientOptions, @@ -51,8 +50,9 @@ impl CosmosClient { ) -> azure_core::Result { let options = options.unwrap_or_default(); Ok(Self { - endpoint: endpoint.as_ref().parse()?, + databases_link: ResourceLink::root(ResourceType::Databases), pipeline: CosmosPipeline::new( + endpoint.as_ref().parse()?, AuthorizationPolicy::from_token_credential(credential), options.client_options.clone(), ), @@ -83,8 +83,9 @@ impl CosmosClient { ) -> azure_core::Result { let options = options.unwrap_or_default(); Ok(Self { - endpoint: endpoint.as_ref().parse()?, + databases_link: ResourceLink::root(ResourceType::Databases), pipeline: CosmosPipeline::new( + endpoint.as_ref().parse()?, AuthorizationPolicy::from_shared_key(key.into()), options.client_options.clone(), ), @@ -97,12 +98,12 @@ impl CosmosClient { /// # Arguments /// * `id` - The ID of the database. pub fn database_client(&self, id: impl AsRef) -> DatabaseClient { - DatabaseClient::new(self.pipeline.clone(), &self.endpoint, id.as_ref()) + DatabaseClient::new(self.pipeline.clone(), id.as_ref()) } /// Gets the endpoint of the database account this client is connected to. pub fn endpoint(&self) -> &Url { - &self.endpoint + &self.pipeline.endpoint } /// Executes a query against databases in the account. @@ -136,11 +137,11 @@ impl CosmosClient { // REASON: This is a documented public API so prefixing with '_' is undesirable. options: Option, ) -> azure_core::Result> { - let url = self.endpoint.with_path_segments(["dbs"]); + let url = self.pipeline.url(&self.databases_link); let base_request = Request::new(url, azure_core::Method::Post); self.pipeline - .send_query_request(query.into(), base_request, ResourceType::Databases) + .send_query_request(query.into(), base_request, self.databases_link.clone()) } /// Creates a new database. @@ -163,12 +164,12 @@ impl CosmosClient { id: String, } - let url = self.endpoint.with_path_segments(["dbs"]); + let url = self.pipeline.url(&self.databases_link); let mut req = Request::new(url, Method::Post); req.set_json(&RequestBody { id })?; self.pipeline - .send(Context::new(), &mut req, ResourceType::Databases) + .send(Context::new(), &mut req, self.databases_link.clone()) .await } } diff --git a/sdk/cosmos/azure_data_cosmos/src/clients/database_client.rs b/sdk/cosmos/azure_data_cosmos/src/clients/database_client.rs index b6207afc90..634a5a7265 100644 --- a/sdk/cosmos/azure_data_cosmos/src/clients/database_client.rs +++ b/sdk/cosmos/azure_data_cosmos/src/clients/database_client.rs @@ -4,33 +4,34 @@ use crate::{ clients::ContainerClient, models::{ContainerProperties, ContainerQueryResults, DatabaseProperties, Item}, - pipeline::{CosmosPipeline, ResourceType}, - utils::AppendPathSegments, + options::ReadDatabaseOptions, + pipeline::CosmosPipeline, + resource_context::{ResourceLink, ResourceType}, CreateContainerOptions, DeleteDatabaseOptions, Query, QueryContainersOptions, - ReadDatabaseOptions, }; use azure_core::{Context, Method, Pager, Request, Response}; -use url::Url; - /// A client for working with a specific database in a Cosmos DB account. /// /// You can get a `DatabaseClient` by calling [`CosmosClient::database_client()`](crate::CosmosClient::database_client()). pub struct DatabaseClient { + link: ResourceLink, + containers_link: ResourceLink, database_id: String, - database_url: Url, pipeline: CosmosPipeline, } impl DatabaseClient { - pub(crate) fn new(pipeline: CosmosPipeline, base_url: &Url, database_id: &str) -> Self { + pub(crate) fn new(pipeline: CosmosPipeline, database_id: &str) -> Self { let database_id = database_id.to_string(); - let database_url = base_url.with_path_segments(["dbs", &database_id]); + let link = ResourceLink::root(ResourceType::Databases).item(&database_id); + let containers_link = link.feed(ResourceType::Containers); Self { + link, + containers_link, database_id, - database_url, pipeline, } } @@ -40,7 +41,7 @@ impl DatabaseClient { /// # Arguments /// * `name` - The name of the container. pub fn container_client(&self, name: impl AsRef) -> ContainerClient { - ContainerClient::new(self.pipeline.clone(), &self.database_url, name.as_ref()) + ContainerClient::new(self.pipeline.clone(), &self.link, name.as_ref()) } /// Returns the identifier of the Cosmos database. @@ -73,9 +74,10 @@ impl DatabaseClient { // REASON: This is a documented public API so prefixing with '_' is undesirable. options: Option, ) -> azure_core::Result> { - let mut req = Request::new(self.database_url.clone(), azure_core::Method::Get); + let url = self.pipeline.url(&self.link); + let mut req = Request::new(url, Method::Get); self.pipeline - .send(Context::new(), &mut req, ResourceType::Databases) + .send(Context::new(), &mut req, self.link.clone()) .await } @@ -110,12 +112,11 @@ impl DatabaseClient { // REASON: This is a documented public API so prefixing with '_' is undesirable. options: Option, ) -> azure_core::Result> { - let mut url = self.database_url.clone(); - url.append_path_segments(["colls"]); - let base_request = Request::new(url, azure_core::Method::Post); + let url = self.pipeline.url(&self.containers_link); + let base_request = Request::new(url, Method::Post); self.pipeline - .send_query_request(query.into(), base_request, ResourceType::Containers) + .send_query_request(query.into(), base_request, self.containers_link.clone()) } /// Creates a new container. @@ -133,12 +134,12 @@ impl DatabaseClient { // REASON: This is a documented public API so prefixing with '_' is undesirable. options: Option, ) -> azure_core::Result>> { - let url = self.database_url.with_path_segments(["colls"]); + let url = self.pipeline.url(&self.containers_link); let mut req = Request::new(url, Method::Post); req.set_json(&properties)?; self.pipeline - .send(Context::new(), &mut req, ResourceType::Containers) + .send(Context::new(), &mut req, self.containers_link.clone()) .await } @@ -154,9 +155,10 @@ impl DatabaseClient { // REASON: This is a documented public API so prefixing with '_' is undesirable. options: Option, ) -> azure_core::Result { - let mut req = Request::new(self.database_url.clone(), Method::Delete); + let url = self.pipeline.url(&self.link); + let mut req = Request::new(url, Method::Delete); self.pipeline - .send(Context::new(), &mut req, ResourceType::Databases) + .send(Context::new(), &mut req, self.link.clone()) .await } } diff --git a/sdk/cosmos/azure_data_cosmos/src/lib.rs b/sdk/cosmos/azure_data_cosmos/src/lib.rs index c92fb27061..b817b7873a 100644 --- a/sdk/cosmos/azure_data_cosmos/src/lib.rs +++ b/sdk/cosmos/azure_data_cosmos/src/lib.rs @@ -16,6 +16,7 @@ mod options; mod partition_key; pub(crate) mod pipeline; mod query; +pub(crate) mod resource_context; pub(crate) mod utils; pub mod models; diff --git a/sdk/cosmos/azure_data_cosmos/src/pipeline/authorization_policy.rs b/sdk/cosmos/azure_data_cosmos/src/pipeline/authorization_policy.rs index f6e6e013b5..ede1bfbd12 100644 --- a/sdk/cosmos/azure_data_cosmos/src/pipeline/authorization_policy.rs +++ b/sdk/cosmos/azure_data_cosmos/src/pipeline/authorization_policy.rs @@ -7,37 +7,21 @@ //! Instead, it uses a custom header format, as defined in the [official documentation](https://docs.microsoft.com/en-us/rest/api/cosmos-db/access-control-on-cosmosdb-resources). //! We implement that policy here, because we can't use any standard Azure SDK authentication policy. -use azure_core::credentials::TokenCredential; -use azure_core::date::OffsetDateTime; +#[cfg_attr(not(feature = "key_auth"), allow(unused_imports))] use azure_core::{ - date, + credentials::{Secret, TokenCredential}, + date::{self, OffsetDateTime}, headers::{HeaderValue, AUTHORIZATION, MS_DATE, VERSION}, Context, Policy, PolicyResult, Request, Url, }; use std::sync::Arc; use tracing::trace; -use url::form_urlencoded; -#[cfg(feature = "key_auth")] -use azure_core::{credentials::Secret, hmac::hmac_sha256}; +use crate::{pipeline::signature_target::SignatureTarget, resource_context::ResourceLink}; + +use crate::utils::url_encode; const AZURE_VERSION: &str = "2020-07-15"; -const VERSION_NUMBER: &str = "1.0"; - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[allow(dead_code)] // For the variants. Can be removed when we have them all implemented. -pub(crate) enum ResourceType { - Databases, - Containers, - Items, - StoredProcedures, - Users, - Permissions, - Attachments, - PartitionKeyRanges, - UserDefinedFunctions, - Triggers, -} #[derive(Debug, Clone)] enum Credential { @@ -49,14 +33,6 @@ enum Credential { PrimaryKey(Secret), } -#[cfg(feature = "key_auth")] -struct SignatureTarget<'a> { - http_method: &'a azure_core::Method, - resource_type: &'a ResourceType, - resource_link: &'a str, - time_nonce: OffsetDateTime, -} - #[derive(Debug, Clone)] pub struct AuthorizationPolicy { credential: Credential, @@ -93,28 +69,21 @@ impl Policy for AuthorizationPolicy { "Authorization policies cannot be the last policy of a pipeline" ); - let time_nonce = OffsetDateTime::now_utc(); + // x-ms-date and the string used in the signature must be exactly the same, so just generate it here once. + let date_string = date::to_rfc1123(&OffsetDateTime::now_utc()).to_lowercase(); - let resource_link = extract_resource_link(request); - let resource_type: &ResourceType = ctx + let resource_link: &ResourceLink = ctx .value() - .expect("ResourceType must be in the Context at this point"); + .expect("ResourceContext should have been provided by CosmosPipeline"); + let auth = generate_authorization( &self.credential, request.url(), - #[cfg(feature = "key_auth")] - SignatureTarget { - http_method: request.method(), - resource_type, - resource_link: &resource_link, - time_nonce, - }, + SignatureTarget::new(*request.method(), resource_link, &date_string), ) .await?; - trace!(?resource_type, resource_link, "AuthorizationPolicy applied"); - - request.insert_header(MS_DATE, HeaderValue::from(date::to_rfc1123(&time_nonce))); + request.insert_header(MS_DATE, HeaderValue::from(date_string)); request.insert_header(VERSION, HeaderValue::from_static(AZURE_VERSION)); request.insert_header(AUTHORIZATION, HeaderValue::from(auth)); @@ -123,57 +92,6 @@ impl Policy for AuthorizationPolicy { } } -/// This function strips the leading slash and the resource name from the uri of the passed request. -/// It does not strip the resource name if the resource name is not present. This is accomplished in -/// four steps (with eager return): -/// 1. Strip leading slash from the uri of the passed request. -/// 2. Find if the uri ends with a `ENDING_STRING`. If so, strip it and return. Every `ENDING_STRING` -/// starts with a leading slash so this check will not match uri composed **only** of the -/// `ENDING_STRING`. -/// 3. Find if the uri **is** the ending string (without the leading slash). If so return an empty -/// string. This covers the exception of the rule above. -/// 4. Return the received uri unchanged. -fn extract_resource_link(request: &Request) -> String { - static ENDING_STRINGS: &[&str] = &[ - "/dbs", - "/colls", - "/docs", - "/sprocs", - "/users", - "/permissions", - "/attachments", - "/pkranges", - "/udfs", - "/triggers", - ]; - - // This strips the leading slash from the uri of the passed request. - let uri_path = request.path_and_query(); - let uri = uri_path.trim_start_matches('/'); - - // We find the above resource names. If found, we strip it and eagerly return. Note that the - // resource names have a leading slash so the suffix will match `test/users` but not - // `test-users`. - for ending in ENDING_STRINGS { - if let Some(uri_without_ending) = uri.strip_suffix(ending) { - return uri_without_ending.to_string(); - } - } - - // This check handles the uris comprised by resource names only. It will match `users` and - // return an empty string. This is necessary because the previous check included a leading - // slash. - if ENDING_STRINGS - .iter() - .map(|ending| &ending[1..]) // this is safe since every ENDING_STRING starts with a slash - .any(|item| uri == item) - { - String::new() - } else { - uri.to_string() - } -} - /// Generates the 'Authorization' header value based on the provided values. /// /// The specific result format depends on the type of the auth token provided. @@ -189,29 +107,26 @@ fn extract_resource_link(request: &Request) -> String { async fn generate_authorization<'a>( auth_token: &Credential, url: &Url, - #[cfg(feature = "key_auth")] signature_target: SignatureTarget<'a>, + + // Unused unless feature="key_auth", but I don't want to mess with excluding it since it makes call sites more complicated + #[allow(unused_variables)] signature_target: SignatureTarget<'a>, ) -> azure_core::Result { - let (authorization_type, signature) = match auth_token { - Credential::Token(token_credential) => ( - "aad", - token_credential + let token = match auth_token { + Credential::Token(token_credential) => { + let token = token_credential .get_token(&[&scope_from_url(url)]) .await? .token .secret() - .to_string(), - ), + .to_string(); + format!("type=aad&ver=1.0&sig={token}") + } #[cfg(feature = "key_auth")] - Credential::PrimaryKey(key) => { - let string_to_sign = string_to_sign(signature_target); - ("master", hmac_sha256(&string_to_sign, key)?) - } + Credential::PrimaryKey(key) => signature_target.into_authorization(key)?, }; - let str_unencoded = format!("type={authorization_type}&ver={VERSION_NUMBER}&sig={signature}"); - - Ok(form_urlencoded::byte_serialize(str_unencoded.as_bytes()).collect::()) + Ok(url_encode(token)) } /// This function generates the scope string from the passed url. The scope string is used to @@ -222,63 +137,29 @@ fn scope_from_url(url: &Url) -> String { format!("{scheme}://{hostname}/.default") } -/// This function generates a valid authorization string, according to the documentation. -/// In case of authorization problems we can compare the `string_to_sign` generated by Azure against -/// our own. -#[cfg(feature = "key_auth")] -fn string_to_sign(signature_target: SignatureTarget) -> String { - // From official docs: - // StringToSign = - // Verb.toLowerCase() + "\n" + - // ResourceType.toLowerCase() + "\n" + - // ResourceLink + "\n" + - // Date.toLowerCase() + "\n" + - // "" + "\n"; - // Notice the empty string at the end so we need to add two new lines - - format!( - "{}\n{}\n{}\n{}\n\n", - match *signature_target.http_method { - azure_core::Method::Get => "get", - azure_core::Method::Put => "put", - azure_core::Method::Post => "post", - azure_core::Method::Delete => "delete", - azure_core::Method::Head => "head", - azure_core::Method::Trace => "trace", - azure_core::Method::Options => "options", - azure_core::Method::Connect => "connect", - azure_core::Method::Patch => "patch", - _ => "extension", - }, - match signature_target.resource_type { - ResourceType::Databases => "dbs", - ResourceType::Containers => "colls", // The rest API uses the old term "colls" (referring to 'collections') to refer to containers - ResourceType::Items => "docs", // The rest API uses the old term "docs" (referring to 'documents') to refer to items - ResourceType::StoredProcedures => "sprocs", - ResourceType::Users => "users", - ResourceType::Permissions => "permissions", - ResourceType::Attachments => "attachments", - ResourceType::PartitionKeyRanges => "pkranges", - ResourceType::UserDefinedFunctions => "udfs", - ResourceType::Triggers => "triggers", - }, - signature_target.resource_link, - date::to_rfc1123(&signature_target.time_nonce).to_lowercase() - ) -} - #[cfg(test)] mod tests { - #[cfg(feature = "key_auth")] - use azure_core::credentials::AccessToken; + use std::sync::Arc; - use super::*; + use azure_core::{ + credentials::{AccessToken, TokenCredential}, + date, + }; + use time::OffsetDateTime; + use url::Url; + + use crate::{ + pipeline::{ + authorization_policy::{generate_authorization, scope_from_url, Credential}, + signature_target::SignatureTarget, + }, + resource_context::{ResourceLink, ResourceType}, + utils::url_encode, + }; #[derive(Debug)] - #[cfg(feature = "key_auth")] struct TestTokenCredential(String); - #[cfg(feature = "key_auth")] #[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] #[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] impl TokenCredential for TestTokenCredential { @@ -295,32 +176,10 @@ mod tests { } } - #[test] - #[cfg(feature = "key_auth")] - fn string_to_sign_generates_expected_string_for_signing() { - let time_nonce = date::parse_rfc3339("1900-01-01T01:00:00.000000000+00:00").unwrap(); - - let ret = string_to_sign(SignatureTarget { - http_method: &azure_core::Method::Get, - resource_type: &ResourceType::Databases, - resource_link: "dbs/MyDatabase/colls/MyCollection", - time_nonce, - }); - assert_eq!( - ret, - "get -dbs -dbs/MyDatabase/colls/MyCollection -mon, 01 jan 1900 01:00:00 gmt - -" - ); - } - #[tokio::test] - #[cfg(feature = "key_auth")] async fn generate_authorization_for_token_credential() { let time_nonce = date::parse_rfc3339("1900-01-01T01:00:00.000000000+00:00").unwrap(); + let date_string = date::to_rfc1123(&time_nonce).to_lowercase(); let cred = Arc::new(TestTokenCredential("test_token".to_string())); let auth_token = Credential::Token(cred); @@ -330,20 +189,21 @@ mon, 01 jan 1900 01:00:00 gmt let ret = generate_authorization( &auth_token, &url, - SignatureTarget { - http_method: &azure_core::Method::Get, - resource_type: &ResourceType::Databases, - resource_link: "dbs/MyDatabase/colls/MyCollection", - time_nonce, - }, + SignatureTarget::new( + azure_core::Method::Get, + &ResourceLink::root(ResourceType::Databases) + .item("MyDatabase") + .feed(ResourceType::Containers) + .item("MyCollection"), + &date_string, + ), ) .await .unwrap(); - let expected: String = form_urlencoded::byte_serialize( + let expected: String = url_encode( b"type=aad&ver=1.0&sig=test_token+https://test_account.example.com/.default", - ) - .collect(); + ); assert_eq!(ret, expected); } @@ -352,6 +212,7 @@ mon, 01 jan 1900 01:00:00 gmt #[cfg(feature = "key_auth")] async fn generate_authorization_for_primary_key_0() { let time_nonce = date::parse_rfc3339("1900-01-01T01:00:00.000000000+00:00").unwrap(); + let date_string = date::to_rfc1123(&time_nonce).to_lowercase(); let auth_token = Credential::PrimaryKey( "8F8xXXOptJxkblM1DBXW7a6NMI5oE8NnwPGYBmwxLCKfejOK7B7yhcCHMGvN3PBrlMLIOeol1Hv9RCdzAZR5sg==".into(), @@ -363,20 +224,20 @@ mon, 01 jan 1900 01:00:00 gmt let ret = generate_authorization( &auth_token, &url, - SignatureTarget { - http_method: &azure_core::Method::Get, - resource_type: &ResourceType::Databases, - resource_link: "dbs/MyDatabase/colls/MyCollection", - time_nonce, - }, + SignatureTarget::new( + azure_core::Method::Get, + &ResourceLink::root(ResourceType::Databases) + .item("MyDatabase") + .feed(ResourceType::Containers) + .item("MyCollection"), + &date_string, + ), ) .await .unwrap(); - let expected: String = form_urlencoded::byte_serialize( - b"type=master&ver=1.0&sig=Qkz/r+1N2+PEnNijxGbGB/ADvLsLBQmZ7uBBMuIwf4I=", - ) - .collect(); + let expected: String = + url_encode(b"type=master&ver=1.0&sig=Qkz/r+1N2+PEnNijxGbGB/ADvLsLBQmZ7uBBMuIwf4I="); assert_eq!(ret, expected); } @@ -385,6 +246,7 @@ mon, 01 jan 1900 01:00:00 gmt #[cfg(feature = "key_auth")] async fn generate_authorization_for_primary_key_1() { let time_nonce = date::parse_rfc3339("2017-04-27T00:51:12.000000000+00:00").unwrap(); + let date_string = date::to_rfc1123(&time_nonce).to_lowercase(); let auth_token = Credential::PrimaryKey( "dsZQi3KtZmCv1ljt3VNWNm7sQUF1y5rJfC6kv5JiwvW0EndXdDku/dkKBp8/ufDToSxL".into(), @@ -396,60 +258,21 @@ mon, 01 jan 1900 01:00:00 gmt let ret = generate_authorization( &auth_token, &url, - SignatureTarget { - http_method: &azure_core::Method::Get, - resource_type: &ResourceType::Databases, - resource_link: "dbs/ToDoList", - time_nonce, - }, + SignatureTarget::new( + azure_core::Method::Get, + &ResourceLink::root(ResourceType::Databases).item("ToDoList"), + &date_string, + ), ) .await .unwrap(); - let expected: String = form_urlencoded::byte_serialize( - b"type=master&ver=1.0&sig=KvBM8vONofkv3yKm/8zD9MEGlbu6jjHDJBp4E9c2ZZI=", - ) - .collect(); + let expected: String = + url_encode(b"type=master&ver=1.0&sig=KvBM8vONofkv3yKm/8zD9MEGlbu6jjHDJBp4E9c2ZZI="); assert_eq!(ret, expected); } - #[test] - fn extract_resource_link_specific_db() { - let request = Request::new( - Url::parse("https://example.com/dbs/second").unwrap(), - azure_core::Method::Get, - ); - assert_eq!(&extract_resource_link(&request), "dbs/second"); - } - - #[test] - fn extract_resource_link_dbs_root() { - let request = Request::new( - Url::parse("https://example.com/dbs").unwrap(), - azure_core::Method::Get, - ); - assert_eq!(&extract_resource_link(&request), ""); - } - - #[test] - fn extract_resource_link_collection_nested() { - let request = Request::new( - Url::parse("https://example.com/colls/second/third").unwrap(), - azure_core::Method::Get, - ); - assert_eq!(&extract_resource_link(&request), "colls/second/third"); - } - - #[test] - fn extract_resource_link_collections_root() { - let request = Request::new( - Url::parse("https://.documents.azure.com/dbs/test_db/colls").unwrap(), - azure_core::Method::Get, - ); - assert_eq!(&extract_resource_link(&request), "dbs/test_db"); - } - #[test] fn scope_from_url_extracts_correct_scope() { let scope = scope_from_url(&Url::parse("https://example.com/dbs/test_db/colls").unwrap()); diff --git a/sdk/cosmos/azure_data_cosmos/src/pipeline/mod.rs b/sdk/cosmos/azure_data_cosmos/src/pipeline/mod.rs index 52fe0e7d9a..5465949aa1 100644 --- a/sdk/cosmos/azure_data_cosmos/src/pipeline/mod.rs +++ b/sdk/cosmos/azure_data_cosmos/src/pipeline/mod.rs @@ -2,49 +2,67 @@ // Licensed under the MIT License. mod authorization_policy; +mod signature_target; use std::sync::Arc; -pub(crate) use authorization_policy::{AuthorizationPolicy, ResourceType}; +pub use authorization_policy::AuthorizationPolicy; use azure_core::{Context, Pager, Request}; use serde::de::DeserializeOwned; use typespec_client_core::http::PagerResult; +use url::Url; -use crate::{constants, Query}; +use crate::{constants, resource_context::ResourceLink, Query}; /// Newtype that wraps an Azure Core pipeline to provide a Cosmos-specific pipeline which configures our authorization policy and enforces that a [`ResourceType`] is set on the context. #[derive(Debug, Clone)] -pub struct CosmosPipeline(azure_core::Pipeline); +pub struct CosmosPipeline { + pub endpoint: Url, + pipeline: azure_core::Pipeline, +} impl CosmosPipeline { pub fn new( + endpoint: Url, auth_policy: AuthorizationPolicy, client_options: azure_core::ClientOptions, ) -> Self { - CosmosPipeline(azure_core::Pipeline::new( - option_env!("CARGO_PKG_NAME"), - option_env!("CARGO_PKG_VERSION"), - client_options, - Vec::new(), - vec![Arc::new(auth_policy)], - )) + CosmosPipeline { + endpoint, + pipeline: azure_core::Pipeline::new( + option_env!("CARGO_PKG_NAME"), + option_env!("CARGO_PKG_VERSION"), + client_options, + Vec::new(), + vec![Arc::new(auth_policy)], + ), + } + } + + /// Creates a [`Url`] out of the provided [`ResourceLink`] + /// + /// This is a little backwards, ideally we'd accept [`ResourceLink`] in the [`CosmosPipeline::send`] method, + /// but we need callers to be able to build an [`azure_core::Request`] so they need to be able to get the full URL. + /// This allows the clients to hold a single thing representing the "connection" to a Cosmos DB account though. + pub fn url(&self, link: &ResourceLink) -> Url { + link.url(&self.endpoint) } pub async fn send( &self, ctx: azure_core::Context<'_>, request: &mut azure_core::Request, - resource_type: ResourceType, + resource_link: ResourceLink, ) -> azure_core::Result> { - let ctx = ctx.with_value(resource_type); - self.0.send(&ctx, request).await + let ctx = ctx.with_value(resource_link); + self.pipeline.send(&ctx, request).await } pub fn send_query_request( &self, query: Query, mut base_request: Request, - resource_type: ResourceType, + resource_link: ResourceLink, ) -> azure_core::Result> { base_request.insert_header(constants::QUERY, "True"); base_request.add_mandatory_header(&constants::QUERY_CONTENT_TYPE); @@ -52,8 +70,8 @@ impl CosmosPipeline { // We have to double-clone here. // First we clone the pipeline to pass it in to the closure - let pipeline = self.0.clone(); - let context = Context::new().with_value(resource_type); + let pipeline = self.pipeline.clone(); + let context = Context::new().with_value(resource_link); Ok(Pager::from_callback(move |continuation| { // Then we have to clone it again to pass it in to the async block. // This is because Pageable can't borrow any data, it has to own it all. diff --git a/sdk/cosmos/azure_data_cosmos/src/pipeline/signature_target.rs b/sdk/cosmos/azure_data_cosmos/src/pipeline/signature_target.rs new file mode 100644 index 0000000000..1a3f220b96 --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos/src/pipeline/signature_target.rs @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#[cfg_attr(not(feature = "key_auth"), allow(unused_imports))] +use azure_core::{credentials::Secret, hmac::hmac_sha256, Method}; + +use crate::resource_context::ResourceLink; + +#[cfg_attr(not(feature = "key_auth"), allow(dead_code))] +pub struct SignatureTarget<'a> { + http_method: azure_core::Method, + link: &'a ResourceLink, + date_string: &'a str, +} + +impl<'a> SignatureTarget<'a> { + pub fn new(http_method: Method, link: &'a ResourceLink, date_string: &'a str) -> Self { + SignatureTarget { + http_method, + link, + date_string, + } + } + + #[cfg(feature = "key_auth")] + pub fn into_authorization(self, key: &Secret) -> azure_core::Result { + let string_to_sign = self.into_signable_string(); + // The signature payload is NOT SECRET. The signature IS SECRET, but we can safely log the signature payload (which can be useful for diagnosing auth errors) + tracing::debug!(signature_payload = ?string_to_sign, "generating Cosmos auth signature"); + let signature = hmac_sha256(&string_to_sign, key)?; + Ok(format!("type=master&ver=1.0&sig={signature}")) + } + + /// This function generates a valid authorization string, according to the documentation. + /// In case of authorization problems we can compare the `string_to_sign` generated by Azure against + /// our own. + #[cfg(feature = "key_auth")] + fn into_signable_string(self) -> String { + // From official docs: + // StringToSign = + // Verb.toLowerCase() + "\n" + + // ResourceType.toLowerCase() + "\n" + + // ResourceLink + "\n" + + // Date.toLowerCase() + "\n" + + // "" + "\n"; + // Notice the empty string at the end so we need to add two new lines + + format!( + "{}\n{}\n{}\n{}\n\n", + // Cosmos' signature algorithm requires lower-case methods, so we use our own match instead of the impl of AsRef, which is uppercase. + match self.http_method { + azure_core::Method::Get => "get", + azure_core::Method::Put => "put", + azure_core::Method::Post => "post", + azure_core::Method::Delete => "delete", + azure_core::Method::Head => "head", + azure_core::Method::Trace => "trace", + azure_core::Method::Options => "options", + azure_core::Method::Connect => "connect", + azure_core::Method::Patch => "patch", + _ => "extension", + }, + self.link.resource_type().path_segment(), + self.link.resource_link(), + self.date_string, + ) + } +} + +#[cfg(test)] +#[cfg(feature = "key_auth")] +mod tests { + use azure_core::date; + + use crate::{ + pipeline::signature_target::SignatureTarget, + resource_context::{ResourceLink, ResourceType}, + }; + + // We test the full authorization header in authorization_policy. + // However, testing the signable string here is useful to isolate failures in constructing the string to be signed + #[test] + fn into_signable_string_generates_correct_value() { + let time_nonce = date::parse_rfc3339("1900-01-01T01:00:00.000000000+00:00").unwrap(); + let date_string = date::to_rfc1123(&time_nonce).to_lowercase(); + + let ret = SignatureTarget::new( + azure_core::Method::Get, + &ResourceLink::root(ResourceType::Databases) + .item("MyDatabase") + .feed(ResourceType::Containers) + .item("MyCollection"), + &date_string, + ) + .into_signable_string(); + assert_eq!( + ret, + "get +dbs +dbs/MyDatabase/colls/MyCollection +mon, 01 jan 1900 01:00:00 gmt + +" + ); + } +} diff --git a/sdk/cosmos/azure_data_cosmos/src/resource_context.rs b/sdk/cosmos/azure_data_cosmos/src/resource_context.rs new file mode 100644 index 0000000000..bec0d57f79 --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos/src/resource_context.rs @@ -0,0 +1,233 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +use url::Url; + +use crate::utils::url_encode; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[allow(dead_code)] // For the variants. Can be removed when we have them all implemented. +pub enum ResourceType { + Databases, + Containers, + Items, + StoredProcedures, + Users, + Permissions, + PartitionKeyRanges, + UserDefinedFunctions, + Triggers, + Offers, +} + +impl ResourceType { + pub fn path_segment(self) -> &'static str { + match self { + ResourceType::Databases => "dbs", + ResourceType::Containers => "colls", + ResourceType::Items => "docs", + ResourceType::StoredProcedures => "sprocs", + ResourceType::Users => "users", + ResourceType::Permissions => "permissions", + ResourceType::PartitionKeyRanges => "pkranges", + ResourceType::UserDefinedFunctions => "udfs", + ResourceType::Triggers => "triggers", + ResourceType::Offers => "offers", + } + } +} + +/// Represents a "resource link" defining a sub-resource in Azure Cosmos DB +/// +/// This value is URL encoded, and can be [`Url::join`]ed to the endpoint root to produce the full absolute URL for a Cosmos DB resource. +/// It's also intended for use by the signature algorithm used when authenticating with a primary key. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ResourceLink { + parent: Option, + item_id: Option, + resource_type: ResourceType, +} + +impl ResourceLink { + pub fn root(resource_type: ResourceType) -> Self { + Self { + parent: None, + resource_type, + item_id: None, + } + } + + pub fn feed(&self, resource_type: ResourceType) -> Self { + Self { + parent: Some(self.path()), + resource_type, + item_id: None, + } + } + + pub fn item(&self, item_id: impl AsRef) -> Self { + let item_id = url_encode(item_id.as_ref().as_bytes()); + Self { + parent: self.parent.clone(), + resource_type: self.resource_type, + item_id: Some(item_id), + } + } + + /// Gets the resource "link" identified by this link, for use when generating the authentication signature. + /// + /// For links referring to items, this is the full path of the item. + /// For links referring to feeds (for query, create, etc. requests where there is no item ID to reference), this is the path to the PARENT resource. + /// + /// See https://learn.microsoft.com/en-us/rest/api/cosmos-db/access-control-on-cosmosdb-resources#constructkeytoken for more details. + #[cfg_attr(not(feature = "key_auth"), allow(dead_code))] // REASON: Currently only used in key_auth feature but we don't want to conditional-compile it. + pub fn resource_link(&self) -> String { + match self.item_id { + Some(_) => self.path(), + None => self.parent.clone().unwrap_or_default(), + } + } + + /// Gets the [`ResourceType`] identified by this link, for use when generating the authentication signature. + #[cfg_attr(not(feature = "key_auth"), allow(dead_code))] // REASON: Currently only used in key_auth feature but we don't want to conditional-compile it. + pub fn resource_type(&self) -> ResourceType { + self.resource_type + } + + /// Gets the path that must be appended to the root account endpoint to access this resource. + pub fn path(&self) -> String { + match (self.parent.as_ref(), self.item_id.as_ref()) { + (None, Some(item_id)) => { + format!("{}/{}", self.resource_type.path_segment(), item_id) + } + (Some(parent), Some(item_id)) => format!( + "{}/{}/{}", + parent, + self.resource_type.path_segment(), + item_id + ), + (None, None) => self.resource_type.path_segment().to_string(), + (Some(ref parent), None) => format!("{}/{}", parent, self.resource_type.path_segment()), + } + } + + /// Creates a new [`Url`] by joining the provided `endpoint` with the path from [`ResourceLink::path`]. + pub fn url(&self, endpoint: &Url) -> Url { + endpoint + .join(&self.path()) + .expect("ResourceLink should always be url-safe") + } +} + +#[cfg(test)] +mod tests { + use crate::resource_context::{ResourceLink, ResourceType}; + + #[test] + pub fn root_link() { + let link = ResourceLink::root(ResourceType::Databases); + assert_eq!( + ResourceLink { + parent: None, + resource_type: ResourceType::Databases, + item_id: None, + }, + link + ); + assert_eq!( + "https://example.com/dbs", + link.url(&"https://example.com/".parse().unwrap()) + .to_string() + ); + assert_eq!("", link.resource_link()); + assert_eq!(ResourceType::Databases, link.resource_type()); + } + + #[test] + pub fn root_item_link() { + let link = ResourceLink::root(ResourceType::Databases).item("TestDB"); + assert_eq!( + ResourceLink { + parent: None, + resource_type: ResourceType::Databases, + item_id: Some("TestDB".to_string()), + }, + link + ); + assert_eq!( + "https://example.com/dbs/TestDB", + link.url(&"https://example.com/".parse().unwrap()) + .to_string() + ); + assert_eq!("dbs/TestDB", link.resource_link()); + assert_eq!(ResourceType::Databases, link.resource_type()); + } + + #[test] + pub fn child_feed_link() { + let link = ResourceLink::root(ResourceType::Databases) + .item("TestDB") + .feed(ResourceType::Containers); + assert_eq!( + ResourceLink { + parent: Some("dbs/TestDB".to_string()), + resource_type: ResourceType::Containers, + item_id: None, + }, + link + ); + assert_eq!( + "https://example.com/dbs/TestDB/colls", + link.url(&"https://example.com/".parse().unwrap()) + .to_string() + ); + assert_eq!("dbs/TestDB", link.resource_link()); + assert_eq!(ResourceType::Containers, link.resource_type()); + } + + #[test] + pub fn child_item_link() { + let link = ResourceLink::root(ResourceType::Databases) + .item("TestDB") + .feed(ResourceType::Containers) + .item("TestContainer"); + assert_eq!( + ResourceLink { + parent: Some("dbs/TestDB".to_string()), + resource_type: ResourceType::Containers, + item_id: Some("TestContainer".to_string()), + }, + link + ); + assert_eq!( + "https://example.com/dbs/TestDB/colls/TestContainer", + link.url(&"https://example.com/".parse().unwrap()) + .to_string() + ); + assert_eq!("dbs/TestDB/colls/TestContainer", link.resource_link()); + assert_eq!(ResourceType::Containers, link.resource_type()); + } + + #[test] + pub fn resource_links_are_url_encoded() { + let link = ResourceLink::root(ResourceType::Databases) + .item("Test DB") + .feed(ResourceType::Containers) + .item("Test/Container"); + assert_eq!( + ResourceLink { + parent: Some("dbs/Test+DB".to_string()), + resource_type: ResourceType::Containers, + item_id: Some("Test%2FContainer".to_string()) + }, + link + ); + assert_eq!( + "https://example.com/dbs/Test+DB/colls/Test%2FContainer", + link.url(&"https://example.com/".parse().unwrap()) + .to_string() + ); + assert_eq!("dbs/Test+DB/colls/Test%2FContainer", link.resource_link()); + assert_eq!(ResourceType::Containers, link.resource_type()); + } +} diff --git a/sdk/cosmos/azure_data_cosmos/src/utils.rs b/sdk/cosmos/azure_data_cosmos/src/utils.rs index 1f6702fb9e..adc018a383 100644 --- a/sdk/cosmos/azure_data_cosmos/src/utils.rs +++ b/sdk/cosmos/azure_data_cosmos/src/utils.rs @@ -1,26 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -use url::Url; - -/// Appends new path segments to the target [`Url`]. -pub trait AppendPathSegments: Clone { - fn append_path_segments<'a>(&mut self, segments: impl IntoIterator); - - fn with_path_segments<'a>(&self, segments: impl IntoIterator) -> Self { - let mut new = self.clone(); - new.append_path_segments(segments); - new - } -} - -impl AppendPathSegments for Url { - fn append_path_segments<'a>(&mut self, segments: impl IntoIterator) { - let mut path_segments = self - .path_segments_mut() - .expect("the URL must not be a 'cannot-be-a-base' URL"); - for segment in segments { - path_segments.push(segment.as_ref()); - } - } +pub fn url_encode(s: impl AsRef<[u8]>) -> String { + url::form_urlencoded::byte_serialize(s.as_ref()).collect::() }