Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Cosmos] Refactoring how the resource type and link are passed down to the auth policy #1861

Merged
merged 7 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 39 additions & 67 deletions sdk/cosmos/azure_data_cosmos/src/clients/container_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,37 @@ use crate::{
constants,
models::{ContainerProperties, Item, QueryResults},
options::{QueryOptions, ReadContainerOptions},
pipeline::{CosmosPipeline, ResourceType},
utils::AppendPathSegments,
pipeline::CosmosPipeline,
resource_context::{ResourceLink, ResourceType},
DeleteContainerOptions, ItemOptions, PartitionKey, Query, QueryPartitionStrategy,
};

use azure_core::{Context, Method, Pager, Request, Response};
use serde::{de::DeserializeOwned, Serialize};
use typespec_client_core::http::PagerResult;
use url::Url;

/// A client for working with a specific container in a Cosmos DB account.
///
/// You can get a `Container` by calling [`DatabaseClient::container_client()`](crate::clients::DatabaseClient::container_client()).
pub struct ContainerClient {
container_url: Url,
link: ResourceLink,
items_link: ResourceLink,
pipeline: CosmosPipeline,
}

impl ContainerClient {
pub(crate) fn new(pipeline: CosmosPipeline, database_url: &Url, container_name: &str) -> Self {
let container_url = database_url.with_path_segments(["colls", container_name]);
pub(crate) fn new(
pipeline: CosmosPipeline,
database_link: &ResourceLink,
container_id: &str,
) -> Self {
let link = database_link
.feed(ResourceType::Containers)
.item(container_id);
let items_link = link.feed(ResourceType::Items);

Self {
container_url,
link,
items_link,
pipeline,
}
}
Expand Down Expand Up @@ -58,9 +65,10 @@ impl ContainerClient {
// REASON: This is a documented public API so prefixing with '_' is undesirable.
options: Option<ReadContainerOptions>,
) -> azure_core::Result<Response<ContainerProperties>> {
let mut req = Request::new(self.container_url.clone(), Method::Get);
let url = self.pipeline.url(&self.link);
let mut req = Request::new(url, Method::Get);
self.pipeline
.send(Context::new(), &mut req, ResourceType::Containers)
.send(Context::new(), &mut req, self.link.clone())
.await
}

Expand All @@ -76,9 +84,10 @@ impl ContainerClient {
// REASON: This is a documented public API so prefixing with '_' is undesirable.
options: Option<DeleteContainerOptions>,
) -> azure_core::Result<Response> {
let mut req = Request::new(self.container_url.clone(), Method::Delete);
let url = self.pipeline.url(&self.link);
let mut req = Request::new(url, Method::Delete);
self.pipeline
.send(Context::new(), &mut req, ResourceType::Containers)
.send(Context::new(), &mut req, self.link.clone())
.await
}

Expand Down Expand Up @@ -126,12 +135,12 @@ impl ContainerClient {
// REASON: This is a documented public API so prefixing with '_' is undesirable.
options: Option<ItemOptions>,
) -> azure_core::Result<Response<Item<T>>> {
let url = self.container_url.with_path_segments(["docs"]);
let url = self.pipeline.url(&self.items_link);
let mut req = Request::new(url, Method::Post);
req.insert_headers(&partition_key.into())?;
req.set_json(&item)?;
self.pipeline
.send(Context::new(), &mut req, ResourceType::Items)
.send(Context::new(), &mut req, self.items_link.clone())
.await
}

Expand Down Expand Up @@ -181,15 +190,12 @@ impl ContainerClient {
// REASON: This is a documented public API so prefixing with '_' is undesirable.
options: Option<ItemOptions>,
) -> azure_core::Result<Response<Item<T>>> {
let url = self
.container_url
.with_path_segments(["docs", item_id.as_ref()]);
let link = self.items_link.item(item_id);
let url = self.pipeline.url(&link);
let mut req = Request::new(url, Method::Put);
req.insert_headers(&partition_key.into())?;
req.set_json(&item)?;
self.pipeline
.send(Context::new(), &mut req, ResourceType::Items)
.await
self.pipeline.send(Context::new(), &mut req, link).await
}

/// Creates or replaces an item in the container.
Expand Down Expand Up @@ -239,13 +245,13 @@ impl ContainerClient {
// REASON: This is a documented public API so prefixing with '_' is undesirable.
options: Option<ItemOptions>,
) -> azure_core::Result<Response<Item<T>>> {
let url = self.container_url.with_path_segments(["docs"]);
let url = self.pipeline.url(&self.items_link);
let mut req = Request::new(url, Method::Post);
req.insert_header(constants::IS_UPSERT, "true");
req.insert_headers(&partition_key.into())?;
req.set_json(&item)?;
self.pipeline
.send(Context::new(), &mut req, ResourceType::Items)
.send(Context::new(), &mut req, self.items_link.clone())
.await
}

Expand Down Expand Up @@ -288,14 +294,11 @@ impl ContainerClient {
// REASON: This is a documented public API so prefixing with '_' is undesirable.
options: Option<ItemOptions>,
) -> azure_core::Result<Response<Item<T>>> {
let url = self
.container_url
.with_path_segments(["docs", item_id.as_ref()]);
let link = self.items_link.item(item_id);
let url = self.pipeline.url(&link);
let mut req = Request::new(url, Method::Get);
req.insert_headers(&partition_key.into())?;
self.pipeline
.send(Context::new(), &mut req, ResourceType::Items)
.await
self.pipeline.send(Context::new(), &mut req, link).await
}

/// Deletes an item from the container.
Expand Down Expand Up @@ -326,14 +329,11 @@ impl ContainerClient {
// REASON: This is a documented public API so prefixing with '_' is undesirable.
options: Option<ItemOptions>,
) -> azure_core::Result<Response> {
let url = self
.container_url
.with_path_segments(["docs", item_id.as_ref()]);
let link = self.items_link.item(item_id);
let url = self.pipeline.url(&link);
let mut req = Request::new(url, Method::Delete);
req.insert_headers(&partition_key.into())?;
self.pipeline
.send(Context::new(), &mut req, ResourceType::Items)
.await
self.pipeline.send(Context::new(), &mut req, link).await
}

/// Executes a single-partition query against items in the container.
Expand Down Expand Up @@ -399,40 +399,12 @@ impl ContainerClient {
// REASON: This is a documented public API so prefixing with '_' is undesirable.
options: Option<QueryOptions>,
) -> azure_core::Result<Pager<QueryResults<T>>> {
let mut url = self.container_url.clone();
url.append_path_segments(["docs"]);
let mut base_req = Request::new(url, Method::Post);

base_req.insert_header(constants::QUERY, "True");
base_req.add_mandatory_header(&constants::QUERY_CONTENT_TYPE);

let url = self.pipeline.url(&self.items_link);
let mut base_request = Request::new(url, Method::Post);
let QueryPartitionStrategy::SinglePartition(partition_key) = partition_key.into();
base_req.insert_headers(&partition_key)?;
base_request.insert_headers(&partition_key)?;

base_req.set_json(&query.into())?;

// We have to double-clone here.
// First we clone the pipeline to pass it in to the closure
let pipeline = self.pipeline.clone();
Ok(Pager::from_callback(move |continuation| {
// Then we have to clone it again to pass it in to the async block.
// This is because Pageable can't borrow any data, it has to own it all.
// That's probably good, because it means a Pageable can outlive the client that produced it, but it requires some extra cloning.
let pipeline = pipeline.clone();
let mut req = base_req.clone();
async move {
if let Some(continuation) = continuation {
req.insert_header(constants::CONTINUATION, continuation);
}

let response = pipeline
.send(Context::new(), &mut req, ResourceType::Items)
.await?;
Ok(PagerResult::from_response_header(
response,
&constants::CONTINUATION,
))
}
}))
self.pipeline
.send_query_request(query.into(), base_request, self.items_link.clone())
}
}
27 changes: 14 additions & 13 deletions sdk/cosmos/azure_data_cosmos/src/clients/cosmos_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
use crate::{
clients::DatabaseClient,
models::{DatabaseProperties, DatabaseQueryResults, Item},
pipeline::{AuthorizationPolicy, CosmosPipeline, ResourceType},
utils::AppendPathSegments,
pipeline::{AuthorizationPolicy, CosmosPipeline},
resource_context::{ResourceLink, ResourceType},
CosmosClientOptions, CreateDatabaseOptions, Query, QueryDatabasesOptions,
};
use azure_core::{credentials::TokenCredential, Context, Method, Request, Response, Url};

use serde::Serialize;
use std::sync::Arc;

Expand All @@ -19,8 +18,8 @@ use azure_core::credentials::Secret;
/// Client for Azure Cosmos DB.
#[derive(Debug, Clone)]
pub struct CosmosClient {
endpoint: Url,
pub(crate) pipeline: CosmosPipeline,
databases_link: ResourceLink,
pipeline: CosmosPipeline,

#[allow(dead_code)]
options: CosmosClientOptions,
Expand Down Expand Up @@ -51,8 +50,9 @@ impl CosmosClient {
) -> azure_core::Result<Self> {
let options = options.unwrap_or_default();
Ok(Self {
endpoint: endpoint.as_ref().parse()?,
databases_link: ResourceLink::root(ResourceType::Databases),
pipeline: CosmosPipeline::new(
endpoint.as_ref().parse()?,
AuthorizationPolicy::from_token_credential(credential),
options.client_options.clone(),
),
Expand Down Expand Up @@ -83,8 +83,9 @@ impl CosmosClient {
) -> azure_core::Result<Self> {
let options = options.unwrap_or_default();
Ok(Self {
endpoint: endpoint.as_ref().parse()?,
databases_link: ResourceLink::root(ResourceType::Databases),
pipeline: CosmosPipeline::new(
endpoint.as_ref().parse()?,
AuthorizationPolicy::from_shared_key(key.into()),
options.client_options.clone(),
),
Expand All @@ -97,12 +98,12 @@ impl CosmosClient {
/// # Arguments
/// * `id` - The ID of the database.
pub fn database_client(&self, id: impl AsRef<str>) -> DatabaseClient {
DatabaseClient::new(self.pipeline.clone(), &self.endpoint, id.as_ref())
DatabaseClient::new(self.pipeline.clone(), id.as_ref())
}

/// Gets the endpoint of the database account this client is connected to.
pub fn endpoint(&self) -> &Url {
&self.endpoint
&self.pipeline.endpoint
}

/// Executes a query against databases in the account.
Expand Down Expand Up @@ -136,11 +137,11 @@ impl CosmosClient {
// REASON: This is a documented public API so prefixing with '_' is undesirable.
options: Option<QueryDatabasesOptions>,
) -> azure_core::Result<azure_core::Pager<DatabaseQueryResults>> {
let url = self.endpoint.with_path_segments(["dbs"]);
let url = self.pipeline.url(&self.databases_link);
let base_request = Request::new(url, azure_core::Method::Post);

self.pipeline
.send_query_request(query.into(), base_request, ResourceType::Databases)
.send_query_request(query.into(), base_request, self.databases_link.clone())
}

/// Creates a new database.
Expand All @@ -163,12 +164,12 @@ impl CosmosClient {
id: String,
}

let url = self.endpoint.with_path_segments(["dbs"]);
let url = self.pipeline.url(&self.databases_link);
let mut req = Request::new(url, Method::Post);
req.set_json(&RequestBody { id })?;

self.pipeline
.send(Context::new(), &mut req, ResourceType::Databases)
.send(Context::new(), &mut req, self.databases_link.clone())
.await
}
}
Loading