From 3ad40f62dc5dd38712f2f8e3b2d87f046b1780b7 Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Mon, 9 Nov 2020 13:48:18 +0100 Subject: [PATCH] Factor token provider out into a separate trait --- Cargo.toml | 2 + examples/basic.rs | 7 ++- src/lib.rs | 129 +++++++++++++++++++++++++--------------------- 3 files changed, 76 insertions(+), 62 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 49b5b136d..dd696ff53 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,11 +23,13 @@ tracing-opentelemetry = "0.8" tracing-subscriber = "0.2" [dependencies] +async-trait = "0.1.41" derivative = "2.1.1" futures = "0.3" hex = "0.4" http = "0.2" hyper = "0.13" +hyper-rustls = "0.20" log = "0.4" opentelemetry = "0.8" prost = "0.6" diff --git a/examples/basic.rs b/examples/basic.rs index db160d172..699db4369 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -1,5 +1,5 @@ use opentelemetry::{api::Provider, sdk}; -use opentelemetry_stackdriver::StackDriverExporter; +use opentelemetry_stackdriver::{StackDriverExporter, YupAuthorizer}; use tracing::{span, Level}; use tracing_subscriber::prelude::*; @@ -27,9 +27,8 @@ async fn main() { } async fn init_tracing(stackdriver_creds: impl AsRef) { - let exporter = StackDriverExporter::connect(stackdriver_creds, PathBuf::from("tokens.json"), &TokioSpawner, None, 5) - .await - .unwrap(); + let authorizer = YupAuthorizer::new(stackdriver_creds, PathBuf::from("tokens.json")).await.unwrap(); + let exporter = StackDriverExporter::connect(authorizer, &TokioSpawner, None, 5).await.unwrap(); let provider = sdk::Provider::builder().with_simple_exporter(exporter).build(); tracing_subscriber::registry() diff --git a/src/lib.rs b/src/lib.rs index cd3abc127..a797ebc4e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,15 +20,16 @@ // When this PR is merged we should be able to remove this attribute: // https://github.com/danburkert/prost/pull/291 +use async_trait::async_trait; use derivative::Derivative; use futures::stream::StreamExt; -use hyper::client::connect::Connect; use opentelemetry::{ api::core::Value, exporter::trace::{ExportResult, SpanData, SpanExporter}, }; use proto::google::devtools::cloudtrace::v2::BatchWriteSpansRequest; use std::{ + fmt, sync::{ atomic::{AtomicUsize, Ordering}, Arc, @@ -38,7 +39,7 @@ use std::{ use tonic::{ metadata::MetadataValue, transport::{Channel, ClientTlsConfig}, - IntoRequest, Request, + Request, }; use yup_oauth2::authenticator::Authenticator; @@ -88,8 +89,7 @@ pub struct StackDriverExporter { impl StackDriverExporter { /// If `num_concurrent_requests` is set to `0` or `None` then no limit is enforced. pub async fn connect( - credentials_path: impl AsRef, - persistent_token_file: impl Into>, + authenticator: impl Authorizer, spawn: &S, maximum_shutdown_duration: Option, num_concurrent_requests: impl Into>, @@ -97,14 +97,6 @@ impl StackDriverExporter { let num_concurrent_requests = num_concurrent_requests.into(); let uri = http::uri::Uri::from_static("https://cloudtrace.googleapis.com:443"); - let service_account_key = yup_oauth2::read_service_account_key(&credentials_path).await?; - let project_name = service_account_key.project_id.as_ref().ok_or("project_id is missing")?.clone(); - let mut authenticator = yup_oauth2::ServiceAccountAuthenticator::builder(service_account_key); - if let Some(persistent_token_file) = persistent_token_file.into() { - authenticator = authenticator.persist_tokens_to_disk(persistent_token_file); - } - let authenticator = authenticator.build().await?; - let mut rustls_config = rustls::ClientConfig::new(); rustls_config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); rustls_config.set_protocols(&[Vec::from("h2".as_bytes())]); @@ -117,7 +109,6 @@ impl StackDriverExporter { Box::new(Self::export_inner( TraceServiceClient::new(channel), authenticator, - project_name, rx, pending_count.clone(), num_concurrent_requests, @@ -136,36 +127,22 @@ impl StackDriverExporter { self.pending_count.load(Ordering::Relaxed) } - async fn export_inner( + async fn export_inner( client: TraceServiceClient, - authenticator: Authenticator, - project_name: String, + authorizer: impl Authorizer, rx: futures::channel::mpsc::Receiver>>, pending_count: Arc, num_concurrent: impl Into>, - ) where - C: Connect + Clone + Send + Sync + 'static, - { - let authenticator = &authenticator; + ) { + let authorizer = &authorizer; rx.for_each_concurrent(num_concurrent, move |batch| { let mut client = client.clone(); // This clone is cheap and allows for concurrent requests (see https://github.com/hyperium/tonic/issues/285#issuecomment-595880400) - let project_name = project_name.clone(); let pending_count = pending_count.clone(); async move { use proto::google::devtools::cloudtrace::v2::{ span::{time_event::Value, Attributes, TimeEvents}, Span, }; - let scopes = &["https://www.googleapis.com/auth/trace.append"]; - let token = authenticator.token(scopes).await; - log::trace!("Got StackDriver auth token: {:?}", token); - let bearer_token = match token { - Ok(token) => format!("Bearer {}", token.as_str()), - Err(e) => { - log::error!("StackDriver authentication failed {:?}", e); - return; - } - }; let spans = batch .into_iter() @@ -196,7 +173,7 @@ impl StackDriverExporter { Span { name: format!( "projects/{}/traces/{}/spans/{}", - project_name, + authorizer.project_name(), hex::encode(span.span_context.trace_id().to_u128().to_be_bytes()), hex::encode(span.span_context.span_id().to_u64().to_be_bytes()) ), @@ -212,15 +189,21 @@ impl StackDriverExporter { }) .collect::>(); - let req = BatchWriteSpansRequest { - name: format!("projects/{}", project_name), + let mut req = Request::new(BatchWriteSpansRequest { + name: format!("projects/{}", authorizer.project_name()), spans, - }; + }); + + if let Err(e) = authorizer.authorize(&mut req).await { + log::error!("StackDriver authentication failed {}", e); + return; + } + client - .batch_write_spans(AuthenticatedRequest::new(req, &bearer_token)) + .batch_write_spans(req) .await .map_err(|e| { - log::error!("StackDriver push failed {:?}", e); + log::error!("StackDriver push failed {}", e); }) .ok(); pending_count.fetch_sub(1, Ordering::Relaxed); @@ -230,27 +213,6 @@ impl StackDriverExporter { } } -struct AuthenticatedRequest<'a, T> { - inner: T, - auth: &'a str, -} - -impl<'a, T> AuthenticatedRequest<'a, T> { - pub fn new(inner: T, auth: &'a str) -> Self { - Self { inner, auth } - } -} - -impl IntoRequest for AuthenticatedRequest<'_, T> { - fn into_request(self) -> Request { - let mut req = Request::new(self.inner); - req - .metadata_mut() - .insert("authorization", MetadataValue::from_str(&self.auth).unwrap()); - req - } -} - impl SpanExporter for StackDriverExporter { fn export(&self, batch: Vec>) -> ExportResult { match self.tx.clone().try_send(batch) { @@ -278,6 +240,57 @@ impl SpanExporter for StackDriverExporter { } } +#[async_trait] +pub trait Authorizer: Sync + Send + 'static { + type Error: fmt::Display + fmt::Debug + Send; + + fn project_name(&self) -> &str; + async fn authorize(&self, request: &mut Request) -> Result<(), Self::Error>; +} + +pub struct YupAuthorizer { + authenticator: Authenticator>, + project_name: String, +} + +impl YupAuthorizer { + pub async fn new( + credentials_path: impl AsRef, + persistent_token_file: impl Into>, + ) -> Result> { + let service_account_key = yup_oauth2::read_service_account_key(&credentials_path).await?; + let project_name = service_account_key.project_id.as_ref().ok_or("project_id is missing")?.clone(); + let mut authenticator = yup_oauth2::ServiceAccountAuthenticator::builder(service_account_key); + if let Some(persistent_token_file) = persistent_token_file.into() { + authenticator = authenticator.persist_tokens_to_disk(persistent_token_file); + } + + Ok(Self { + authenticator: authenticator.build().await?, + project_name, + }) + } +} + +#[async_trait] +impl Authorizer for YupAuthorizer { + type Error = Box; + + fn project_name(&self) -> &str { + &self.project_name + } + + async fn authorize(&self, req: &mut Request) -> Result<(), Self::Error> { + let scopes = &["https://www.googleapis.com/auth/trace.append"]; + let token = self.authenticator.token(scopes).await?; + req.metadata_mut().insert( + "authorization", + MetadataValue::from_str(&format!("Bearer {}", token.as_str())).unwrap(), + ); + Ok(()) + } +} + fn attribute_value_conversion(v: Value) -> AttributeValue { use proto::google::devtools::cloudtrace::v2::attribute_value; let new_value = match v {