From ee083afa8ff641bd0ebdc26a72500eec8dbfc47f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C2=A0?= <> Date: Wed, 21 Jul 2021 01:59:41 -0700 Subject: [PATCH 1/3] Add WSL support via UnixSocket --- SshAgentLib/PageantAgent.cs | 35 ++++++ SshAgentLib/SshAgentLib.csproj | 1 + SshAgentLib/UnixSocket.cs | 221 +++++++++++++++++++++++++++++++++ 3 files changed, 257 insertions(+) create mode 100644 SshAgentLib/UnixSocket.cs diff --git a/SshAgentLib/PageantAgent.cs b/SshAgentLib/PageantAgent.cs index fedbc68f..52ac58fc 100644 --- a/SshAgentLib/PageantAgent.cs +++ b/SshAgentLib/PageantAgent.cs @@ -70,6 +70,7 @@ public class PageantAgent : Agent object lockObject = new object(); CygwinSocket cygwinSocket; MsysSocket msysSocket; + UnixSocket wslSocket; WindowsOpenSshPipe opensshPipe; #endregion @@ -286,6 +287,39 @@ public void StopMsysSocket() msysSocket = null; } + /// + /// Starts a wsl style socket that can be used by the ssh program + /// that comes with wsl. + /// + /// The path to the socket file that will be created. + public void StartWslSocket(string path) + { + if (disposed) { + throw new ObjectDisposedException("PagentAgent"); + } + if (wslSocket != null) { + return; + } + // only overwrite a file if it looks like a WslSocket file. + // TODO: Might be good to test that there are not network sockets using + // the port specified in this file. + if (File.Exists(path) && UnixSocket.TestFile(path)) { + File.Delete(path); + } + wslSocket = new UnixSocket(path); + wslSocket.ConnectionHandler = connectionHandler; + } + + public void StopWslSocket() + { + if (disposed) + throw new ObjectDisposedException("PagentAgent"); + if (wslSocket == null) + return; + wslSocket.Dispose(); + wslSocket = null; + } + public void StartWindowsOpenSshPipe() { if (disposed) { @@ -351,6 +385,7 @@ private void RunWindowInNewAppcontext() // make sure socket files are cleaned up when we stop. StopCygwinSocket(); StopMsysSocket(); + StopWslSocket(); StopWindowsOpenSshPipe(); if (hwnd != IntPtr.Zero) { diff --git a/SshAgentLib/SshAgentLib.csproj b/SshAgentLib/SshAgentLib.csproj index 6b257def..d2c1e815 100644 --- a/SshAgentLib/SshAgentLib.csproj +++ b/SshAgentLib/SshAgentLib.csproj @@ -75,6 +75,7 @@ + diff --git a/SshAgentLib/UnixSocket.cs b/SshAgentLib/UnixSocket.cs new file mode 100644 index 00000000..59a1512d --- /dev/null +++ b/SshAgentLib/UnixSocket.cs @@ -0,0 +1,221 @@ +// +// UnixSocket.cs +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +using Microsoft.Win32.SafeHandles; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Net; +using System.Net.Sockets; +using System.Reflection; +using System.Runtime.InteropServices; +using System.Security.AccessControl; +using System.Security.Principal; +using System.Text; +using System.Threading; + +namespace dlech.SshAgentLib +{ + public class UnixSocket : IDisposable + { + const string waitHandleNamePrefix = "unix.local_socket.secret"; + + static int clientCount = 0; + + string path; + Socket socket; + SocketAddress sockaddr; + Guid guid; + bool disposed; + List clientSockets = new List(); + object clientSocketsLock = new object(); + + public delegate void ConnectionHandlerFunc(Stream stream, Process process); + public ConnectionHandlerFunc ConnectionHandler { get; set; } + + /// + /// Create new "unix domain" socket for use with Linux + /// + /// The name of the file to use for the socket + public UnixSocket(string path) + { + this.path = path; + guid = Guid.NewGuid(); + { + try { + socket = new Socket(AddressFamily.Unix, SocketType.Stream, + ProtocolType.Unspecified); + var endpoint = new UnixEndPoint(path); + sockaddr = endpoint.Serialize(); + socket.Bind(endpoint); + var fileSecurity = File.GetAccessControl(path); + // This turns off ACL inheritance and removes all inherited rules + fileSecurity.SetAccessRuleProtection(true, false); + // We are left with no permissions at all, so we have to add them + // back for the current user + var userOnlyRule = new FileSystemAccessRule( + WindowsIdentity.GetCurrent().User, + FileSystemRights.FullControl, + AccessControlType.Allow); + fileSecurity.SetAccessRule(userOnlyRule); + File.SetAccessControl(path, fileSecurity); + socket.Listen(5); + var socketThread = new Thread(AcceptConnections); + socketThread.Name = "UnixSocket"; + socketThread.Start(); + } catch (Exception) { + if (socket != null) + socket.Close(); + File.Delete(path); + throw; + } + } + } + + /// + /// Tests a file to see if it looks like a Unix socket file + /// + /// The path to the file. + /// true if the file contents look correct + public static bool TestFile(string path) + { + var info = new FileInfo(path); + return info.Length == 0; + } + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + void Dispose(bool disposing) + { + if (!disposed) { + disposed = true; + if (disposing) { + // Dispose managed resources + foreach (var clientSocket in clientSockets) { + clientSocket.Dispose(); + } + socket.Dispose(); + File.Delete(path); + } + // Dispose unmanaged resources + } + } + + void AcceptConnections() + { + var buffer = new byte[16]; + while (true) { + try { + BindingFlags + sbinding = BindingFlags.Static | BindingFlags.NonPublic | BindingFlags.Public, + ibinding = BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public; + var m_Size = typeof(SocketAddress).GetField("m_Size", ibinding); + var args = new object[] + { + typeof(Socket).GetField("m_Handle", ibinding).GetValue(socket), + (byte[])typeof(SocketAddress).GetField("m_Buffer", ibinding).GetValue(sockaddr), + (int)m_Size.GetValue(sockaddr) + }; + var acceptedHandle = (SafeHandleMinusOneIsInvalid)typeof(Socket).Assembly.GetType("System.Net.SafeCloseSocket").GetMethod("Accept", sbinding).Invoke(null, args); + m_Size.SetValue(sockaddr, (int)args[2]); + var endpoint = new UnixEndPoint(null).Create(sockaddr); + var clientSocket = acceptedHandle.IsInvalid ? null : (Socket)typeof(Socket).GetMethod("CreateAcceptSocket", ibinding).Invoke(socket, new object[] { acceptedHandle, endpoint, false }); + if (clientSocket == null) { Marshal.ThrowExceptionForHR(Marshal.GetHRForLastWin32Error()); } + var clientThread = new Thread(() => { + try { + using (var stream = new NetworkStream(clientSocket)) { + Process proc = null; + if (ConnectionHandler != null) { + ConnectionHandler(stream, proc); + } + } + } catch { + // can throw if remote closes the connection at a bad time + } finally { + lock (clientSocketsLock) { + clientSockets.Remove(clientSocket); + } + } + }); + lock (clientSocketsLock) { + clientSockets.Add(clientSocket); + } + clientThread.Name = string.Format("UnixClient{0}", clientCount++); + clientThread.Start(); + } catch (Exception ex) { + Debug.Assert(disposed, ex.ToString()); + break; + } + } + } + } + + [Serializable] + public class UnixEndPoint : EndPoint + { + public string Filename { get; private set; } + + public UnixEndPoint(string path) : base() + { + this.Filename = path; + } + + public override AddressFamily AddressFamily { get { return AddressFamily.Unix; } } + + public override EndPoint Create(SocketAddress socketAddress) + { + int size = socketAddress.Size - 2; + var bytes = new byte[size]; + for (int i = 0; i < bytes.Length; i++) + { + bytes[i] = socketAddress[i + 2]; + if (i > 0 && bytes[i] == 0) + { + size = i; + break; + } + } + return new UnixEndPoint(Encoding.UTF8.GetString(bytes, 0, size)); + } + public override SocketAddress Serialize() + { + var bytes = Encoding.UTF8.GetBytes(this.Filename); + var maxLen = 108; + if (bytes.Length > maxLen) { + throw new PathTooLongException(string.Format("Path ({0} bytes) was too long for UNIX-domain socket (max {1} bytes)", bytes.Length, maxLen)); + } + var addr = new SocketAddress(AddressFamily.Unix, sizeof(short) + maxLen); + for (int i = 0; i < bytes.Length && i < maxLen; i++) + { + addr[2 + i] = bytes[i]; + } + if (bytes.Length < maxLen) { + addr[sizeof(short) + bytes.Length] = 0; + } + return addr; + } + } +} From fea1c64aeffe86760a3c4b3a761521d43fea552f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C2=A0?= <> Date: Thu, 6 Jan 2022 19:29:20 -0800 Subject: [PATCH 2/3] Address review comments --- SshAgentLib/UnixSocket.cs | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/SshAgentLib/UnixSocket.cs b/SshAgentLib/UnixSocket.cs index 59a1512d..3590a71d 100644 --- a/SshAgentLib/UnixSocket.cs +++ b/SshAgentLib/UnixSocket.cs @@ -1,6 +1,8 @@ // // UnixSocket.cs // +// Allows WSL1 connections via AF_UNIX sockets on Windows 10 and above. +// // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights @@ -132,15 +134,18 @@ void AcceptConnections() BindingFlags sbinding = BindingFlags.Static | BindingFlags.NonPublic | BindingFlags.Public, ibinding = BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public; - var m_Size = typeof(SocketAddress).GetField("m_Size", ibinding); var args = new object[] { typeof(Socket).GetField("m_Handle", ibinding).GetValue(socket), (byte[])typeof(SocketAddress).GetField("m_Buffer", ibinding).GetValue(sockaddr), - (int)m_Size.GetValue(sockaddr) + sockaddr.Size }; var acceptedHandle = (SafeHandleMinusOneIsInvalid)typeof(Socket).Assembly.GetType("System.Net.SafeCloseSocket").GetMethod("Accept", sbinding).Invoke(null, args); - m_Size.SetValue(sockaddr, (int)args[2]); + // Accept() returns the socket address for the accepted connection, along with its size. + // We want the SocketAddress object to contain that address when creating the endpoint. + // We use reflection to access the field since there is no public setter. + // This is safe because accept() returns the actual number of bytes written to the buffer, so the output cannot overflow the buffer. + typeof(SocketAddress).GetField("m_Size", ibinding).SetValue(sockaddr, (int)args[2]); var endpoint = new UnixEndPoint(null).Create(sockaddr); var clientSocket = acceptedHandle.IsInvalid ? null : (Socket)typeof(Socket).GetMethod("CreateAcceptSocket", ibinding).Invoke(socket, new object[] { acceptedHandle, endpoint, false }); if (clientSocket == null) { Marshal.ThrowExceptionForHR(Marshal.GetHRForLastWin32Error()); } From 7c85e3adf592833a6af6c5efd3d193d789904bba Mon Sep 17 00:00:00 2001 From: David Lechner Date: Sat, 29 Jan 2022 15:13:00 -0600 Subject: [PATCH 3/3] rework WslSocket - Use backported UnixDomainSocketEndPoint from .NET core. - Rename UnixSocket to WslSocket since this is Windows-only. - Remove use of System.Reflection. - Use async/await instead of blocking threads. - Add unit tests. - Fix bugs discoverd by unit tests. - Add changelog entry. --- .editorconfig | 2 +- CHANGELOG.md | 3 +- .../UnixDomainSocketEndPoint.Windows.cs | 27 +++ .../Microsoft/UnixDomainSocketEndPoint.cs | 161 +++++++++++++ SshAgentLib/PageantAgent.cs | 15 +- SshAgentLib/SshAgentLib.csproj | 7 +- SshAgentLib/UnixSocket.cs | 226 ------------------ SshAgentLib/WslSocket.cs | 162 +++++++++++++ SshAgentLibTests/SshAgentLibTests.csproj | 1 + SshAgentLibTests/WslSocketTests.cs | 100 ++++++++ 10 files changed, 467 insertions(+), 237 deletions(-) create mode 100644 SshAgentLib/Microsoft/UnixDomainSocketEndPoint.Windows.cs create mode 100644 SshAgentLib/Microsoft/UnixDomainSocketEndPoint.cs delete mode 100644 SshAgentLib/UnixSocket.cs create mode 100644 SshAgentLib/WslSocket.cs create mode 100644 SshAgentLibTests/WslSocketTests.cs diff --git a/.editorconfig b/.editorconfig index 758eb9d9..ea604296 100644 --- a/.editorconfig +++ b/.editorconfig @@ -147,5 +147,5 @@ csharp_preserve_single_line_statements = true csharp_preserve_single_line_blocks = true # starting to convert to 4 spaces -[*WindowsOpenSshPipe*.cs] +[*{WindowsOpenSshPipe,WslSocket,UnixDomainSocketEndPoint}*.cs] indent_size = 4 diff --git a/CHANGELOG.md b/CHANGELOG.md index b7ffa653..5a0e4a5e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,8 @@ ## Added - Added this changelog. -- Added PuTTY private key v3 support +- Added PuTTY private key v3 support. +- Added Window UNIX socket for WSL `ssh-agent` support. ## Fixed - Fixed using incorrect unmanaged memory free function in `PagentClent.SendMessage()`. diff --git a/SshAgentLib/Microsoft/UnixDomainSocketEndPoint.Windows.cs b/SshAgentLib/Microsoft/UnixDomainSocketEndPoint.Windows.cs new file mode 100644 index 00000000..0274c1f5 --- /dev/null +++ b/SshAgentLib/Microsoft/UnixDomainSocketEndPoint.Windows.cs @@ -0,0 +1,27 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace System.Net.Sockets +{ + /// Represents a Unix Domain Socket endpoint as a path. + public sealed partial class UnixDomainSocketEndPoint : EndPoint + { +#pragma warning disable CA1802 // on Unix these need to be static readonly rather than const, so we do the same on Windows for consistency + private static readonly int s_nativePathOffset = 2; // sizeof(sun_family) + private static readonly int s_nativePathLength = 108; // sizeof(sun_path) + private static readonly int s_nativeAddressSize = s_nativePathOffset + s_nativePathLength; // sizeof(sockaddr_un) +#pragma warning restore CA1802 + + private SocketAddress CreateSocketAddressForSerialize() => + new SocketAddress(AddressFamily.Unix, s_nativeAddressSize); + + // from afunix.h: + //#define UNIX_PATH_MAX 108 + //typedef struct sockaddr_un + //{ + // ADDRESS_FAMILY sun_family; /* AF_UNIX */ + // char sun_path[UNIX_PATH_MAX]; /* pathname */ + //} + //SOCKADDR_UN, *PSOCKADDR_UN; + } +} diff --git a/SshAgentLib/Microsoft/UnixDomainSocketEndPoint.cs b/SshAgentLib/Microsoft/UnixDomainSocketEndPoint.cs new file mode 100644 index 00000000..f2fe1716 --- /dev/null +++ b/SshAgentLib/Microsoft/UnixDomainSocketEndPoint.cs @@ -0,0 +1,161 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Text; +using System.IO; + +namespace System.Net.Sockets +{ + /// Represents a Unix Domain Socket endpoint as a path. + public sealed partial class UnixDomainSocketEndPoint : EndPoint + { + private const AddressFamily EndPointAddressFamily = AddressFamily.Unix; + + private readonly string _path; + private readonly byte[] _encodedPath; + + // Tracks the file Socket should delete on Dispose. + internal string BoundFileName { get; } + + public UnixDomainSocketEndPoint(string path) + : this(path, null) + { } + + private UnixDomainSocketEndPoint(string path, string boundFileName) + { + if (path == null) + { + throw new ArgumentNullException(nameof(path)); + } + + BoundFileName = boundFileName; + + // Pathname socket addresses should be null-terminated. + // Linux abstract socket addresses start with a zero byte, they must not be null-terminated. + bool isAbstract = IsAbstract(path); + int bufferLength = Encoding.UTF8.GetByteCount(path); + if (!isAbstract) + { + // for null terminator + bufferLength++; + } + + if (path.Length == 0 || bufferLength > s_nativePathLength) + { + const string ArgumentOutOfRange_PathLengthInvalid = + "The path '{0}' is of an invalid length for use with domain sockets on this platform. The length must be between 1 and {1} characters, inclusive."; + + throw new ArgumentOutOfRangeException( + nameof(path), path, + string.Format(ArgumentOutOfRange_PathLengthInvalid, path, s_nativePathLength)); + //SR.Format(SR.ArgumentOutOfRange_PathLengthInvalid, path, s_nativePathLength)); + } + + _path = path; + _encodedPath = new byte[bufferLength]; + int bytesEncoded = Encoding.UTF8.GetBytes(path, 0, path.Length, _encodedPath, 0); + Debug.Assert(bufferLength - (isAbstract ? 0 : 1) == bytesEncoded); + + // FIXME: see https://github.com/dotnet/runtime/blob/f85ea976f81945ea18cd5dc71959cccecdc93cd2/src/libraries/Common/src/System/Net/SocketProtocolSupportPal.Windows.cs#L14 + //if (!Socket.OSSupportsUnixDomainSockets) + //{ + // throw new PlatformNotSupportedException(); + //} + } + + internal static int MaxAddressSize => s_nativeAddressSize; + + internal UnixDomainSocketEndPoint(SocketAddress socketAddress) + { + if (socketAddress == null) + { + throw new ArgumentNullException(nameof(socketAddress)); + } + + if (socketAddress.Family != EndPointAddressFamily || + socketAddress.Size > s_nativeAddressSize) + { + throw new ArgumentOutOfRangeException(nameof(socketAddress)); + } + + if (socketAddress.Size > s_nativePathOffset) + { + _encodedPath = new byte[socketAddress.Size - s_nativePathOffset]; + for (int i = 0; i < _encodedPath.Length; i++) + { + _encodedPath[i] = socketAddress[s_nativePathOffset + i]; + } + + // Strip trailing null of pathname socket addresses. + int length = _encodedPath.Length; + if (!IsAbstract(_encodedPath)) + { + // Since this isn't an abstract path, we're sure our first byte isn't 0. + while (_encodedPath[length - 1] == 0) + { + length--; + } + } + _path = Encoding.UTF8.GetString(_encodedPath, 0, length); + } + else + { + _encodedPath = Array.Empty(); + _path = string.Empty; + } + } + + public override SocketAddress Serialize() + { + SocketAddress result = CreateSocketAddressForSerialize(); + + for (int index = 0; index < _encodedPath.Length; index++) + { + result[s_nativePathOffset + index] = _encodedPath[index]; + } + + return result; + } + + public override EndPoint Create(SocketAddress socketAddress) => new UnixDomainSocketEndPoint(socketAddress); + + public override AddressFamily AddressFamily => EndPointAddressFamily; + + public override string ToString() + { + bool isAbstract = IsAbstract(_path); + if (isAbstract) + { + // return string.Concat("@", _path.AsSpan(1)); + return "@" + _path.Substring(1); + } + else + { + return _path; + } + } + + internal UnixDomainSocketEndPoint CreateBoundEndPoint() + { + if (IsAbstract(_path)) + { + return this; + } + return new UnixDomainSocketEndPoint(_path, Path.GetFullPath(_path)); + } + + internal UnixDomainSocketEndPoint CreateUnboundEndPoint() + { + if (IsAbstract(_path) || BoundFileName is null) + { + return this; + } + return new UnixDomainSocketEndPoint(_path, null); + } + + private static bool IsAbstract(string path) => path.Length > 0 && path[0] == '\0'; + + private static bool IsAbstract(byte[] encodedPath) => encodedPath.Length > 0 && encodedPath[0] == 0; + } +} diff --git a/SshAgentLib/PageantAgent.cs b/SshAgentLib/PageantAgent.cs index 4cd43d86..05244fa4 100644 --- a/SshAgentLib/PageantAgent.cs +++ b/SshAgentLib/PageantAgent.cs @@ -1,4 +1,4 @@ -// +// // PageantAgent.cs // // Author(s): David Lechner @@ -53,6 +53,7 @@ public class PageantAgent : Agent const int ERROR_CLASS_ALREADY_EXISTS = 1410; const int WM_COPYDATA = 0x004A; const int WSAECONNABORTED = 10053; + const int WSAECONNRESET = 10054; /* From PuTTY source code */ @@ -70,7 +71,7 @@ public class PageantAgent : Agent object lockObject = new object(); CygwinSocket cygwinSocket; MsysSocket msysSocket; - UnixSocket wslSocket; + WslSocket wslSocket; WindowsOpenSshPipe opensshPipe; Thread winThread; @@ -302,13 +303,10 @@ public void StartWslSocket(string path) return; } // only overwrite a file if it looks like a WslSocket file. - // TODO: Might be good to test that there are not network sockets using - // the port specified in this file. - if (File.Exists(path) && UnixSocket.TestFile(path)) { + if (File.Exists(path) && WslSocket.TestFile(path)) { File.Delete(path); } - wslSocket = new UnixSocket(path); - wslSocket.ConnectionHandler = connectionHandler; + wslSocket = new WslSocket(path, connectionHandler); } public void StopWslSocket() @@ -490,7 +488,8 @@ void connectionHandler(Stream stream, Process process) } } catch (IOException ex) { var socketException = ex.InnerException as SocketException; - if (socketException != null && socketException.ErrorCode == WSAECONNABORTED) { + if (socketException != null && ( + socketException.ErrorCode == WSAECONNABORTED || socketException.ErrorCode == WSAECONNRESET)) { // expected error return; } diff --git a/SshAgentLib/SshAgentLib.csproj b/SshAgentLib/SshAgentLib.csproj index 65b6d2f0..a3523e30 100644 --- a/SshAgentLib/SshAgentLib.csproj +++ b/SshAgentLib/SshAgentLib.csproj @@ -95,7 +95,9 @@ - + + + @@ -160,6 +162,9 @@ 1.8.1.3 + + 4.3.0 +