Skip to content

Commit

Permalink
EdgeHub: Allow multiplexing client connections over AMQP (#587)
Browse files Browse the repository at this point in the history
* Add AMQP Downstream Multiplexing support

* Amqp Mux changes

* Fix link handlers

* Cleanup

* Get product code to build

* Cleanup

* Fix tests

* Fix tests

* Format and cleanup

* Fix merge

* fix inheritance

* Update edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/CbsNode.cs

Co-Authored-By: varunpuranik <[email protected]>

* Update edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/CbsNode.cs

Co-Authored-By: varunpuranik <[email protected]>

* Remove commented members

* Add C2D subscription if not module identity

* Fix tests
  • Loading branch information
varunpuranik authored Dec 2, 2018
1 parent c33eaee commit 93be534
Show file tree
Hide file tree
Showing 36 changed files with 562 additions and 525 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ void OnConnectionOpening(object sender, OpenEventArgs e)
amqpConnection.Extensions.Add(cbsNode);
}

IConnectionHandler connectionHandler = new ConnectionHandler(new EdgeAmqpConnection(amqpConnection), this.connectionProvider);
IClientConnectionsHandler connectionHandler = new ClientConnectionsHandler(this.connectionProvider);
amqpConnection.Extensions.Add(connectionHandler);
}

Expand Down
21 changes: 2 additions & 19 deletions edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/CbsNode.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp
/// This class is used to get tokens from the Client on the CBS link. It generates
/// an identity from the received token and authenticates it.
/// </summary>
class CbsNode : ICbsNode
class CbsNode : ICbsNode, IAmqpAuthenticator
{
static readonly List<UriPathTemplate> ResourceTemplates = new List<UriPathTemplate>
{
Expand Down Expand Up @@ -71,22 +71,6 @@ public void RegisterLink(IAmqpLink link)
Events.LinkRegistered(link);
}

// TODO: Temporary implementation - just get the first credentials and return it.
public async Task<AmqpAuthentication> GetAmqpAuthentication()
{
if (!this.clientCredentialsMap.Any())
{
throw new InvalidOperationException("No valid credentials found");
}

KeyValuePair<string, CredentialsInfo> creds = this.clientCredentialsMap.First();
if (!creds.Value.IsAuthenticated)
{
creds.Value.IsAuthenticated = await this.authenticator.AuthenticateAsync(creds.Value.ClientCredentials);
}
return new AmqpAuthentication(creds.Value.IsAuthenticated, Option.Some(creds.Value.ClientCredentials));
}

public async Task<bool> AuthenticateAsync(string id)
{
try
Expand Down Expand Up @@ -119,7 +103,7 @@ public async Task<bool> AuthenticateAsync(string id)
Events.ErrorAuthenticatingIdentity(id, e);
return false;
}
}
}

