From 4f938c4ba0da48163ceafad6f168fcf5e19ccaa9 Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Tue, 9 May 2023 14:50:52 -0700 Subject: [PATCH] Add support for custom delegates as SKFunctions --- .../Diagnostics/CompilerServicesAttributes.cs | 2 +- .../SkillDefinition/SKFunctionTests2.cs | 30 +++++++ .../SkillDefinition/SKFunctionTests3.cs | 41 +++++++++ .../SkillDefinition/SKFunction.cs | 90 ++++++++++--------- .../Extensions/KernelGrpcExtensions.cs | 4 +- .../Extensions/KernelOpenApiExtensions.cs | 4 +- .../webapi/Skills/CopilotChatPlanner.cs | 1 + 7 files changed, 126 insertions(+), 46 deletions(-) diff --git a/dotnet/src/InternalUtilities/Diagnostics/CompilerServicesAttributes.cs b/dotnet/src/InternalUtilities/Diagnostics/CompilerServicesAttributes.cs index 06722df25475..92b546f23556 100644 --- a/dotnet/src/InternalUtilities/Diagnostics/CompilerServicesAttributes.cs +++ b/dotnet/src/InternalUtilities/Diagnostics/CompilerServicesAttributes.cs @@ -1,4 +1,4 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. diff --git a/dotnet/src/SemanticKernel.UnitTests/SkillDefinition/SKFunctionTests2.cs b/dotnet/src/SemanticKernel.UnitTests/SkillDefinition/SKFunctionTests2.cs index 276a650e3344..65cbf999af0b 100644 --- a/dotnet/src/SemanticKernel.UnitTests/SkillDefinition/SKFunctionTests2.cs +++ b/dotnet/src/SemanticKernel.UnitTests/SkillDefinition/SKFunctionTests2.cs @@ -246,6 +246,36 @@ Task Test(SKContext cx) Assert.Equal("foo", context.Result); } + [Fact] + public async Task ItSupportsAsyncType7Async() + { + // Arrange + [SKFunction("Test")] + [SKFunctionName("Test")] + async Task TestAsync(SKContext cx) + { + await Task.Delay(0); + s_canary = s_expected; + cx.Variables.Update("foo"); + cx["canary"] = s_expected; + return cx; + } + + var context = this.MockContext(""); + + // Act + var function = SKFunction.FromNativeMethod(Method(TestAsync), log: this._log.Object); + Assert.NotNull(function); + SKContext result = await function.InvokeAsync(context); + + // Assert + Assert.False(result.ErrorOccurred); + this.VerifyFunctionTypeMatch(7); + Assert.Equal(s_expected, s_canary); + Assert.Equal(s_expected, context["canary"]); + Assert.Equal("foo", context.Result); + } + [Fact] public async Task ItSupportsType8Async() { diff --git a/dotnet/src/SemanticKernel.UnitTests/SkillDefinition/SKFunctionTests3.cs b/dotnet/src/SemanticKernel.UnitTests/SkillDefinition/SKFunctionTests3.cs index 770d6c78c1fb..0290433b0570 100644 --- a/dotnet/src/SemanticKernel.UnitTests/SkillDefinition/SKFunctionTests3.cs +++ b/dotnet/src/SemanticKernel.UnitTests/SkillDefinition/SKFunctionTests3.cs @@ -3,10 +3,14 @@ using System.Collections.Generic; using System.Linq; using System.Reflection; +using System.Threading; using System.Threading.Tasks; using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Memory; using Microsoft.SemanticKernel.Orchestration; using Microsoft.SemanticKernel.SkillDefinition; +using Moq; +using SemanticKernel.UnitTests.XunitHelpers; using Xunit; namespace SemanticKernel.UnitTests.SkillDefinition; @@ -64,6 +68,43 @@ public void ItThrowsForInvalidFunctions() Assert.Equal(3, count); } + [Fact] + public async Task ItCanImportNativeFunctionsAsync() + { + // Arrange + var variables = new ContextVariables(); + var skills = new SkillCollection(); + var logger = TestConsoleLogger.Log; + var cancellationToken = default(CancellationToken); + var memory = new Mock(); + var context = new SKContext(variables, memory.Object, skills.ReadOnlySkillCollection, logger, cancellationToken); + + // Note: the function doesn't have any SK attributes + async Task ExecuteAsync(SKContext contextIn) + { + Assert.Equal("NO", contextIn["done"]); + contextIn["canary"] = "YES"; + + await Task.Delay(0); + return contextIn; + } + + // Act + context["done"] = "NO"; + ISKFunction function = SKFunction.FromNativeFunction( + nativeFunction: ExecuteAsync, + parameters: null, + description: "description", + skillName: "skillName", + functionName: "functionName"); + + SKContext result = await function.InvokeAsync(context, cancellationToken: cancellationToken); + + // Assert + Assert.Equal("YES", context["canary"]); + Assert.Equal("YES", result["canary"]); + } + private sealed class InvalidSkill { [SKFunction("one")] diff --git a/dotnet/src/SemanticKernel/SkillDefinition/SKFunction.cs b/dotnet/src/SemanticKernel/SkillDefinition/SKFunction.cs index fd70cfc7aa1d..dd5a29a9b446 100644 --- a/dotnet/src/SemanticKernel/SkillDefinition/SKFunction.cs +++ b/dotnet/src/SemanticKernel/SkillDefinition/SKFunction.cs @@ -64,7 +64,7 @@ public CompleteRequestSettings RequestSettings { if (string.IsNullOrWhiteSpace(skillName)) { skillName = SkillCollection.GlobalSkill; } - MethodDetails methodDetails = GetMethodDetails(methodSignature, methodContainerInstance, log); + MethodDetails methodDetails = GetMethodDetails(methodSignature, methodContainerInstance, true, log); // If the given method is not a valid SK function if (!methodDetails.HasSkFunctionAttribute) @@ -83,6 +83,37 @@ public CompleteRequestSettings RequestSettings log: log); } + /// + /// Create a native function instance, wrapping a delegate function + /// + /// Function to invoke + /// SK skill name + /// SK function name + /// SK function description + /// SK function parameters + /// Application logger + /// SK function instance + public static ISKFunction FromNativeFunction( + Delegate nativeFunction, + string skillName, + string functionName, + string description, + IEnumerable? parameters = null, + ILogger? log = null) + { + MethodDetails methodDetails = GetMethodDetails(nativeFunction.Method, null, false, log); + + return new SKFunction( + delegateType: methodDetails.Type, + delegateFunction: methodDetails.Function, + parameters: (parameters ?? Enumerable.Empty()).ToList(), + description: description, + skillName: skillName, + functionName: functionName, + isSemantic: false, + log: log); + } + /// /// Create a native function instance, given a semantic function configuration. /// @@ -115,13 +146,15 @@ async Task LocalFunc( } catch (AIException ex) { - const string Message = "Something went wrong while rendering the semantic function or while executing the text completion. Function: {0}.{1}. Error: {2}. Details: {3}"; + const string Message = "Something went wrong while rendering the semantic function" + + " or while executing the text completion. Function: {0}.{1}. Error: {2}. Details: {3}"; log?.LogError(ex, Message, skillName, functionName, ex.Message, ex.Detail); context.Fail(ex.Message, ex); } catch (Exception ex) when (!ex.IsCriticalException()) { - const string Message = "Something went wrong while rendering the semantic function or while executing the text completion. Function: {0}.{1}. Error: {2}"; + const string Message = "Something went wrong while rendering the semantic function" + + " or while executing the text completion. Function: {0}.{1}. Error: {2}"; log?.LogError(ex, Message, skillName, functionName, ex.Message); context.Fail(ex.Message, ex); } @@ -140,37 +173,6 @@ async Task LocalFunc( log: log); } - /// - /// Create a native function instance, wrapping a native object method - /// - /// Signature of the method to invoke - /// SK skill name - /// SK function name - /// SK function description - /// SK function parameters - /// Application logger - /// SK function instance - public static ISKFunction FromCustomMethod( - Func> customFunction, - string skillName, - string functionName, - string description, - IEnumerable? parameters = null, - ILogger? log = null) - { - var function = new SKFunction( - delegateType: SKFunction.DelegateTypes.ContextSwitchInSKContextOutTaskSKContext, - delegateFunction: customFunction, - parameters: (parameters ?? Enumerable.Empty()).ToList(), - description: description, - skillName: skillName, - functionName: functionName, - isSemantic: false, - log: log); - - return function; - } - /// public FunctionView Describe() { @@ -524,7 +526,11 @@ private void EnsureContextHasSkills(SKContext context) context.Skills ??= this._skillCollection; } - private static MethodDetails GetMethodDetails(MethodInfo methodSignature, object? methodContainerInstance, ILogger? log = null) + private static MethodDetails GetMethodDetails( + MethodInfo methodSignature, + object? methodContainerInstance, + bool skAttributesRequired = true, + ILogger? log = null) { Verify.NotNull(methodSignature); @@ -544,11 +550,13 @@ private static MethodDetails GetMethodDetails(MethodInfo methodSignature, object if (!result.HasSkFunctionAttribute || skFunctionAttribute == null) { - log?.LogTrace("Method '{0}' doesn't have '{1}' attribute.", result.Name, typeof(SKFunctionAttribute).Name); - return result; + log?.LogTrace("Method '{0}' doesn't have '{1}' attribute", result.Name, typeof(SKFunctionAttribute).Name); + if (skAttributesRequired) { return result; } + } + else + { + result.HasSkFunctionAttribute = true; } - - result.HasSkFunctionAttribute = true; (result.Type, result.Function, bool hasStringParam) = GetDelegateInfo(methodContainerInstance, methodSignature); @@ -598,9 +606,9 @@ private static MethodDetails GetMethodDetails(MethodInfo methodSignature, object // Note: the name "input" is reserved for the main parameter Verify.ParametersUniqueness(result.Parameters); - result.Description = skFunctionAttribute.Description ?? ""; + result.Description = skFunctionAttribute?.Description ?? ""; - log?.LogTrace("Method '{0}' found.", result.Name); + log?.LogTrace("Method '{0}' found, type `{1}`", result.Name, result.Type.ToString("G")); return result; } diff --git a/dotnet/src/Skills/Skills.Grpc/Extensions/KernelGrpcExtensions.cs b/dotnet/src/Skills/Skills.Grpc/Extensions/KernelGrpcExtensions.cs index 26b63941b4cc..67071f1aff2c 100644 --- a/dotnet/src/Skills/Skills.Grpc/Extensions/KernelGrpcExtensions.cs +++ b/dotnet/src/Skills/Skills.Grpc/Extensions/KernelGrpcExtensions.cs @@ -172,8 +172,8 @@ async Task ExecuteAsync(SKContext context) return context; } - var function = SKFunction.FromCustomMethod( - customFunction: ExecuteAsync, + var function = SKFunction.FromNativeFunction( + nativeFunction: ExecuteAsync, parameters: operationParameters.ToList(), description: operation.Name, skillName: skillName, diff --git a/dotnet/src/Skills/Skills.OpenAPI/Extensions/KernelOpenApiExtensions.cs b/dotnet/src/Skills/Skills.OpenAPI/Extensions/KernelOpenApiExtensions.cs index b54bc1a4f2fc..3e3cf2c307f3 100644 --- a/dotnet/src/Skills/Skills.OpenAPI/Extensions/KernelOpenApiExtensions.cs +++ b/dotnet/src/Skills/Skills.OpenAPI/Extensions/KernelOpenApiExtensions.cs @@ -318,8 +318,8 @@ async Task ExecuteAsync(SKContext context) }) .ToList(); - var function = SKFunction.FromCustomMethod( - customFunction: ExecuteAsync, + var function = SKFunction.FromNativeFunction( + nativeFunction: ExecuteAsync, parameters: parameters, description: operation.Description, skillName: skillName, diff --git a/samples/apps/copilot-chat-app/webapi/Skills/CopilotChatPlanner.cs b/samples/apps/copilot-chat-app/webapi/Skills/CopilotChatPlanner.cs index 96328f896fd2..c37cbf51ba33 100644 --- a/samples/apps/copilot-chat-app/webapi/Skills/CopilotChatPlanner.cs +++ b/samples/apps/copilot-chat-app/webapi/Skills/CopilotChatPlanner.cs @@ -47,6 +47,7 @@ public Task CreatePlanAsync(string goal) // No functions are available - return an empty plan. return Task.FromResult(new Plan(goal)); } + return new ActionPlanner(this.Kernel).CreatePlanAsync(goal); } }