Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add tests and consolidate some reflection to find handlers #4435

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 136 additions & 21 deletions dotnet/src/Microsoft.AutoGen/Client/AgentBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,46 @@

namespace Microsoft.AutoGen.Core;

/// <summary>
/// Represents the base class for an agent in the AutoGen system.
/// </summary>
public abstract class AgentBase
{
/// <summary>
/// The activity source for tracing.
/// </summary>
public static readonly ActivitySource s_source = new("AutoGen.Agent");

/// <summary>
/// Gets the unique identifier of the agent.
/// </summary>
public AgentId AgentId => _context.AgentId;

private readonly object _lock = new();
private readonly ConcurrentDictionary<string, TaskCompletionSource<RpcResponse>> _pendingRequests = [];
private readonly ConcurrentDictionary<string, TaskCompletionSource<RpcResponse>> _pendingRequests = new();

private readonly Channel<object> _mailbox = Channel.CreateUnbounded<object>();
private readonly RuntimeContext _context;

/// <summary>
/// Gets the runtime context of the agent.
/// </summary>
public RuntimeContext Context => _context;

/// <summary>
/// Gets or sets the route of the agent.
/// </summary>
public string Route { get; set; } = "base";

protected internal ILogger<AgentBase> _logger;
protected readonly EventTypes EventTypes;

/// <summary>
/// Initializes a new instance of the <see cref="AgentBase"/> class.
/// </summary>
/// <param name="context">The runtime context of the agent.</param>
/// <param name="eventTypes">The event types associated with the agent.</param>
/// <param name="logger">The logger instance for logging.</param>
protected AgentBase(
RuntimeContext context,
EventTypes eventTypes,
Expand All @@ -37,10 +62,19 @@ protected AgentBase(
context.AgentInstance = this;
EventTypes = eventTypes;
_logger = logger ?? LoggerFactory.Create(builder => { }).CreateLogger<AgentBase>();

Completion = Start();
}

/// <summary>
/// Gets the task representing the completion of the agent's operations.
/// </summary>
internal Task Completion { get; }

/// <summary>
/// Starts the message pump for the agent.
/// </summary>
/// <returns>A task representing the asynchronous operation.</returns>
internal Task Start()
{
var didSuppress = false;
Expand All @@ -62,6 +96,11 @@ internal Task Start()
}
}
}

/// <summary>
/// Receives a message and writes it to the mailbox.
/// </summary>
/// <param name="message">The message to receive.</param>
public void ReceiveMessage(Message message) => _mailbox.Writer.TryWrite(message);

private async Task RunMessagePump()
Expand All @@ -86,6 +125,13 @@ private async Task RunMessagePump()
}
}
}

/// <summary>
/// Handles an RPC message.
/// </summary>
/// <param name="msg">The message to handle.</param>
/// <param name="cancellationToken">A token to cancel the operation.</param>
/// <returns>A task representing the asynchronous operation.</returns>
protected internal async Task HandleRpcMessage(Message msg, CancellationToken cancellationToken = default)
{
switch (msg.MessageCase)
Expand Down Expand Up @@ -115,6 +161,12 @@ await this.InvokeWithActivityAsync(
break;
}
}

/// <summary>
/// Subscribes to a topic.
/// </summary>
/// <param name="topic">The topic to subscribe to.</param>
/// <returns>A list of subscribed topics.</returns>
public List<string> Subscribe(string topic)
{
Message message = new()
Expand All @@ -136,16 +188,31 @@ public List<string> Subscribe(string topic)

return new List<string> { topic };
}

/// <summary>
/// Stores the agent state asynchronously.
/// </summary>
/// <param name="state">The agent state to store.</param>
/// <param name="cancellationToken">A token to cancel the operation.</param>
/// <returns>A task representing the asynchronous operation.</returns>
public async Task StoreAsync(AgentState state, CancellationToken cancellationToken = default)
{
await _context.StoreAsync(state, cancellationToken).ConfigureAwait(false);
return;
}

/// <summary>
/// Reads the agent state asynchronously.
/// </summary>
/// <typeparam name="T">The type of the agent state.</typeparam>
/// <param name="agentId">The ID of the agent whose state is to be read.</param>
/// <param name="cancellationToken">A token to cancel the operation.</param>
/// <returns>A task representing the asynchronous operation, containing the agent state.</returns>
public async Task<T> ReadAsync<T>(AgentId agentId, CancellationToken cancellationToken = default) where T : IMessage, new()
{
var agentstate = await _context.ReadAsync(agentId, cancellationToken).ConfigureAwait(false);
return agentstate.FromAgentState<T>();
var agentState = await _context.ReadAsync(agentId, cancellationToken).ConfigureAwait(false);
return agentState.FromAgentState<T>();
}

private void OnResponseCore(RpcResponse response)
{
var requestId = response.RequestId;
Expand All @@ -160,13 +227,14 @@ private void OnResponseCore(RpcResponse response)

completion.SetResult(response);
}
private async Task OnRequestCoreAsync(RpcRequest request, CancellationToken cancellationToken = default)

