Skip to content

Commit

Permalink
Add support for custom delegates as SKFunctions
Browse files Browse the repository at this point in the history
  • Loading branch information
dluc committed May 10, 2023
1 parent b24c7dc commit 4f938c4
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 46 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Licensed to the .NET Foundation under one or more agreements.
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,36 @@ Task<SKContext> Test(SKContext cx)
Assert.Equal("foo", context.Result);
}

[Fact]
public async Task ItSupportsAsyncType7Async()
{
// Arrange
[SKFunction("Test")]
[SKFunctionName("Test")]
async Task<SKContext> TestAsync(SKContext cx)
{
await Task.Delay(0);
s_canary = s_expected;
cx.Variables.Update("foo");
cx["canary"] = s_expected;
return cx;
}

var context = this.MockContext("");

// Act
var function = SKFunction.FromNativeMethod(Method(TestAsync), log: this._log.Object);
Assert.NotNull(function);
SKContext result = await function.InvokeAsync(context);

// Assert
Assert.False(result.ErrorOccurred);
this.VerifyFunctionTypeMatch(7);
Assert.Equal(s_expected, s_canary);
Assert.Equal(s_expected, context["canary"]);
Assert.Equal("foo", context.Result);
}

[Fact]
public async Task ItSupportsType8Async()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Memory;
using Microsoft.SemanticKernel.Orchestration;
using Microsoft.SemanticKernel.SkillDefinition;
using Moq;
using SemanticKernel.UnitTests.XunitHelpers;
using Xunit;

namespace SemanticKernel.UnitTests.SkillDefinition;
Expand Down Expand Up @@ -64,6 +68,43 @@ public void ItThrowsForInvalidFunctions()
Assert.Equal(3, count);
}

[Fact]
public async Task ItCanImportNativeFunctionsAsync()
{
// Arrange
var variables = new ContextVariables();
var skills = new SkillCollection();
var logger = TestConsoleLogger.Log;
var cancellationToken = default(CancellationToken);
var memory = new Mock<ISemanticTextMemory>();
var context = new SKContext(variables, memory.Object, skills.ReadOnlySkillCollection, logger, cancellationToken);

// Note: the function doesn't have any SK attributes
async Task<SKContext> ExecuteAsync(SKContext contextIn)
{
Assert.Equal("NO", contextIn["done"]);
contextIn["canary"] = "YES";

await Task.Delay(0);
return contextIn;
}

// Act
context["done"] = "NO";
ISKFunction function = SKFunction.FromNativeFunction(
nativeFunction: ExecuteAsync,
parameters: null,
description: "description",
skillName: "skillName",
functionName: "functionName");

SKContext result = await function.InvokeAsync(context, cancellationToken: cancellationToken);

// Assert
Assert.Equal("YES", context["canary"]);
Assert.Equal("YES", result["canary"]);
}

private sealed class InvalidSkill
{
[SKFunction("one")]
Expand Down
90 changes: 49 additions & 41 deletions dotnet/src/SemanticKernel/SkillDefinition/SKFunction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public CompleteRequestSettings RequestSettings
{
if (string.IsNullOrWhiteSpace(skillName)) { skillName = SkillCollection.GlobalSkill; }

MethodDetails methodDetails = GetMethodDetails(methodSignature, methodContainerInstance, log);
MethodDetails methodDetails = GetMethodDetails(methodSignature, methodContainerInstance, true, log);

// If the given method is not a valid SK function
if (!methodDetails.HasSkFunctionAttribute)
Expand All @@ -83,6 +83,37 @@ public CompleteRequestSettings RequestSettings
log: log);
}

/// <summary>
/// Create a native function instance, wrapping a delegate function
/// </summary>
/// <param name="nativeFunction">Function to invoke</param>
/// <param name="skillName">SK skill name</param>
/// <param name="functionName">SK function name</param>
/// <param name="description">SK function description</param>
/// <param name="parameters">SK function parameters</param>
/// <param name="log">Application logger</param>
/// <returns>SK function instance</returns>
public static ISKFunction FromNativeFunction(
Delegate nativeFunction,
string skillName,
string functionName,
string description,
IEnumerable<ParameterView>? parameters = null,
ILogger? log = null)
{
MethodDetails methodDetails = GetMethodDetails(nativeFunction.Method, null, false, log);

return new SKFunction(
delegateType: methodDetails.Type,
delegateFunction: methodDetails.Function,
parameters: (parameters ?? Enumerable.Empty<ParameterView>()).ToList(),
description: description,
skillName: skillName,
functionName: functionName,
isSemantic: false,
log: log);
}

