Skip to content

Commit

Permalink
[Cosmos] Refactoring how the resource type and link are passed down t…
Browse files Browse the repository at this point in the history
…o the auth policy (#1861)
  • Loading branch information
analogrelay authored Oct 25, 2024
1 parent 3fa79e3 commit 6a7faef
Show file tree
Hide file tree
Showing 9 changed files with 521 additions and 385 deletions.
106 changes: 39 additions & 67 deletions sdk/cosmos/azure_data_cosmos/src/clients/container_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,37 @@ use crate::{
constants,
models::{ContainerProperties, Item, QueryResults},
options::{QueryOptions, ReadContainerOptions},
pipeline::{CosmosPipeline, ResourceType},
utils::AppendPathSegments,
pipeline::CosmosPipeline,
resource_context::{ResourceLink, ResourceType},
DeleteContainerOptions, ItemOptions, PartitionKey, Query, QueryPartitionStrategy,
};

use azure_core::{Context, Method, Pager, Request, Response};
use serde::{de::DeserializeOwned, Serialize};
use typespec_client_core::http::PagerResult;
use url::Url;

/// A client for working with a specific container in a Cosmos DB account.
///
/// You can get a `Container` by calling [`DatabaseClient::container_client()`](crate::clients::DatabaseClient::container_client()).
pub struct ContainerClient {
container_url: Url,
link: ResourceLink,
items_link: ResourceLink,
pipeline: CosmosPipeline,
}

impl ContainerClient {
pub(crate) fn new(pipeline: CosmosPipeline, database_url: &Url, container_name: &str) -> Self {
let container_url = database_url.with_path_segments(["colls", container_name]);
pub(crate) fn new(
pipeline: CosmosPipeline,
database_link: &ResourceLink,
container_id: &str,
) -> Self {
let link = database_link
.feed(ResourceType::Containers)
.item(container_id);
let items_link = link.feed(ResourceType::Items);

Self {
container_url,
link,
items_link,
pipeline,
}
}
Expand Down Expand Up @@ -58,9 +65,10 @@ impl ContainerClient {
// REASON: This is a documented public API so prefixing with '_' is undesirable.
options: Option<ReadContainerOptions>,
) -> azure_core::Result<Response<ContainerProperties>> {
let mut req = Request::new(self.container_url.clone(), Method::Get);
let url = self.pipeline.url(&self.link);
let mut req = Request::new(url, Method::Get);
self.pipeline
.send(Context::new(), &mut req, ResourceType::Containers)
.send(Context::new(), &mut req, self.link.clone())
.await
}

Expand All @@ -76,9 +84,10 @@ impl ContainerClient {
// REASON: This is a documented public API so prefixing with '_' is undesirable.
options: Option<DeleteContainerOptions>,
) -> azure_core::Result<Response> {
let mut req = Request::new(self.container_url.clone(), Method::Delete);
let url = self.pipeline.url(&self.link);
let mut req = Request::new(url, Method::Delete);
self.pipeline
.send(Context::new(), &mut req, ResourceType::Containers)
.send(Context::new(), &mut req, self.link.clone())
.await
}

Expand Down Expand Up @@ -126,12 +135,12 @@ impl ContainerClient {
// REASON: This is a documented public API so prefixing with '_' is undesirable.
options: Option<ItemOptions>,
) -> azure_core::Result<Response<Item<T>>> {
let url = self.container_url.with_path_segments(["docs"]);
let url = self.pipeline.url(&self.items_link);
let mut req = Request::new(url, Method::Post);
req.insert_headers(&partition_key.into())?;
req.set_json(&item)?;
self.pipeline
.send(Context::new(), &mut req, ResourceType::Items)
.send(Context::new(), &mut req, self.items_link.clone())
.await
}

Expand Down Expand Up @@ -181,15 +190,12 @@ impl ContainerClient {
// REASON: This is a documented public API so prefixing with '_' is undesirable.
options: Option<ItemOptions>,
) -> azure_core::Result<Response<Item<T>>> {
let url = self
.container_url
.with_path_segments(["docs", item_id.as_ref()]);
let link = self.items_link.item(item_id);
let url = self.pipeline.url(&link);
let mut req = Request::new(url, Method::Put);
req.insert_headers(&partition_key.into())?;
req.set_json(&item)?;
self.pipeline
.send(Context::new(), &mut req, ResourceType::Items)
.await
self.pipeline.send(Context::new(), &mut req, link).await
}

/// Creates or replaces an item in the container.
Expand Down Expand Up @@ -239,13 +245,13 @@ impl ContainerClient {
// REASON: This is a documented public API so prefixing with '_' is undesirable.
options: Option<ItemOptions>,
) -> azure_core::Result<Response<Item<T>>> {
let url = self.container_url.with_path_segments(["docs"]);
let url = self.pipeline.url(&self.items_link);
let mut req = Request::new(url, Method::Post);
req.insert_header(constants::IS_UPSERT, "true");
req.insert_headers(&partition_key.into())?;
req.set_json(&item)?;
self.pipeline
.send(Context::new(), &mut req, ResourceType::Items)
.send(Context::new(), &mut req, self.items_link.clone())
.await
}

Expand Down Expand Up @@ -288,14 +294,11 @@ impl ContainerClient {
// REASON: This is a documented public API so prefixing with '_' is undesirable.
options: Option<ItemOptions>,
) -> azure_core::Result<Response<Item<T>>> {
let url = self
.container_url
.with_path_segments(["docs", item_id.as_ref()]);
let link = self.items_link.item(item_id);
let url = self.pipeline.url(&link);
let mut req = Request::new(url, Method::Get);
req.insert_headers(&partition_key.into())?;
self.pipeline
.send(Context::new(), &mut req, ResourceType::Items)
.await
self.pipeline.send(Context::new(), &mut req, link).await
}

/// Deletes an item from the container.
Expand Down Expand Up @@ -326,14 +329,11 @@ impl ContainerClient {
// REASON: This is a documented public API so prefixing with '_' is undesirable.
options: Option<ItemOptions>,
) -> azure_core::Result<Response> {
let url = self
.container_url
.with_path_segments(["docs", item_id.as_ref()]);
let link = self.items_link.item(item_id);
let url = self.pipeline.url(&link);
let mut req = Request::new(url, Method::Delete);
req.insert_headers(&partition_key.into())?;
self.pipeline
.send(Context::new(), &mut req, ResourceType::Items)
.await
self.pipeline.send(Context::new(), &mut req, link).await
}

/// Executes a single-partition query against items in the container.
Expand Down Expand Up @@ -399,40 +399,12 @@ impl ContainerClient {
// REASON: This is a documented public API so prefixing with '_' is undesirable.
options: Option<QueryOptions>,
) -> azure_core::Result<Pager<QueryResults<T>>> {
let mut url = self.container_url.clone();
url.append_path_segments(["docs"]);
let mut base_req = Request::new(url, Method::Post);

base_req.insert_header(constants::QUERY, "True");
base_req.add_mandatory_header(&constants::QUERY_CONTENT_TYPE);

let url = self.pipeline.url(&self.items_link);
let mut base_request = Request::new(url, Method::Post);
let QueryPartitionStrategy::SinglePartition(partition_key) = partition_key.into();
base_req.insert_headers(&partition_key)?;
base_request.insert_headers(&partition_key)?;

base_req.set_json(&query.into())?;

// We have to double-clone here.
// First we clone the pipeline to pass it in to the closure
let pipeline = self.pipeline.clone();
Ok(Pager::from_callback(move |continuation| {
// Then we have to clone it again to pass it in to the async block.
// This is because Pageable can't borrow any data, it has to own it all.
// That's probably good, because it means a Pageable can outlive the client that produced it, but it requires some extra cloning.
let pipeline = pipeline.clone();
let mut req = base_req.clone();
async move {
if let Some(continuation) = continuation {
req.insert_header(constants::CONTINUATION, continuation);
}

let response = pipeline
.send(Context::new(), &mut req, ResourceType::Items)
.await?;
Ok(PagerResult::from_response_header(
response,
&constants::CONTINUATION,
))
}
}))
self.pipeline
.send_query_request(query.into(), base_request, self.items_link.clone())
}
}
27 changes: 14 additions & 13 deletions sdk/cosmos/azure_data_cosmos/src/clients/cosmos_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
use crate::{
clients::DatabaseClient,
models::{DatabaseProperties, DatabaseQueryResults, Item},
pipeline::{AuthorizationPolicy, CosmosPipeline, ResourceType},
utils::AppendPathSegments,
pipeline::{AuthorizationPolicy, CosmosPipeline},
resource_context::{ResourceLink, ResourceType},
CosmosClientOptions, CreateDatabaseOptions, Query, QueryDatabasesOptions,
};
use azure_core::{credentials::TokenCredential, Context, Method, Request, Response, Url};

use serde::Serialize;
use std::sync::Arc;

Expand All @@ -19,8 +18,8 @@ use azure_core::credentials::Secret;
/// Client for Azure Cosmos DB.
#[derive(Debug, Clone)]
pub struct CosmosClient {
endpoint: Url,
pub(crate) pipeline: CosmosPipeline,
databases_link: ResourceLink,
pipeline: CosmosPipeline,

#[allow(dead_code)]
options: CosmosClientOptions,
Expand Down Expand Up @@ -51,8 +50,9 @@ impl CosmosClient {
) -> azure_core::Result<Self> {
let options = options.unwrap_or_default();
Ok(Self {
endpoint: endpoint.as_ref().parse()?,
databases_link: ResourceLink::root(ResourceType::Databases),
pipeline: CosmosPipeline::new(
endpoint.as_ref().parse()?,
AuthorizationPolicy::from_token_credential(credential),
options.client_options.clone(),
),
Expand Down Expand Up @@ -83,8 +83,9 @@ impl CosmosClient {
) -> azure_core::Result<Self> {
let options = options.unwrap_or_default();
Ok(Self {
endpoint: endpoint.as_ref().parse()?,
databases_link: ResourceLink::root(ResourceType::Databases),
pipeline: CosmosPipeline::new(
endpoint.as_ref().parse()?,
AuthorizationPolicy::from_shared_key(key.into()),
options.client_options.clone(),
),
Expand All @@ -97,12 +98,12 @@ impl CosmosClient {
/// # Arguments
/// * `id` - The ID of the database.
pub fn database_client(&self, id: impl AsRef<str>) -> DatabaseClient {
DatabaseClient::new(self.pipeline.clone(), &self.endpoint, id.as_ref())
DatabaseClient::new(self.pipeline.clone(), id.as_ref())
}

/// Gets the endpoint of the database account this client is connected to.
pub fn endpoint(&self) -> &Url {
&self.endpoint
&self.pipeline.endpoint
}

/// Executes a query against databases in the account.
Expand Down Expand Up @@ -136,11 +137,11 @@ impl CosmosClient {
// REASON: This is a documented public API so prefixing with '_' is undesirable.
options: Option<QueryDatabasesOptions>,
) -> azure_core::Result<azure_core::Pager<DatabaseQueryResults>> {
let url = self.endpoint.with_path_segments(["dbs"]);
let url = self.pipeline.url(&self.databases_link);
let base_request = Request::new(url, azure_core::Method::Post);

self.pipeline
.send_query_request(query.into(), base_request, ResourceType::Databases)
.send_query_request(query.into(), base_request, self.databases_link.clone())
}

/// Creates a new database.
Expand All @@ -163,12 +164,12 @@ impl CosmosClient {
id: String,
}

let url = self.endpoint.with_path_segments(["dbs"]);
let url = self.pipeline.url(&self.databases_link);
let mut req = Request::new(url, Method::Post);
req.set_json(&RequestBody { id })?;

self.pipeline
.send(Context::new(), &mut req, ResourceType::Databases)
.send(Context::new(), &mut req, self.databases_link.clone())
.await
}
}
Loading

0 comments on commit 6a7faef

Please sign in to comment.