private async Task OnRequestCoreAsync(RpcRequest request, CancellationToken cancellationToken)
{
RpcResponse response;

try
{
response = await HandleRequest(request).ConfigureAwait(false);
response = await HandleRequestAsync(request).ConfigureAwait(false);
}
catch (Exception ex)
{
Expand All @@ -175,6 +243,13 @@ private async Task OnRequestCoreAsync(RpcRequest request, CancellationToken canc
await _context.SendResponseAsync(request, response, cancellationToken).ConfigureAwait(false);
}

/// <summary>
/// Sends a request asynchronously.
/// </summary>
/// <param name="target">The target agent ID.</param>
/// <param name="method">The method to call.</param>
/// <param name="parameters">The parameters for the method.</param>
/// <returns>A task representing the asynchronous operation, containing the RPC response.</returns>
protected async Task<RpcResponse> RequestAsync(AgentId target, string method, Dictionary<string, string> parameters)
{
var requestId = Guid.NewGuid().ToString();
Expand Down Expand Up @@ -219,13 +294,27 @@ static async (state, ct) =>
return await completion.Task.ConfigureAwait(false);
}

/// <summary>
/// Publishes a message asynchronously.
/// </summary>
/// <typeparam name="T">The type of the message.</typeparam>
/// <param name="message">The message to publish.</param>
/// <param name="source">The source of the message.</param>
/// <param name="token">A token to cancel the operation.</param>
/// <returns>A task representing the asynchronous operation.</returns>
public async ValueTask PublishMessageAsync<T>(T message, string? source = null, CancellationToken token = default) where T : IMessage
{
var src = string.IsNullOrWhiteSpace(source) ? AgentId.Key : source;
var evt = message.ToCloudEvent(src);
await PublishEventAsync(evt, token).ConfigureAwait(false);
}

/// <summary>
/// Publishes a cloud event asynchronously.
/// </summary>
/// <param name="item">The cloud event to publish.</param>
/// <param name="cancellationToken">A token to cancel the operation.</param>
/// <returns>A task representing the asynchronous operation.</returns>
public async ValueTask PublishEventAsync(CloudEvent item, CancellationToken cancellationToken = default)
{
var activity = s_source.StartActivity($"PublishEventAsync '{item.Type}'", ActivityKind.Client, Activity.Current?.Context ?? default);
Expand All @@ -236,13 +325,19 @@ public async ValueTask PublishEventAsync(CloudEvent item, CancellationToken canc
await this.InvokeWithActivityAsync(
static async (state, ct) =>
{
await state.Item1._context.PublishEventAsync(state.item, cancellationToken : ct).ConfigureAwait(false);
await state.Item1._context.PublishEventAsync(state.item, cancellationToken: ct).ConfigureAwait(false);
},
(this, item),
activity,
item.Type, cancellationToken).ConfigureAwait(false);
}

/// <summary>
/// Calls the handler for a cloud event.
/// </summary>
/// <param name="item">The cloud event to handle.</param>
/// <param name="cancellationToken">A token to cancel the operation.</param>
/// <returns>A task representing the asynchronous operation.</returns>
public Task CallHandler(CloudEvent item, CancellationToken cancellationToken)
{
// Only send the event to the handler if the agent type is handling that type
Expand All @@ -263,12 +358,12 @@ public Task CallHandler(CloudEvent item, CancellationToken cancellationToken)
{
methodInfo = genericInterfaceType.GetMethod("Handle", BindingFlags.Public | BindingFlags.Instance)
?? throw new InvalidOperationException($"Method not found on type {genericInterfaceType.FullName}");
return methodInfo.Invoke(this, [convertedPayload, cancellationToken]) as Task ?? Task.CompletedTask;
return methodInfo.Invoke(this, new object[] { convertedPayload, cancellationToken }) as Task ?? Task.CompletedTask;
}

// The error here is we have registered for an event that we do not have code to listen to
throw new InvalidOperationException($"No handler found for event '{item.Type}'; expecting IHandle<{item.Type}> implementation.");

}
catch (Exception ex)
{
Expand All @@ -280,26 +375,46 @@ public Task CallHandler(CloudEvent item, CancellationToken cancellationToken)
return Task.CompletedTask;
}

public Task<RpcResponse> HandleRequest(RpcRequest request) => Task.FromResult(new RpcResponse { Error = "Not implemented" });

//TODO: should this be async and cancellable?
public virtual Task HandleObject(object item)
/// <summary>
/// Handles an RPC request.
/// </summary>
/// <param name="request">The request to handle.</param>
/// <param name="cancellationToken">A token to cancel the operation.</param>
/// <returns>A task representing the asynchronous operation, containing the RPC response.</returns>
public Task<RpcResponse> HandleRequestAsync(RpcRequest request, CancellationToken cancellationToken = default) => Task.FromResult(new RpcResponse { Error = "Not implemented" });


/// <summary>
/// Handles an object asynchronously by invoking the appropriate handler method based on the object's type.
/// </summary>
/// <param name="item">The object to handle.</param>
/// <param name="cancellationToken">A token to cancel the operation.</param>
/// <returns>A task representing the asynchronous operation.</returns>
/// <exception cref="InvalidOperationException">Thrown when no handler is found for the object's type.</exception>
public virtual Task HandleObjectAsync(object item, CancellationToken cancellationToken)
{
// get all Handle<T> methods
var handleTMethods = GetType().GetMethods().Where(m => m.Name == "Handle" && m.GetParameters().Length == 1).ToList();
var lookup = GetType().GetHandlersLookupTable();

// get the one that matches the type of the item
var handleTMethod = handleTMethods.FirstOrDefault(m => m.GetParameters()[0].ParameterType == item.GetType());

// if we found one, invoke it
if (handleTMethod != null)
if (lookup.TryGetValue(item.GetType(), out var method))
{
return (Task)handleTMethod.Invoke(this, [item])!;
if (method is null)
{
throw new InvalidOperationException($"No handler found for type {item.GetType().FullName}");
}
return (Task)method.Invoke(this, [item, cancellationToken])!;
}

// otherwise, complain
throw new InvalidOperationException($"No handler found for type {item.GetType().FullName}");
}

/// <summary>
/// Publishes a cloud event asynchronously.
/// </summary>
/// <param name="topic">The topic of the event.</param>
/// <param name="evt">The event to publish.</param>
/// <param name="cancellationToken">A token to cancel the operation.</param>
/// <returns>A task representing the asynchronous operation.</returns>
public async ValueTask PublishEventAsync(string topic, IMessage evt, CancellationToken cancellationToken = default)
{
await PublishEventAsync(evt.ToCloudEvent(topic), cancellationToken).ConfigureAwait(false);
Expand Down
39 changes: 39 additions & 0 deletions dotnet/src/Microsoft.AutoGen/Client/IHandleExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// IHandleExtensions.cs

using System.Reflection;

namespace Microsoft.AutoGen.Core;


/// <summary>
/// Provides extension methods for types implementing the IHandle interface.
/// </summary>
public static class IHandleExtensions
{
/// <summary>
/// Gets all the handler methods from the interfaces implemented by the specified type.
/// </summary>
/// <param name="type">The type to get the handler methods from.</param>
/// <returns>An array of MethodInfo objects representing the handler methods.</returns>
public static MethodInfo[] GetHandlers(this Type type)
{
var handlers = type.GetInterfaces().Where(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IHandle<>));
return handlers.SelectMany(h => h.GetMethods().Where(m => m.Name == "Handle")).ToArray();
}

/// <summary>
/// Gets a lookup table of handler methods from the interfaces implemented by the specified type.
/// </summary>
/// <param name="type">The type to get the handler methods from.</param>
/// <returns>A dictionary where the key is the generic type and the value is the MethodInfo of the handler method.</returns>
public static Dictionary<Type, MethodInfo> GetHandlersLookupTable(this Type type)
{
var handlers = type.GetHandlers();
return handlers.ToDictionary(h =>
{
var generic = h.DeclaringType!.GetGenericArguments();
return generic[0];
});
}
}
4 changes: 2 additions & 2 deletions dotnet/test/Microsoft.AutoGen.Agents.Tests/AgentBaseTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ public async Task ItInvokeRightHandlerTestAsync()
.Returns(new ValueTask());
var agent = new TestAgent(mockContext.Object, new EventTypes(TypeRegistry.Empty, [], []), new Logger<AgentBase>(new LoggerFactory()));

await agent.HandleObject("hello world");
await agent.HandleObject(42);
await agent.HandleObjectAsync("hello world", CancellationToken.None);
await agent.HandleObjectAsync(42, CancellationToken.None);

agent.ReceivedItems.Should().HaveCount(2);
agent.ReceivedItems[0].Should().Be("hello world");
Expand Down
13 changes: 0 additions & 13 deletions dotnet/test/Microsoft.AutoGen.Core.Tests/AgentMetadataTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

using FluentAssertions;
using FluentAssertions.Execution;
using Microsoft.Extensions.Logging;
using Tests.Events;
using Xunit;

Expand All @@ -21,15 +20,3 @@ public void EventTypes_IsPopulated_From_Assembly()
eventTypes.CheckIfTypeHandles(typeof(TestAgent), GoodBye.Descriptor.FullName).Should().BeTrue();
}
}

public class TestAgent : AgentBase, IHandle<GoodBye>
{
public TestAgent(RuntimeContext context, EventTypes eventTypes, ILogger<AgentBase>? logger = null) : base(context, eventTypes, logger)
{
}

public Task Handle(GoodBye item, CancellationToken cancellationToken)
{
throw new NotImplementedException();
}
}
Loading