Skip to content

Commit

Permalink
Rysweet 4671 remove iagent base (#4673)
Browse files Browse the repository at this point in the history
* refactor renaming agent base and removing unused stuff

Authored-by: Kosta Petan <[email protected]>
Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: Kosta Petan <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Ryan Sweet <[email protected]>
Co-authored-by: Victor Dibia <[email protected]>
  • Loading branch information
4 people authored Dec 12, 2024
1 parent b9d682c commit 7d4bf9b
Show file tree
Hide file tree
Showing 12 changed files with 28 additions and 54 deletions.
2 changes: 1 addition & 1 deletion dotnet/samples/dev-team/DevTeam.Backend/Agents/Sandbox.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

// namespace DevTeam.Backend;

// public sealed class Sandbox : AgentBase
// public sealed class Sandbox : Agent
// {
// private const string ReminderName = "SandboxRunReminder";
// private readonly IManageAzure _azService;
Expand Down
23 changes: 0 additions & 23 deletions dotnet/src/Microsoft.AutoGen/Abstractions/IAgentBase.cs

This file was deleted.

17 changes: 7 additions & 10 deletions dotnet/src/Microsoft.AutoGen/Agents/Agent.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Agent.cs

using System.Collections.Concurrent;
using System.Diagnostics;
using System.Reflection;
using System.Text;
Expand All @@ -12,16 +13,15 @@

namespace Microsoft.AutoGen.Agents;

public abstract class Agent : IAgentBase, IHandle
public abstract class Agent : IHandle
{
public static readonly ActivitySource s_source = new("AutoGen.Agent");
public AgentId AgentId => _runtime.AgentId;
private readonly object _lock = new();
private readonly Dictionary<string, TaskCompletionSource<RpcResponse>> _pendingRequests = [];
private readonly ConcurrentDictionary<string, TaskCompletionSource<RpcResponse>> _pendingRequests = [];

private readonly Channel<object> _mailbox = Channel.CreateUnbounded<object>();
private readonly IAgentRuntime _runtime;
public string Route { get; set; } = "base";

protected internal ILogger<Agent> _logger;
public IAgentRuntime Context => _runtime;
Expand Down Expand Up @@ -235,18 +235,15 @@ protected async Task<RpcResponse> RequestAsync(AgentId target, string method, Di
activity?.SetTag("peer.service", target.ToString());

var completion = new TaskCompletionSource<RpcResponse>(TaskCreationOptions.RunContinuationsAsynchronously);
_runtime.Update(request, activity);
Context!.Update(request, activity);
await this.InvokeWithActivityAsync(
static async ((Agent Agent, RpcRequest Request, TaskCompletionSource<RpcResponse>) state, CancellationToken ct) =>
static async (state, ct) =>
{
var (self, request, completion) = state;

lock (self._lock)
{
self._pendingRequests[request.RequestId] = completion;
}
self._pendingRequests.AddOrUpdate(request.RequestId, _ => completion, (_, __) => completion);

await state.Agent._runtime.SendRequestAsync(state.Agent, state.Request).ConfigureAwait(false);
await state.Item1.Context!.SendRequestAsync(state.Item1, state.request, ct).ConfigureAwait(false);

await completion.Task.ConfigureAwait(false);
},
Expand Down
6 changes: 3 additions & 3 deletions dotnet/src/Microsoft.AutoGen/Agents/AgentRuntime.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ internal sealed class AgentRuntime(AgentId agentId, IAgentWorker worker, ILogger
private readonly IAgentWorker worker = worker;

public AgentId AgentId { get; } = agentId;
public ILogger<Agent> Logger { get; } = logger;
public IAgentBase? AgentInstance { get; set; }
private ILogger<Agent> Logger { get; } = logger;
public Agent? AgentInstance { get; set; }
private DistributedContextPropagator DistributedContextPropagator { get; } = distributedContextPropagator;
public (string?, string?) GetTraceIdAndState(IDictionary<string, string> metadata)
{
Expand Down Expand Up @@ -79,7 +79,7 @@ public async ValueTask SendResponseAsync(RpcRequest request, RpcResponse respons
response.RequestId = request.RequestId;
await worker.SendResponseAsync(response, cancellationToken).ConfigureAwait(false);
}
public async ValueTask SendRequestAsync(IAgentBase agent, RpcRequest request, CancellationToken cancellationToken = default)
public async ValueTask SendRequestAsync(Agent agent, RpcRequest request, CancellationToken cancellationToken = default)
{
await worker.SendRequestAsync(agent, request, cancellationToken).ConfigureAwait(false);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// InferenceAgent.cs

using Google.Protobuf;
using Microsoft.AutoGen.Abstractions;
using Microsoft.Extensions.AI;
namespace Microsoft.AutoGen.Agents;
public abstract class InferenceAgent<T>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

using System.Globalization;
using System.Text;
using Microsoft.AutoGen.Abstractions;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Connectors.OpenAI;
using Microsoft.SemanticKernel.Memory;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ namespace Microsoft.AutoGen.Agents;

public interface IHandleConsole : IHandle<Output>, IHandle<Input>
{
string Route { get; }
AgentId AgentId { get; }
ValueTask PublishMessageAsync<T>(T message, string? source = null, CancellationToken token = default) where T : IMessage;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@
// IAgentRuntime.cs

using System.Diagnostics;
using Microsoft.AutoGen.Abstractions;

namespace Microsoft.AutoGen.Abstractions;
namespace Microsoft.AutoGen.Agents;

public interface IAgentRuntime
{
AgentId AgentId { get; }
IAgentBase? AgentInstance { get; set; }
Agent? AgentInstance { get; set; }
ValueTask StoreAsync(AgentState value, CancellationToken cancellationToken = default);
ValueTask<AgentState> ReadAsync(AgentId agentId, CancellationToken cancellationToken = default);
ValueTask SendResponseAsync(RpcRequest request, RpcResponse response, CancellationToken cancellationToken = default);
ValueTask SendRequestAsync(IAgentBase agent, RpcRequest request, CancellationToken cancellationToken = default);
ValueTask SendRequestAsync(Agent agent, RpcRequest request, CancellationToken cancellationToken = default);
ValueTask SendMessageAsync(Message message, CancellationToken cancellationToken = default);
ValueTask PublishEventAsync(CloudEvent @event, CancellationToken cancellationToken = default);
void Update(RpcRequest request, Activity? activity);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// IAgentWorker.cs
namespace Microsoft.AutoGen.Abstractions;
using Microsoft.AutoGen.Abstractions;
namespace Microsoft.AutoGen.Agents;

public interface IAgentWorker
{
ValueTask PublishEventAsync(CloudEvent evt, CancellationToken cancellationToken = default);
ValueTask SendRequestAsync(IAgentBase agent, RpcRequest request, CancellationToken cancellationToken = default);
ValueTask SendRequestAsync(Agent agent, RpcRequest request, CancellationToken cancellationToken = default);
ValueTask SendResponseAsync(RpcResponse response, CancellationToken cancellationToken = default);
ValueTask SendMessageAsync(Message message, CancellationToken cancellationToken = default);
ValueTask StoreAsync(AgentState value, CancellationToken cancellationToken = default);
Expand Down
8 changes: 4 additions & 4 deletions dotnet/src/Microsoft.AutoGen/Agents/Services/AgentWorker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ public class AgentWorker :
IAgentWorker
{
private readonly ConcurrentDictionary<string, Type> _agentTypes = new();
private readonly ConcurrentDictionary<(string Type, string Key), IAgentBase> _agents = new();
private readonly ConcurrentDictionary<(string Type, string Key), Agent> _agents = new();
private readonly ILogger<AgentWorker> _logger;
private readonly Channel<object> _mailbox = Channel.CreateUnbounded<object>();
private readonly ConcurrentDictionary<string, AgentState> _agentStates = new();
private readonly ConcurrentDictionary<string, (IAgentBase Agent, string OriginalRequestId)> _pendingClientRequests = new();
private readonly ConcurrentDictionary<string, (Agent Agent, string OriginalRequestId)> _pendingClientRequests = new();
private readonly CancellationTokenSource _shutdownCts;
private readonly IServiceProvider _serviceProvider;
private readonly IEnumerable<Tuple<string, Type>> _configuredAgentTypes;
Expand Down Expand Up @@ -54,7 +54,7 @@ public async ValueTask PublishEventAsync(CloudEvent cloudEvent, CancellationToke
agent.ReceiveMessage(new Message { CloudEvent = cloudEvent });
}
}
public async ValueTask SendRequestAsync(IAgentBase agent, RpcRequest request, CancellationToken cancellationToken = default)
public async ValueTask SendRequestAsync(Agent agent, RpcRequest request, CancellationToken cancellationToken = default)
{
var requestId = Guid.NewGuid().ToString();
_pendingClientRequests[requestId] = (agent, request.RequestId);
Expand Down Expand Up @@ -190,7 +190,7 @@ public async Task StopAsync(CancellationToken cancellationToken)
{
}
}
private IAgentBase GetOrActivateAgent(AgentId agentId)
private Agent GetOrActivateAgent(AgentId agentId)
{
if (!_agents.TryGetValue((agentId.Type, agentId.Key), out var agent))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ public sealed class GrpcAgentWorker(
{
private readonly object _channelLock = new();
private readonly ConcurrentDictionary<string, Type> _agentTypes = new();
private readonly ConcurrentDictionary<(string Type, string Key), IAgentBase> _agents = new();
private readonly ConcurrentDictionary<string, (IAgentBase Agent, string OriginalRequestId)> _pendingRequests = new();
private readonly ConcurrentDictionary<(string Type, string Key), Agent> _agents = new();
private readonly ConcurrentDictionary<string, (Agent Agent, string OriginalRequestId)> _pendingRequests = new();
private readonly Channel<(Message Message, TaskCompletionSource WriteCompletionSource)> _outboundMessagesChannel = Channel.CreateBounded<(Message, TaskCompletionSource)>(new BoundedChannelOptions(1024)
{
AllowSynchronousContinuations = true,
Expand Down Expand Up @@ -187,7 +187,7 @@ private async Task RunWritePump()
item.WriteCompletionSource.TrySetCanceled();
}
}
private IAgentBase GetOrActivateAgent(AgentId agentId)
private Agent GetOrActivateAgent(AgentId agentId)
{
if (!_agents.TryGetValue((agentId.Type, agentId.Key), out var agent))
{
Expand Down Expand Up @@ -275,7 +275,7 @@ await WriteChannelAsync(new Message
await WriteChannelAsync(new Message { Response = response }, cancellationToken).ConfigureAwait(false);
}
// new is intentional
public new async ValueTask SendRequestAsync(IAgentBase agent, RpcRequest request, CancellationToken cancellationToken = default)
public new async ValueTask SendRequestAsync(Agent agent, RpcRequest request, CancellationToken cancellationToken = default)
{
var requestId = Guid.NewGuid().ToString();
_pendingRequests[requestId] = (agent, request.RequestId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ public static IHostApplicationBuilder AddAgentWorker(this IHostApplicationBuilde

var eventsMap = AppDomain.CurrentDomain.GetAssemblies()
.SelectMany(assembly => assembly.GetTypes())
.Where(type => type != null && ReflectionHelper.IsSubclassOfGeneric(type, typeof(Agent)) && !type.IsAbstract)
.Where(type => ReflectionHelper.IsSubclassOfGeneric(type, typeof(Agent)) && !type.IsAbstract)

.Select(t => (t, t.GetInterfaces()
.Where(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IHandle<>))
.Select(i => (GetMessageDescriptor(i.GetGenericArguments().First())?.FullName ?? "")).ToHashSet()))
Expand Down

0 comments on commit 7d4bf9b

Please sign in to comment.