Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.Net: Tool filters #4922

Closed
wants to merge 26 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
610e7e7
Tool filter and context classes
gitri-ms Feb 5, 2024
39b125a
save progress
gitri-ms Feb 5, 2024
95156bc
Use enum to control tool stop behavior
gitri-ms Feb 7, 2024
e9b61fd
Helper to update chat history in chat options
gitri-ms Feb 7, 2024
4f9a1cd
Move where filters are applied, update args
gitri-ms Feb 7, 2024
f263823
fix chat history issue
gitri-ms Feb 8, 2024
699b11d
Clean up warnings, add iterations to chatResult metadata
gitri-ms Feb 8, 2024
efbfe78
Merge branch 'main' into tool-filters
gitri-ms Feb 8, 2024
29fdedc
Merge branch 'main' into tool-filters
markwallace-microsoft Feb 8, 2024
865f3e5
bug fixes
gitri-ms Feb 12, 2024
a7e6e45
Add unit tests
gitri-ms Feb 13, 2024
6172aa5
Merge branch 'tool-filters' of https://github.com/gitri-ms/semantic-k…
gitri-ms Feb 13, 2024
1f5752f
Fix bug in test, add comments
gitri-ms Feb 13, 2024
3dee7d6
Merge branch 'main' into tool-filters
gitri-ms Feb 13, 2024
f7cd9f7
Streaming impl, add experimental attribute
gitri-ms Feb 13, 2024
037a996
Merge branch 'main' into tool-filters
gitri-ms Feb 13, 2024
a709123
Add tests for StopTools, StopAutoInvoke
gitri-ms Feb 14, 2024
363a9aa
Additional test
gitri-ms Feb 14, 2024
bddcf05
Revert change to example
gitri-ms Feb 14, 2024
cbdf008
Remove blank line
gitri-ms Feb 14, 2024
09d7c86
Test cases for chat streaming
gitri-ms Feb 15, 2024
46e8af1
Merge branch 'main' into tool-filters
gitri-ms Feb 15, 2024
cb3467e
Address pr comments
gitri-ms Feb 15, 2024
85aeeff
Merge branch 'main' into tool-filters
gitri-ms Feb 15, 2024
087b307
reduce code duplication
gitri-ms Feb 15, 2024
a1012bc
dotnet format
gitri-ms Feb 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dotnet/docs/EXPERIMENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ You can use the following diagnostic IDs to ignore warnings or errors for a part
- SKEXP0013: OpenAI parameters
- SKEXP0014: OpenAI chat history extension
- SKEXP0015: OpenAI file service
- SKEXP0016: OpenAI tool call filters

## Memory connectors

Expand Down
122 changes: 114 additions & 8 deletions dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ namespace Microsoft.SemanticKernel.Connectors.OpenAI;
/// </summary>
internal abstract class ClientCore
{
private const string ModelIterationsCompletedKey = "ModelIterationsCompleted";

private const int MaxResultsPerPrompt = 128;

/// <summary>
Expand Down Expand Up @@ -176,25 +178,27 @@ internal async IAsyncEnumerable<StreamingTextContent> GetStreamingTextContentsAs
};
}

private static Dictionary<string, object?> GetResponseMetadata(ChatCompletions completions)
private static Dictionary<string, object?> GetResponseMetadata(ChatCompletions completions, int modelIterations)
{
return new Dictionary<string, object?>(5)
return new Dictionary<string, object?>(6)
{
{ nameof(completions.Id), completions.Id },
{ nameof(completions.Created), completions.Created },
{ nameof(completions.PromptFilterResults), completions.PromptFilterResults },
{ nameof(completions.SystemFingerprint), completions.SystemFingerprint },
{ nameof(completions.Usage), completions.Usage },
{ ModelIterationsCompletedKey, modelIterations },
};
}

private static Dictionary<string, object?> GetResponseMetadata(StreamingChatCompletionsUpdate completions)
private static Dictionary<string, object?> GetResponseMetadata(StreamingChatCompletionsUpdate completions, int modelIterations)
{
return new Dictionary<string, object?>(3)
return new Dictionary<string, object?>(4)
{
{ nameof(completions.Id), completions.Id },
{ nameof(completions.Created), completions.Created },
{ nameof(completions.SystemFingerprint), completions.SystemFingerprint },
{ ModelIterationsCompletedKey, modelIterations },
};
}

Expand Down Expand Up @@ -303,7 +307,7 @@ internal async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsy
throw new KernelException("Chat completions not found");
}

IReadOnlyDictionary<string, object?> metadata = GetResponseMetadata(responseData);
IReadOnlyDictionary<string, object?> metadata = GetResponseMetadata(responseData, iteration);

