Skip to content

Commit

Permalink
(#64): Migrate Requester to new infra
Browse files Browse the repository at this point in the history
That is, using the TopologyProvider for topology stuff and using the
channel factory for retrieving channel.
  • Loading branch information
par.dahlman committed Mar 3, 2016
1 parent 720dfbb commit 3b6ac11
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 95 deletions.
1 change: 1 addition & 0 deletions src/RawRabbit.vNext/IServiceCollectionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ public static IServiceCollection AddRawRabbit<TMessageContext>(this IServiceColl
p.GetService<IMessageContextProvider<TMessageContext>>(),
p.GetService<IErrorHandlingStrategy>(),
p.GetService<IBasicPropertiesProvider>(),
p.GetService<ITopologyProvider>(),
p.GetService<RawRabbitConfiguration>().RequestTimeout))
.AddTransient<IBusClient<TMessageContext>, BaseBusClient<TMessageContext>>();
custom?.Invoke(collection);
Expand Down
6 changes: 3 additions & 3 deletions src/RawRabbit/ErrorHandling/DefaultStrategy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,13 @@ private static Exception UnwrapInnerException(Exception exception)
return exception;
}

public Task OnResponseRecievedAsync<TResponse>(BasicDeliverEventArgs args, TaskCompletionSource<object> responseTcs)
public Task OnResponseRecievedAsync(BasicDeliverEventArgs args, TaskCompletionSource<object> responseTcs)
{
OnResponseRecieved<TResponse>(args, responseTcs);
OnResponseRecieved(args, responseTcs);
return Task.FromResult(true);
}

public void OnResponseRecieved<TResponse>(BasicDeliverEventArgs args, TaskCompletionSource<object> responseTcs)
public void OnResponseRecieved(BasicDeliverEventArgs args, TaskCompletionSource<object> responseTcs)
{
var containsException = args?.BasicProperties?.Headers?.ContainsKey(PropertyHeaders.ExceptionHeader) ?? false;

Expand Down
4 changes: 2 additions & 2 deletions src/RawRabbit/ErrorHandling/IErrorHandlingStrategy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace RawRabbit.ErrorHandling
public interface IErrorHandlingStrategy
{
Task OnRequestHandlerExceptionAsync(IRawConsumer rawConsumer, IConsumerConfiguration cfg, BasicDeliverEventArgs args, Exception exception);
Task OnResponseRecievedAsync<TResponse>(BasicDeliverEventArgs args, TaskCompletionSource<object> responseTcs);
void OnResponseRecieved<TResponse>(BasicDeliverEventArgs args, TaskCompletionSource<object> responseTcs);
Task OnResponseRecievedAsync(BasicDeliverEventArgs args, TaskCompletionSource<object> responseTcs);
void OnResponseRecieved(BasicDeliverEventArgs args, TaskCompletionSource<object> responseTcs);
}
}
175 changes: 85 additions & 90 deletions src/RawRabbit/Operations/Requester.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using RabbitMQ.Client;
using RawRabbit.Common;
using RawRabbit.Configuration.Request;
using RawRabbit.Configuration.Respond;
Expand All @@ -15,19 +17,23 @@

