diff --git a/CoreRemoting.Tests/CoreRemoting.Tests.csproj b/CoreRemoting.Tests/CoreRemoting.Tests.csproj index 441d771..39ae35d 100644 --- a/CoreRemoting.Tests/CoreRemoting.Tests.csproj +++ b/CoreRemoting.Tests/CoreRemoting.Tests.csproj @@ -16,6 +16,7 @@ + diff --git a/CoreRemoting.Tests/DisposableTests.cs b/CoreRemoting.Tests/DisposableTests.cs new file mode 100644 index 0000000..df339ba --- /dev/null +++ b/CoreRemoting.Tests/DisposableTests.cs @@ -0,0 +1,57 @@ +using CoreRemoting.Toolbox; +using System.Threading.Tasks; +using Xunit; + +namespace CoreRemoting.Tests +{ + public class DisposableTests + { + [Fact] + public void Disposable_executes_action_on_Dispose() + { + var disposed = false; + + void Dispose() => + disposed = true; + + using (Disposable.Create(Dispose)) + Assert.False(disposed); + + Assert.True(disposed); + } + + [Fact] + public async Task AsyncDisposable_executes_action_on_DisposeAsync() + { + var disposed = false; + + async Task DisposeTask() + { + await Task.Yield(); + disposed = true; + } + + await using (Disposable.Create(DisposeTask)) + Assert.False(disposed); + + Assert.True(disposed); + } + + [Fact] + public async Task AsyncTaskDisposable_executes_action_on_DisposeAsync() + { + var disposed = false; + + async ValueTask DisposeAsync() + { + await Task.Yield(); + disposed = true; + } + + await using (Disposable.Create(DisposeAsync)) + Assert.False(disposed); + + Assert.True(disposed); + } + } +} \ No newline at end of file diff --git a/CoreRemoting.Tests/RpcTests.cs b/CoreRemoting.Tests/RpcTests.cs index 6691635..ddd6f5a 100644 --- a/CoreRemoting.Tests/RpcTests.cs +++ b/CoreRemoting.Tests/RpcTests.cs @@ -3,6 +3,7 @@ using System.Diagnostics; using System.Threading; using System.Threading.Tasks; +using CoreRemoting.Authentication; using CoreRemoting.Channels; using CoreRemoting.Serialization; using CoreRemoting.Tests.ExternalTypes; @@ -832,8 +833,10 @@ public void Authentication_is_taken_into_account_and_RejectCall_event_is_fired() void RejectCall(object sender, ServerRpcContext e) => rejectedMethod = e.MethodCallMessage.MethodName; - _serverFixture.Server.Config.AuthenticationRequired = true; - _serverFixture.Server.RejectCall += RejectCall; + var server = _serverFixture.Server; + server.RejectCall += RejectCall; + server.Config.AuthenticationRequired = true; + try { using var client = new RemotingClient(new ClientConfig() @@ -859,8 +862,44 @@ void RejectCall(object sender, ServerRpcContext e) => } finally { - _serverFixture.Server.Config.AuthenticationRequired = false; - _serverFixture.Server.RejectCall -= RejectCall; + server.Config.AuthenticationRequired = false; + server.RejectCall -= RejectCall; + } + } + + [Fact] + public void Authentication_handler_has_access_to_the_current_session() + { + var server = _serverFixture.Server; + var authProvider = server.Config.AuthenticationProvider; + server.Config.AuthenticationRequired = true; + server.Config.AuthenticationProvider = new FakeAuthProvider + { + AuthenticateFake = c => RemotingSession.Current != null + }; + + try + { + using var client = new RemotingClient(new ClientConfig() + { + ConnectionTimeout = 0, + InvocationTimeout = 0, + SendTimeout = 0, + Channel = ClientChannel, + MessageEncryption = false, + ServerPort = _serverFixture.Server.Config.NetworkPort, + Credentials = [new Credential()], + }); + + client.Connect(); + + var proxy = client.CreateProxy(); + Assert.Equal("123", proxy.Reverse("321")); + } + finally + { + server.Config.AuthenticationProvider = authProvider; + server.Config.AuthenticationRequired = false; } } } diff --git a/CoreRemoting/IRemotingServer.cs b/CoreRemoting/IRemotingServer.cs index 7a6d2d8..074ea9d 100644 --- a/CoreRemoting/IRemotingServer.cs +++ b/CoreRemoting/IRemotingServer.cs @@ -27,14 +27,14 @@ public interface IRemotingServer : IDisposable event EventHandler AfterCall; /// - /// Event: Fires if an error occurs. + /// Event: Fires when an RPC call is rejected before BeginCall event. . /// - event EventHandler Error; + event EventHandler RejectCall; /// - /// Event: Fires when an RPC call is rejected before BeginCall event. . + /// Event: Fires if an error occurs. /// - event EventHandler RejectCall; + event EventHandler Error; /// /// Gets the unique name of this server instance. diff --git a/CoreRemoting/RemotingSession.cs b/CoreRemoting/RemotingSession.cs index c9082de..63477b3 100644 --- a/CoreRemoting/RemotingSession.cs +++ b/CoreRemoting/RemotingSession.cs @@ -13,6 +13,7 @@ using CoreRemoting.Encryption; using CoreRemoting.Serialization; using Serialize.Linq.Nodes; +using CoreRemoting.Toolbox; namespace CoreRemoting { @@ -258,6 +259,8 @@ private void OnReceiveMessage(byte[] rawMessage) _currentlyProcessedMessagesCounter.AddCount(1); + CurrentSession.Value = this; + try { var message = _server.Serializer.Deserialize(rawMessage); @@ -285,6 +288,8 @@ private void OnReceiveMessage(byte[] rawMessage) finally { _currentlyProcessedMessagesCounter.Signal(); + + CurrentSession.Value = null; } }); } @@ -424,8 +429,6 @@ private void ProcessRpcMessage(WireMessage request) try { - CurrentSession.Value = this; - if (_server.Config.AuthenticationRequired && !_isAuthenticated) throw new NetworkException("Session is not authenticated."); @@ -470,10 +473,6 @@ private void ProcessRpcMessage(WireMessage request) serializedResult = _server.Serializer.Serialize(serverRpcContext.Exception); } - finally - { - CurrentSession.Value = null; - } object result = null; @@ -481,8 +480,6 @@ private void ProcessRpcMessage(WireMessage request) { try { - CurrentSession.Value = this; - ((RemotingServer)_server).OnBeforeCall(serverRpcContext); result = method.Invoke(serverRpcContext.ServiceInstance, @@ -543,10 +540,6 @@ private void ProcessRpcMessage(WireMessage request) serializedResult = _server.Serializer.Serialize(serverRpcContext.Exception); } - finally - { - CurrentSession.Value = null; - } if (!oneWay) { @@ -587,8 +580,6 @@ private void ProcessRpcMessage(WireMessage request) _rawMessageTransport.SendMessage( _server.Serializer.Serialize(methodResultMessage)); - - CurrentSession.Value = null; } private MethodInfo GetMethodInfo(MethodCallMessage callMessage, Type serviceInterfaceType, Type[] parameterTypes) diff --git a/CoreRemoting/Toolbox/Disposable.cs b/CoreRemoting/Toolbox/Disposable.cs new file mode 100644 index 0000000..44933f0 --- /dev/null +++ b/CoreRemoting/Toolbox/Disposable.cs @@ -0,0 +1,52 @@ +using System; +using System.Threading.Tasks; + +namespace CoreRemoting.Toolbox +{ + /// + /// Helper class to create disposable primitives. + /// + public static class Disposable + { + private class SyncDisposable(Action disposeAction) : IDisposable + { + void IDisposable.Dispose() => + disposeAction?.Invoke(); + } + + /// + /// Creates a disposable object. + /// + /// An action to invoke on disposal. + public static IDisposable Create(Action disposeAction) => + new SyncDisposable(disposeAction); + + private class AsyncDisposable( + Func disposeAsync, + Func disposeTaskAsync) : IAsyncDisposable + { + async ValueTask IAsyncDisposable.DisposeAsync() + { + if (disposeAsync != null) + await disposeAsync(); + + if (disposeTaskAsync != null) + await disposeTaskAsync(); + } + } + + /// + /// Creates an asynchronous disposable object. + /// + /// An action to invoke on disposal. + public static IAsyncDisposable Create(Func disposeAsync) => + new AsyncDisposable(disposeAsync, null); + + /// + /// Creates an asynchronous disposable object. + /// + /// An action to invoke on disposal. + public static IAsyncDisposable Create(Func disposeAsync) => + new AsyncDisposable(null, disposeAsync); + } +}