diff --git a/CoreRemoting.Tests/RpcTests.cs b/CoreRemoting.Tests/RpcTests.cs index fcbdb71..6691635 100644 --- a/CoreRemoting.Tests/RpcTests.cs +++ b/CoreRemoting.Tests/RpcTests.cs @@ -768,9 +768,72 @@ void AfterCall(object sender, ServerRpcContext e) => } [Fact] - public void Authentication_is_taken_into_account() + public void BeginCall_event_handler_can_intercept_and_cancel_method_calls() { + var counter = 0; + + void InterceptMethodCalls(object sender, ServerRpcContext e) + { + Interlocked.Increment(ref counter); + + // swap Echo and Reverse methods + e.MethodCallMessage.MethodName = e.MethodCallMessage.MethodName switch + { + "Echo" => "Reverse", + "Reverse" => "Echo", + var others => others + }; + + // disable IHobbitService + if (e.MethodCallMessage.ServiceName.Contains("IHobbitService")) + { + e.Cancel = true; + } + } + + _serverFixture.Server.BeginCall += InterceptMethodCalls; + try + { + using var client = new RemotingClient(new ClientConfig() + { + ConnectionTimeout = 0, + InvocationTimeout = 0, + SendTimeout = 0, + Channel = ClientChannel, + MessageEncryption = false, + ServerPort = _serverFixture.Server.Config.NetworkPort, + }); + + client.Connect(); + + // try swapped methods + var proxy = client.CreateProxy(); + Assert.Equal("321", proxy.Echo("123")); + Assert.Equal("Hello", proxy.Reverse("Hello")); + + // try disabled service + var hobbit = client.CreateProxy(); + Assert.Throws(() => + hobbit.QueryHobbits(h => h.LastName != "")); + + // check interception counter + Assert.Equal(3, counter); + } + finally + { + _serverFixture.Server.BeginCall -= InterceptMethodCalls; + } + } + + [Fact] + public void Authentication_is_taken_into_account_and_RejectCall_event_is_fired() + { + var rejectedMethod = string.Empty; + void RejectCall(object sender, ServerRpcContext e) => + rejectedMethod = e.MethodCallMessage.MethodName; + _serverFixture.Server.Config.AuthenticationRequired = true; + _serverFixture.Server.RejectCall += RejectCall; try { using var client = new RemotingClient(new ClientConfig() @@ -790,10 +853,14 @@ public void Authentication_is_taken_into_account() // Session is not authenticated Assert.Contains("authenticated", ex.Message); + + // Method call was rejected + Assert.Equal("Hello", rejectedMethod); } finally { _serverFixture.Server.Config.AuthenticationRequired = false; + _serverFixture.Server.RejectCall -= RejectCall; } } } diff --git a/CoreRemoting.Tests/Tools/ITestService.cs b/CoreRemoting.Tests/Tools/ITestService.cs index 7685f04..37f4097 100644 --- a/CoreRemoting.Tests/Tools/ITestService.cs +++ b/CoreRemoting.Tests/Tools/ITestService.cs @@ -29,6 +29,8 @@ public interface ITestService : IBaseService string Echo(string text); + string Reverse(string text); + void MethodWithOutParameter(out int counter); void Error(string text); diff --git a/CoreRemoting.Tests/Tools/TestService.cs b/CoreRemoting.Tests/Tools/TestService.cs index 6bdf221..b5e9a5d 100644 --- a/CoreRemoting.Tests/Tools/TestService.cs +++ b/CoreRemoting.Tests/Tools/TestService.cs @@ -1,5 +1,6 @@ using System; using System.Data; +using System.Linq; using System.Threading.Tasks; using CoreRemoting.Tests.ExternalTypes; @@ -54,6 +55,11 @@ public string Echo(string text) return text; } + public string Reverse(string text) + { + return new string(text.Reverse().ToArray()); + } + public void MethodWithOutParameter(out int counter) { _counter++; diff --git a/CoreRemoting/IRemotingServer.cs b/CoreRemoting/IRemotingServer.cs index d9a54bc..7a6d2d8 100644 --- a/CoreRemoting/IRemotingServer.cs +++ b/CoreRemoting/IRemotingServer.cs @@ -12,12 +12,17 @@ namespace CoreRemoting public interface IRemotingServer : IDisposable { /// - /// Event: Fires before an RPC call is invoked. + /// Event: Fires when an RPC call is prepared and can be canceled. + /// + event EventHandler BeginCall; + + /// + /// Event: Fires just before an RPC call is invoked. /// event EventHandler BeforeCall; - + /// - /// Event: Fires after an RPC call is invoked. + /// Event: Fires after an RPC call is invoked, both on success or failure. /// event EventHandler AfterCall; @@ -25,47 +30,52 @@ public interface IRemotingServer : IDisposable /// Event: Fires if an error occurs. /// event EventHandler Error; - + + /// + /// Event: Fires when an RPC call is rejected before BeginCall event. . + /// + event EventHandler RejectCall; + /// /// Gets the unique name of this server instance. /// string UniqueServerInstanceName { get; } - + /// /// Gets the dependency injection container that is used a service registry. /// IDependencyInjectionContainer ServiceRegistry { get; } - + /// /// Gets the configured serializer. /// ISerializerAdapter Serializer { get; } - + /// /// Gets the component for easy building of method call messages. /// MethodCallMessageBuilder MethodCallMessageBuilder { get; } - + /// /// Gets the component for encryption and decryption of messages. /// IMessageEncryptionManager MessageEncryptionManager { get; } - + /// /// Gets the session repository to perform session management tasks. /// ISessionRepository SessionRepository { get; } - + /// /// Gets the configuration settings. /// ServerConfig Config { get; } - + /// /// Starts listening for client requests. /// void Start(); - + /// /// Stops listening for client requests and close all open client connections. /// diff --git a/CoreRemoting/RemotingServer.cs b/CoreRemoting/RemotingServer.cs index 10aa923..cf45688 100644 --- a/CoreRemoting/RemotingServer.cs +++ b/CoreRemoting/RemotingServer.cs @@ -11,6 +11,7 @@ using CoreRemoting.Serialization; using CoreRemoting.Serialization.Bson; using ServiceLifetime = CoreRemoting.DependencyInjection.ServiceLifetime; +using System.Runtime.ExceptionServices; namespace CoreRemoting { @@ -91,7 +92,17 @@ public RemotingServer(ServerConfig config = null) /// Event: Fires if an error occurs. /// public event EventHandler Error; - + + /// + /// Event: Fires when an RPC call is rejected before BeforeCall event. + /// + public event EventHandler RejectCall; + + /// + /// Event: Fires when an RPC call is prepared and can be canceled. + /// + public event EventHandler BeginCall; + /// /// Gets the dependency injection container that is used a service registry. /// @@ -134,7 +145,7 @@ public RemotingServer(ServerConfig config = null) public IServerChannel Channel { get; private set; } /// - /// Fires the OnBeforeCall event. + /// Fires the event. /// /// Server side RPC call context internal void OnBeforeCall(ServerRpcContext serverRpcContext) @@ -143,7 +154,7 @@ internal void OnBeforeCall(ServerRpcContext serverRpcContext) } /// - /// Fires the OnAfterCall event. + /// Fires the event. /// /// Server side RPC call context internal void OnAfterCall(ServerRpcContext serverRpcContext) @@ -152,14 +163,43 @@ internal void OnAfterCall(ServerRpcContext serverRpcContext) } /// - /// Fires the OnError event. + /// Fires the event. /// /// Exception that describes the occurred error internal void OnError(Exception ex) { Error?.Invoke(this, ex); } - + + /// + /// Fires the event. + /// + /// Server side RPC call context + internal void OnRejectCall(ServerRpcContext serverRpcContext) + { + RejectCall?.Invoke(this, serverRpcContext); + } + + /// + /// Fires the event. + /// + /// Server side RPC call context + internal void OnBeginCall(ServerRpcContext serverRpcContext) + { + BeginCall?.Invoke(this, serverRpcContext); + + if (serverRpcContext.Cancel) + { + var cancelEx = serverRpcContext.Exception ?? + new RemoteInvocationException($"Invocation canceled: { + serverRpcContext.MethodCallMessage.ServiceName}.{ + serverRpcContext.MethodCallMessage.MethodName}"); + + // rethrow the exception keeping the original stack trace + ExceptionDispatchInfo.Capture(cancelEx).Throw(); + } + } + /// /// Starts listening for client requests. /// diff --git a/CoreRemoting/RemotingSession.cs b/CoreRemoting/RemotingSession.cs index f85ffa7..c9082de 100644 --- a/CoreRemoting/RemotingSession.cs +++ b/CoreRemoting/RemotingSession.cs @@ -100,7 +100,7 @@ internal RemotingSession(int keySize, byte[] clientPublicKey, IRemotingServer se rawData: rawContent) }; - var rawData = _server.Serializer.Serialize(typeof(SignedMessageData), signedMessageData); + var rawData = _server.Serializer.Serialize(typeof(SignedMessageData), signedMessageData); completeHandshakeMessage = new WireMessage @@ -175,7 +175,7 @@ internal RemotingSession(int keySize, byte[] clientPublicKey, IRemotingServer se /// Optional exception from the transport infrastructure private void OnErrorOccured(string errorMessage, Exception ex) { - var exception = new RemotingException(errorMessage, innerEx: ex); + var exception = new RemotingException(errorMessage, innerEx: ex); ((RemotingServer)_server).OnError(exception); } @@ -249,19 +249,19 @@ private void OnReceiveMessage(byte[] rawMessage) Task.Run(() => { _lastActivityTimestamp = DateTime.Now; - + if (rawMessage == null) return; - + if (rawMessage.Length == 0) return; - + _currentlyProcessedMessagesCounter.AddCount(1); - + try { var message = _server.Serializer.Deserialize(rawMessage); - + switch (message.MessageType.ToLower()) { case "auth": @@ -290,7 +290,7 @@ private void OnReceiveMessage(byte[] rawMessage) } /// - /// Processes a wire message that contains a goodbye message, which is sent from a client to close the session. + /// Processes a wire message that contains a goodbye message, which is sent from a client to close the session. /// /// Wire message from client private void ProcessGoodbyeMessage(WireMessage request) @@ -309,26 +309,26 @@ private void ProcessGoodbyeMessage(WireMessage request) sharedSecret: sharedSecret, sendersPublicKeyBlob: _clientPublicKeyBlob, sendersPublicKeySize: _keyPair?.KeySize ?? 0)); - + if (goodbyeMessage.SessionId != _sessionId) return; var resultMessage = _server.MessageEncryptionManager.CreateWireMessage( messageType: request.MessageType, - serializedMessage: new byte[0], + serializedMessage: [], serializer: _server.Serializer, keyPair: _keyPair, sharedSecret: sharedSecret, uniqueCallKey: request.UniqueCallKey); - + _rawMessageTransport.SendMessage(_server.Serializer.Serialize(resultMessage)); - + Task.Run(() => _server.SessionRepository.RemoveSession(_sessionId)); } /// - /// Processes a wire message that contains a authentication request message, which is sent from a client to request authentication of a set of credentials. + /// Processes a wire message that contains a authentication request message, which is sent from a client to request authentication of a set of credentials. /// /// Wire message from client private void ProcessAuthenticationRequestMessage(WireMessage request) @@ -413,14 +413,13 @@ private void ProcessRpcMessage(WireMessage request) : new Guid(request.UniqueCallKey), ServiceInstance = null, MethodCallMessage = callMessage, + MethodCallParameterValues = [], + MethodCallParameterTypes = [], Session = this }; var serializedResult = Array.Empty(); var method = default(MethodInfo); - var parameterValues = Array.Empty(); - // ReSharper disable once RedundantAssignment - var parameterTypes = Array.Empty(); var oneWay = false; try @@ -430,19 +429,23 @@ private void ProcessRpcMessage(WireMessage request) if (_server.Config.AuthenticationRequired && !_isAuthenticated) throw new NetworkException("Session is not authenticated."); - var service = _server.ServiceRegistry.GetService(callMessage.ServiceName); - var serviceInterfaceType = - _server.ServiceRegistry.GetServiceInterfaceType(callMessage.ServiceName); - CallContext.RestoreFromSnapshot(callMessage.CallContextSnapshot); - serverRpcContext.ServiceInstance = service; - callMessage.UnwrapParametersFromDeserializedMethodCallMessage( - out parameterValues, - out parameterTypes); + out var parameterValues, + out var parameterTypes); parameterValues = MapArguments(parameterValues, parameterTypes); + serverRpcContext.MethodCallParameterValues = parameterValues; + serverRpcContext.MethodCallParameterTypes = parameterTypes; + + ((RemotingServer)_server).OnBeginCall(serverRpcContext); + + var service = _server.ServiceRegistry.GetService(callMessage.ServiceName); + var serviceInterfaceType = + _server.ServiceRegistry.GetServiceInterfaceType(callMessage.ServiceName); + + serverRpcContext.ServiceInstance = service; method = GetMethodInfo(callMessage, serviceInterfaceType, parameterTypes); if (method == null) @@ -459,6 +462,8 @@ private void ProcessRpcMessage(WireMessage request) message: ex.Message, innerEx: ex.ToSerializable()); + ((RemotingServer)_server).OnRejectCall(serverRpcContext); + if (oneWay) return; @@ -480,7 +485,8 @@ private void ProcessRpcMessage(WireMessage request) ((RemotingServer)_server).OnBeforeCall(serverRpcContext); - result = method.Invoke(serverRpcContext.ServiceInstance, parameterValues); + result = method.Invoke(serverRpcContext.ServiceInstance, + serverRpcContext.MethodCallParameterValues); var returnType = method.ReturnType; @@ -551,7 +557,7 @@ private void ProcessRpcMessage(WireMessage request) serializer: _server.Serializer, uniqueCallKey: serverRpcContext.UniqueCallKey, method: method, - args: parameterValues, + args: serverRpcContext.MethodCallParameterValues, returnValue: result); } @@ -648,7 +654,7 @@ private MethodInfo GetMethodInfo(MethodCallMessage callMessage, Type serviceInte private object[] MapArguments(object[] arguments, Type[] argumentTypes) { object[] mappedArguments = new object[arguments.Length]; - + for (int i = 0; i < arguments.Length; i++) { var argument = arguments[i]; @@ -661,10 +667,10 @@ private object[] MapArguments(object[] arguments, Type[] argumentTypes) else mappedArguments[i] = argument; } - + return mappedArguments; } - + /// /// Maps a delegate argument into a delegate proxy. /// @@ -721,10 +727,10 @@ private bool MapLinqExpressionArgument(Type argumentType, object argument, out o mappedArgument = argument; return false; } - + var expression = ((ExpressionNode)argument).ToExpression(); mappedArgument = expression; - + return true; } @@ -744,7 +750,7 @@ public void Close() } #endregion - + #region IDisposable implementation /// @@ -756,18 +762,18 @@ public void Dispose() return; _isDisposing = true; - + _rawMessageTransport.ReceiveMessage -= OnReceiveMessage; _rawMessageTransport.ErrorOccured -= OnErrorOccured; - + _currentlyProcessedMessagesCounter.Signal(); _currentlyProcessedMessagesCounter.Wait(_server.Config.WaitTimeForCurrentlyProcessedMessagesOnDispose); - + var sharedSecret = MessageEncryption ? _sessionId.ToByteArray() : null; - + var wireMessage = _server.MessageEncryptionManager.CreateWireMessage( serializedMessage: Array.Empty(), @@ -788,7 +794,7 @@ public void Dispose() } BeforeDispose?.Invoke(); - + _keyPair?.Dispose(); _delegateProxyFactory = null; _delegateProxyCache.Clear(); @@ -798,7 +804,7 @@ public void Dispose() } #endregion - + #region Retrieving current session /// diff --git a/CoreRemoting/ServerRpcContext.cs b/CoreRemoting/ServerRpcContext.cs index 7d5307a..04b1d4e 100644 --- a/CoreRemoting/ServerRpcContext.cs +++ b/CoreRemoting/ServerRpcContext.cs @@ -13,28 +13,43 @@ public class ServerRpcContext /// Gets or sets the unique key of RPC call. /// public Guid UniqueCallKey { get; set; } - + /// /// Gets or sets the last exception that is occurred. /// public Exception Exception { get; set; } - + + /// + /// Gets or sets a value whether the call is canceled by event handler. + /// + public bool Cancel { get; set; } + /// /// Gets the message that describes the remote method call. /// public MethodCallMessage MethodCallMessage { get; internal set; } - + + /// + /// Gets or sets the unwrapped method call parameter values. + /// + public object[] MethodCallParameterValues { get; set; } + + /// + /// Gets or sets the unwrapped method call parameter types. + /// + public Type[] MethodCallParameterTypes { get; set; } + /// /// Gets or sets the message that contains the results of a remote method call. /// public MethodCallResultMessage MethodCallResultMessage { get; set; } - + /// /// Gets or sets the instance of the service, on which the method is called. /// - [SuppressMessage("ReSharper", "UnusedAutoPropertyAccessor.Global")] + [SuppressMessage("ReSharper", "UnusedAutoPropertyAccessor.Global")] public object ServiceInstance { get; set; } - + /// /// Gets or sets the CoreRemoting session that is used to handle the RPC. ///