diff --git a/opentelemetry-aws/src/lib.rs b/opentelemetry-aws/src/lib.rs index 8f06947f13..6552ae9ec8 100644 --- a/opentelemetry-aws/src/lib.rs +++ b/opentelemetry-aws/src/lib.rs @@ -37,8 +37,11 @@ #[cfg(feature = "trace")] pub mod trace { use opentelemetry::{ + global::{self, Error}, propagation::{text_map_propagator::FieldIter, Extractor, Injector, TextMapPropagator}, - trace::{SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId, TraceState}, + trace::{ + SpanContext, SpanId, TraceContextExt, TraceError, TraceFlags, TraceId, TraceState, + }, Context, }; use std::convert::{TryFrom, TryInto}; @@ -125,21 +128,29 @@ pub mod trace { } } - let trace_state: TraceState = TraceState::from_key_value(kv_vec)?; - - if trace_id.to_u128() == 0 { - return Err(()); - } + match TraceState::from_key_value(kv_vec) { + Ok(trace_state) => { + if trace_id.to_u128() == 0 { + return Err(()); + } - let context: SpanContext = SpanContext::new( - trace_id, - parent_segment_id, - sampling_decision, - true, - trace_state, - ); + let context: SpanContext = SpanContext::new( + trace_id, + parent_segment_id, + sampling_decision, + true, + trace_state, + ); - Ok(context) + Ok(context) + } + Err(trace_state_err) => { + global::handle_error(Error::Trace(TraceError::Other(Box::new( + trace_state_err, + )))); + Err(()) //todo: assign an error type instead of using () + } + } } } diff --git a/opentelemetry-jaeger/src/lib.rs b/opentelemetry-jaeger/src/lib.rs index 6b2444e25b..d5d013aea8 100644 --- a/opentelemetry-jaeger/src/lib.rs +++ b/opentelemetry-jaeger/src/lib.rs @@ -200,8 +200,11 @@ mod propagator { //! //! [`Jaeger documentation`]: https://www.jaegertracing.io/docs/1.18/client-libraries/#propagation-format use opentelemetry::{ + global::{self, Error}, propagation::{text_map_propagator::FieldIter, Extractor, Injector, TextMapPropagator}, - trace::{SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId, TraceState}, + trace::{ + SpanContext, SpanId, TraceContextExt, TraceError, TraceFlags, TraceId, TraceState, + }, Context, }; use std::borrow::Cow; @@ -324,7 +327,15 @@ mod propagator { .map(|value| (key.to_string(), value.to_string())) }); - TraceState::from_key_value(uber_context_keys) + match TraceState::from_key_value(uber_context_keys) { + Ok(trace_state) => Ok(trace_state), + Err(trace_state_err) => { + global::handle_error(Error::Trace(TraceError::Other(Box::new( + trace_state_err, + )))); + Err(()) //todo: assign an error type instead of using () + } + } } } diff --git a/opentelemetry/src/sdk/trace/tracer.rs b/opentelemetry/src/sdk/trace/tracer.rs index b74d8cb4f0..46d401db7a 100644 --- a/opentelemetry/src/sdk/trace/tracer.rs +++ b/opentelemetry/src/sdk/trace/tracer.rs @@ -361,7 +361,7 @@ mod tests { SamplingResult { decision: SamplingDecision::RecordAndSample, attributes: Vec::new(), - trace_state: trace_state.insert("foo".into(), "notbar".into()).unwrap(), + trace_state: trace_state.insert("foo", "notbar").unwrap(), } } } diff --git a/opentelemetry/src/trace/mod.rs b/opentelemetry/src/trace/mod.rs index 0797bc8276..0b142c3e55 100644 --- a/opentelemetry/src/trace/mod.rs +++ b/opentelemetry/src/trace/mod.rs @@ -189,7 +189,7 @@ pub use self::{ noop::{NoopSpan, NoopSpanExporter, NoopTracer, NoopTracerProvider}, provider::TracerProvider, span::{Span, SpanKind, StatusCode}, - span_context::{SpanContext, SpanId, TraceFlags, TraceId, TraceState}, + span_context::{SpanContext, SpanId, TraceFlags, TraceId, TraceState, TraceStateError}, tracer::{SpanBuilder, Tracer}, }; use crate::sdk::export::ExportError; diff --git a/opentelemetry/src/trace/span_context.rs b/opentelemetry/src/trace/span_context.rs index c226e12536..57398fa405 100644 --- a/opentelemetry/src/trace/span_context.rs +++ b/opentelemetry/src/trace/span_context.rs @@ -16,6 +16,7 @@ use std::collections::VecDeque; use std::fmt; use std::ops::{BitAnd, BitOr, Not}; use std::str::FromStr; +use thiserror::Error; /// Flags that can be set on a [`SpanContext`]. /// @@ -238,16 +239,15 @@ impl TraceState { /// # Examples /// /// ``` - /// use opentelemetry::trace::TraceState; + /// use opentelemetry::trace::{TraceState, TraceStateError}; /// /// let kvs = vec![("foo", "bar"), ("apple", "banana")]; - /// let trace_state: Result = TraceState::from_key_value(kvs); + /// let trace_state: Result = TraceState::from_key_value(kvs); /// /// assert!(trace_state.is_ok()); /// assert_eq!(trace_state.unwrap().header(), String::from("foo=bar,apple=banana")) /// ``` - #[allow(clippy::all)] - pub fn from_key_value(trace_state: T) -> Result + pub fn from_key_value(trace_state: T) -> Result where T: IntoIterator, K: ToString, @@ -257,14 +257,16 @@ impl TraceState { .into_iter() .map(|(key, value)| { let (key, value) = (key.to_string(), value.to_string()); - if !TraceState::valid_key(key.as_str()) || !TraceState::valid_value(value.as_str()) - { - return Err(()); + if !TraceState::valid_key(key.as_str()) { + return Err(TraceStateError::InvalidKey(key)); + } + if !TraceState::valid_value(value.as_str()) { + return Err(TraceStateError::InvalidValue(value)); } Ok((key, value)) }) - .collect::, ()>>()?; + .collect::, TraceStateError>>()?; if ordered_data.is_empty() { Ok(TraceState(None)) @@ -292,13 +294,20 @@ impl TraceState { /// updated key/value is returned. /// /// ['spec']: https://www.w3.org/TR/trace-context/#list - #[allow(clippy::all)] - pub fn insert(&self, key: String, value: String) -> Result { - if !TraceState::valid_key(key.as_str()) || !TraceState::valid_value(value.as_str()) { - return Err(()); + pub fn insert(&self, key: K, value: V) -> Result + where + K: Into, + V: Into, + { + let (key, value) = (key.into(), value.into()); + if !TraceState::valid_key(key.as_str()) { + return Err(TraceStateError::InvalidKey(key)); + } + if !TraceState::valid_value(value.as_str()) { + return Err(TraceStateError::InvalidValue(value)); } - let mut trace_state = self.delete(key.clone())?; + let mut trace_state = self.delete_from_deque(key.clone()); let kvs = trace_state.0.get_or_insert(VecDeque::with_capacity(1)); kvs.push_front((key, value)); @@ -307,26 +316,29 @@ impl TraceState { } /// Removes the given key-value pair from the `TraceState`. If the key is invalid per the - /// [W3 Spec]['spec'] or the key does not exist an `Err` is returned. Else, a new `TraceState` + /// [W3 Spec]['spec'] an `Err` is returned. Else, a new `TraceState` /// with the removed entry is returned. /// + /// If the key is not in `TraceState`. The original `TraceState` will be cloned and returned. /// ['spec']: https://www.w3.org/TR/trace-context/#list - #[allow(clippy::all)] - pub fn delete(&self, key: String) -> Result { + pub fn delete>(&self, key: K) -> Result { + let key = key.into(); if !TraceState::valid_key(key.as_str()) { - return Err(()); + return Err(TraceStateError::InvalidKey(key)); } - let mut owned = self.clone(); - let kvs = owned.0.as_mut().ok_or(())?; + Ok(self.delete_from_deque(key)) + } - if let Some(index) = kvs.iter().position(|x| *x.0 == *key) { - kvs.remove(index); - } else { - return Err(()); + /// Delete key from trace state's deque. The key MUST be valid + fn delete_from_deque(&self, key: String) -> TraceState { + let mut owned = self.clone(); + if let Some(kvs) = owned.0.as_mut() { + if let Some(index) = kvs.iter().position(|x| *x.0 == *key) { + kvs.remove(index); + } } - - Ok(owned) + owned } /// Creates a new `TraceState` header string, delimiting each key and value with a `=` and each @@ -350,7 +362,7 @@ impl TraceState { } impl FromStr for TraceState { - type Err = (); + type Err = TraceStateError; fn from_str(s: &str) -> Result { let list_members: Vec<&str> = s.split_terminator(',').collect(); @@ -358,7 +370,7 @@ impl FromStr for TraceState { for list_member in list_members { match list_member.find('=') { - None => return Err(()), + None => return Err(TraceStateError::InvalidList(list_member.to_string())), Some(separator_index) => { let (key, value) = list_member.split_at(separator_index); key_value_pairs @@ -371,6 +383,23 @@ impl FromStr for TraceState { } } +/// Error returned by `TraceState` operations. +#[derive(Error, Debug)] +#[non_exhaustive] +pub enum TraceStateError { + /// The key is invalid. See https://www.w3.org/TR/trace-context/#key for requirement for keys. + #[error("{0} is not a valid key in TraceState, see https://www.w3.org/TR/trace-context/#key for more details")] + InvalidKey(String), + + /// The value is invalid. See https://www.w3.org/TR/trace-context/#value for requirement for values. + #[error("{0} is not a valid value in TraceState, see https://www.w3.org/TR/trace-context/#value for more details")] + InvalidValue(String), + + /// The value is invalid. See https://www.w3.org/TR/trace-context/#list for requirement for list members. + #[error("{0} is not a valid list member in TraceState, see https://www.w3.org/TR/trace-context/#list for more details")] + InvalidList(String), +} + /// Immutable portion of a `Span` which can be serialized and propagated. /// /// Spans that do not have the `sampled` flag set in their [`TraceFlags`] will @@ -514,7 +543,7 @@ mod tests { let new_key = format!("{}-{}", test_case.0.get(test_case.2).unwrap(), "test"); - let updated_trace_state = test_case.0.insert(test_case.2.into(), new_key.clone()); + let updated_trace_state = test_case.0.insert(test_case.2, new_key.clone()); assert!(updated_trace_state.is_ok()); let updated_trace_state = updated_trace_state.unwrap(); @@ -533,4 +562,12 @@ mod tests { assert!(deleted_trace_state.get(test_case.2).is_none()); } } + + #[test] + fn test_trace_state_insert() { + let trace_state = TraceState::from_key_value(vec![("foo", "bar")]).unwrap(); + let inserted_trace_state = trace_state.insert("testkey", "testvalue").unwrap(); + assert!(trace_state.get("testkey").is_none()); // The original state doesn't change + assert_eq!(inserted_trace_state.get("testkey").unwrap(), "testvalue"); // + } }