// If we don't want to attempt to invoke any functions, just return the result.
// Or if we are auto-invoking but we somehow end up with other than 1 choice even though only 1 was requested, similarly bail.
Expand Down Expand Up @@ -367,6 +371,19 @@ internal async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsy
continue;
}

try
gitri-ms marked this conversation as resolved.
Show resolved Hide resolved
{
// Invoke the pre-invocation filter.
var invokingContext = chatExecutionSettings.ToolCallBehavior?.OnToolInvokingFilter(openAIFunctionToolCall, chat, iteration);
this.ApplyToolFilterContextChanges(invokingContext, chatOptions, chat, chatExecutionSettings, ref autoInvoke);
}
catch (OperationCanceledException)
{
// Add cancellation message to chat history and bail out of any remaining tool calls
AddResponseMessage(chatOptions, chat, null, $"A tool filter requested cancellation before tool invocation. Model iterations completed: {iteration}", toolCall.Id, this.Logger);
break;
}

// Make sure the requested function is one we requested. If we're permitting any kernel function to be invoked,
// then we don't need to check this, as it'll be handled when we look up the function in the kernel to be able
// to invoke it. If we're permitting only a specific list of functions, though, then we need to explicitly check.
Expand Down Expand Up @@ -395,7 +412,7 @@ internal async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsy
functionResult = (await function.InvokeAsync(kernel, functionArgs, cancellationToken: cancellationToken).ConfigureAwait(false)).GetValue<object>() ?? string.Empty;
}
#pragma warning disable CA1031 // Do not catch general exception types
catch (Exception e)
catch (Exception e) when (!e.IsCriticalException())
#pragma warning restore CA1031
{
AddResponseMessage(chatOptions, chat, null, $"Error: Exception while invoking function. {e.Message}", toolCall.Id, this.Logger);
Expand All @@ -407,6 +424,18 @@ internal async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsy
}
AddResponseMessage(chatOptions, chat, functionResult as string ?? JsonSerializer.Serialize(functionResult), errorMessage: null, toolCall.Id, this.Logger);

try
{
// Invoke the post-invocation filter.
var invokedContext = chatExecutionSettings.ToolCallBehavior?.OnToolInvokedFilter(openAIFunctionToolCall, functionResult, chat, iteration);
this.ApplyToolFilterContextChanges(invokedContext, chatOptions, chat, chatExecutionSettings, ref autoInvoke);
}
catch (OperationCanceledException)
{
// The tool call already happened so we can't cancel it, but bail out of any remaining tool calls
break;
}

static void AddResponseMessage(ChatCompletionsOptions chatOptions, ChatHistory chat, string? result, string? errorMessage, string toolId, ILogger logger)
{
// Log any error
Expand Down Expand Up @@ -447,6 +476,58 @@ static void AddResponseMessage(ChatCompletionsOptions chatOptions, ChatHistory c
}
}

private void ApplyToolFilterContextChanges(
ToolFilterContext? context,
ChatCompletionsOptions chatOptions,
ChatHistory chatHistory,
OpenAIPromptExecutionSettings executionSettings,
ref bool autoInvoke)
{
if (context is not null)
{
// Since the tool filter has access to the chat history, the chat history may have been modified.
// We want to make sure any subsequent requests to the model reflect these changes. The chatOptions object
// contains all the configuration information for a chat request, including a copy of the chat history.
// So we need to update the chat history stored in the chatOptions object to match what is in the chatHistory object.
this.UpdateChatOptions(chatOptions, chatHistory, executionSettings);

// Check if filter has requested a stop
this.HandleStopBehavior(context, chatOptions, ref autoInvoke);
}
}

private void HandleStopBehavior(ToolFilterContext context, ChatCompletionsOptions chatOptions, ref bool autoInvoke)
{
switch (context.StopBehavior)
{
case ToolFilterStopBehavior.StopAutoInvoke:
autoInvoke = false;
break;
case ToolFilterStopBehavior.StopTools:
chatOptions.ToolChoice = ChatCompletionsToolChoice.None;
break;
case ToolFilterStopBehavior.Cancel:
chatOptions.ToolChoice = ChatCompletionsToolChoice.None;
throw new OperationCanceledException();
}
}

private void UpdateChatOptions(ChatCompletionsOptions options, ChatHistory chatHistory, OpenAIPromptExecutionSettings executionSettings)
{
// Clear out messages, then copy over from chat history
options.Messages.Clear();

if (!string.IsNullOrWhiteSpace(executionSettings?.ChatSystemPrompt) && !chatHistory.Any(m => m.Role == AuthorRole.System))
{
options.Messages.Add(GetRequestMessage(new ChatMessageContent(AuthorRole.System, executionSettings!.ChatSystemPrompt)));
}

foreach (var message in chatHistory)
{
options.Messages.Add(GetRequestMessage(message));
}
}

