diff --git a/src/Renci.SshNet/Common/Extensions.cs b/src/Renci.SshNet/Common/Extensions.cs index 8fed310c1..0305e17d5 100644 --- a/src/Renci.SshNet/Common/Extensions.cs +++ b/src/Renci.SshNet/Common/Extensions.cs @@ -245,6 +245,9 @@ public static bool IsEqualTo(this byte[] left, byte[] right) return true; } +#if NET6_0_OR_GREATER + return left.AsSpan().SequenceEqual(right); +#else if (left.Length != right.Length) { return false; @@ -259,6 +262,7 @@ public static bool IsEqualTo(this byte[] left, byte[] right) } return true; +#endif } /// diff --git a/src/Renci.SshNet/Common/SshDataStream.cs b/src/Renci.SshNet/Common/SshDataStream.cs index 1eedd3ffc..656485a76 100644 --- a/src/Renci.SshNet/Common/SshDataStream.cs +++ b/src/Renci.SshNet/Common/SshDataStream.cs @@ -1,4 +1,7 @@ using System; +#if NET6_0_OR_GREATER +using System.Buffers.Binary; +#endif using System.Globalization; using System.IO; using System.Text; @@ -62,8 +65,14 @@ public bool IsEndOfData /// data to write. public void Write(uint value) { +#if NET6_0_OR_GREATER + Span bytes = stackalloc byte[4]; + BinaryPrimitives.WriteUInt32BigEndian(bytes, value); + Write(bytes); +#else var bytes = Pack.UInt32ToBigEndian(value); Write(bytes, 0, bytes.Length); +#endif } /// @@ -72,8 +81,14 @@ public void Write(uint value) /// data to write. public void Write(ulong value) { +#if NET6_0_OR_GREATER + Span bytes = stackalloc byte[8]; + BinaryPrimitives.WriteUInt64BigEndian(bytes, value); + Write(bytes); +#else var bytes = Pack.UInt64ToBigEndian(value); Write(bytes, 0, bytes.Length); +#endif } /// @@ -188,8 +203,14 @@ public BigInteger ReadBigInt() /// public uint ReadUInt32() { +#if NET6_0_OR_GREATER + Span bytes = stackalloc byte[4]; + ReadBytes(bytes); + return BinaryPrimitives.ReadUInt32BigEndian(bytes); +#else var data = ReadBytes(4); return Pack.BigEndianToUInt32(data); +#endif } /// @@ -200,8 +221,14 @@ public uint ReadUInt32() /// public ulong ReadUInt64() { +#if NET6_0_OR_GREATER + Span bytes = stackalloc byte[8]; + ReadBytes(bytes); + return BinaryPrimitives.ReadUInt64BigEndian(bytes); +#else var data = ReadBytes(8); return Pack.BigEndianToUInt64(data); +#endif } /// @@ -264,5 +291,22 @@ private byte[] ReadBytes(int length) return data; } + +#if NET6_0_OR_GREATER + /// + /// Fills the specified span with bytes from the internal buffer. + /// + /// The span to fill. + /// The Length of is greater than the actual number of bytes read. + private void ReadBytes(Span destination) + { + var bytesRead = Read(destination); + + if (bytesRead < destination.Length) + { + throw new ArgumentException(nameof(destination), string.Format(CultureInfo.InvariantCulture, "The requested length ({0}) is greater than the actual number of bytes read ({1}).", destination.Length, bytesRead)); + } + } +#endif } } diff --git a/src/Renci.SshNet/Session.cs b/src/Renci.SshNet/Session.cs index 9fec6bd7e..ced015563 100644 --- a/src/Renci.SshNet/Session.cs +++ b/src/Renci.SshNet/Session.cs @@ -1254,11 +1254,18 @@ private Message ReceiveMessage(Socket socket) if (_serverMac != null) { var clientHash = _serverMac.ComputeHash(data, 0, data.Length - serverMacLength); + + bool serverHashEqualsClientHash; + +#if NET6_0_OR_GREATER + var serverHash = data.AsSpan(data.Length - serverMacLength); + serverHashEqualsClientHash = serverHash.SequenceEqual(clientHash); +#else var serverHash = data.Take(data.Length - serverMacLength, serverMacLength); + serverHashEqualsClientHash = serverHash.IsEqualTo(clientHash); +#endif - // TODO add IsEqualTo overload that takes left+right index and number of bytes to compare; - // TODO that way we can eliminate the extra allocation of the Take above - if (!serverHash.IsEqualTo(clientHash)) + if (!serverHashEqualsClientHash) { throw new SshConnectionException("MAC error", DisconnectReason.MacError); }