diff --git a/dotnet/docs/EXPERIMENTS.md b/dotnet/docs/EXPERIMENTS.md index a45db90a77ab..5974b9946848 100644 --- a/dotnet/docs/EXPERIMENTS.md +++ b/dotnet/docs/EXPERIMENTS.md @@ -26,6 +26,7 @@ You can use the following diagnostic IDs to ignore warnings or errors for a part - SKEXP0013: OpenAI parameters - SKEXP0014: OpenAI chat history extension - SKEXP0015: OpenAI file service +- SKEXP0016: OpenAI tool call filters ## Memory connectors diff --git a/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs b/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs index 4ef0fe2d5de9..1de30b590d7d 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs @@ -29,6 +29,8 @@ namespace Microsoft.SemanticKernel.Connectors.OpenAI; /// internal abstract class ClientCore { + private const string ModelIterationsCompletedKey = "ModelIterationsCompleted"; + private const int MaxResultsPerPrompt = 128; /// @@ -176,25 +178,27 @@ internal async IAsyncEnumerable GetStreamingTextContentsAs }; } - private static Dictionary GetResponseMetadata(ChatCompletions completions) + private static Dictionary GetResponseMetadata(ChatCompletions completions, int modelIterations) { - return new Dictionary(5) + return new Dictionary(6) { { nameof(completions.Id), completions.Id }, { nameof(completions.Created), completions.Created }, { nameof(completions.PromptFilterResults), completions.PromptFilterResults }, { nameof(completions.SystemFingerprint), completions.SystemFingerprint }, { nameof(completions.Usage), completions.Usage }, + { ModelIterationsCompletedKey, modelIterations }, }; } - private static Dictionary GetResponseMetadata(StreamingChatCompletionsUpdate completions) + private static Dictionary GetResponseMetadata(StreamingChatCompletionsUpdate completions, int modelIterations) { - return new Dictionary(3) + return new Dictionary(4) { { nameof(completions.Id), completions.Id }, { nameof(completions.Created), completions.Created }, { nameof(completions.SystemFingerprint), completions.SystemFingerprint }, + { ModelIterationsCompletedKey, modelIterations }, }; } @@ -303,7 +307,7 @@ internal async Task> GetChatMessageContentsAsy throw new KernelException("Chat completions not found"); } - IReadOnlyDictionary metadata = GetResponseMetadata(responseData); + IReadOnlyDictionary metadata = GetResponseMetadata(responseData, iteration); // If we don't want to attempt to invoke any functions, just return the result. // Or if we are auto-invoking but we somehow end up with other than 1 choice even though only 1 was requested, similarly bail. @@ -367,6 +371,19 @@ internal async Task> GetChatMessageContentsAsy continue; } + try + { + // Invoke the pre-invocation filter. + var invokingContext = chatExecutionSettings.ToolCallBehavior?.OnToolInvokingFilter(openAIFunctionToolCall, chat, iteration); + this.ApplyToolFilterContextChanges(invokingContext, chatOptions, chat, chatExecutionSettings, ref autoInvoke); + } + catch (OperationCanceledException) + { + // Add cancellation message to chat history and bail out of any remaining tool calls + AddResponseMessage(chatOptions, chat, null, $"A tool filter requested cancellation before tool invocation. Model iterations completed: {iteration}", toolCall.Id, this.Logger); + break; + } + // Make sure the requested function is one we requested. If we're permitting any kernel function to be invoked, // then we don't need to check this, as it'll be handled when we look up the function in the kernel to be able // to invoke it. If we're permitting only a specific list of functions, though, then we need to explicitly check. @@ -395,7 +412,7 @@ internal async Task> GetChatMessageContentsAsy functionResult = (await function.InvokeAsync(kernel, functionArgs, cancellationToken: cancellationToken).ConfigureAwait(false)).GetValue() ?? string.Empty; } #pragma warning disable CA1031 // Do not catch general exception types - catch (Exception e) + catch (Exception e) when (!e.IsCriticalException()) #pragma warning restore CA1031 { AddResponseMessage(chatOptions, chat, null, $"Error: Exception while invoking function. {e.Message}", toolCall.Id, this.Logger); @@ -407,6 +424,18 @@ internal async Task> GetChatMessageContentsAsy } AddResponseMessage(chatOptions, chat, functionResult as string ?? JsonSerializer.Serialize(functionResult), errorMessage: null, toolCall.Id, this.Logger); + try + { + // Invoke the post-invocation filter. + var invokedContext = chatExecutionSettings.ToolCallBehavior?.OnToolInvokedFilter(openAIFunctionToolCall, functionResult, chat, iteration); + this.ApplyToolFilterContextChanges(invokedContext, chatOptions, chat, chatExecutionSettings, ref autoInvoke); + } + catch (OperationCanceledException) + { + // The tool call already happened so we can't cancel it, but bail out of any remaining tool calls + break; + } + static void AddResponseMessage(ChatCompletionsOptions chatOptions, ChatHistory chat, string? result, string? errorMessage, string toolId, ILogger logger) { // Log any error @@ -447,6 +476,58 @@ static void AddResponseMessage(ChatCompletionsOptions chatOptions, ChatHistory c } } + private void ApplyToolFilterContextChanges( + ToolFilterContext? context, + ChatCompletionsOptions chatOptions, + ChatHistory chatHistory, + OpenAIPromptExecutionSettings executionSettings, + ref bool autoInvoke) + { + if (context is not null) + { + // Since the tool filter has access to the chat history, the chat history may have been modified. + // We want to make sure any subsequent requests to the model reflect these changes. The chatOptions object + // contains all the configuration information for a chat request, including a copy of the chat history. + // So we need to update the chat history stored in the chatOptions object to match what is in the chatHistory object. + this.UpdateChatOptions(chatOptions, chatHistory, executionSettings); + + // Check if filter has requested a stop + this.HandleStopBehavior(context, chatOptions, ref autoInvoke); + } + } + + private void HandleStopBehavior(ToolFilterContext context, ChatCompletionsOptions chatOptions, ref bool autoInvoke) + { + switch (context.StopBehavior) + { + case ToolFilterStopBehavior.StopAutoInvoke: + autoInvoke = false; + break; + case ToolFilterStopBehavior.StopTools: + chatOptions.ToolChoice = ChatCompletionsToolChoice.None; + break; + case ToolFilterStopBehavior.Cancel: + chatOptions.ToolChoice = ChatCompletionsToolChoice.None; + throw new OperationCanceledException(); + } + } + + private void UpdateChatOptions(ChatCompletionsOptions options, ChatHistory chatHistory, OpenAIPromptExecutionSettings executionSettings) + { + // Clear out messages, then copy over from chat history + options.Messages.Clear(); + + if (!string.IsNullOrWhiteSpace(executionSettings?.ChatSystemPrompt) && !chatHistory.Any(m => m.Role == AuthorRole.System)) + { + options.Messages.Add(GetRequestMessage(new ChatMessageContent(AuthorRole.System, executionSettings!.ChatSystemPrompt))); + } + + foreach (var message in chatHistory) + { + options.Messages.Add(GetRequestMessage(message)); + } + } + internal async IAsyncEnumerable GetStreamingChatMessageContentsAsync( ChatHistory chat, PromptExecutionSettings? executionSettings, @@ -485,7 +566,7 @@ internal async IAsyncEnumerable GetStreamingC CompletionsFinishReason finishReason = default; await foreach (StreamingChatCompletionsUpdate update in response.ConfigureAwait(false)) { - metadata ??= GetResponseMetadata(update); + metadata ??= GetResponseMetadata(update, iteration); streamedRole ??= update.Role; finishReason = update.FinishReason ?? default; @@ -557,6 +638,19 @@ internal async IAsyncEnumerable GetStreamingC continue; } + try + { + // Invoke the pre-invocation filter. + var invokingContext = chatExecutionSettings.ToolCallBehavior?.OnToolInvokingFilter(openAIFunctionToolCall, chat, iteration); + this.ApplyToolFilterContextChanges(invokingContext, chatOptions, chat, chatExecutionSettings, ref autoInvoke); + } + catch (OperationCanceledException) + { + // Add cancellation message to chat history and bail out of any remaining tool calls + AddResponseMessage(chatOptions, chat, streamedRole, toolCall, metadata, null, $"A tool filter requested cancellation before tool invocation. Model iterations completed: {iteration}", this.Logger); + break; + } + // Make sure the requested function is one we requested. If we're permitting any kernel function to be invoked, // then we don't need to check this, as it'll be handled when we look up the function in the kernel to be able // to invoke it. If we're permitting only a specific list of functions, though, then we need to explicitly check. @@ -585,7 +679,7 @@ internal async IAsyncEnumerable GetStreamingC functionResult = (await function.InvokeAsync(kernel, functionArgs, cancellationToken: cancellationToken).ConfigureAwait(false)).GetValue() ?? string.Empty; } #pragma warning disable CA1031 // Do not catch general exception types - catch (Exception e) + catch (Exception e) when (!e.IsCriticalException()) #pragma warning restore CA1031 { AddResponseMessage(chatOptions, chat, streamedRole, toolCall, metadata, result: null, $"Error: Exception while invoking function. {e.Message}", this.Logger); @@ -597,6 +691,18 @@ internal async IAsyncEnumerable GetStreamingC } AddResponseMessage(chatOptions, chat, streamedRole, toolCall, metadata, functionResult as string ?? JsonSerializer.Serialize(functionResult), errorMessage: null, this.Logger); + try + { + // Invoke the post-invocation filter. + var invokedContext = chatExecutionSettings.ToolCallBehavior?.OnToolInvokedFilter(openAIFunctionToolCall, functionResult, chat, iteration); + this.ApplyToolFilterContextChanges(invokedContext, chatOptions, chat, chatExecutionSettings, ref autoInvoke); + } + catch (OperationCanceledException) + { + // This tool call already happened so we can't cancel it, but bail out of any remaining tool calls + break; + } + static void AddResponseMessage( ChatCompletionsOptions chatOptions, ChatHistory chat, ChatRole? streamedRole, ChatCompletionsToolCall tool, IReadOnlyDictionary? metadata, string? result, string? errorMessage, ILogger logger) diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Connectors.OpenAI.csproj b/dotnet/src/Connectors/Connectors.OpenAI/Connectors.OpenAI.csproj index 85bea2268368..b66184035320 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI/Connectors.OpenAI.csproj +++ b/dotnet/src/Connectors/Connectors.OpenAI/Connectors.OpenAI.csproj @@ -6,7 +6,7 @@ $(AssemblyName) netstandard2.0 true - $(NoWarn);NU5104;SKEXP0005,SKEXP0013,SKEXP0014 + $(NoWarn);NU5104;SKEXP0005,SKEXP0013,SKEXP0014,SKEXP0016 true diff --git a/dotnet/src/Connectors/Connectors.OpenAI/IToolFilter.cs b/dotnet/src/Connectors/Connectors.OpenAI/IToolFilter.cs new file mode 100644 index 000000000000..6b7820db689b --- /dev/null +++ b/dotnet/src/Connectors/Connectors.OpenAI/IToolFilter.cs @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.SemanticKernel.Connectors.OpenAI; + +/// +/// Interface for tool filters. +/// +[Experimental("SKEXP0016")] +public interface IToolFilter +{ + /// + /// Method which is executed before tool invocation. + /// + /// Data related to tool before invocation. + void OnToolInvoking(ToolInvokingContext context); + + /// + /// Method which is executed after tool invocation. + /// + /// Data related to tool after invocation. + void OnToolInvoked(ToolInvokedContext context); +} diff --git a/dotnet/src/Connectors/Connectors.OpenAI/ToolCallBehavior.cs b/dotnet/src/Connectors/Connectors.OpenAI/ToolCallBehavior.cs index 2650775f034b..447745f5cb83 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI/ToolCallBehavior.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI/ToolCallBehavior.cs @@ -2,8 +2,10 @@ using System.Collections.Generic; using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; using System.Linq; using Azure.AI.OpenAI; +using Microsoft.SemanticKernel.ChatCompletion; namespace Microsoft.SemanticKernel.Connectors.OpenAI; @@ -35,6 +37,12 @@ public abstract class ToolCallBehavior /// private const int DefaultMaximumAutoInvokeAttempts = 5; + /// + /// Gets the collection of filters that will be applied to tool calls. + /// + [Experimental("SKEXP0016")] + public IList Filters { get; } = new List(); + /// /// Gets an instance that will provide all of the 's plugins' function information. /// Function call requests from the model will be propagated back to the caller. @@ -236,4 +244,40 @@ internal override void ConfigureOptions(Kernel? kernel, ChatCompletionsOptions o /// internal override int MaximumUseAttempts => 1; } + + #region Filters + internal ToolInvokingContext? OnToolInvokingFilter(OpenAIFunctionToolCall toolCall, ChatHistory chatHistory, int iteration) + { + ToolInvokingContext? context = null; + + if (this.Filters is { Count: > 0 }) + { + context = new(toolCall, chatHistory, iteration); + + for (int i = 0; i < this.Filters.Count; i++) + { + this.Filters[i].OnToolInvoking(context); + } + } + + return context; + } + + internal ToolInvokedContext? OnToolInvokedFilter(OpenAIFunctionToolCall toolCall, object? result, ChatHistory chatHistory, int iteration) + { + ToolInvokedContext? context = null; + + if (this.Filters is { Count: > 0 }) + { + context = new(toolCall, result, chatHistory, iteration); + + for (int i = 0; i < this.Filters.Count; i++) + { + this.Filters[i].OnToolInvoked(context); + } + } + + return context; + } + #endregion } diff --git a/dotnet/src/Connectors/Connectors.OpenAI/ToolFilterContext.cs b/dotnet/src/Connectors/Connectors.OpenAI/ToolFilterContext.cs new file mode 100644 index 000000000000..8180e02baefd --- /dev/null +++ b/dotnet/src/Connectors/Connectors.OpenAI/ToolFilterContext.cs @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; +using Microsoft.SemanticKernel.ChatCompletion; + +namespace Microsoft.SemanticKernel.Connectors.OpenAI; + +/// +/// Enum describing the different ways tool calling can be stopped. +/// +[Experimental("SKEXP0016")] +public enum ToolFilterStopBehavior +{ + /// + /// Continue using tools + /// + None, + + /// + /// Cancel the current tool call, and don't invoke or request any more tools + /// + Cancel, + + /// + /// Invoke the current tool call(s) but don't request any more tools + /// + StopTools, + + /// + /// Continue requesting tools, but turn off auto-invoke + /// + StopAutoInvoke +}; + +/// +/// Base class with data related to tool invocation. +/// +public abstract class ToolFilterContext +{ + /// + /// Initializes a new instance of the class. + /// + /// The with which this filter is associated. + /// The chat history associated with the operation. + /// The number of model iterations completed thus far for the request. + internal ToolFilterContext(OpenAIFunctionToolCall toolCall, ChatHistory chatHistory, int modelIterations) + { + Verify.NotNull(toolCall); + + this.ToolCall = toolCall; + this.ChatHistory = chatHistory; + this.ModelIterations = modelIterations; + } + + /// + /// Gets the tool call associated with this filter. + /// + public OpenAIFunctionToolCall ToolCall { get; } + + /// + /// Gets the chat history associated with the operation. + /// + public ChatHistory ChatHistory { get; } + + /// + /// Gets the number of model iterations that have been completed for the request so far. + /// + public int ModelIterations { get; } + + /// + /// Gets or sets a value indicating whether subsequent tool calls should be stopped, + /// and if so, which stop behavior should be followed. + /// + /// + /// If there are multiple filters registered, subsequent filters + /// may see and change a value set by a previous filter. The final result is what will + /// be considered by the component that triggers filter. + /// + public ToolFilterStopBehavior StopBehavior { get; set; } = ToolFilterStopBehavior.None; +} diff --git a/dotnet/src/Connectors/Connectors.OpenAI/ToolInvokedContext.cs b/dotnet/src/Connectors/Connectors.OpenAI/ToolInvokedContext.cs new file mode 100644 index 000000000000..b68e806e0c79 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.OpenAI/ToolInvokedContext.cs @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; +using Microsoft.SemanticKernel.ChatCompletion; + +namespace Microsoft.SemanticKernel.Connectors.OpenAI; + +/// +/// Class with data related to tool after invocation. +/// +[Experimental("SKEXP0016")] +public sealed class ToolInvokedContext : ToolFilterContext +{ + /// + /// Initializes a new instance of the class. + /// + /// The with which this filter is associated. + /// The result of the tool's invocation. + /// The chat history associated with the operation. + /// The number of model iterations completed thus far for the request. + public ToolInvokedContext(OpenAIFunctionToolCall toolCall, object? result, ChatHistory chatHistory, int modelIterations) + : base(toolCall, chatHistory, modelIterations) + { + this.Result = result; + } + + /// + /// Gets the result of the tool's invocation. + /// + public object? Result { get; } +} diff --git a/dotnet/src/Connectors/Connectors.OpenAI/ToolInvokingContext.cs b/dotnet/src/Connectors/Connectors.OpenAI/ToolInvokingContext.cs new file mode 100644 index 000000000000..426ccb045395 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.OpenAI/ToolInvokingContext.cs @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; +using Microsoft.SemanticKernel.ChatCompletion; + +namespace Microsoft.SemanticKernel.Connectors.OpenAI; + +/// +/// Class with data related to tool before invocation. +/// +[Experimental("SKEXP0016")] +public sealed class ToolInvokingContext : ToolFilterContext +{ + /// + /// Initializes a new instance of the class. + /// + /// The with which this filter is associated. + /// The chat history associated with the operation. + /// The number of model iterations completed thus far for the request. + public ToolInvokingContext(OpenAIFunctionToolCall toolCall, ChatHistory chatHistory, int modelIterations) + : base(toolCall, chatHistory, modelIterations) + { + } +} diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Connectors.UnitTests.csproj b/dotnet/src/Connectors/Connectors.UnitTests/Connectors.UnitTests.csproj index 9e5be20bb853..cb7f6220db04 100644 --- a/dotnet/src/Connectors/Connectors.UnitTests/Connectors.UnitTests.csproj +++ b/dotnet/src/Connectors/Connectors.UnitTests/Connectors.UnitTests.csproj @@ -10,7 +10,7 @@ enable disable false - CA2007,CA1806,CA1869,CA1861,IDE0300,VSTHRD111,SKEXP0001,SKEXP0002,SKEXP0003,SKEXP0004,SKEXP0005,SKEXP0010,SKEXP0011,SKEXP0012,SKEXP0013,SKEXP0014,SKEXP0015,SKEXP0020,SKEXP0021,SKEXP0022,SKEXP0023,SKEXP0024,SKEXP0025,SKEXP0026,SKEXP0027,SKEXP0028,SKEXP0029,SKEXP0030,SKEXP0031,SKEXP0032,SKEXP0052 + CA2007,CA1806,CA1869,CA1861,IDE0300,VSTHRD111,SKEXP0001,SKEXP0002,SKEXP0003,SKEXP0004,SKEXP0005,SKEXP0010,SKEXP0011,SKEXP0012,SKEXP0013,SKEXP0014,SKEXP0015,SKEXP0016,SKEXP0020,SKEXP0021,SKEXP0022,SKEXP0023,SKEXP0024,SKEXP0025,SKEXP0026,SKEXP0027,SKEXP0028,SKEXP0029,SKEXP0030,SKEXP0031,SKEXP0032,SKEXP0052 diff --git a/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/ChatCompletion/OpenAIChatCompletionServiceTests.cs b/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/ChatCompletion/OpenAIChatCompletionServiceTests.cs index ebc14928d444..e07bd54dd1ee 100644 --- a/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/ChatCompletion/OpenAIChatCompletionServiceTests.cs +++ b/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/ChatCompletion/OpenAIChatCompletionServiceTests.cs @@ -281,7 +281,7 @@ public void Dispose() ""role"": ""assistant"", ""content"": null, ""function_call"": { - ""name"": ""TimePlugin_Date"", + ""name"": ""TimePlugin-Date"", ""arguments"": ""{}"" } }, diff --git a/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/ChatCompletion/ToolFilterTests.cs b/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/ChatCompletion/ToolFilterTests.cs new file mode 100644 index 000000000000..3bfce01f3f3b --- /dev/null +++ b/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/ChatCompletion/ToolFilterTests.cs @@ -0,0 +1,626 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Text; +using System.Text.Json; +using System.Threading.Tasks; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.OpenAI; +using Xunit; + +namespace SemanticKernel.Connectors.UnitTests.OpenAI.ChatCompletion; + +public sealed class ToolFilterTests : IDisposable +{ + private readonly MultipleHttpMessageHandlerStub _messageHandlerStub; + private readonly HttpClient _httpClient; + private readonly OpenAIChatCompletionService _service; + private readonly OpenAIPromptExecutionSettings _settings; + + public ToolFilterTests() + { + this._messageHandlerStub = new MultipleHttpMessageHandlerStub(); + this._httpClient = new HttpClient(this._messageHandlerStub, false); + + this._service = new OpenAIChatCompletionService(modelId: "gpt-3.5-turbo", apiKey: "NOKEY", httpClient: this._httpClient); + this._settings = new OpenAIPromptExecutionSettings() { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions }; + } + + [Fact] + public async Task PreInvocationToolFilterIsTriggeredAsync() + { + // Arrange + var toolInvocations = 0; + var filterInvocations = 0; + + var kernel = new Kernel(); + kernel.ImportPluginFromObject(new FakePlugin(() => toolInvocations++)); + + this._settings.ToolCallBehavior!.Filters.Clear(); + this._settings.ToolCallBehavior.Filters.Add( + new FakeToolFilter(onToolInvoking: (context) => + { + filterInvocations++; + })); + + using var response1 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(ToolResponseNoArgs) }; + using var response2 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(OpenAITestHelper.GetTestResponse("chat_completion_test_response.json")) }; + this._messageHandlerStub.ResponsesToReturn = [response1, response2]; + + // Act + var result = await this._service.GetChatMessageContentsAsync([], this._settings, kernel); + + // Assert + Assert.Equal(1, toolInvocations); + Assert.Equal(1, filterInvocations); + } + + [Fact] + public async Task PreInvocationToolFilterChangesArgumentAsync() + { + // Arrange + const string NewInput = "newValue"; + + var kernel = new Kernel(); + kernel.ImportPluginFromObject(new FakePluginWithArg((string originalInput) => originalInput)); + + this._settings.ToolCallBehavior!.Filters.Clear(); + this._settings.ToolCallBehavior.Filters.Add( + new FakeToolFilter(onToolInvoking: (context) => + { + context.ToolCall.Arguments!["input"] = NewInput; + })); + + using var response1 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(ToolResponseWithArgs) }; + using var response2 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(OpenAITestHelper.GetTestResponse("chat_completion_test_response.json")) }; + this._messageHandlerStub.ResponsesToReturn = [response1, response2]; + + var chatHistory = new ChatHistory(); + + // Act + var result = await this._service.GetChatMessageContentsAsync(chatHistory, this._settings, kernel); + + // Assert + Assert.Equal(NewInput, chatHistory.Where(m => m.Role == AuthorRole.Tool).First().Content); + } + + [Fact] + public async Task PreInvocationToolFilterCancellationWorksCorrectlyAsync() + { + // Arrange + var functionInvocations = 0; + var preFilterInvocations = 0; + var postFilterInvocations = 0; + + var kernel = new Kernel(); + kernel.ImportPluginFromObject(new FakePlugin(() => functionInvocations++)); + + this._settings.ToolCallBehavior!.Filters.Clear(); + this._settings.ToolCallBehavior.Filters.Add( + new FakeToolFilter( + onToolInvoking: (context) => + { + preFilterInvocations++; + context.StopBehavior = ToolFilterStopBehavior.Cancel; + }, + onToolInvoked: (context) => + { + postFilterInvocations++; + })); + + using var response1 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(ToolResponseNoArgs) }; + using var response2 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(OpenAITestHelper.GetTestResponse("chat_completion_test_response.json")) }; + this._messageHandlerStub.ResponsesToReturn = [response1, response2]; + + var chatHistory = new ChatHistory(); + + // Act + var result = await this._service.GetChatMessageContentsAsync(chatHistory, this._settings, kernel); + + // Assert + Assert.Equal(1, preFilterInvocations); + Assert.Equal(0, functionInvocations); + Assert.Equal(0, postFilterInvocations); + Assert.Equal("A tool filter requested cancellation before tool invocation. Model iterations completed: 1", chatHistory.Last().Content); + + var requestContents = this._messageHandlerStub.RequestContents; + Assert.Equal(2, requestContents.Count); + requestContents.ForEach(Assert.NotNull); + var secondContent = Encoding.UTF8.GetString(requestContents[1]!); + var secondContentJson = JsonSerializer.Deserialize(secondContent); + Assert.Equal("none", secondContentJson.GetProperty("tool_choice").GetString()); + } + + [Fact] + public async Task PostInvocationToolFilterIsTriggeredAsync() + { + // Arrange + var functionInvocations = 0; + var filterInvocations = 0; + + var kernel = new Kernel(); + kernel.ImportPluginFromObject(new FakePlugin(() => functionInvocations++)); + + this._settings.ToolCallBehavior!.Filters.Clear(); + this._settings.ToolCallBehavior.Filters.Add( + new FakeToolFilter(onToolInvoked: (context) => + { + filterInvocations++; + })); + + using var response1 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(ToolResponseNoArgs) }; + using var response2 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(OpenAITestHelper.GetTestResponse("chat_completion_test_response.json")) }; + this._messageHandlerStub.ResponsesToReturn = [response1, response2]; + + // Act + var result = await this._service.GetChatMessageContentsAsync([], this._settings, kernel); + + // Assert + Assert.Equal(1, functionInvocations); + Assert.Equal(1, filterInvocations); + } + + [Fact] + public async Task PostInvocationToolFilterCancellationWorksCorrectlyAsync() + { + // Arrange + var functionInvocations = 0; + var preFilterInvocations = 0; + var postFilterInvocations = 0; + + var kernel = new Kernel(); + kernel.ImportPluginFromObject(new FakePlugin(() => functionInvocations++)); + + this._settings.ToolCallBehavior!.Filters.Clear(); + this._settings.ToolCallBehavior.Filters.Add( + new FakeToolFilter( + onToolInvoking: (context) => + { + preFilterInvocations++; + }, + onToolInvoked: (context) => + { + postFilterInvocations++; + context.StopBehavior = ToolFilterStopBehavior.Cancel; + })); + + using var response1 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(ToolResponseMultipleToolCalls) }; + using var response2 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(OpenAITestHelper.GetTestResponse("chat_completion_test_response.json")) }; + this._messageHandlerStub.ResponsesToReturn = [response1, response2]; + + var chatHistory = new ChatHistory(); + + // Act + var result = await this._service.GetChatMessageContentsAsync(chatHistory, this._settings, kernel); + + // Assert + Assert.Equal(1, preFilterInvocations); + Assert.Equal(1, functionInvocations); + Assert.Equal(1, postFilterInvocations); + + var requestContents = this._messageHandlerStub.RequestContents; + Assert.Equal(2, requestContents.Count); + requestContents.ForEach(Assert.NotNull); + var secondContent = Encoding.UTF8.GetString(requestContents[1]!); + var secondContentJson = JsonSerializer.Deserialize(secondContent); + Assert.Equal("none", secondContentJson.GetProperty("tool_choice").GetString()); + } + + [Fact] + public async Task PostInvocationToolFilterChangesChatHistoryAsync() + { + // Arrange + var kernel = new Kernel(); + kernel.ImportPluginFromObject(new FakePlugin(() => { })); + + var chatHistory = new ChatHistory(); + chatHistory.AddUserMessage("Hello, world!"); + + this._settings.ToolCallBehavior!.Filters.Clear(); + this._settings.ToolCallBehavior.Filters.Add( + new FakeToolFilter(onToolInvoked: (context) => + { + context.ChatHistory.AddAssistantMessage("Tool filter was invoked."); + context.ChatHistory.First().Content = "Hello, SK!"; + })); + + using var response1 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(ToolResponseNoArgs) }; + using var response2 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(OpenAITestHelper.GetTestResponse("chat_completion_test_response.json")) }; + this._messageHandlerStub.ResponsesToReturn = [response1, response2]; + + // Act + var result = await this._service.GetChatMessageContentsAsync(chatHistory, this._settings, kernel); + + // Assert + Assert.Equal(4, chatHistory.Count); // includes tool call and tool result messages + Assert.Equal("Hello, SK!", chatHistory.First().Content); + Assert.Equal("Tool filter was invoked.", chatHistory.Last().Content); + } + + [Fact] + public async Task ToolFilterStopAutoInvokeWorksCorrectlyAsync() + { + // Arrange + var toolInvocations = 0; + var kernel = new Kernel(); + kernel.ImportPluginFromObject(new FakePlugin(() => { toolInvocations++; })); + + this._settings.ToolCallBehavior!.Filters.Clear(); + this._settings.ToolCallBehavior.Filters.Add( + new FakeToolFilter(onToolInvoked: (context) => + { + context.StopBehavior = ToolFilterStopBehavior.StopAutoInvoke; + })); + + using var response1 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(ToolResponseNoArgs) }; + using var response2 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(ToolResponseNoArgs) }; + this._messageHandlerStub.ResponsesToReturn = [response1, response2]; + + // Act + var result = await this._service.GetChatMessageContentsAsync([], this._settings, kernel); + + // Assert + Assert.Equal(1, toolInvocations); + + var requestContents = this._messageHandlerStub.RequestContents; + Assert.Equal(2, requestContents.Count); + requestContents.ForEach(Assert.NotNull); + var secondContent = Encoding.UTF8.GetString(requestContents[1]!); + var secondContentJson = JsonSerializer.Deserialize(secondContent); + Assert.Equal("auto", secondContentJson.GetProperty("tool_choice").GetString()); + } + + [Fact] + public async Task ToolFilterStopToolsWorksCorrectlyAsync() + { + // Arrange + var kernel = new Kernel(); + kernel.ImportPluginFromObject(new FakePlugin(() => { })); + + this._settings.ToolCallBehavior!.Filters.Clear(); + this._settings.ToolCallBehavior.Filters.Add( + new FakeToolFilter(onToolInvoked: (context) => + { + context.StopBehavior = ToolFilterStopBehavior.StopTools; + })); + + using var response1 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(ToolResponseNoArgs) }; + using var response2 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(OpenAITestHelper.GetTestResponse("chat_completion_test_response.json")) }; + this._messageHandlerStub.ResponsesToReturn = [response1, response2]; + + // Act + var result = await this._service.GetChatMessageContentsAsync([], this._settings, kernel); + + // Assert + var requestContents = this._messageHandlerStub.RequestContents; + Assert.Equal(2, requestContents.Count); + requestContents.ForEach(Assert.NotNull); + var secondContent = Encoding.UTF8.GetString(requestContents[1]!); + var secondContentJson = JsonSerializer.Deserialize(secondContent); + Assert.Equal("none", secondContentJson.GetProperty("tool_choice").GetString()); + } + + [Fact] + public async Task MultipleToolFiltersCancellationWorksCorrectlyAsync() + { + // Arrange + var functionInvocations = 0; + var filterInvocations = 0; + var kernel = new Kernel(); + kernel.ImportPluginFromObject(new FakePlugin(() => functionInvocations++)); + + var toolFilter1 = new FakeToolFilter(onToolInvoking: (context) => + { + filterInvocations++; + context.StopBehavior = ToolFilterStopBehavior.Cancel; + }); + + var toolFilter2 = new FakeToolFilter(onToolInvoking: (context) => + { + Assert.Equal(ToolFilterStopBehavior.Cancel, context.StopBehavior); + + filterInvocations++; + context.StopBehavior = ToolFilterStopBehavior.None; + }); + + this._settings.ToolCallBehavior!.Filters.Clear(); + this._settings.ToolCallBehavior.Filters.Add(toolFilter1); + this._settings.ToolCallBehavior.Filters.Add(toolFilter2); + + using var response1 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(ToolResponseNoArgs) }; + using var response2 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(OpenAITestHelper.GetTestResponse("chat_completion_test_response.json")) }; + this._messageHandlerStub.ResponsesToReturn = [response1, response2]; + + // Act + var result = await this._service.GetChatMessageContentsAsync([], this._settings, kernel); + + // Assert + Assert.Equal(1, functionInvocations); + Assert.Equal(2, filterInvocations); + } + + [Fact] + public async Task ToolFiltersAreExecutedInCorrectOrderAsync() + { + // Arrange + var executionOrder = new List(); + + var toolFilter1 = new FakeToolFilter( + onToolInvoking: (context) => executionOrder.Add("ToolFilter1-Invoking"), + onToolInvoked: (context) => executionOrder.Add("ToolFilter1-Invoked")); + + var toolFilter2 = new FakeToolFilter( + onToolInvoking: (context) => executionOrder.Add("ToolFilter2-Invoking"), + onToolInvoked: (context) => executionOrder.Add("ToolFilter2-Invoked")); + + var toolFilter3 = new FakeToolFilter( + onToolInvoking: (context) => executionOrder.Add("ToolFilter3-Invoking"), + onToolInvoked: (context) => executionOrder.Add("ToolFilter3-Invoked")); + + var kernel = new Kernel(); + kernel.ImportPluginFromObject(new FakePlugin(() => { })); + + this._settings.ToolCallBehavior!.Filters.Clear(); + this._settings.ToolCallBehavior.Filters.Add(toolFilter1); + this._settings.ToolCallBehavior.Filters.Add(toolFilter2); + this._settings.ToolCallBehavior.Filters.Insert(1, toolFilter3); + + using var response1 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(ToolResponseNoArgs) }; + using var response2 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(OpenAITestHelper.GetTestResponse("chat_completion_test_response.json")) }; + this._messageHandlerStub.ResponsesToReturn = [response1, response2]; + + // Act + var result = await this._service.GetChatMessageContentAsync(new ChatHistory(), this._settings, kernel); + + // Assert + Assert.Equal("ToolFilter1-Invoking", executionOrder[0]); + Assert.Equal("ToolFilter3-Invoking", executionOrder[1]); + Assert.Equal("ToolFilter2-Invoking", executionOrder[2]); + Assert.Equal("ToolFilter1-Invoked", executionOrder[3]); + Assert.Equal("ToolFilter3-Invoked", executionOrder[4]); + Assert.Equal("ToolFilter2-Invoked", executionOrder[5]); + } + + [Fact] + public async Task ToolFiltersAreTriggeredOnStreamingAsync() + { + // Arrange + int functionCallCount = 0; + int preFilterInvocations = 0; + int postFilterInvocations = 0; + + var kernel = Kernel.CreateBuilder().Build(); + var function1 = KernelFunctionFactory.CreateFromMethod((string location) => + { + functionCallCount++; + return "Some weather"; + }, "GetCurrentWeather"); + + kernel.Plugins.Add(KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1])); + + var toolFilter = new FakeToolFilter( + onToolInvoking: (context) => preFilterInvocations++, + onToolInvoked: (context) => postFilterInvocations++); + + this._settings.ToolCallBehavior!.Filters.Clear(); + this._settings.ToolCallBehavior.Filters.Add(toolFilter); + + using var response1 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(OpenAITestHelper.GetTestResponse("chat_completion_streaming_single_function_call_test_response.txt")) }; + using var response2 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(OpenAITestHelper.GetTestResponse("chat_completion_streaming_test_response.txt")) }; + + this._messageHandlerStub.ResponsesToReturn = [response1, response2]; + + // Act & Assert + await foreach (var chunk in this._service.GetStreamingChatMessageContentsAsync([], this._settings, kernel)) + { + Assert.Equal("Test chat streaming response", chunk.Content); + } + + Assert.Equal(1, functionCallCount); + Assert.Equal(1, preFilterInvocations); + Assert.Equal(1, postFilterInvocations); + } + + [Fact] + public async Task PreInvocationToolFilterCancellationWorksOnStreamingAsync() + { + // Arrange + int functionCallCount = 0; + int preFilterInvocations = 0; + int postFilterInvocations = 0; + + var kernel = Kernel.CreateBuilder().Build(); + var function1 = KernelFunctionFactory.CreateFromMethod((string location) => + { + functionCallCount++; + return "Some weather"; + }, "GetCurrentWeather"); + + kernel.Plugins.Add(KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1])); + + var toolFilter = new FakeToolFilter( + onToolInvoking: (context) => + { + context.StopBehavior = ToolFilterStopBehavior.Cancel; + preFilterInvocations++; + }, + onToolInvoked: (context) => postFilterInvocations++); + + this._settings.ToolCallBehavior!.Filters.Clear(); + this._settings.ToolCallBehavior.Filters.Add(toolFilter); + + using var response1 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(OpenAITestHelper.GetTestResponse("chat_completion_streaming_single_function_call_test_response.txt")) }; + using var response2 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(OpenAITestHelper.GetTestResponse("chat_completion_streaming_test_response.txt")) }; + + this._messageHandlerStub.ResponsesToReturn = [response1, response2]; + + // Act & Assert + await foreach (var chunk in this._service.GetStreamingChatMessageContentsAsync([], this._settings, kernel)) + { + Assert.Equal("Test chat streaming response", chunk.Content); + } + + Assert.Equal(1, preFilterInvocations); + Assert.Equal(0, functionCallCount); + Assert.Equal(0, postFilterInvocations); + + var requestContents = this._messageHandlerStub.RequestContents; + Assert.Equal(2, requestContents.Count); + requestContents.ForEach(Assert.NotNull); + var secondContent = Encoding.UTF8.GetString(requestContents[1]!); + var secondContentJson = JsonSerializer.Deserialize(secondContent); + Assert.Equal("none", secondContentJson.GetProperty("tool_choice").GetString()); + } + + private sealed class FakeToolFilter( + Action? onToolInvoking = null, + Action? onToolInvoked = null) : IToolFilter + { + private readonly Action? _onToolInvoking = onToolInvoking; + private readonly Action? _onToolInvoked = onToolInvoked; + + public void OnToolInvoked(ToolInvokedContext context) => + this._onToolInvoked?.Invoke(context); + + public void OnToolInvoking(ToolInvokingContext context) => + this._onToolInvoking?.Invoke(context); + } + + private sealed class FakePlugin(Action action) + { + [KernelFunction] + public void Foo() + { + action(); + } + } + + private sealed class FakePluginWithArg(Func action) + { + [KernelFunction] + public string Bar(string input) + { + return action(input); + } + } + + public void Dispose() + { + this._httpClient.Dispose(); + this._messageHandlerStub.Dispose(); + } + + private const string ToolResponseNoArgs = @"{ + ""id"": ""response-id"", + ""object"": ""chat.completion"", + ""created"": 1699896916, + ""model"": ""gpt-3.5-turbo-0613"", + ""choices"": [ + { + ""index"": 0, + ""message"": { + ""role"": ""assistant"", + ""content"": null, + ""tool_calls"": [ + { + ""id"": ""1"", + ""type"": ""function"", + ""function"": { + ""name"": ""FakePlugin-Foo"", + ""arguments"": ""{}"" + } + } + ] + }, + ""logprobs"": null, + ""finish_reason"": ""tool_calls"" + } + ], + ""usage"": { + ""prompt_tokens"": 82, + ""completion_tokens"": 17, + ""total_tokens"": 99 + } +} +"; + + private const string ToolResponseWithArgs = @"{ + ""id"": ""response-id"", + ""object"": ""chat.completion"", + ""created"": 1699896916, + ""model"": ""gpt-3.5-turbo-0613"", + ""choices"": [ + { + ""index"": 0, + ""message"": { + ""role"": ""assistant"", + ""content"": null, + ""tool_calls"": [ + { + ""id"": ""1"", + ""type"": ""function"", + ""function"": { + ""name"": ""FakePluginWithArg-Bar"", + ""arguments"": ""{\n\""input\"": \""oldValue\""\n}"" + } + } + ] + }, + ""logprobs"": null, + ""finish_reason"": ""tool_calls"" + } + ], + ""usage"": { + ""prompt_tokens"": 82, + ""completion_tokens"": 17, + ""total_tokens"": 99 + } +}"; + + private const string ToolResponseMultipleToolCalls = @"{ + ""id"": ""response-id"", + ""object"": ""chat.completion"", + ""created"": 1699896916, + ""model"": ""gpt-3.5-turbo-0613"", + ""choices"": [ + { + ""index"": 0, + ""message"": { + ""role"": ""assistant"", + ""content"": null, + ""tool_calls"": [ + { + ""id"": ""1"", + ""type"": ""function"", + ""function"": { + ""name"": ""FakePlugin-Foo"", + ""arguments"": ""{}"" + } + }, + { + ""id"": ""2"", + ""type"": ""function"", + ""function"": { + ""name"": ""FakePlugin-Foo"", + ""arguments"": ""{}"" + } + } + ] + }, + ""logprobs"": null, + ""finish_reason"": ""tool_calls"" + } + ], + ""usage"": { + ""prompt_tokens"": 82, + ""completion_tokens"": 17, + ""total_tokens"": 99 + } +} +"; +} diff --git a/dotnet/src/Planners/Planners.OpenAI/Stepwise/FunctionCallingStepwisePlanner.cs b/dotnet/src/Planners/Planners.OpenAI/Stepwise/FunctionCallingStepwisePlanner.cs index 9c1bdd484547..174e06b60373 100644 --- a/dotnet/src/Planners/Planners.OpenAI/Stepwise/FunctionCallingStepwisePlanner.cs +++ b/dotnet/src/Planners/Planners.OpenAI/Stepwise/FunctionCallingStepwisePlanner.cs @@ -1,13 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Collections.Generic; using System.ComponentModel; -using System.Diagnostics.CodeAnalysis; -using System.Text.Json; using System.Threading; using System.Threading.Tasks; -using Json.More; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using Microsoft.SemanticKernel.ChatCompletion; @@ -70,82 +66,56 @@ private async Task ExecuteCoreAsync( // Clone the kernel so that we can add planner-specific plugins without affecting the original kernel instance var clonedKernel = kernel.Clone(); - clonedKernel.ImportPluginFromType(); + + // The final answer flag is set when the UserInteraction plugin is invoked + bool finalAnswerFound = false; + string finalAnswer = string.Empty; + + // Import the UserInteraction plugin to the kernel + var userInteraction = new UserInteraction((answer) => { finalAnswerFound = true; finalAnswer = answer; }); + clonedKernel.ImportPluginFromObject(userInteraction); // Create and invoke a kernel function to generate the initial plan var initialPlan = await this.GeneratePlanAsync(question, clonedKernel, logger, cancellationToken).ConfigureAwait(false); var chatHistoryForSteps = await this.BuildChatHistoryForStepAsync(question, initialPlan, clonedKernel, promptTemplateFactory, cancellationToken).ConfigureAwait(false); - for (int i = 0; i < this._options.MaxIterations; i++) + for (int iteration = 0; iteration < this._options.MaxIterations; /* iteration is incremented within the loop */) { // sleep for a bit to avoid rate limiting - if (i > 0) + if (iteration > 0) { await Task.Delay(this._options.MinIterationTimeMs, cancellationToken).ConfigureAwait(false); } // For each step, request another completion to select a function for that step chatHistoryForSteps.AddUserMessage(StepwiseUserMessage); - var chatResult = await this.GetCompletionWithFunctionsAsync(chatHistoryForSteps, clonedKernel, chatCompletion, stepExecutionSettings, logger, cancellationToken).ConfigureAwait(false); + var chatResult = await this.GetCompletionWithFunctionsAsync(iteration, chatHistoryForSteps, clonedKernel, chatCompletion, stepExecutionSettings, logger, cancellationToken).ConfigureAwait(false); chatHistoryForSteps.Add(chatResult); - // Check for function response - if (!this.TryGetFunctionResponse(chatResult, out IReadOnlyList? functionResponses, out string? functionResponseError)) + // Increment iteration based on the number of model round trips that occurred as a result of the request + object? value = null; + chatResult.Metadata?.TryGetValue("ModelIterationsCompleted", out value); + if (value is not null and int) { - // No function response found. Either AI returned a chat message, or something went wrong when parsing the function. - // Log the error (if applicable), then let the planner continue. - if (functionResponseError is not null) - { - chatHistoryForSteps.AddUserMessage(functionResponseError); - } - continue; + iteration += (int)value; } - - // Check for final answer in the function response - foreach (OpenAIFunctionToolCall functionResponse in functionResponses) + else { - if (this.TryFindFinalAnswer(functionResponse, out string finalAnswer, out string? finalAnswerError)) - { - if (finalAnswerError is not null) - { - // We found a final answer, but failed to parse it properly. - // Log the error message in chat history and let the planner try again. - chatHistoryForSteps.AddUserMessage(finalAnswerError); - continue; - } - - // Success! We found a final answer, so return the planner result. - return new FunctionCallingStepwisePlannerResult - { - FinalAnswer = finalAnswer, - ChatHistory = chatHistoryForSteps, - Iterations = i + 1, - }; - } + // Could not find iterations in metadata, so assume just one + iteration++; } - // Look up function in kernel - foreach (OpenAIFunctionToolCall functionResponse in functionResponses) + // Check for final answer + if (finalAnswerFound) { - if (clonedKernel.Plugins.TryGetFunctionAndArguments(functionResponse, out KernelFunction? pluginFunction, out KernelArguments? arguments)) - { - try - { - // Execute function and add to result to chat history - var result = (await clonedKernel.InvokeAsync(pluginFunction, arguments, cancellationToken).ConfigureAwait(false)).GetValue(); - chatHistoryForSteps.AddMessage(AuthorRole.Tool, ParseObjectAsString(result), metadata: new Dictionary(1) { { OpenAIChatMessageContent.ToolIdProperty, functionResponse.Id } }); - } - catch (Exception ex) when (!ex.IsCriticalException()) - { - chatHistoryForSteps.AddMessage(AuthorRole.Tool, ex.Message, metadata: new Dictionary(1) { { OpenAIChatMessageContent.ToolIdProperty, functionResponse.Id } }); - chatHistoryForSteps.AddUserMessage($"Failed to execute function {functionResponse.FullyQualifiedName}. Try something else!"); - } - } - else + // Success! we found a final answer, so return the planner result + return new FunctionCallingStepwisePlannerResult { - chatHistoryForSteps.AddUserMessage($"Function {functionResponse.FullyQualifiedName} does not exist in the kernel. Try something else!"); - } + FinalAnswer = chatResult.Content ?? finalAnswer, + ChatHistory = chatHistoryForSteps, + Iterations = iteration, + }; } } @@ -159,6 +129,7 @@ private async Task ExecuteCoreAsync( } private async Task GetCompletionWithFunctionsAsync( + int iterationsCompleted, ChatHistory chatHistory, Kernel kernel, IChatCompletionService chatCompletion, @@ -166,7 +137,12 @@ private async Task GetCompletionWithFunctionsAsync( ILogger logger, CancellationToken cancellationToken) { - openAIExecutionSettings.ToolCallBehavior = ToolCallBehavior.EnableKernelFunctions; + openAIExecutionSettings.ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions; + + // Set filters to stop automatic tool invocation when a final answer is found or max iterations limit is reached + int iterationsRemaining = this._options.MaxIterations - iterationsCompleted; + openAIExecutionSettings.ToolCallBehavior.Filters.Add(new FinalAnswerFilter()); + openAIExecutionSettings.ToolCallBehavior.Filters.Add(new MaxIterationsFilter((iteration) => { return iteration < iterationsRemaining; })); await this.ValidateTokenCountAsync(chatHistory, kernel, logger, openAIExecutionSettings, cancellationToken).ConfigureAwait(false); return await chatCompletion.GetChatMessageContentAsync(chatHistory, openAIExecutionSettings, kernel, cancellationToken).ConfigureAwait(false); @@ -214,76 +190,6 @@ private async Task BuildChatHistoryForStepAsync( return chatHistory; } - private bool TryGetFunctionResponse(ChatMessageContent chatMessage, [NotNullWhen(true)] out IReadOnlyList? functionResponses, out string? errorMessage) - { - OpenAIChatMessageContent? openAiChatMessage = chatMessage as OpenAIChatMessageContent; - Verify.NotNull(openAiChatMessage, nameof(openAiChatMessage)); - - functionResponses = null; - errorMessage = null; - try - { - functionResponses = openAiChatMessage.GetOpenAIFunctionToolCalls(); - } - catch (JsonException) - { - errorMessage = "That function call is invalid. Try something else!"; - } - - return functionResponses is { Count: > 0 }; - } - - private bool TryFindFinalAnswer(OpenAIFunctionToolCall functionResponse, out string finalAnswer, out string? errorMessage) - { - finalAnswer = string.Empty; - errorMessage = null; - - if (functionResponse.PluginName == "UserInteraction" && functionResponse.FunctionName == "SendFinalAnswer") - { - if (functionResponse.Arguments is { Count: > 0 } arguments && arguments.TryGetValue("answer", out object? valueObj)) - { - finalAnswer = ParseObjectAsString(valueObj); - } - else - { - errorMessage = "Returned answer in incorrect format. Try again!"; - } - return true; - } - return false; - } - - private static string ParseObjectAsString(object? valueObj) - { - string resultStr = string.Empty; - - if (valueObj is RestApiOperationResponse apiResponse) - { - resultStr = apiResponse.Content as string ?? string.Empty; - } - else if (valueObj is string valueStr) - { - resultStr = valueStr; - } - else if (valueObj is JsonElement valueElement) - { - if (valueElement.ValueKind == JsonValueKind.String) - { - resultStr = valueElement.GetString() ?? ""; - } - else - { - resultStr = valueElement.ToJsonString(); - } - } - else - { - resultStr = JsonSerializer.Serialize(valueObj); - } - - return resultStr; - } - private async Task ValidateTokenCountAsync( ChatHistory chatHistory, Kernel kernel, @@ -347,17 +253,70 @@ private async Task ValidateTokenCountAsync( /// public sealed class UserInteraction { + private readonly Action _setCompleted; + + /// + /// Constructs the object with a specified callback to indicate the plan completion. + /// + /// Delegate to tell the planner that the plan is completed and a final answer has been found. + public UserInteraction(Action setCompleted) + { + this._setCompleted = setCompleted; + } + /// /// This function is used by the to indicate when the final answer has been found. /// /// The final answer for the plan. [KernelFunction] [Description("This function is used to send the final answer of a plan to the user.")] -#pragma warning disable IDE0060 // Remove unused parameter. The parameter is purely an indication to the LLM and is not intended to be used. public string SendFinalAnswer([Description("The final answer")] string answer) -#pragma warning restore IDE0060 { - return "Thanks"; + this._setCompleted(answer); + return answer; + } + } + + #region Filters + + // A tool filter that stops tool calling once the final answer has been found + private sealed class FinalAnswerFilter : IToolFilter + { + public void OnToolInvoking(ToolInvokingContext context) { } + + public void OnToolInvoked(ToolInvokedContext context) + { + if (context.ToolCall.FullyQualifiedName.Equals($"UserInteraction{OpenAIFunction.NameSeparator}SendFinalAnswer", StringComparison.Ordinal)) + { + // We've found the final answer, so cancel any remaining tool calls. + context.StopBehavior = ToolFilterStopBehavior.Cancel; + } + } + } + + /// + /// A tool filter that stops tool calling once the maximum model iterations have been reached. + /// + private sealed class MaxIterationsFilter : IToolFilter + { + private readonly Func _shouldContinue; + + public MaxIterationsFilter(Func shouldContinue) + { + this._shouldContinue = shouldContinue; + } + + public void OnToolInvoking(ToolInvokingContext context) { } + + public void OnToolInvoked(ToolInvokedContext context) + { + if (!this._shouldContinue(context.ModelIterations)) + { + // We've reached the maximum iterations for the planner. + // Invoke any tool calls already specified, but stop requesting more tools. + context.StopBehavior = ToolFilterStopBehavior.StopTools; + } } } + #endregion }