internal async IAsyncEnumerable<OpenAIStreamingChatMessageContent> GetStreamingChatMessageContentsAsync(
ChatHistory chat,
PromptExecutionSettings? executionSettings,
Expand Down Expand Up @@ -485,7 +566,7 @@ internal async IAsyncEnumerable<OpenAIStreamingChatMessageContent> GetStreamingC
CompletionsFinishReason finishReason = default;
await foreach (StreamingChatCompletionsUpdate update in response.ConfigureAwait(false))
{
metadata ??= GetResponseMetadata(update);
metadata ??= GetResponseMetadata(update, iteration);
streamedRole ??= update.Role;
finishReason = update.FinishReason ?? default;

Expand Down Expand Up @@ -557,6 +638,19 @@ internal async IAsyncEnumerable<OpenAIStreamingChatMessageContent> GetStreamingC
continue;
}

try
gitri-ms marked this conversation as resolved.
Show resolved Hide resolved
{
// Invoke the pre-invocation filter.
var invokingContext = chatExecutionSettings.ToolCallBehavior?.OnToolInvokingFilter(openAIFunctionToolCall, chat, iteration);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a lot that happens between here and the actual function invocation, including actually getting the corresponding KernelFunction object and creating the KernelArguments to pass to it. I see the value of a callback that's really early in the process, so that a callback can see the raw request from the model, but in that case, should it actually be even earlier and be passed just the raw string name and string arguments? And then have a separate callback that's invoked with the Kernel, KernelFunction, KernelArguments, etc. just before function.InvokeAsync is actually called?

I think it'd be helpful to enumerate all the extensibility scenarios we're trying to enable here, i.e. the different things we expect folks will want to do with this, and then write out the example code for each, showing both that it's possible and what the code would look like. Those can all become samples, too.

For example:

  • Want to limit the number of back and forths with the model, to avoid runaway costs or infinite loops, disabling additional function calling after some number of iterations
  • Want to update what functions are available based on the interactions with the model
  • Want to limit the number of recursive function invocations that can be made (e.g. agents talking back and forth to each other via function calling)
  • Want to screen the arguments being passed to a function and replace the argument with something else
  • Want to screen the results of a function and replace it with something else (it's possible this and the above would already be handled by normal function filters)
  • Want to stop iterating if a particular function is requested, returning that function's result as the result of the operation (basically the eventual invocation of that function was the ultimate goal)
  • ...

this.ApplyToolFilterContextChanges(invokingContext, chatOptions, chat, chatExecutionSettings, ref autoInvoke);
}
catch (OperationCanceledException)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels a bit icky to me. Does this mean that we're saying the way you early-exit non-exceptionally is to throw an OperationCanceledException?

{
// Add cancellation message to chat history and bail out of any remaining tool calls
AddResponseMessage(chatOptions, chat, streamedRole, toolCall, metadata, null, $"A tool filter requested cancellation before tool invocation. Model iterations completed: {iteration}", this.Logger);
break;
}
gitri-ms marked this conversation as resolved.
Show resolved Hide resolved

// Make sure the requested function is one we requested. If we're permitting any kernel function to be invoked,
// then we don't need to check this, as it'll be handled when we look up the function in the kernel to be able
// to invoke it. If we're permitting only a specific list of functions, though, then we need to explicitly check.
Expand Down Expand Up @@ -585,7 +679,7 @@ internal async IAsyncEnumerable<OpenAIStreamingChatMessageContent> GetStreamingC
functionResult = (await function.InvokeAsync(kernel, functionArgs, cancellationToken: cancellationToken).ConfigureAwait(false)).GetValue<object>() ?? string.Empty;
}
#pragma warning disable CA1031 // Do not catch general exception types
catch (Exception e)
catch (Exception e) when (!e.IsCriticalException())
#pragma warning restore CA1031
{
AddResponseMessage(chatOptions, chat, streamedRole, toolCall, metadata, result: null, $"Error: Exception while invoking function. {e.Message}", this.Logger);
Expand All @@ -597,6 +691,18 @@ internal async IAsyncEnumerable<OpenAIStreamingChatMessageContent> GetStreamingC
}
AddResponseMessage(chatOptions, chat, streamedRole, toolCall, metadata, functionResult as string ?? JsonSerializer.Serialize(functionResult), errorMessage: null, this.Logger);

try
{
// Invoke the post-invocation filter.
var invokedContext = chatExecutionSettings.ToolCallBehavior?.OnToolInvokedFilter(openAIFunctionToolCall, functionResult, chat, iteration);
this.ApplyToolFilterContextChanges(invokedContext, chatOptions, chat, chatExecutionSettings, ref autoInvoke);
}
catch (OperationCanceledException)
{
// This tool call already happened so we can't cancel it, but bail out of any remaining tool calls
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're adding a message in the case of early-exit before tool invocation but not after tool invocation?

break;
}

static void AddResponseMessage(
ChatCompletionsOptions chatOptions, ChatHistory chat, ChatRole? streamedRole, ChatCompletionsToolCall tool, IReadOnlyDictionary<string, object?>? metadata,
string? result, string? errorMessage, ILogger logger)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<RootNamespace>$(AssemblyName)</RootNamespace>
<TargetFramework>netstandard2.0</TargetFramework>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<NoWarn>$(NoWarn);NU5104;SKEXP0005,SKEXP0013,SKEXP0014</NoWarn>
<NoWarn>$(NoWarn);NU5104;SKEXP0005,SKEXP0013,SKEXP0014,SKEXP0016</NoWarn>
<EnablePackageValidation>true</EnablePackageValidation>
</PropertyGroup>

Expand Down
24 changes: 24 additions & 0 deletions dotnet/src/Connectors/Connectors.OpenAI/IToolFilter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Diagnostics.CodeAnalysis;

namespace Microsoft.SemanticKernel.Connectors.OpenAI;

/// <summary>
/// Interface for tool filters.
/// </summary>
[Experimental("SKEXP0016")]
public interface IToolFilter
{
/// <summary>
/// Method which is executed before tool invocation.
/// </summary>
/// <param name="context">Data related to tool before invocation.</param>
void OnToolInvoking(ToolInvokingContext context);

/// <summary>
/// Method which is executed after tool invocation.
/// </summary>
/// <param name="context">Data related to tool after invocation.</param>
void OnToolInvoked(ToolInvokedContext context);
}
44 changes: 44 additions & 0 deletions dotnet/src/Connectors/Connectors.OpenAI/ToolCallBehavior.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using Azure.AI.OpenAI;
using Microsoft.SemanticKernel.ChatCompletion;

namespace Microsoft.SemanticKernel.Connectors.OpenAI;

Expand Down Expand Up @@ -35,6 +37,12 @@ public abstract class ToolCallBehavior
/// </remarks>
private const int DefaultMaximumAutoInvokeAttempts = 5;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we going to expose this configuration?

Copy link
Contributor Author

@gitri-ms gitri-ms Feb 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would argue that should be in a separate PR, since it's unrelated to the tool filters or the planner updates that this PR covers. (I don't mind creating that PR though, should be fairly quick.) Also, if we expose this field, do we want to expose ToolCallBehavior.MaximumUseAttempts as well?


