Skip to content

Commit

Permalink
Resolve sync with main
Browse files Browse the repository at this point in the history
  • Loading branch information
crickman committed Aug 19, 2024
2 parents 86e1df6 + c78c4da commit ac038d2
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 149 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace Agents;

/// <summary>
/// Demonstrate service selection for <see cref="ChatCompletionAgent"/> through setting service-id
/// on <see cref="ChatCompletionAgent.Arguments"/> and also providing override <see cref="KernelArguments"/>
/// on <see cref="ChatHistoryKernelAgent.Arguments"/> and also providing override <see cref="KernelArguments"/>
/// when calling <see cref="ChatCompletionAgent.InvokeAsync"/>
/// </summary>
public class ChatCompletion_ServiceSelection(ITestOutputHelper output) : BaseTest(output)
Expand Down
49 changes: 4 additions & 45 deletions dotnet/src/Agents/Core/ChatCompletionAgent.cs
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;
using System.Globalization;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.Agents.History;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Services;

Expand All @@ -16,20 +13,12 @@ namespace Microsoft.SemanticKernel.Agents;
/// </summary>
/// <remarks>
/// NOTE: Enable OpenAIPromptExecutionSettings.ToolCallBehavior for agent plugins.
/// (<see cref="ChatCompletionAgent.Arguments"/>)
/// (<see cref="ChatHistoryKernelAgent.Arguments"/>)
/// </remarks>
public sealed class ChatCompletionAgent : KernelAgent, IChatHistoryHandler
public sealed class ChatCompletionAgent : ChatHistoryKernelAgent
{
/// <summary>
/// Optional arguments for the agent.
/// </summary>
public KernelArguments? Arguments { get; init; }

/// <inheritdoc/>
public IChatHistoryReducer? HistoryReducer { get; init; }

/// <inheritdoc/>
public async IAsyncEnumerable<ChatMessageContent> InvokeAsync(
public override async IAsyncEnumerable<ChatMessageContent> InvokeAsync(
ChatHistory history,
KernelArguments? arguments = null,
Kernel? kernel = null,
Expand Down Expand Up @@ -74,7 +63,7 @@ await chatCompletionService.GetChatMessageContentsAsync(
}

/// <inheritdoc/>
public async IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamingAsync(
public override async IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamingAsync(
ChatHistory history,
KernelArguments? arguments = null,
Kernel? kernel = null,
Expand Down Expand Up @@ -118,36 +107,6 @@ public async IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamingAsync(
}
}

/// <inheritdoc/>
protected override IEnumerable<string> GetChannelKeys()
{
// Distinguish from other channel types.
yield return typeof(ChatHistoryChannel).FullName!;

// Agents with different reducers shall not share the same channel.
// Agents with the same or equivalent reducer shall share the same channel.
if (this.HistoryReducer != null)
{
// Explicitly include the reducer type to eliminate the possibility of hash collisions
// with custom implementations of IChatHistoryReducer.
yield return this.HistoryReducer.GetType().FullName!;

yield return this.HistoryReducer.GetHashCode().ToString(CultureInfo.InvariantCulture);
}
}

/// <inheritdoc/>
protected override Task<AgentChannel> CreateChannelAsync(CancellationToken cancellationToken)
{
ChatHistoryChannel channel =
new()
{
Logger = this.LoggerFactory.CreateLogger<ChatHistoryChannel>()
};

return Task.FromResult<AgentChannel>(channel);
}

internal static (IChatCompletionService service, PromptExecutionSettings? executionSettings) GetChatCompletionService(Kernel kernel, KernelArguments? arguments)
{
// Need to provide a KernelFunction to the service selector as a container for the execution-settings.
Expand Down
9 changes: 4 additions & 5 deletions dotnet/src/Agents/Core/ChatHistoryChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@
using System.Threading;
using System.Threading.Tasks;
using Microsoft.SemanticKernel.Agents.Extensions;
using Microsoft.SemanticKernel.Agents.History;
using Microsoft.SemanticKernel.ChatCompletion;

namespace Microsoft.SemanticKernel.Agents;

/// <summary>
/// A <see cref="AgentChannel"/> specialization for that acts upon a <see cref="IChatHistoryHandler"/>.
/// A <see cref="AgentChannel"/> specialization for that acts upon a <see cref="ChatHistoryKernelAgent"/>.
/// </summary>
public sealed class ChatHistoryChannel : AgentChannel
{
Expand All @@ -22,13 +21,13 @@ public sealed class ChatHistoryChannel : AgentChannel
Agent agent,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
if (agent is not IChatHistoryHandler historyHandler)
if (agent is not ChatHistoryKernelAgent historyAgent)
{
throw new KernelException($"Invalid channel binding for agent: {agent.Id} ({agent.GetType().FullName})");
}

// Pre-process history reduction.
await this._history.ReduceAsync(historyHandler.HistoryReducer, cancellationToken).ConfigureAwait(false);
await historyAgent.ReduceAsync(this._history, cancellationToken).ConfigureAwait(false);

// Capture the current message count to evaluate history mutation.
int messageCount = this._history.Count;
Expand All @@ -38,7 +37,7 @@ public sealed class ChatHistoryChannel : AgentChannel
Queue<ChatMessageContent> messageQueue = [];

ChatMessageContent? yieldMessage = null;
await foreach (ChatMessageContent responseMessage in historyHandler.InvokeAsync(this._history, null, null, cancellationToken).ConfigureAwait(false))
await foreach (ChatMessageContent responseMessage in historyAgent.InvokeAsync(this._history, null, null, cancellationToken).ConfigureAwait(false))
{
// Capture all messages that have been included in the mutated the history.
for (int messageIndex = messageCount; messageIndex < this._history.Count; messageIndex++)
Expand Down
80 changes: 80 additions & 0 deletions dotnet/src/Agents/Core/ChatHistoryKernelAgent.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;
using System.Globalization;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.Agents.History;
using Microsoft.SemanticKernel.ChatCompletion;

namespace Microsoft.SemanticKernel.Agents;

/// <summary>
/// A <see cref="KernelAgent"/> specialization bound to a <see cref="ChatHistoryChannel"/>.
/// </summary>
/// <remarks>
/// NOTE: Enable OpenAIPromptExecutionSettings.ToolCallBehavior for agent plugins.
/// (<see cref="ChatHistoryKernelAgent.Arguments"/>)
/// </remarks>
public abstract class ChatHistoryKernelAgent : KernelAgent
{
/// <summary>
/// Optional arguments for the agent.
/// </summary>
public KernelArguments? Arguments { get; init; }

/// <inheritdoc/>
public IChatHistoryReducer? HistoryReducer { get; init; }

/// <inheritdoc/>
public abstract IAsyncEnumerable<ChatMessageContent> InvokeAsync(
ChatHistory history,
KernelArguments? arguments = null,
Kernel? kernel = null,
CancellationToken cancellationToken = default);

/// <inheritdoc/>
public abstract IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamingAsync(
ChatHistory history,
KernelArguments? arguments = null,
Kernel? kernel = null,
CancellationToken cancellationToken = default);

/// <summary>
/// Reduce the provided history
/// </summary>
/// <param name="history">The source history</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns></returns>
public Task<bool> ReduceAsync(ChatHistory history, CancellationToken cancellationToken = default) =>
history.ReduceAsync(this.HistoryReducer, cancellationToken);

/// <inheritdoc/>
protected sealed override IEnumerable<string> GetChannelKeys()
{
yield return typeof(ChatHistoryChannel).FullName!;

// Agents with different reducers shall not share the same channel.
// Agents with the same or equivalent reducer shall share the same channel.
if (this.HistoryReducer != null)
{
// Explicitly include the reducer type to eliminate the possibility of hash collisions
// with custom implementations of IChatHistoryReducer.
yield return this.HistoryReducer.GetType().FullName!;

yield return this.HistoryReducer.GetHashCode().ToString(CultureInfo.InvariantCulture);
}
}

/// <inheritdoc/>
protected sealed override Task<AgentChannel> CreateChannelAsync(CancellationToken cancellationToken)
{
ChatHistoryChannel channel =
new()
{
Logger = this.LoggerFactory.CreateLogger<ChatHistoryChannel>()
};

return Task.FromResult<AgentChannel>(channel);
}
}
46 changes: 0 additions & 46 deletions dotnet/src/Agents/Core/IChatHistoryHandler.cs

This file was deleted.

23 changes: 0 additions & 23 deletions dotnet/src/Agents/Core/IChatHistoryHandlerExtensions.cs

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public class ChatHistoryChannelTests
{
/// <summary>
/// Verify a <see cref="ChatHistoryChannel"/> throws if passed an agent that
/// does not implement <see cref="IChatHistoryHandler"/>.
/// does not implement <see cref="ChatHistoryKernelAgent"/>.
/// </summary>
[Fact]
public async Task VerifyAgentWithoutIChatHistoryHandlerAsync()
Expand Down
32 changes: 4 additions & 28 deletions dotnet/src/Agents/UnitTests/MockAgent.cs
Original file line number Diff line number Diff line change
@@ -1,29 +1,23 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Agents;
using Microsoft.SemanticKernel.Agents.History;
using Microsoft.SemanticKernel.ChatCompletion;

namespace SemanticKernel.Agents.UnitTests;

/// <summary>
/// Mock definition of <see cref="KernelAgent"/> with a <see cref="IChatHistoryHandler"/> contract.
/// Mock definition of <see cref="KernelAgent"/> with a <see cref="ChatHistoryKernelAgent"/> contract.
/// </summary>
internal class MockAgent : KernelAgent, IChatHistoryHandler
internal class MockAgent : ChatHistoryKernelAgent
{
public int InvokeCount { get; private set; }

public IReadOnlyList<ChatMessageContent> Response { get; set; } = [];

public IChatHistoryReducer? HistoryReducer { get; init; }

public IAsyncEnumerable<ChatMessageContent> InvokeAsync(
public override IAsyncEnumerable<ChatMessageContent> InvokeAsync(
ChatHistory history,
KernelArguments? arguments = null,
Kernel? kernel = null,
Expand All @@ -34,7 +28,7 @@ public IAsyncEnumerable<ChatMessageContent> InvokeAsync(
return this.Response.ToAsyncEnumerable();
}

public IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamingAsync(
public override IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamingAsync(
ChatHistory history,
KernelArguments? arguments = null,
Kernel? kernel = null,
Expand All @@ -43,22 +37,4 @@ public IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamingAsync(
this.InvokeCount++;
return this.Response.Select(m => new StreamingChatMessageContent(m.Role, m.Content)).ToAsyncEnumerable();
}

/// <inheritdoc/>
protected internal override IEnumerable<string> GetChannelKeys()
{
yield return Guid.NewGuid().ToString();
}

/// <inheritdoc/>
protected internal override Task<AgentChannel> CreateChannelAsync(CancellationToken cancellationToken)
{
ChatHistoryChannel channel =
new()
{
Logger = this.LoggerFactory.CreateLogger<ChatHistoryChannel>()
};

return Task.FromResult<AgentChannel>(channel);
}
}

0 comments on commit ac038d2

Please sign in to comment.