Skip to content
This repository has been archived by the owner on Nov 8, 2022. It is now read-only.

Commit

Permalink
Fix race conditions is case of threaded Iterate() call
Browse files Browse the repository at this point in the history
  • Loading branch information
nvoronchev committed Dec 12, 2019
1 parent b288288 commit a49bb67
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 122 deletions.
10 changes: 6 additions & 4 deletions src/Bus.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,13 @@ public static Bus Starter
throw new ArgumentNullException ("address");

Bus bus;
if (buses.TryGetValue (address, out bus))
return bus;
lock (buses) {
if (buses.TryGetValue (address, out bus))
return bus;

bus = new Bus (address);
buses[address] = bus;
bus = new Bus (address);
buses[address] = bus;
}

return bus;
}
Expand Down
2 changes: 1 addition & 1 deletion src/BusException.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public override string Message
{
get
{
return ErrorName + ": " + ErrorMessage;
return ErrorName + ": " + ErrorMessage + Environment.NewLine + this.StackTrace;
}
}

Expand Down
136 changes: 87 additions & 49 deletions src/Connection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
// See COPYING for details

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using System.Reflection;

namespace DBus
Expand Down Expand Up @@ -37,23 +39,24 @@ public class Connection

Dictionary<uint,PendingCall> pendingCalls = new Dictionary<uint,PendingCall> ();
Queue<Message> inbound = new Queue<Message> ();
Dictionary<ObjectPath,BusObject> registeredObjects = new Dictionary<ObjectPath,BusObject> ();
ConcurrentDictionary<ObjectPath,BusObject> registeredObjects = new ConcurrentDictionary<ObjectPath,BusObject> ();
private readonly ReadMessageTask readMessageTask;

public delegate void MonitorEventHandler (Message msg);
public MonitorEventHandler Monitors; // subscribe yourself to this list of observers if you want to get notified about each incoming message

protected Connection ()
{

readMessageTask = new ReadMessageTask (this);
}

internal Connection (Transport transport)
internal Connection (Transport transport) : this ()
{
this.transport = transport;
transport.Connection = this;
}