/// <summary>
/// Gets the collection of filters that will be applied to tool calls.
/// </summary>
[Experimental("SKEXP0016")]
public IList<IToolFilter> Filters { get; } = new List<IToolFilter>();
Copy link
Member

@stephentoub stephentoub Feb 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want this here, at least not as a public mutable list. It means any code can just do ToolCallBehavior.AutoInvokeKernelFunctions.Filters.Add(myFilter), and it'll be added to the singleton that will apply to everyone, which also means you need to remember to remove filters after you're done with them, even in the case of exception.

I think instead we should add overloads to the existing factories below, e.g.

public static ToolCallBehavior EnableFunctions(
    IEnumerable<OpenAIFunction>? functions, // if null, functions are sourced from the Kernel ala AutoInvokeKernelFunctions,
    EnableFunctionsOptions options);
...
public sealed class EnableFunctionsOptions
{
    public bool AutoInvoke { get; set; }
    public IList<IToolFilter> Filters { get; }
    ... // any other customization desired
}

or something along those lines. That's just a sketch; names and overall shape are debatable.


/// <summary>
/// Gets an instance that will provide all of the <see cref="Kernel"/>'s plugins' function information.
/// Function call requests from the model will be propagated back to the caller.
Expand Down Expand Up @@ -236,4 +244,40 @@ internal override void ConfigureOptions(Kernel? kernel, ChatCompletionsOptions o
/// </remarks>
internal override int MaximumUseAttempts => 1;
}

#region Filters
internal ToolInvokingContext? OnToolInvokingFilter(OpenAIFunctionToolCall toolCall, ChatHistory chatHistory, int iteration)
{
ToolInvokingContext? context = null;

if (this.Filters is { Count: > 0 })
{
context = new(toolCall, chatHistory, iteration);

for (int i = 0; i < this.Filters.Count; i++)
{
this.Filters[i].OnToolInvoking(context);
}
}

return context;
}

internal ToolInvokedContext? OnToolInvokedFilter(OpenAIFunctionToolCall toolCall, object? result, ChatHistory chatHistory, int iteration)
{
ToolInvokedContext? context = null;

if (this.Filters is { Count: > 0 })
{
context = new(toolCall, result, chatHistory, iteration);

for (int i = 0; i < this.Filters.Count; i++)
{
this.Filters[i].OnToolInvoked(context);
}
}

return context;
}
#endregion
}
Loading
Loading