Skip to content
This repository has been archived by the owner on Nov 8, 2022. It is now read-only.

Fix race conditions is case of threaded Iterate() call #71

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/Bus.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,13 @@ public static Bus Starter
throw new ArgumentNullException ("address");

Bus bus;
if (buses.TryGetValue (address, out bus))
return bus;
lock (buses) {
if (buses.TryGetValue (address, out bus))
return bus;

bus = new Bus (address);
buses[address] = bus;
bus = new Bus (address);
buses[address] = bus;
}

return bus;
}
Expand Down
2 changes: 1 addition & 1 deletion src/BusException.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public override string Message
{
get
{
return ErrorName + ": " + ErrorMessage;
return ErrorName + ": " + ErrorMessage + Environment.NewLine + this.StackTrace;
}
}

Expand Down
136 changes: 87 additions & 49 deletions src/Connection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
// See COPYING for details

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using System.Reflection;

namespace DBus
Expand Down Expand Up @@ -37,23 +39,24 @@ public class Connection

Dictionary<uint,PendingCall> pendingCalls = new Dictionary<uint,PendingCall> ();
Queue<Message> inbound = new Queue<Message> ();
Dictionary<ObjectPath,BusObject> registeredObjects = new Dictionary<ObjectPath,BusObject> ();
ConcurrentDictionary<ObjectPath,BusObject> registeredObjects = new ConcurrentDictionary<ObjectPath,BusObject> ();
private readonly ReadMessageTask readMessageTask;

public delegate void MonitorEventHandler (Message msg);
public MonitorEventHandler Monitors; // subscribe yourself to this list of observers if you want to get notified about each incoming message

protected Connection ()
{

readMessageTask = new ReadMessageTask (this);
}

internal Connection (Transport transport)
internal Connection (Transport transport) : this ()
{
this.transport = transport;
transport.Connection = this;
}