internal Connection (string address)
internal Connection (string address) : this ()
{
OpenPrivate (address);
Authenticate ();
Expand Down Expand Up @@ -183,11 +186,12 @@ internal uint GenerateSerial ()

internal Message SendWithReplyAndBlock (Message msg, bool keepFDs)
{
PendingCall pending = SendWithReply (msg, keepFDs);
return pending.Reply;
using (PendingCall pending = SendWithPendingReply (msg, keepFDs)) {
return pending.Reply;
}
}

internal PendingCall SendWithReply (Message msg, bool keepFDs)
internal PendingCall SendWithPendingReply (Message msg, bool keepFDs)
{
msg.ReplyExpected = true;

Expand Down Expand Up @@ -215,27 +219,23 @@ internal virtual uint Send (Message msg)
return msg.Header.Serial;
}

//temporary hack
internal void DispatchSignals ()
public void Iterate ()
{
lock (inbound) {
while (inbound.Count != 0) {
Message msg = inbound.Dequeue ();
try {
HandleSignal (msg);
} finally {
msg.Dispose ();
}
}
}
Iterate (new CancellationToken (false));
}

public void Iterate ()
public void Iterate (CancellationToken stopWaitToken)
{
Message msg = transport.ReadMessage ();

HandleMessage (msg);
DispatchSignals ();
if (TryGetStoredSignalMessage (out Message inboundMsg)) {
try {
HandleSignal (inboundMsg);
} finally {
inboundMsg.Dispose ();
}
} else {
var msg = readMessageTask.MakeSureTaskRunAndWait (stopWaitToken);
HandleMessage (msg);
}
}

internal virtual void HandleMessage (Message msg)
Expand All @@ -251,21 +251,19 @@ internal virtual void HandleMessage (Message msg)
try {

//TODO: Restrict messages to Local ObjectPath?

{
object field_value = msg.Header[FieldCode.ReplySerial];
object field_value = msg.Header [FieldCode.ReplySerial];
if (field_value != null) {
uint reply_serial = (uint)field_value;
PendingCall pending;

lock (pendingCalls) {
PendingCall pending;
if (pendingCalls.TryGetValue (reply_serial, out pending)) {
if (pendingCalls.Remove (reply_serial)) {
pending.Reply = msg;
if (pending.KeepFDs)
cleanupFDs = false; // caller is responsible for closing FDs
}

if (!pendingCalls.Remove (reply_serial))
return;
pending.Reply = msg;
if (pending.KeepFDs)
cleanupFDs = false; // caller is responsible for closing FDs
return;
}
}
Expand All @@ -285,8 +283,7 @@ internal virtual void HandleMessage (Message msg)
break;
case MessageType.Signal:
//HandleSignal (msg);
lock (inbound)
inbound.Enqueue (msg);
StoreInboundSignalMessage (msg); //temporary hack
cleanupFDs = false; // FDs are closed after signal is handled
break;
case MessageType.Error:
Expand Down Expand Up @@ -391,14 +388,15 @@ internal void HandleMethodCall (MessageContainer method_call)
//this is messy and inefficient
List<string> linkNodes = new List<string> ();
int depth = method_call.Path.Decomposed.Length;
foreach (ObjectPath pth in registeredObjects.Keys) {
if (pth.Value == (method_call.Path.Value)) {
ExportObject exo = (ExportObject)registeredObjects[pth];
foreach(var objKeyValuePair in registeredObjects) {
ObjectPath pth = objKeyValuePair.Key;
ExportObject exo = (ExportObject) objKeyValuePair.Value;
if (pth.Value == method_call.Path.Value) {
exo.WriteIntrospect (intro);
} else {
for (ObjectPath cur = pth ; cur != null ; cur = cur.Parent) {
for (ObjectPath cur = pth; cur != null; cur = cur.Parent) {
if (cur.Value == method_call.Path.Value) {
string linkNode = pth.Decomposed[depth];
string linkNode = pth.Decomposed [depth];
if (!linkNodes.Contains (linkNode)) {
intro.WriteNode (linkNode);
linkNodes.Add (linkNode);
Expand All @@ -415,9 +413,8 @@ internal void HandleMethodCall (MessageContainer method_call)
return;
}

BusObject bo;
if (registeredObjects.TryGetValue (method_call.Path, out bo)) {
ExportObject eo = (ExportObject)bo;
if (registeredObjects.TryGetValue(method_call.Path, out BusObject bo)) {
ExportObject eo = (ExportObject) bo;
eo.HandleMethodCall (method_call);
} else {
MaybeSendUnknownMethodError (method_call);
Expand Down Expand Up @@ -464,14 +461,10 @@ public void Register (ObjectPath path, object obj)

public object Unregister (ObjectPath path)
{
BusObject bo;

if (!registeredObjects.TryGetValue (path, out bo))
if (!registeredObjects.TryRemove (path, out BusObject bo))
throw new Exception ("Cannot unregister " + path + " as it isn't registered");

registeredObjects.Remove (path);

ExportObject eo = (ExportObject)bo;
ExportObject eo = (ExportObject) bo;
eo.Registered = false;

return eo.Object;
Expand All @@ -486,6 +479,25 @@ internal protected virtual void RemoveMatch (string rule)
{
}

private void StoreInboundSignalMessage (Message msg)
{
lock (inbound) {
inbound.Enqueue (msg);
}
}

private bool TryGetStoredSignalMessage (out Message msg)
{
msg = null;
lock (inbound) {
if (inbound.Count != 0) {
msg = inbound.Dequeue ();
return true;
}
}
return false;
}

static UUID ReadMachineId (string fname)
{
byte[] data = File.ReadAllBytes (fname);
Expand All @@ -494,5 +506,31 @@ static UUID ReadMachineId (string fname)

return UUID.Parse (System.Text.Encoding.ASCII.GetString (data, 0, 32));
}

private class ReadMessageTask
{
private readonly Connection ownerConnection;
private Task<Message> task = null;
private object taskLock = new object ();

public ReadMessageTask (Connection connection)
{
ownerConnection = connection;
}

public Message MakeSureTaskRunAndWait(CancellationToken stopWaitToken)
{
Task<Message> catchedTask = null;

lock (taskLock) {
if (task == null || task.IsCompleted) {
task = Task<Message>.Run (() => ownerConnection.transport.ReadMessage ());
}
catchedTask = task;
}
catchedTask.Wait (stopWaitToken);
return catchedTask.Result;
}
}
}
}
14 changes: 6 additions & 8 deletions src/ExportObject.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,11 @@ internal virtual void WriteIntrospect (Introspector intro)
internal static MethodCaller GetMCaller (MethodInfo mi)
{
MethodCaller mCaller;
if (!mCallers.TryGetValue (mi, out mCaller)) {
mCaller = TypeImplementer.GenCaller (mi);
mCallers[mi] = mCaller;
lock (mCallers) {
if (!mCallers.TryGetValue (mi, out mCaller)) {
mCaller = TypeImplementer.GenCaller (mi);
mCallers[mi] = mCaller;
}
}
return mCaller;
}
Expand All @@ -90,11 +92,7 @@ public virtual void HandleMethodCall (MessageContainer method_call)
return;
}

MethodCaller mCaller;
if (!mCallers.TryGetValue (mi, out mCaller)) {
mCaller = TypeImplementer.GenCaller (mi);
mCallers[mi] = mCaller;
}
MethodCaller mCaller = GetMCaller (mi);

Signature inSig, outSig;
bool hasDisposableList;
Expand Down
63 changes: 29 additions & 34 deletions src/Protocol/PendingCall.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,33 @@

namespace DBus.Protocol
{
public class PendingCall : IAsyncResult
public class PendingCall : IAsyncResult, IDisposable
{
Connection conn;
Message reply;
ManualResetEvent waitHandle;
bool completedSync;
bool keepFDs;

private Connection conn;
private Message reply;
private ManualResetEvent waitHandle = new ManualResetEvent (false);
private bool completedSync = false;
private bool keepFDs;
private CancellationTokenSource stopWait = new CancellationTokenSource();

public event Action<Message> Completed;

public PendingCall (Connection conn) : this (conn, false) {}
public PendingCall(Connection conn)
: this (conn, false)
{
}

public PendingCall (Connection conn, bool keepFDs)
{
this.conn = conn;
this.keepFDs = keepFDs;
}

public void Dispose()
{
stopWait.Dispose ();
}

internal bool KeepFDs
{
get {
Expand All @@ -33,36 +43,24 @@ internal bool KeepFDs

public Message Reply {
get {
if (reply != null)
return reply;

if (Thread.CurrentThread == conn.mainThread) {
while (reply == null)
conn.HandleMessage (conn.Transport.ReadMessage ());

completedSync = true;

conn.DispatchSignals ();
} else {
if (waitHandle == null)
Interlocked.CompareExchange (ref waitHandle, new ManualResetEvent (false), null);

while (reply == null)
waitHandle.WaitOne ();

completedSync = false;
while (reply == null) {
try {
conn.Iterate (stopWait.Token);
}
catch (OperationCanceledException) {
}
}

return reply;
}
}

set {
if (reply != null)
throw new Exception ("Cannot handle reply more than once");

reply = value;

if (waitHandle != null)
waitHandle.Set ();
waitHandle.Set ();

stopWait.Cancel ();

if (Completed != null)
Completed (reply);
Expand All @@ -84,9 +82,6 @@ object IAsyncResult.AsyncState {

WaitHandle IAsyncResult.AsyncWaitHandle {
get {
if (waitHandle == null)
waitHandle = new ManualResetEvent (false);

return waitHandle;
}
}
Expand Down
Loading

0 comments on commit a49bb67

Please sign in to comment.