Skip to content

Commit

Permalink
Add WaitForCompletionOrCreateCheckStatusResponseAsync to Microsoft.Az…
Browse files Browse the repository at this point in the history
…ure.Functions.Worker.DurableTaskClientExtensions (#2875)

* Initial implementation of WaitForCompletionOrCreateCheckStatusResponseAsync

* Support X-Forwarded-Host et al

* Removed output of request headers used in my debugging

* Set location header to include returnInternalServerErrorOnFailure=true if requested

* update api and add unit test

* update sortings

* Remove unnecessary spaces

* add back forword request handling and update test accordingly

* update by comment

* add summary

* update test

* remove x-original-forwarded as we shouldn't use this

* default getinputsandoutputs to false

* update test by comment

---------

Co-authored-by: [email protected] <[email protected]>
  • Loading branch information
dixonte and nytian authored Nov 5, 2024
1 parent 8470f3d commit 79e2295
Show file tree
Hide file tree
Showing 2 changed files with 346 additions and 3 deletions.
114 changes: 112 additions & 2 deletions src/Worker.Extensions.DurableTask/DurableTaskClientExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License. See License.txt in the project root for license information.

using System;
using System.Linq;
using System.Net;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -18,6 +19,70 @@ namespace Microsoft.Azure.Functions.Worker;
/// </summary>
public static class DurableTaskClientExtensions
{
/// <summary>
/// Waits for the completion of the specified orchestration instance with a retry interval, controlled by the cancellation token.
/// If the orchestration does not complete within the required time, returns an HTTP response containing the <see cref="HttpManagementPayload"/> class to manage instances.
/// </summary>
/// <param name="client">The <see cref="DurableTaskClient"/>.</param>
/// <param name="request">The HTTP request that this response is for.</param>
/// <param name="instanceId">The ID of the orchestration instance to check.</param>
/// <param name="retryInterval">The timeout between checks for output from the durable function. The default value is 1 second.</param>
/// <param name="returnInternalServerErrorOnFailure">Optional parameter that configures the http response code returned. Defaults to <c>false</c>.</param>
/// <param name="getInputsAndOutputs">Optional parameter that configures whether to get the inputs and outputs of the orchestration. Defaults to <c>false</c>.</param>
/// <param name="cancellation">A token that signals if the wait should be canceled. If canceled, call CreateCheckStatusResponseAsync to return a reponse contains a HttpManagementPayload.</param>
/// <returns></returns>
public static async Task<HttpResponseData> WaitForCompletionOrCreateCheckStatusResponseAsync(
this DurableTaskClient client,
HttpRequestData request,
string instanceId,
TimeSpan? retryInterval = null,
bool returnInternalServerErrorOnFailure = false,
bool getInputsAndOutputs = false,
CancellationToken cancellation = default
)
{
TimeSpan retryIntervalLocal = retryInterval ?? TimeSpan.FromSeconds(1);
try
{
while (true)
{
var status = await client.GetInstanceAsync(instanceId, getInputsAndOutputs: getInputsAndOutputs);
if (status != null)
{
if (status.RuntimeStatus == OrchestrationRuntimeStatus.Completed ||
#pragma warning disable CS0618 // Type or member is obsolete
status.RuntimeStatus == OrchestrationRuntimeStatus.Canceled ||
#pragma warning restore CS0618 // Type or member is obsolete
status.RuntimeStatus == OrchestrationRuntimeStatus.Terminated ||
status.RuntimeStatus == OrchestrationRuntimeStatus.Failed)
{
var response = request.CreateResponse(
(status.RuntimeStatus == OrchestrationRuntimeStatus.Failed && returnInternalServerErrorOnFailure) ? HttpStatusCode.InternalServerError : HttpStatusCode.OK);
await response.WriteAsJsonAsync(new
{
Name = status.Name,
InstanceId = status.InstanceId,
CreatedAt = status.CreatedAt,
LastUpdatedAt = status.LastUpdatedAt,
RuntimeStatus = status.RuntimeStatus.ToString(), // Convert enum to string
SerializedInput = status.SerializedInput,
SerializedOutput = status.SerializedOutput,
SerializedCustomStatus = status.SerializedCustomStatus
}, statusCode: response.StatusCode);

return response;
}
}
await Task.Delay(retryIntervalLocal, cancellation);
}
}
// If the task is canceled, call CreateCheckStatusResponseAsync to return a response containing instance management URLs.
catch (OperationCanceledException)
{
return await CreateCheckStatusResponseAsync(client, request, instanceId);
}
}

/// <summary>
/// Creates an HTTP response that is useful for checking the status of the specified instance.
/// </summary>
Expand Down Expand Up @@ -170,13 +235,13 @@ static string BuildUrl(string url, params string?[] queryValues)
// The base URL could be null if:
// 1. The DurableTaskClient isn't a FunctionsDurableTaskClient (which would have the baseUrl from bindings)
// 2. There's no valid HttpRequestData provided
string? baseUrl = ((request != null) ? request.Url.GetLeftPart(UriPartial.Authority) : GetBaseUrl(client));
string? baseUrl = ((request != null) ? GetBaseUrlFromRequest(request) : GetBaseUrl(client));

if (baseUrl == null)
{
throw new InvalidOperationException("Failed to create HTTP management payload as base URL is null. Either use Functions bindings or provide an HTTP request to create the HttpPayload.");
}

bool isFromRequest = request != null;

string formattedInstanceId = Uri.EscapeDataString(instanceId);
Expand Down Expand Up @@ -214,6 +279,51 @@ private static ObjectSerializer GetObjectSerializer(HttpResponseData response)
?? throw new InvalidOperationException("A serializer is not configured for the worker.");
}

private static string? GetBaseUrlFromRequest(HttpRequestData request)
{
// Default to the scheme from the request URL
string proto = request.Url.Scheme;
string host = request.Url.Authority;

// Check for "Forwarded" header
if (request.Headers.TryGetValues("Forwarded", out var forwardedHeaders))
{
var forwardedDict = forwardedHeaders.FirstOrDefault()?.Split(';')
.Select(pair => pair.Split('='))
.Where(pair => pair.Length == 2)
.ToDictionary(pair => pair[0].Trim(), pair => pair[1].Trim());

if (forwardedDict != null)
{
if (forwardedDict.TryGetValue("proto", out var forwardedProto))
{
proto = forwardedProto;
}
if (forwardedDict.TryGetValue("host", out var forwardedHost))
{
host = forwardedHost;
// Return if either proto or host (or both) were found in "Forwarded" header
return $"{proto}://{forwardedHost}";
}
}
}
// Check for "X-Forwarded-Proto" and "X-Forwarded-Host" headers if "Forwarded" is not present
if (request.Headers.TryGetValues("X-Forwarded-Proto", out var protos))
{
proto = protos.FirstOrDefault() ?? proto;
}
if (request.Headers.TryGetValues("X-Forwarded-Host", out var hosts))
{
// Return base URL if either "X-Forwarded-Proto" or "X-Forwarded-Host" (or both) are found
host = hosts.FirstOrDefault() ?? host;
return $"{proto}://{host}";
}

// Construct and return the base URL from default fallback values
return $"{proto}://{host}";
}


private static string? GetQueryParams(DurableTaskClient client)
{
return client is FunctionsDurableTaskClient functions ? functions.QueryString : null;
Expand Down
Loading

0 comments on commit 79e2295

Please sign in to comment.