internal Connection (string address)
internal Connection (string address) : this ()
{
OpenPrivate (address);
Authenticate ();
Expand Down Expand Up @@ -183,11 +186,12 @@ internal uint GenerateSerial ()

internal Message SendWithReplyAndBlock (Message msg, bool keepFDs)
{
PendingCall pending = SendWithReply (msg, keepFDs);
return pending.Reply;
using (PendingCall pending = SendWithPendingReply (msg, keepFDs)) {
return pending.Reply;
}
}

internal PendingCall SendWithReply (Message msg, bool keepFDs)
internal PendingCall SendWithPendingReply (Message msg, bool keepFDs)
{
msg.ReplyExpected = true;

Expand Down Expand Up @@ -215,27 +219,23 @@ internal virtual uint Send (Message msg)
return msg.Header.Serial;
}

//temporary hack
internal void DispatchSignals ()
public void Iterate ()
{
lock (inbound) {
while (inbound.Count != 0) {
Message msg = inbound.Dequeue ();
try {
HandleSignal (msg);
} finally {
msg.Dispose ();
}
}
}
Iterate (new CancellationToken (false));
}

public void Iterate ()
public void Iterate (CancellationToken stopWaitToken)
{
Message msg = transport.ReadMessage ();

HandleMessage (msg);
DispatchSignals ();
if (TryGetStoredSignalMessage (out Message inboundMsg)) {
try {
HandleSignal (inboundMsg);
} finally {
inboundMsg.Dispose ();
}
} else {
var msg = readMessageTask.MakeSureTaskRunAndWait (stopWaitToken);
HandleMessage (msg);
}
}

internal virtual void HandleMessage (Message msg)
Expand All @@ -251,21 +251,19 @@ internal virtual void HandleMessage (Message msg)
try {

//TODO: Restrict messages to Local ObjectPath?

{
object field_value = msg.Header[FieldCode.ReplySerial];
object field_value = msg.Header [FieldCode.ReplySerial];
if (field_value != null) {
uint reply_serial = (uint)field_value;
PendingCall pending;

lock (pendingCalls) {
PendingCall pending;
if (pendingCalls.TryGetValue (reply_serial, out pending)) {
if (pendingCalls.Remove (reply_serial)) {
pending.Reply = msg;
if (pending.KeepFDs)
cleanupFDs = false; // caller is responsible for closing FDs
}

if (!pendingCalls.Remove (reply_serial))
return;
pending.Reply = msg;
if (pending.KeepFDs)
cleanupFDs = false; // caller is responsible for closing FDs
return;
}
}
Expand All @@ -285,8 +283,7 @@ internal virtual void HandleMessage (Message msg)
break;
case MessageType.Signal:
//HandleSignal (msg);
lock (inbound)
inbound.Enqueue (msg);
StoreInboundSignalMessage (msg); //temporary hack
cleanupFDs = false; // FDs are closed after signal is handled
break;
case MessageType.Error:
Expand Down Expand Up @@ -391,14 +388,15 @@ internal void HandleMethodCall (MessageContainer method_call)
//this is messy and inefficient
List<string> linkNodes = new List<string> ();
int depth = method_call.Path.Decomposed.Length;
foreach (ObjectPath pth in registeredObjects.Keys) {
if (pth.Value == (method_call.Path.Value)) {
ExportObject exo = (ExportObject)registeredObjects[pth];
foreach(var objKeyValuePair in registeredObjects) {
ObjectPath pth = objKeyValuePair.Key;
ExportObject exo = (ExportObject) objKeyValuePair.Value;
if (pth.Value == method_call.Path.Value) {
exo.WriteIntrospect (intro);
} else {
for (ObjectPath cur = pth ; cur != null ; cur = cur.Parent) {
for (ObjectPath cur = pth; cur != null; cur = cur.Parent) {
if (cur.Value == method_call.Path.Value) {
string linkNode = pth.Decomposed[depth];
string linkNode = pth.Decomposed [depth];
if (!linkNodes.Contains (linkNode)) {
intro.WriteNode (linkNode);
linkNodes.Add (linkNode);
Expand All @@ -415,9 +413,8 @@ internal void HandleMethodCall (MessageContainer method_call)
return;
}

BusObject bo;
if (registeredObjects.TryGetValue (method_call.Path, out bo)) {
ExportObject eo = (ExportObject)bo;
if (registeredObjects.TryGetValue(method_call.Path, out BusObject bo)) {
ExportObject eo = (ExportObject) bo;
eo.HandleMethodCall (method_call);
} else {
MaybeSendUnknownMethodError (method_call);
Expand Down Expand Up @@ -464,14 +461,10 @@ public void Register (ObjectPath path, object obj)

public object Unregister (ObjectPath path)
{
BusObject bo;

if (!registeredObjects.TryGetValue (path, out bo))
if (!registeredObjects.TryRemove (path, out BusObject bo))
throw new Exception ("Cannot unregister " + path + " as it isn't registered");

registeredObjects.Remove (path);

ExportObject eo = (ExportObject)bo;
ExportObject eo = (ExportObject) bo;
eo.Registered = false;

return eo.Object;
Expand All @@ -486,6 +479,25 @@ internal protected virtual void RemoveMatch (string rule)
{
}

private void StoreInboundSignalMessage (Message msg)
{
lock (inbound) {
inbound.Enqueue (msg);
}
}

private bool TryGetStoredSignalMessage (out Message msg)
{
msg = null;
lock (inbound) {
if (inbound.Count != 0) {
msg = inbound.Dequeue ();
return true;
}
}
return false;
}

static UUID ReadMachineId (string fname)
{
byte[] data = File.ReadAllBytes (fname);
Expand All @@ -494,5 +506,31 @@ static UUID ReadMachineId (string fname)

return UUID.Parse (System.Text.Encoding.ASCII.GetString (data, 0, 32));
}

private class ReadMessageTask
{
private readonly Connection ownerConnection;
private Task<Message> task = null;
private object taskLock = new object ();

public ReadMessageTask (Connection connection)
{
ownerConnection = connection;
}

public Message MakeSureTaskRunAndWait(CancellationToken stopWaitToken)
{
Task<Message> catchedTask = null;

lock (taskLock) {
if (task == null || task.IsCompleted) {
task = Task<Message>.Run (() => ownerConnection.transport.ReadMessage ());
}
catchedTask = task;
}
catchedTask.Wait (stopWaitToken);
return catchedTask.Result;
}
}
}
}
14 changes: 6 additions & 8 deletions src/ExportObject.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,11 @@ internal virtual void WriteIntrospect (Introspector intro)
internal static MethodCaller GetMCaller (MethodInfo mi)
{
MethodCaller mCaller;
if (!mCallers.TryGetValue (mi, out mCaller)) {
mCaller = TypeImplementer.GenCaller (mi);
mCallers[mi] = mCaller;
lock (mCallers) {
if (!mCallers.TryGetValue (mi, out mCaller)) {
mCaller = TypeImplementer.GenCaller (mi);
mCallers[mi] = mCaller;
}
}
return mCaller;
}
Expand All @@ -90,11 +92,7 @@ public virtual void HandleMethodCall (MessageContainer method_call)
return;
}

MethodCaller mCaller;
if (!mCallers.TryGetValue (mi, out mCaller)) {
mCaller = TypeImplementer.GenCaller (mi);
mCallers[mi] = mCaller;
}
MethodCaller mCaller = GetMCaller (mi);

Signature inSig, outSig;
bool hasDisposableList;
Expand Down
63 changes: 29 additions & 34 deletions src/Protocol/PendingCall.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,33 @@

namespace DBus.Protocol
{
public class PendingCall : IAsyncResult
public class PendingCall : IAsyncResult, IDisposable
{
Connection conn;
Message reply;
ManualResetEvent waitHandle;
bool completedSync;
bool keepFDs;

private Connection conn;
private Message reply;
private ManualResetEvent waitHandle = new ManualResetEvent (false);
private bool completedSync = false;
private bool keepFDs;
private CancellationTokenSource stopWait = new CancellationTokenSource();

public event Action<Message> Completed;

public PendingCall (Connection conn) : this (conn, false) {}
public PendingCall(Connection conn)
: this (conn, false)
{
}

public PendingCall (Connection conn, bool keepFDs)
{
this.conn = conn;
this.keepFDs = keepFDs;
}

public void Dispose()
{
stopWait.Dispose ();
}

internal bool KeepFDs
{
get {
Expand All @@ -33,36 +43,24 @@ internal bool KeepFDs

public Message Reply {
get {
if (reply != null)
return reply;

if (Thread.CurrentThread == conn.mainThread) {
while (reply == null)
conn.HandleMessage (conn.Transport.ReadMessage ());

completedSync = true;

conn.DispatchSignals ();
} else {
if (waitHandle == null)
Interlocked.CompareExchange (ref waitHandle, new ManualResetEvent (false), null);

while (reply == null)
waitHandle.WaitOne ();

completedSync = false;
while (reply == null) {
try {
conn.Iterate (stopWait.Token);
}
catch (OperationCanceledException) {
}
}

return reply;
}
}

set {
if (reply != null)
throw new Exception ("Cannot handle reply more than once");

reply = value;

if (waitHandle != null)
waitHandle.Set ();
waitHandle.Set ();

stopWait.Cancel ();

if (Completed != null)
Completed (reply);
Expand All @@ -84,9 +82,6 @@ object IAsyncResult.AsyncState {

WaitHandle IAsyncResult.AsyncWaitHandle {
get {
if (waitHandle == null)
waitHandle = new ManualResetEvent (false);

return waitHandle;
}
}
Expand Down
Loading