diff --git a/src/System.Diagnostics.DiagnosticSource/src/System/Diagnostics/HttpHandlerDiagnosticListener.cs b/src/System.Diagnostics.DiagnosticSource/src/System/Diagnostics/HttpHandlerDiagnosticListener.cs index bbfb87c05521..3520b2e0b92d 100644 --- a/src/System.Diagnostics.DiagnosticSource/src/System/Diagnostics/HttpHandlerDiagnosticListener.cs +++ b/src/System.Diagnostics.DiagnosticSource/src/System/Diagnostics/HttpHandlerDiagnosticListener.cs @@ -87,7 +87,7 @@ private void Initialize() } } -#region private helper classes + #region private helper classes private class HashtableWrapper : Hashtable, IEnumerable { @@ -603,25 +603,52 @@ private void RaiseRequestEvent(HttpWebRequest request) activity.Start(); } - request.Headers.Add(RequestIdHeaderName, activity.Id); - // we expect baggage to be empty or contain a few items - using (IEnumerator> e = activity.Baggage.GetEnumerator()) + if (activity.IdFormat == ActivityIdFormat.W3C) { - if (e.MoveNext()) + // do not inject header if it was injected already + // perhaps tracing systems wants to override it + if (request.Headers.Get(TraceParentHeaderName) == null) { - StringBuilder baggage = new StringBuilder(); - do + request.Headers.Add(TraceParentHeaderName, activity.Id); + + var traceState = activity.TraceStateString; + if (traceState != null) + { + request.Headers.Add(TraceStateHeaderName, traceState); + } + } + } + else + { + // do not inject header if it was injected already + // perhaps tracing systems wants to override it + if (request.Headers.Get(RequestIdHeaderName) == null) + { + request.Headers.Add(RequestIdHeaderName, activity.Id); + } + } + + if (request.Headers.Get(CorrelationContextHeaderName) == null) + { + // we expect baggage to be empty or contain a few items + using (IEnumerator> e = activity.Baggage.GetEnumerator()) + { + if (e.MoveNext()) { - KeyValuePair item = e.Current; - baggage.Append(item.Key).Append('=').Append(item.Value).Append(','); + StringBuilder baggage = new StringBuilder(); + do + { + KeyValuePair item = e.Current; + baggage.Append(item.Key).Append('=').Append(item.Value).Append(','); + } + while (e.MoveNext()); + baggage.Remove(baggage.Length - 1, 1); + request.Headers.Add(CorrelationContextHeaderName, baggage.ToString()); } - while (e.MoveNext()); - baggage.Remove(baggage.Length - 1, 1); - request.Headers.Add(CorrelationContextHeaderName, baggage.ToString()); } } - // There is no gurantee that Activity.Current will flow to the Response, so let's stop it here + // There is no guarantee that Activity.Current will flow to the Response, so let's stop it here activity.Stop(); } } @@ -631,7 +658,7 @@ private void RaiseResponseEvent(HttpWebRequest request, HttpWebResponse response // Response event could be received several times for the same request in case it was redirected // IsLastResponse checks if response is the last one (no more redirects will happen) // based on response StatusCode and number or redirects done so far - if (request.Headers[RequestIdHeaderName] != null && IsLastResponse(request, response.StatusCode)) + if (request.Headers.Get(RequestIdHeaderName) != null && IsLastResponse(request, response.StatusCode)) { // only send Stop if request was instrumented this.Write(RequestStopName, new { Request = request, Response = response }); @@ -643,7 +670,7 @@ private void RaiseResponseEvent(HttpWebRequest request, HttpStatusCode statusCod // Response event could be received several times for the same request in case it was redirected // IsLastResponse checks if response is the last one (no more redirects will happen) // based on response StatusCode and number or redirects done so far - if (request.Headers[RequestIdHeaderName] != null && IsLastResponse(request, statusCode)) + if (request.Headers.Get(RequestIdHeaderName) != null && IsLastResponse(request, statusCode)) { this.Write(RequestStopExName, new { Request = request, StatusCode = statusCode, Headers = headers }); } @@ -653,10 +680,10 @@ private bool IsLastResponse(HttpWebRequest request, HttpStatusCode statusCode) { if (request.AllowAutoRedirect) { - if (statusCode == HttpStatusCode.Ambiguous || // 300 - statusCode == HttpStatusCode.Moved || // 301 - statusCode == HttpStatusCode.Redirect || // 302 - statusCode == HttpStatusCode.RedirectMethod || // 303 + if (statusCode == HttpStatusCode.Ambiguous || // 300 + statusCode == HttpStatusCode.Moved || // 301 + statusCode == HttpStatusCode.Redirect || // 302 + statusCode == HttpStatusCode.RedirectMethod || // 303 statusCode == HttpStatusCode.RedirectKeepVerb || // 307 (int)statusCode == 308) // 308 Permanent Redirect is not in netfx yet, and so has to be specified this way. { @@ -696,7 +723,7 @@ private static void PrepareReflectionObjects() s_connectionType == null || s_writeListField == null || s_httpResponseAccessor == null || - s_autoRedirectsAccessor == null || + s_autoRedirectsAccessor == null || s_coreResponseDataType == null || s_coreStatusCodeAccessor == null || s_coreHeadersAccessor == null) @@ -727,7 +754,7 @@ private static Func CreateFieldGetter(string fie if (field != null) { string methodName = field.ReflectedType.FullName + ".get_" + field.Name; - DynamicMethod getterMethod = new DynamicMethod(methodName, typeof(TField), new [] { typeof(TClass) }, true); + DynamicMethod getterMethod = new DynamicMethod(methodName, typeof(TField), new[] { typeof(TClass) }, true); ILGenerator generator = getterMethod.GetILGenerator(); generator.Emit(OpCodes.Ldarg_0); generator.Emit(OpCodes.Ldfld, field); @@ -749,7 +776,7 @@ private static Func CreateFieldGetter(Type classType, st if (field != null) { string methodName = classType.FullName + ".get_" + field.Name; - DynamicMethod getterMethod = new DynamicMethod(methodName, typeof(TField), new [] { typeof(object) }, true); + DynamicMethod getterMethod = new DynamicMethod(methodName, typeof(TField), new[] { typeof(object) }, true); ILGenerator generator = getterMethod.GetILGenerator(); generator.Emit(OpCodes.Ldarg_0); generator.Emit(OpCodes.Castclass, classType); @@ -766,7 +793,7 @@ private static Func CreateFieldGetter(Type classType, st internal static HttpHandlerDiagnosticListener s_instance = new HttpHandlerDiagnosticListener(); -#region private fields + #region private fields private const string DiagnosticListenerName = "System.Net.Http.Desktop"; private const string ActivityName = "System.Net.Http.Desktop.HttpRequestOut"; private const string RequestStartName = "System.Net.Http.Desktop.HttpRequestOut.Start"; @@ -775,6 +802,8 @@ private static Func CreateFieldGetter(Type classType, st private const string InitializationFailed = "System.Net.Http.InitializationFailed"; private const string RequestIdHeaderName = "Request-Id"; private const string CorrelationContextHeaderName = "Correlation-Context"; + private const string TraceParentHeaderName = "traceparent"; + private const string TraceStateHeaderName = "tracestate"; // Fields for controlling initialization of the HttpHandlerDiagnosticListener singleton private bool initialized = false; diff --git a/src/System.Diagnostics.DiagnosticSource/tests/ActivityTests.cs b/src/System.Diagnostics.DiagnosticSource/tests/ActivityTests.cs index 2047294e9530..06d4a92aa80e 100644 --- a/src/System.Diagnostics.DiagnosticSource/tests/ActivityTests.cs +++ b/src/System.Diagnostics.DiagnosticSource/tests/ActivityTests.cs @@ -288,7 +288,7 @@ public static bool IdIsW3CFormat(string id) return false; if (id[52] != '-') return false; - return Regex.IsMatch(id, "^[0-9a-f][0-9a-f]-[0-9a-f]*-[0-9a-f]*-[0-9a-f][0-9a-f]$"); + return Regex.IsMatch(id, "^[0-9a-f]{2}-[0-9a-f]{32}-[0-9a-f]{16}-[0-9a-f]{2}$"); } public static bool IsLowerCaseHex(string s) @@ -341,13 +341,13 @@ public void ActivityTraceIdTests() // Use in Dictionary (this does assume we have no collisions in IDs over 100 tries (very good). var dict = new Dictionary(); - for(int i = 0; i < 100; i++) + for (int i = 0; i < 100; i++) { var newId7 = ActivityTraceId.CreateRandom(); dict[newId7] = newId7.ToHexString(); } int ctr = 0; - foreach(string value in dict.Values) + foreach (string value in dict.Values) { string valueInDict; Assert.True(dict.TryGetValue(ActivityTraceId.CreateFromString(value.AsSpan()), out valueInDict)); diff --git a/src/System.Diagnostics.DiagnosticSource/tests/HttpHandlerDiagnosticListenerTests.cs b/src/System.Diagnostics.DiagnosticSource/tests/HttpHandlerDiagnosticListenerTests.cs index b947e552d0c6..48acf6bb24dd 100644 --- a/src/System.Diagnostics.DiagnosticSource/tests/HttpHandlerDiagnosticListenerTests.cs +++ b/src/System.Diagnostics.DiagnosticSource/tests/HttpHandlerDiagnosticListenerTests.cs @@ -8,6 +8,7 @@ using System.Net; using System.Net.Http; using System.Reflection; +using System.Text.RegularExpressions; using System.Threading; using System.Threading.Tasks; using Xunit; @@ -131,6 +132,8 @@ public async Task TestBasicReceiveAndResponseEvents() HttpWebRequest startRequest = ReadPublicProperty(startEvent.Value, "Request"); Assert.NotNull(startRequest); Assert.NotNull(startRequest.Headers["Request-Id"]); + Assert.Null(startRequest.Headers["traceparent"]); + Assert.Null(startRequest.Headers["tracestate"]); KeyValuePair stopEvent; Assert.True(eventRecords.Records.TryDequeue(out stopEvent)); @@ -142,6 +145,142 @@ public async Task TestBasicReceiveAndResponseEvents() } } + [OuterLoop] + [Fact] + public async Task TestW3CHeaders() + { + try + { + using (var eventRecords = new EventObserverAndRecorder()) + { + Activity.DefaultIdFormat = ActivityIdFormat.W3C; + Activity.ForceDefaultIdFormat = true; + // Send a random Http request to generate some events + using (var client = new HttpClient()) + { + (await client.GetAsync(Configuration.Http.RemoteEchoServer)).Dispose(); + } + + // Check to make sure: The first record must be a request, the next record must be a response. + KeyValuePair startEvent; + Assert.True(eventRecords.Records.TryDequeue(out startEvent)); + Assert.Equal("System.Net.Http.Desktop.HttpRequestOut.Start", startEvent.Key); + HttpWebRequest startRequest = ReadPublicProperty(startEvent.Value, "Request"); + Assert.NotNull(startRequest); + + var traceparent = startRequest.Headers["traceparent"]; + Assert.NotNull(traceparent); + Assert.True(Regex.IsMatch(traceparent, "^[0-9a-f][0-9a-f]-[0-9a-f]{32}-[0-9a-f]{16}-[0-9a-f][0-9a-f]$")); + Assert.Null(startRequest.Headers["tracestate"]); + Assert.Null(startRequest.Headers["Request-Id"]); + } + } + finally + { + CleanUp(); + } + } + + [OuterLoop] + [Fact] + public async Task TestW3CHeadersTraceStateAndCorrelationContext() + { + try + { + using (var eventRecords = new EventObserverAndRecorder()) + { + var parent = new Activity("w3c activity"); + parent.SetParentId(ActivityTraceId.CreateRandom(), ActivitySpanId.CreateRandom()); + parent.TraceStateString = "some=state"; + parent.AddBaggage("k", "v"); + parent.Start(); + + // Send a random Http request to generate some events + using (var client = new HttpClient()) + { + (await client.GetAsync(Configuration.Http.RemoteEchoServer)).Dispose(); + } + + parent.Stop(); + + // Check to make sure: The first record must be a request, the next record must be a response. + Assert.True(eventRecords.Records.TryDequeue(out var evnt)); + Assert.Equal("System.Net.Http.Desktop.HttpRequestOut.Start", evnt.Key); + HttpWebRequest startRequest = ReadPublicProperty(evnt.Value, "Request"); + Assert.NotNull(startRequest); + + var traceparent = startRequest.Headers["traceparent"]; + var tracestate = startRequest.Headers["tracestate"]; + var correlationContext = startRequest.Headers["Correlation-Context"]; + Assert.NotNull(traceparent); + Assert.Equal("some=state", tracestate); + Assert.Equal("k=v", correlationContext); + Assert.True(traceparent.StartsWith($"00-{parent.TraceId.ToHexString()}-")); + Assert.True(Regex.IsMatch(traceparent, "^[0-9a-f]{2}-[0-9a-f]{32}-[0-9a-f]{16}-[0-9a-f]{2}$")); + Assert.Null(startRequest.Headers["Request-Id"]); + } + } + finally + { + CleanUp(); + } + } + + + [OuterLoop] + [Fact] + public async Task DoNotInjectRequestIdWhenPresent() + { + using (var eventRecords = new EventObserverAndRecorder()) + { + // Send a random Http request to generate some events + using (var client = new HttpClient()) + using (var request = new HttpRequestMessage(HttpMethod.Get, Configuration.Http.RemoteEchoServer)) + { + request.Headers.Add("Request-Id", "|rootId.1."); + (await client.SendAsync(request)).Dispose(); + } + + // Check to make sure: The first record must be a request, the next record must be a response. + Assert.True(eventRecords.Records.TryDequeue(out var evnt)); + HttpWebRequest startRequest = ReadPublicProperty(evnt.Value, "Request"); + Assert.NotNull(startRequest); + Assert.Equal("|rootId.1.", startRequest.Headers["Request-Id"]); + } + } + + [OuterLoop] + [Fact] + public async Task DoNotInjectTraceParentWhenPresent() + { + try + { + using (var eventRecords = new EventObserverAndRecorder()) + { + Activity.DefaultIdFormat = ActivityIdFormat.W3C; + Activity.ForceDefaultIdFormat = true; + // Send a random Http request to generate some events + using (var client = new HttpClient()) + using (var request = new HttpRequestMessage(HttpMethod.Get, Configuration.Http.RemoteEchoServer)) + { + request.Headers.Add("traceparent", "00-abcdef0123456789abcdef0123456789-abcdef0123456789-01"); + (await client.SendAsync(request)).Dispose(); + } + + // Check to make sure: The first record must be a request, the next record must be a response. + Assert.True(eventRecords.Records.TryDequeue(out var evnt)); + HttpWebRequest startRequest = ReadPublicProperty(evnt.Value, "Request"); + Assert.NotNull(startRequest); + + Assert.Equal("00-abcdef0123456789abcdef0123456789-abcdef0123456789-01", startRequest.Headers["traceparent"]); + } + } + finally + { + CleanUp(); + } + } + /// /// Test to make sure we get both request and response events. /// @@ -237,7 +376,7 @@ await Assert.ThrowsAsync( public async Task TestCanceledRequest() { CancellationTokenSource cts = new CancellationTokenSource(TimeSpan.FromSeconds(10)); - using (var eventRecords = new EventObserverAndRecorder( _ => { cts.Cancel();})) + using (var eventRecords = new EventObserverAndRecorder(_ => { cts.Cancel(); })) { using (var client = new HttpClient()) { @@ -486,6 +625,18 @@ public void TestMultipleConcurrentRequests() } } + + public void CleanUp() + { + Activity.DefaultIdFormat = ActivityIdFormat.Hierarchical; + Activity.ForceDefaultIdFormat = false; + + while (Activity.Current != null) + { + Activity.Current.Stop(); + } + } + private static T ReadPublicProperty(object obj, string propertyName) { Type type = obj.GetType();