diff --git a/SshAgentLib/WindowsOpenSshPipe.cs b/SshAgentLib/WindowsOpenSshPipe.cs index e0110cf6..57f515e4 100644 --- a/SshAgentLib/WindowsOpenSshPipe.cs +++ b/SshAgentLib/WindowsOpenSshPipe.cs @@ -28,95 +28,98 @@ using System.IO; using System.IO.Pipes; using System.Runtime.InteropServices; -using System.Security.AccessControl; -using System.Security.Principal; -using System.Threading; namespace dlech.SshAgentLib { - public class WindowsOpenSshPipe : IDisposable + public sealed class WindowsOpenSshPipe : IDisposable { - const string agentPipeId = "openssh-ssh-agent"; - const int receiveBufferSize = 5 * 1024; + private const string agentPipeName = "openssh-ssh-agent"; - static uint threadId = 0; + private const int BufferSizeIn = 5 * 1024; + private const int BufferSizeOut = 5 * 1024; + + private bool disposed; + private NamedPipeServerStream listeningServer; - NamedPipeServerStream listeningServer; - public delegate void ConnectionHandlerFunc(Stream stream, Process process); public ConnectionHandlerFunc ConnectionHandler { get; set; } - + public WindowsOpenSshPipe() { - if (File.Exists(string.Format("//./pipe/{0}", agentPipeId))) { + if (File.Exists($"//./pipe/{agentPipeName}")) { throw new PageantRunningException(); } - var thread = new Thread(listenerThread) { - Name = "WindowsOpenSshPipe.Listener", - IsBackground = true - }; - thread.Start(); + + AwaitConnection(); } - + [DllImport("kernel32.dll", SetLastError = true)] - static extern bool GetNamedPipeClientProcessId(IntPtr Pipe, out long ClientProcessId); + private static extern bool GetNamedPipeClientProcessId(IntPtr Pipe, out uint ClientProcessId); - void listenerThread() + private void AwaitConnection() { + if (disposed) { + return; + } + + listeningServer = new NamedPipeServerStream(agentPipeName, + PipeDirection.InOut, + NamedPipeServerStream.MaxAllowedServerInstances, + PipeTransmissionMode.Byte, + // TODO: Consider setting PipeOptions.CurrentUserOnly + PipeOptions.Asynchronous | PipeOptions.WriteThrough, + BufferSizeIn, + BufferSizeOut); + try { - while (true) { - var server = new NamedPipeServerStream(agentPipeId, PipeDirection.InOut, NamedPipeServerStream.MaxAllowedServerInstances, - PipeTransmissionMode.Byte, PipeOptions.WriteThrough, receiveBufferSize, receiveBufferSize); - listeningServer = server; - server.WaitForConnection(); - listeningServer = null; - var thread = new Thread(connectionThread) { - Name = string.Format("WindowsOpenSshPipe.Connection{0}", threadId++), - IsBackground = true - }; - thread.Start(server); - } + listeningServer.BeginWaitForConnection(AcceptConnection, listeningServer); + Debug.WriteLine("Started new server and awaiting connection ..."); } - catch (Exception) { - // don't crash background thread + catch (ObjectDisposedException) { + // Could happen if we're disposing while starting a server + } + catch (Exception ex) { + // Should never happen but we don't want to crash KeePass + Debug.WriteLine($"{ex.GetType()} in AwaitConnection(): {ex.Message}"); + listeningServer.Dispose(); } } - void connectionThread(object obj) + private void AcceptConnection(IAsyncResult result) { + Debug.WriteLine("Received new connection ..."); + AwaitConnection(); + + var server = result.AsyncState as NamedPipeServerStream; try { - var server = obj as NamedPipeServerStream; + server.EndWaitForConnection(result); - long clientPid; - if (!GetNamedPipeClientProcessId(server.SafePipeHandle.DangerousGetHandle(), out clientPid)) { + if (!GetNamedPipeClientProcessId(server.SafePipeHandle.DangerousGetHandle(), out var clientPid)) { throw new IOException("Failed to get client PID", Marshal.GetHRForLastWin32Error()); } - var proc = Process.GetProcessById((int)clientPid); - ConnectionHandler(server, proc); - server.Disconnect(); - server.Dispose(); + var clientProcess = Process.GetProcessById((int)clientPid); + Debug.WriteLine($"Processing request from process: {clientProcess.MainModule.ModuleName} (PID: {clientPid})"); + ConnectionHandler(server, clientProcess); } - catch (Exception) { - // TODO: add event to notify when there is a problem + catch (ObjectDisposedException) { + // Server has been disposed + } + catch (Exception ex) { + // Should never happen but we don't want to crash KeePass + Debug.WriteLine($"{ex.GetType()} in AcceptConnection(): {ex.Message}"); + } + finally { + server.Dispose(); } - } - - public void Dispose() - { - Dispose(true); - GC.SuppressFinalize(this); } - void Dispose(bool disposing) + public void Dispose() { - if (disposing) { - if (listeningServer != null) { - listeningServer.Dispose(); - listeningServer = null; - } - } + disposed = true; + listeningServer?.Dispose(); + listeningServer = null; } } }