namespace RawRabbit.Operations
{
public class Requester<TMessageContext> : OperatorBase, IRequester where TMessageContext : IMessageContext
public class Requester<TMessageContext> : IRequester where TMessageContext : IMessageContext
{
private readonly IChannelFactory _channelFactory;
private readonly IConsumerFactory _consumerFactory;
private readonly IMessageSerializer _serializer;
private readonly IMessageContextProvider<TMessageContext> _contextProvider;
private readonly IErrorHandlingStrategy _errorStrategy;
private readonly IBasicPropertiesProvider _propertiesProvider;
private readonly ITopologyProvider _topologyProvider;
private readonly TimeSpan _requestTimeout;
private readonly ConcurrentDictionary<Type, IRawConsumer> _typeToConsumer;
private readonly ConcurrentDictionary<string, TaskCompletionSource<object>> _responseTcsDictionary;
private readonly ConcurrentDictionary<string, Timer> _requestTimerDictionary;
private Timer _disposeConsumerTimer;
private readonly ConcurrentDictionary<IModel, IRawConsumer> _channelToConsumer;
private readonly ConcurrentDictionary<IRawConsumer, List<string>> _consumerToQueue;
private readonly ILogger _logger = LogManager.GetLogger<Requester<TMessageContext>>();
private bool _channelActive;
private readonly object _topologyLock = new object();
private readonly object _consumerLock = new object();

public Requester(
IChannelFactory channelFactory,
Expand All @@ -36,115 +42,106 @@ public Requester(
IMessageContextProvider<TMessageContext> contextProvider,
IErrorHandlingStrategy errorStrategy,
IBasicPropertiesProvider propertiesProvider,
ITopologyProvider topologyProvider,
TimeSpan requestTimeout)
: base(channelFactory, serializer)
{
_channelFactory = channelFactory;
_consumerFactory = consumerFactory;
_serializer = serializer;
_contextProvider = contextProvider;
_errorStrategy = errorStrategy;
_propertiesProvider = propertiesProvider;
_topologyProvider = topologyProvider;
_requestTimeout = requestTimeout;
_typeToConsumer = new ConcurrentDictionary<Type, IRawConsumer>();
_responseTcsDictionary = new ConcurrentDictionary<string, TaskCompletionSource<object>>();
_requestTimerDictionary = new ConcurrentDictionary<string, Timer>();
_channelToConsumer = new ConcurrentDictionary<IModel, IRawConsumer>();
_consumerToQueue = new ConcurrentDictionary<IRawConsumer, List<string>>();
}

public Task<TResponse> RequestAsync<TRequest, TResponse>(TRequest message, Guid globalMessageId, RequestConfiguration cfg)
{
var props = _propertiesProvider.GetProperties<TResponse>(p =>
{
p.ReplyTo = cfg.ReplyQueue.QueueName;
p.CorrelationId = Guid.NewGuid().ToString();
p.Expiration = _requestTimeout.TotalMilliseconds.ToString();
p.Headers.Add(PropertyHeaders.Context, _contextProvider.GetMessageContext(globalMessageId));
});
var consumer = GetOrCreateConsumerForType<TResponse>(cfg);
var body = Serializer.Serialize(message);

Task.Run(() => CreateOrUpdateDisposeTimer());

var responseTcs = new TaskCompletionSource<object>();
_responseTcsDictionary.TryAdd(props.CorrelationId, responseTcs);
var queueTask = _topologyProvider.DeclareQueueAsync(cfg.Queue);
var exchangeTask = _topologyProvider.DeclareExchangeAsync(cfg.Exchange);
var consumerTask = GetOrCreateConsumerAsync(cfg);

_requestTimerDictionary.TryAdd(
props.CorrelationId,
new Timer(state =>
return Task
.WhenAll(consumerTask, queueTask, exchangeTask)
.ContinueWith(t =>
{
Timer timer;
if (!_requestTimerDictionary.TryRemove(props.CorrelationId, out timer))
var consumer = consumerTask.Result;
lock (consumer)
{
_logger.LogWarning($"Unable to find request timer for {props.CorrelationId}.");
if (!_consumerToQueue[consumer].Contains(cfg.Queue.QueueName))
{
consumer.Model.BasicConsume(cfg.Queue.QueueName, cfg.NoAck, consumer);
_consumerToQueue[consumer].Add(cfg.Queue.QueueName);
}
}
timer?.Dispose();
responseTcs.TrySetException(new TimeoutException($"The request '{props.CorrelationId}' timed out after {_requestTimeout.ToString("g")}."));
}, null, _requestTimeout, new TimeSpan(-1)));
var props = _propertiesProvider.GetProperties<TResponse>(p =>
{
p.ReplyTo = cfg.ReplyQueue.QueueName;
p.CorrelationId = Guid.NewGuid().ToString();
p.Expiration = _requestTimeout.TotalMilliseconds.ToString();
p.Headers.Add(PropertyHeaders.Context, _contextProvider.GetMessageContext(globalMessageId));
});
var body = _serializer.Serialize(message);
var responseTcs = new TaskCompletionSource<object>();
_responseTcsDictionary.TryAdd(props.CorrelationId, responseTcs);
_requestTimerDictionary.TryAdd(
props.CorrelationId,
new Timer(state =>
{
Timer timer;
if (!_requestTimerDictionary.TryRemove(props.CorrelationId, out timer))
{
_logger.LogWarning($"Unable to find request timer for {props.CorrelationId}.");
}
timer?.Dispose();
responseTcs.TrySetException(new TimeoutException($"The request '{props.CorrelationId}' timed out after {_requestTimeout.ToString("g")}."));
}, null, _requestTimeout, new TimeSpan(-1)));
consumer.Model.BasicPublish(
exchange: cfg.Exchange.ExchangeName,
routingKey: cfg.RoutingKey,
basicProperties: props,
body: body
);
return responseTcs.Task.ContinueWith(tResponse => (TResponse) tResponse.Result);
consumer.Model.BasicPublish(
exchange: cfg.Exchange.ExchangeName,
routingKey: cfg.RoutingKey,
basicProperties: props,
body: body
);
return responseTcs.Task.ContinueWith(tResponse => (TResponse) tResponse.Result);
})
.Unwrap();
}

private void CreateOrUpdateDisposeTimer()
private Task<IRawConsumer> GetOrCreateConsumerAsync(IConsumerConfiguration cfg)
{
if (_disposeConsumerTimer != null)
{
_channelActive = true;
return;
}
_disposeConsumerTimer = new Timer(state =>
{
if (_channelActive)
{
_channelActive = false;
return;
}
if (!_responseTcsDictionary.IsEmpty)
{
return;
}
_disposeConsumerTimer?.Dispose();
_disposeConsumerTimer = null;
foreach (var type in _typeToConsumer.Keys)
return _channelFactory
.GetChannelAsync()
.ContinueWith(tChannel =>
{
IRawConsumer consumer;
if (_typeToConsumer.TryRemove(type, out consumer))
IRawConsumer existingConsumer;
if (_channelToConsumer.TryGetValue(tChannel.Result, out existingConsumer))
{
consumer?.Disconnect();
consumer?.Model?.Dispose();
return existingConsumer;
}
}
}, null, _requestTimeout, _requestTimeout);
lock (_consumerLock)
{
if (_channelToConsumer.TryGetValue(tChannel.Result, out existingConsumer))
{
return existingConsumer;
}
var newConsumer = _consumerFactory.CreateConsumer(cfg, tChannel.Result);
WireUpConsumer(newConsumer);
_channelToConsumer.TryAdd(tChannel.Result, newConsumer);
_consumerToQueue.TryAdd(newConsumer, new List<string>());
return newConsumer;
}
});
}

private IRawConsumer GetOrCreateConsumerForType<TResponse>(IConsumerConfiguration cfg)
private void WireUpConsumer(IRawConsumer consumer)
{
var responseType = typeof(TResponse);
IRawConsumer existingConsumer;
if (_typeToConsumer.TryGetValue(responseType, out existingConsumer))
{
_logger.LogDebug($"Channel for existing cunsomer of {responseType.Name} found.");
if (existingConsumer.Model.IsOpen)
{
_logger.LogDebug($"Channel is open and will be reused.");
return existingConsumer;
}
else
{
existingConsumer?.Model?.Dispose();
_logger.LogInformation($"Channel for consumer of {responseType.Name} is closed. A new consumer will be created, of course.");
}
}

_logger.LogInformation($"Creatinga new consumer for message {responseType.Name}.");
var consumer = _consumerFactory.CreateConsumer(cfg, ChannelFactory.CreateChannel());
_typeToConsumer.TryAdd(typeof(TResponse), consumer);

DeclareQueue(cfg.Queue, consumer.Model);
DeclareExchange(cfg.Exchange, consumer.Model);
consumer.OnMessageAsync = (o, args) =>
{
TaskCompletionSource<object> responseTcs;
Expand All @@ -161,20 +158,18 @@ private IRawConsumer GetOrCreateConsumerForType<TResponse>(IConsumerConfiguratio
{
_logger.LogInformation($"Unable to find request timer for message {args.BasicProperties.CorrelationId}.");
}
_errorStrategy.OnResponseRecievedAsync<TResponse>(args, responseTcs);
_errorStrategy.OnResponseRecievedAsync(args, responseTcs);
if (responseTcs?.Task?.IsFaulted ?? true)
{
return Task.FromResult(true);
}
var response = Serializer.Deserialize(args);
var response = _serializer.Deserialize(args);
responseTcs.TrySetResult(response);
return Task.FromResult(true);
}
_logger.LogWarning($"Unable to find callback for {args.BasicProperties.CorrelationId}.");
throw new Exception($"Can not find callback for {args.BasicProperties.CorrelationId}");
};
consumer.Model.BasicConsume(cfg.Queue.QueueName, cfg.NoAck, consumer);
return consumer;
}
}
}

0 comments on commit 3b6ac11

Please sign in to comment.