async void OnMessageReceived(AmqpMessage message)
{
Expand Down Expand Up @@ -175,7 +159,6 @@ async Task HandleTokenUpdate(AmqpMessage message)
{
credentialsInfo.ClientCredentials = clientCredentials;
}

if (credentialsInfo.IsAuthenticated)
{
await this.credentialsCache.Add(clientCredentials);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp
{
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using System.Web;
using Microsoft.Azure.Devices.Edge.Hub.Amqp.LinkHandlers;
Expand All @@ -18,91 +19,42 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp
/// It maintains the IIdentity and the IDeviceListener for the connection, and provides it to the link handlers.
/// It also maintains a registry of the links open on that connection, and makes sure duplicate/invalid links are not opened.
/// </summary>
class ConnectionHandler : IConnectionHandler
class ClientConnectionHandler : IConnectionHandler
{
readonly IDictionary<LinkType, ILinkHandler> registry = new Dictionary<LinkType, ILinkHandler>();
bool isInitialized;
IDeviceListener deviceListener;
AmqpAuthentication amqpAuthentication;
readonly IIdentity identity;

readonly AsyncLock initializationLock = new AsyncLock();
readonly AsyncLock registryUpdateLock = new AsyncLock();
readonly IAmqpConnection connection;
readonly IConnectionProvider connectionProvider;
Option<IDeviceListener> deviceListener = Option.None<IDeviceListener>();

public ConnectionHandler(IAmqpConnection connection, IConnectionProvider connectionProvider)
public ClientConnectionHandler(IIdentity identity, IConnectionProvider connectionProvider)
{
this.connection = Preconditions.CheckNotNull(connection, nameof(connection));
this.identity = Preconditions.CheckNotNull(identity, nameof(identity));
this.connectionProvider = Preconditions.CheckNotNull(connectionProvider, nameof(connectionProvider));
}

public async Task<IDeviceListener> GetDeviceListener()
public Task<IDeviceListener> GetDeviceListener()
{
await this.EnsureInitialized();
return this.deviceListener;
}

public async Task<AmqpAuthentication> GetAmqpAuthentication()
{
await this.EnsureInitialized();
return this.amqpAuthentication;
}

async Task EnsureInitialized()
{
if (!this.isInitialized)
{
using (await this.initializationLock.LockAsync())
{
if (!this.isInitialized)
return this.deviceListener.Map(d => Task.FromResult(d))
.GetOrElse(
async () =>
{
AmqpAuthentication amqpAuth;
// Check if Principal is SaslPrincipal
if (this.connection.Principal is SaslPrincipal saslPrincipal)
{
amqpAuth = saslPrincipal.AmqpAuthentication;
}
else
using (await this.initializationLock.LockAsync())
{
// Else the connection uses CBS authentication. Get AmqpAuthentication from the CbsNode
var cbsNode = this.connection.FindExtension<ICbsNode>();
if (cbsNode == null)
{
throw new InvalidOperationException("CbsNode is null");
}

amqpAuth = await cbsNode.GetAmqpAuthentication();
return await this.deviceListener.Map(d => Task.FromResult(d))
.GetOrElse(
async () =>
{
IDeviceListener dl = await this.connectionProvider.GetDeviceListenerAsync(this.identity);
var deviceProxy = new DeviceProxy(this, this.identity);
dl.BindDeviceProxy(deviceProxy);
this.deviceListener = Option.Some(dl);
return dl;
});
}

if (!amqpAuth.IsAuthenticated)
{
throw new InvalidOperationException("Connection not authenticated");
}

IClientCredentials clientCredentials = amqpAuth.ClientCredentials.Expect(() => new InvalidOperationException("Authenticated connection should have a valid identity"));
this.deviceListener = await this.connectionProvider.GetDeviceListenerAsync(clientCredentials.Identity);
var deviceProxy = new DeviceProxy(this, clientCredentials.Identity);
this.deviceListener.BindDeviceProxy(deviceProxy);
this.amqpAuthentication = amqpAuth;
this.isInitialized = true;
Events.InitializedConnectionHandler(clientCredentials.Identity);
}
}
}
}

async Task<Option<IClientCredentials>> GetUpdatedAuthenticatedIdentity()
{
var cbsNode = this.connection.FindExtension<ICbsNode>();
if (cbsNode != null)
{
AmqpAuthentication updatedAmqpAuthentication = await cbsNode.GetAmqpAuthentication();
if (updatedAmqpAuthentication.IsAuthenticated)
{
return updatedAmqpAuthentication.ClientCredentials;
}
}
return Option.None<IClientCredentials>();
});
}

public async Task RegisterLinkHandler(ILinkHandler linkHandler)
Expand Down Expand Up @@ -170,23 +122,29 @@ public async Task RemoveLinkHandler(ILinkHandler linkHandler)
}
}

Task CloseAllLinks()
{
IList<ILinkHandler> links = this.registry.Values.ToList();
IEnumerable<Task> closeTasks = links.Select(l => l.CloseAsync(Constants.DefaultTimeout));
return Task.WhenAll(closeTasks);
}

async Task CloseConnection()
{
using (await this.initializationLock.LockAsync())
{
this.isInitialized = false;
await (this.deviceListener?.CloseAsync() ?? Task.CompletedTask);
await this.deviceListener.ForEachAsync(d => d.CloseAsync());
}
}

public class DeviceProxy : IDeviceProxy
{
readonly ConnectionHandler connectionHandler;
readonly ClientConnectionHandler clientConnectionHandler;
readonly AtomicBoolean isActive = new AtomicBoolean(true);

public DeviceProxy(ConnectionHandler connectionHandler, IIdentity identity)
public DeviceProxy(ClientConnectionHandler clientConnectionHandler, IIdentity identity)
{
this.connectionHandler = connectionHandler;
this.clientConnectionHandler = clientConnectionHandler;
this.Identity = identity;
}

Expand All @@ -195,14 +153,14 @@ public Task CloseAsync(Exception ex)
if (this.isActive.GetAndSet(false))
{
Events.ClosingProxy(this.Identity, ex);
return this.connectionHandler.connection.Close();
return this.clientConnectionHandler.CloseAllLinks();
}
return Task.CompletedTask;
}

public Task SendC2DMessageAsync(IMessage message)
{
if (!this.connectionHandler.registry.TryGetValue(LinkType.C2D, out ILinkHandler linkHandler))
if (!this.clientConnectionHandler.registry.TryGetValue(LinkType.C2D, out ILinkHandler linkHandler))
{
Events.LinkNotFound(LinkType.ModuleMessages, this.Identity, "C2D message");
return Task.CompletedTask;
Expand All @@ -216,7 +174,7 @@ public Task SendC2DMessageAsync(IMessage message)

public Task SendMessageAsync(IMessage message, string input)
{
if (!this.connectionHandler.registry.TryGetValue(LinkType.ModuleMessages, out ILinkHandler linkHandler))
if (!this.clientConnectionHandler.registry.TryGetValue(LinkType.ModuleMessages, out ILinkHandler linkHandler))
{
Events.LinkNotFound(LinkType.ModuleMessages, this.Identity, "message");
return Task.CompletedTask;
Expand All @@ -228,7 +186,7 @@ public Task SendMessageAsync(IMessage message, string input)

public async Task<DirectMethodResponse> InvokeMethodAsync(DirectMethodRequest request)
{
if (!this.connectionHandler.registry.TryGetValue(LinkType.MethodSending, out ILinkHandler linkHandler))
if (!this.clientConnectionHandler.registry.TryGetValue(LinkType.MethodSending, out ILinkHandler linkHandler))
{
Events.LinkNotFound(LinkType.ModuleMessages, this.Identity, "method request");
return default(DirectMethodResponse);
Expand All @@ -251,7 +209,7 @@ public async Task<DirectMethodResponse> InvokeMethodAsync(DirectMethodRequest re

public Task OnDesiredPropertyUpdates(IMessage desiredProperties)
{
if (!this.connectionHandler.registry.TryGetValue(LinkType.TwinSending, out ILinkHandler linkHandler))
if (!this.clientConnectionHandler.registry.TryGetValue(LinkType.TwinSending, out ILinkHandler linkHandler))
{
Events.LinkNotFound(LinkType.ModuleMessages, this.Identity, "desired properties update");
return Task.CompletedTask;
Expand All @@ -263,7 +221,7 @@ public Task OnDesiredPropertyUpdates(IMessage desiredProperties)

public Task SendTwinUpdate(IMessage twin)
{
if (!this.connectionHandler.registry.TryGetValue(LinkType.TwinSending, out ILinkHandler linkHandler))
if (!this.clientConnectionHandler.registry.TryGetValue(LinkType.TwinSending, out ILinkHandler linkHandler))
{
Events.LinkNotFound(LinkType.ModuleMessages, this.Identity, "twin update");
return Task.CompletedTask;
Expand All @@ -283,12 +241,12 @@ public void SetInactive()
this.isActive.Set(false);
}

public Task<Option<IClientCredentials>> GetUpdatedIdentity() => this.connectionHandler.GetUpdatedAuthenticatedIdentity();
public Task<Option<IClientCredentials>> GetUpdatedIdentity() => throw new NotImplementedException();
}

static class Events
{
static readonly ILogger Log = Logger.Factory.CreateLogger<ConnectionHandler>();
static readonly ILogger Log = Logger.Factory.CreateLogger<ClientConnectionHandler>();
const int IdStart = AmqpEventIds.ConnectionHandler;

enum EventIds
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright (c) Microsoft. All rights reserved.

namespace Microsoft.Azure.Devices.Edge.Hub.Amqp
{
using System.Collections.Concurrent;
using Microsoft.Azure.Devices.Edge.Hub.Core;
using Microsoft.Azure.Devices.Edge.Hub.Core.Identity;
using Microsoft.Azure.Devices.Edge.Util;

class ClientConnectionsHandler : IClientConnectionsHandler
{
readonly ConcurrentDictionary<string, ClientConnectionHandler> connectionHandlers = new ConcurrentDictionary<string, ClientConnectionHandler>();
readonly IConnectionProvider connectionProvider;

public ClientConnectionsHandler(IConnectionProvider connectionProvider)
{
this.connectionProvider = Preconditions.CheckNotNull(connectionProvider, nameof(connectionProvider));
}

public IConnectionHandler GetConnectionHandler(IIdentity identity) =>
this.connectionHandlers.GetOrAdd(identity.Id, i => new ClientConnectionHandler(identity, this.connectionProvider));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public async Task<IPrincipal> AuthenticateAsync(string identity, string password
throw new EdgeHubConnectionException("Authentication failed.");
}

return new SaslPrincipal(new AmqpAuthentication(true, Option.Some(deviceIdentity)));
return new SaslPrincipal(true, deviceIdentity);
}
catch (Exception ex) when (!ex.IsFatal())
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// Copyright (c) Microsoft. All rights reserved.

namespace Microsoft.Azure.Devices.Edge.Hub.Amqp
{
using System.Threading.Tasks;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@
namespace Microsoft.Azure.Devices.Edge.Hub.Amqp
{
using System;
using System.Threading.Tasks;

public interface ICbsNode : IAmqpAuthenticator, IDisposable
public interface ICbsNode : IDisposable
{
void RegisterLink(IAmqpLink link);

Task<AmqpAuthentication> GetAmqpAuthentication();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Copyright (c) Microsoft. All rights reserved.

namespace Microsoft.Azure.Devices.Edge.Hub.Amqp
{
using Microsoft.Azure.Devices.Edge.Hub.Core.Identity;

public interface IClientConnectionsHandler
{
IConnectionHandler GetConnectionHandler(IIdentity identity);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ public interface IConnectionHandler
{
Task<IDeviceListener> GetDeviceListener();

Task<AmqpAuthentication> GetAmqpAuthentication();

Task RegisterLinkHandler(ILinkHandler linkHandler);

Task RemoveLinkHandler(ILinkHandler linkHandler);
Expand Down
Loading

0 comments on commit 93be534

Please sign in to comment.