/// <summary>
/// Create a native function instance, given a semantic function configuration.
/// </summary>
Expand Down Expand Up @@ -115,13 +146,15 @@ async Task<SKContext> LocalFunc(
}
catch (AIException ex)
{
const string Message = "Something went wrong while rendering the semantic function or while executing the text completion. Function: {0}.{1}. Error: {2}. Details: {3}";
const string Message = "Something went wrong while rendering the semantic function" +
" or while executing the text completion. Function: {0}.{1}. Error: {2}. Details: {3}";
log?.LogError(ex, Message, skillName, functionName, ex.Message, ex.Detail);
context.Fail(ex.Message, ex);
}
catch (Exception ex) when (!ex.IsCriticalException())
{
const string Message = "Something went wrong while rendering the semantic function or while executing the text completion. Function: {0}.{1}. Error: {2}";
const string Message = "Something went wrong while rendering the semantic function" +
" or while executing the text completion. Function: {0}.{1}. Error: {2}";
log?.LogError(ex, Message, skillName, functionName, ex.Message);
context.Fail(ex.Message, ex);
}
Expand All @@ -140,37 +173,6 @@ async Task<SKContext> LocalFunc(
log: log);
}

/// <summary>
/// Create a native function instance, wrapping a native object method
/// </summary>
/// <param name="customFunction">Signature of the method to invoke</param>
/// <param name="skillName">SK skill name</param>
/// <param name="functionName">SK function name</param>
/// <param name="description">SK function description</param>
/// <param name="parameters">SK function parameters</param>
/// <param name="log">Application logger</param>
/// <returns>SK function instance</returns>
public static ISKFunction FromCustomMethod(
Func<SKContext, Task<SKContext>> customFunction,
string skillName,
string functionName,
string description,
IEnumerable<ParameterView>? parameters = null,
ILogger? log = null)
{
var function = new SKFunction(
delegateType: SKFunction.DelegateTypes.ContextSwitchInSKContextOutTaskSKContext,
delegateFunction: customFunction,
parameters: (parameters ?? Enumerable.Empty<ParameterView>()).ToList(),
description: description,
skillName: skillName,
functionName: functionName,
isSemantic: false,
log: log);

return function;
}

/// <inheritdoc/>
public FunctionView Describe()
{
Expand Down Expand Up @@ -524,7 +526,11 @@ private void EnsureContextHasSkills(SKContext context)
context.Skills ??= this._skillCollection;
}

private static MethodDetails GetMethodDetails(MethodInfo methodSignature, object? methodContainerInstance, ILogger? log = null)
private static MethodDetails GetMethodDetails(
MethodInfo methodSignature,
object? methodContainerInstance,
bool skAttributesRequired = true,
ILogger? log = null)
{
Verify.NotNull(methodSignature);

Expand All @@ -544,11 +550,13 @@ private static MethodDetails GetMethodDetails(MethodInfo methodSignature, object

if (!result.HasSkFunctionAttribute || skFunctionAttribute == null)
{
log?.LogTrace("Method '{0}' doesn't have '{1}' attribute.", result.Name, typeof(SKFunctionAttribute).Name);
return result;
log?.LogTrace("Method '{0}' doesn't have '{1}' attribute", result.Name, typeof(SKFunctionAttribute).Name);
if (skAttributesRequired) { return result; }
}
else
{
result.HasSkFunctionAttribute = true;
}

result.HasSkFunctionAttribute = true;

(result.Type, result.Function, bool hasStringParam) = GetDelegateInfo(methodContainerInstance, methodSignature);

Expand Down Expand Up @@ -598,9 +606,9 @@ private static MethodDetails GetMethodDetails(MethodInfo methodSignature, object
// Note: the name "input" is reserved for the main parameter
Verify.ParametersUniqueness(result.Parameters);

result.Description = skFunctionAttribute.Description ?? "";
result.Description = skFunctionAttribute?.Description ?? "";

log?.LogTrace("Method '{0}' found.", result.Name);
log?.LogTrace("Method '{0}' found, type `{1}`", result.Name, result.Type.ToString("G"));

return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ async Task<SKContext> ExecuteAsync(SKContext context)
return context;
}

var function = SKFunction.FromCustomMethod(
customFunction: ExecuteAsync,
var function = SKFunction.FromNativeFunction(
nativeFunction: ExecuteAsync,
parameters: operationParameters.ToList(),
description: operation.Name,
skillName: skillName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,8 @@ async Task<SKContext> ExecuteAsync(SKContext context)
})
.ToList();

var function = SKFunction.FromCustomMethod(
customFunction: ExecuteAsync,
var function = SKFunction.FromNativeFunction(
nativeFunction: ExecuteAsync,
parameters: parameters,
description: operation.Description,
skillName: skillName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ public Task<Plan> CreatePlanAsync(string goal)
// No functions are available - return an empty plan.
return Task.FromResult(new Plan(goal));
}

return new ActionPlanner(this.Kernel).CreatePlanAsync(goal);
}
}

0 comments on commit 4f938c4

Please sign in to comment.