diff --git a/src/Bus.cs b/src/Bus.cs index 530267f..b74a421 100644 --- a/src/Bus.cs +++ b/src/Bus.cs @@ -59,11 +59,13 @@ public static Bus Starter throw new ArgumentNullException ("address"); Bus bus; - if (buses.TryGetValue (address, out bus)) - return bus; + lock (buses) { + if (buses.TryGetValue (address, out bus)) + return bus; - bus = new Bus (address); - buses[address] = bus; + bus = new Bus (address); + buses[address] = bus; + } return bus; } diff --git a/src/BusException.cs b/src/BusException.cs index b8926b9..9dfeaec 100644 --- a/src/BusException.cs +++ b/src/BusException.cs @@ -25,7 +25,7 @@ public override string Message { get { - return ErrorName + ": " + ErrorMessage; + return ErrorName + ": " + ErrorMessage + Environment.NewLine + this.StackTrace; } } diff --git a/src/Connection.cs b/src/Connection.cs index 0284e6c..26da834 100644 --- a/src/Connection.cs +++ b/src/Connection.cs @@ -3,9 +3,11 @@ // See COPYING for details using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.IO; using System.Threading; +using System.Threading.Tasks; using System.Reflection; namespace DBus @@ -37,23 +39,24 @@ public class Connection Dictionary pendingCalls = new Dictionary (); Queue inbound = new Queue (); - Dictionary registeredObjects = new Dictionary (); + ConcurrentDictionary registeredObjects = new ConcurrentDictionary (); + private readonly ReadMessageTask readMessageTask; public delegate void MonitorEventHandler (Message msg); public MonitorEventHandler Monitors; // subscribe yourself to this list of observers if you want to get notified about each incoming message protected Connection () { - + readMessageTask = new ReadMessageTask (this); } - internal Connection (Transport transport) + internal Connection (Transport transport) : this () { this.transport = transport; transport.Connection = this; } - internal Connection (string address) + internal Connection (string address) : this () { OpenPrivate (address); Authenticate (); @@ -183,11 +186,12 @@ internal uint GenerateSerial () internal Message SendWithReplyAndBlock (Message msg, bool keepFDs) { - PendingCall pending = SendWithReply (msg, keepFDs); - return pending.Reply; + using (PendingCall pending = SendWithPendingReply (msg, keepFDs)) { + return pending.Reply; + } } - internal PendingCall SendWithReply (Message msg, bool keepFDs) + internal PendingCall SendWithPendingReply (Message msg, bool keepFDs) { msg.ReplyExpected = true; @@ -215,27 +219,23 @@ internal virtual uint Send (Message msg) return msg.Header.Serial; } - //temporary hack - internal void DispatchSignals () + public void Iterate () { - lock (inbound) { - while (inbound.Count != 0) { - Message msg = inbound.Dequeue (); - try { - HandleSignal (msg); - } finally { - msg.Dispose (); - } - } - } + Iterate (new CancellationToken (false)); } - public void Iterate () + public void Iterate (CancellationToken stopWaitToken) { - Message msg = transport.ReadMessage (); - - HandleMessage (msg); - DispatchSignals (); + if (TryGetStoredSignalMessage (out Message inboundMsg)) { + try { + HandleSignal (inboundMsg); + } finally { + inboundMsg.Dispose (); + } + } else { + var msg = readMessageTask.MakeSureTaskRunAndWait (stopWaitToken); + HandleMessage (msg); + } } internal virtual void HandleMessage (Message msg) @@ -251,21 +251,19 @@ internal virtual void HandleMessage (Message msg) try { //TODO: Restrict messages to Local ObjectPath? - { - object field_value = msg.Header[FieldCode.ReplySerial]; + object field_value = msg.Header [FieldCode.ReplySerial]; if (field_value != null) { uint reply_serial = (uint)field_value; - PendingCall pending; lock (pendingCalls) { + PendingCall pending; if (pendingCalls.TryGetValue (reply_serial, out pending)) { - if (pendingCalls.Remove (reply_serial)) { - pending.Reply = msg; - if (pending.KeepFDs) - cleanupFDs = false; // caller is responsible for closing FDs - } - + if (!pendingCalls.Remove (reply_serial)) + return; + pending.Reply = msg; + if (pending.KeepFDs) + cleanupFDs = false; // caller is responsible for closing FDs return; } } @@ -285,8 +283,7 @@ internal virtual void HandleMessage (Message msg) break; case MessageType.Signal: //HandleSignal (msg); - lock (inbound) - inbound.Enqueue (msg); + StoreInboundSignalMessage (msg); //temporary hack cleanupFDs = false; // FDs are closed after signal is handled break; case MessageType.Error: @@ -391,14 +388,15 @@ internal void HandleMethodCall (MessageContainer method_call) //this is messy and inefficient List linkNodes = new List (); int depth = method_call.Path.Decomposed.Length; - foreach (ObjectPath pth in registeredObjects.Keys) { - if (pth.Value == (method_call.Path.Value)) { - ExportObject exo = (ExportObject)registeredObjects[pth]; + foreach(var objKeyValuePair in registeredObjects) { + ObjectPath pth = objKeyValuePair.Key; + ExportObject exo = (ExportObject) objKeyValuePair.Value; + if (pth.Value == method_call.Path.Value) { exo.WriteIntrospect (intro); } else { - for (ObjectPath cur = pth ; cur != null ; cur = cur.Parent) { + for (ObjectPath cur = pth; cur != null; cur = cur.Parent) { if (cur.Value == method_call.Path.Value) { - string linkNode = pth.Decomposed[depth]; + string linkNode = pth.Decomposed [depth]; if (!linkNodes.Contains (linkNode)) { intro.WriteNode (linkNode); linkNodes.Add (linkNode); @@ -415,9 +413,8 @@ internal void HandleMethodCall (MessageContainer method_call) return; } - BusObject bo; - if (registeredObjects.TryGetValue (method_call.Path, out bo)) { - ExportObject eo = (ExportObject)bo; + if (registeredObjects.TryGetValue(method_call.Path, out BusObject bo)) { + ExportObject eo = (ExportObject) bo; eo.HandleMethodCall (method_call); } else { MaybeSendUnknownMethodError (method_call); @@ -464,14 +461,10 @@ public void Register (ObjectPath path, object obj) public object Unregister (ObjectPath path) { - BusObject bo; - - if (!registeredObjects.TryGetValue (path, out bo)) + if (!registeredObjects.TryRemove (path, out BusObject bo)) throw new Exception ("Cannot unregister " + path + " as it isn't registered"); - registeredObjects.Remove (path); - - ExportObject eo = (ExportObject)bo; + ExportObject eo = (ExportObject) bo; eo.Registered = false; return eo.Object; @@ -486,6 +479,25 @@ internal protected virtual void RemoveMatch (string rule) { } + private void StoreInboundSignalMessage (Message msg) + { + lock (inbound) { + inbound.Enqueue (msg); + } + } + + private bool TryGetStoredSignalMessage (out Message msg) + { + msg = null; + lock (inbound) { + if (inbound.Count != 0) { + msg = inbound.Dequeue (); + return true; + } + } + return false; + } + static UUID ReadMachineId (string fname) { byte[] data = File.ReadAllBytes (fname); @@ -494,5 +506,31 @@ static UUID ReadMachineId (string fname) return UUID.Parse (System.Text.Encoding.ASCII.GetString (data, 0, 32)); } + + private class ReadMessageTask + { + private readonly Connection ownerConnection; + private Task task = null; + private object taskLock = new object (); + + public ReadMessageTask (Connection connection) + { + ownerConnection = connection; + } + + public Message MakeSureTaskRunAndWait(CancellationToken stopWaitToken) + { + Task catchedTask = null; + + lock (taskLock) { + if (task == null || task.IsCompleted) { + task = Task.Run (() => ownerConnection.transport.ReadMessage ()); + } + catchedTask = task; + } + catchedTask.Wait (stopWaitToken); + return catchedTask.Result; + } + } } } diff --git a/src/ExportObject.cs b/src/ExportObject.cs index 14c2eb3..0edc3b0 100644 --- a/src/ExportObject.cs +++ b/src/ExportObject.cs @@ -67,9 +67,11 @@ internal virtual void WriteIntrospect (Introspector intro) internal static MethodCaller GetMCaller (MethodInfo mi) { MethodCaller mCaller; - if (!mCallers.TryGetValue (mi, out mCaller)) { - mCaller = TypeImplementer.GenCaller (mi); - mCallers[mi] = mCaller; + lock (mCallers) { + if (!mCallers.TryGetValue (mi, out mCaller)) { + mCaller = TypeImplementer.GenCaller (mi); + mCallers[mi] = mCaller; + } } return mCaller; } @@ -90,11 +92,7 @@ public virtual void HandleMethodCall (MessageContainer method_call) return; } - MethodCaller mCaller; - if (!mCallers.TryGetValue (mi, out mCaller)) { - mCaller = TypeImplementer.GenCaller (mi); - mCallers[mi] = mCaller; - } + MethodCaller mCaller = GetMCaller (mi); Signature inSig, outSig; bool hasDisposableList; diff --git a/src/Protocol/PendingCall.cs b/src/Protocol/PendingCall.cs index f8813a4..2bbe6a8 100644 --- a/src/Protocol/PendingCall.cs +++ b/src/Protocol/PendingCall.cs @@ -7,23 +7,33 @@ namespace DBus.Protocol { - public class PendingCall : IAsyncResult + public class PendingCall : IAsyncResult, IDisposable { - Connection conn; - Message reply; - ManualResetEvent waitHandle; - bool completedSync; - bool keepFDs; - + private Connection conn; + private Message reply; + private ManualResetEvent waitHandle = new ManualResetEvent (false); + private bool completedSync = false; + private bool keepFDs; + private CancellationTokenSource stopWait = new CancellationTokenSource(); + public event Action Completed; - public PendingCall (Connection conn) : this (conn, false) {} + public PendingCall(Connection conn) + : this (conn, false) + { + } + public PendingCall (Connection conn, bool keepFDs) { this.conn = conn; this.keepFDs = keepFDs; } + public void Dispose() + { + stopWait.Dispose (); + } + internal bool KeepFDs { get { @@ -33,36 +43,24 @@ internal bool KeepFDs public Message Reply { get { - if (reply != null) - return reply; - - if (Thread.CurrentThread == conn.mainThread) { - while (reply == null) - conn.HandleMessage (conn.Transport.ReadMessage ()); - - completedSync = true; - - conn.DispatchSignals (); - } else { - if (waitHandle == null) - Interlocked.CompareExchange (ref waitHandle, new ManualResetEvent (false), null); - - while (reply == null) - waitHandle.WaitOne (); - - completedSync = false; + while (reply == null) { + try { + conn.Iterate (stopWait.Token); + } + catch (OperationCanceledException) { + } } - return reply; - } + } + set { if (reply != null) throw new Exception ("Cannot handle reply more than once"); - reply = value; - if (waitHandle != null) - waitHandle.Set (); + waitHandle.Set (); + + stopWait.Cancel (); if (Completed != null) Completed (reply); @@ -84,9 +82,6 @@ object IAsyncResult.AsyncState { WaitHandle IAsyncResult.AsyncWaitHandle { get { - if (waitHandle == null) - waitHandle = new ManualResetEvent (false); - return waitHandle; } } diff --git a/src/Protocol/Transport.cs b/src/Protocol/Transport.cs index 910e652..965b9a1 100644 --- a/src/Protocol/Transport.cs +++ b/src/Protocol/Transport.cs @@ -14,6 +14,7 @@ namespace DBus.Transports abstract class Transport { readonly object writeLock = new object (); + readonly object readLock = new object (); [ThreadStatic] static byte[] readBuffer; @@ -154,12 +155,26 @@ protected void FireWakeUp () WakeUp (this, EventArgs.Empty); } + + internal virtual void WriteMessage(Message msg) + { + lock (writeLock) + { + msg.Header.GetHeaderDataToStream (stream); + if (msg.Body != null && msg.Body.Length != 0) + stream.Write (msg.Body, 0, msg.Body.Length); + stream.Flush (); + } + } + internal Message ReadMessage () { Message msg; try { - msg = ReadMessageReal (); + lock (readLock) + msg = ReadMessageReal (); + if (msg == null) { if (connection.IsConnected) { if (ProtocolInformation.Verbose) @@ -173,30 +188,17 @@ internal Message ReadMessage () connection.IsConnected = false; msg = null; } - + if (connection != null && connection.Monitors != null) connection.Monitors (msg); return msg; } - internal virtual int Read (byte[] buffer, int offset, int count, UnixFDArray fdArray) - { - int read = 0; - while (read < count) { - int nread = stream.Read (buffer, offset + read, count - read); - if (nread == 0) - break; - read += nread; - } - - if (read > count) - throw new Exception (); - - return read; - } - - Message ReadMessageReal () + /// + /// Thread unsafe! Safely called from . + /// + private Message ReadMessageReal () { byte[] header = null; byte[] body = null; @@ -266,18 +268,24 @@ Message ReadMessageReal () } Message msg = Message.FromReceivedBytes (Connection, header, body, fdArray); - return msg; } - internal virtual void WriteMessage (Message msg) + internal virtual int Read (byte[] buffer, int offset, int count, UnixFDArray fdArray) { - lock (writeLock) { - msg.Header.GetHeaderDataToStream (stream); - if (msg.Body != null && msg.Body.Length != 0) - stream.Write (msg.Body, 0, msg.Body.Length); - stream.Flush (); + int read = 0; + while (read < count) + { + int nread = stream.Read (buffer, offset + read, count - read); + if (nread == 0) + break; + read += nread; } + + if (read > count) + throw new Exception (); + + return read; } // Returns true if then transport supports unix FDs, even when the