Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve SslStream exception after disposal #79329

Merged
merged 6 commits into from
Jan 6, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ private void CloseInternal()

// Ensure a Read or Auth operation is not in progress,
// block potential future read and auth operations since SslStream is disposing.
// This leaves the _nestedRead = 1 and _nestedAuth = 1, but that's ok, since
// This leaves the _nestedRead = 2 and _nestedAuth = 2, but that's ok, since
// subsequent operations check the _exception sentinel first
if (Interlocked.Exchange(ref _nestedRead, 1) == 0 &&
Interlocked.Exchange(ref _nestedAuth, 1) == 0)
if (Interlocked.Exchange(ref _nestedRead, (int)StreamUse.Disposed) == (int)StreamUse.NotInUse &&
wfurt marked this conversation as resolved.
Show resolved Hide resolved
Interlocked.Exchange(ref _nestedAuth, (int)StreamUse.Disposed) == (int)StreamUse.NotInUse)
{
_buffer.ReturnBuffer();
}
Expand Down Expand Up @@ -162,19 +162,22 @@ private async Task ReplyOnReAuthenticationAsync<TIOAdapter>(byte[]? buffer, Canc
private async Task RenegotiateAsync<TIOAdapter>(CancellationToken cancellationToken)
where TIOAdapter : IReadWriteAdapter
{
if (Interlocked.Exchange(ref _nestedAuth, 1) == 1)
if (Interlocked.CompareExchange(ref _nestedAuth, (int)StreamUse.InUse, (int)StreamUse.NotInUse) != (int)StreamUse.NotInUse)
{
ObjectDisposedException.ThrowIf(_nestedAuth == (int)StreamUse.Disposed, this);
throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, "authenticate"));
}

if (Interlocked.Exchange(ref _nestedRead, 1) == 1)
if (Interlocked.CompareExchange(ref _nestedRead, (int)StreamUse.InUse, (int)StreamUse.NotInUse) != (int)StreamUse.NotInUse)
{
ObjectDisposedException.ThrowIf(_nestedRead == (int)StreamUse.Disposed, this);
throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "read"));
}

if (Interlocked.Exchange(ref _nestedWrite, 1) == 1)
// Write is different since we do not do anything special in Dispose
if (Interlocked.Exchange(ref _nestedWrite, (int)StreamUse.InUse) != (int)StreamUse.NotInUse)
{
_nestedRead = 0;
_nestedRead = (int)StreamUse.NotInUse;
throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "write"));
}

Expand Down Expand Up @@ -231,8 +234,8 @@ private async Task RenegotiateAsync<TIOAdapter>(CancellationToken cancellationTo
_buffer.ReturnBuffer();
}

_nestedRead = 0;
_nestedWrite = 0;
_nestedRead = (int)StreamUse.NotInUse;
_nestedWrite = (int)StreamUse.NotInUse;
_isRenego = false;
// We will not release _nestedAuth at this point to prevent another renegotiation attempt.
}
Expand All @@ -248,7 +251,7 @@ private async Task ForceAuthenticationAsync<TIOAdapter>(bool receiveFirst, byte[
if (reAuthenticationData == null)
{
// prevent nesting only when authentication functions are called explicitly. e.g. handle renegotiation transparently.
if (Interlocked.Exchange(ref _nestedAuth, 1) == 1)
if (Interlocked.Exchange(ref _nestedAuth, (int)StreamUse.InUse) == (int)StreamUse.InUse)
{
throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, "authenticate"));
}
Expand Down Expand Up @@ -335,7 +338,7 @@ private async Task ForceAuthenticationAsync<TIOAdapter>(bool receiveFirst, byte[
{
if (reAuthenticationData == null)
{
_nestedAuth = 0;
_nestedAuth = (int)StreamUse.NotInUse;
_isRenego = false;
}
}
Expand Down Expand Up @@ -494,7 +497,7 @@ private bool CompleteHandshake(ref ProtocolToken? alertToken, out SslPolicyError
{
ProcessHandshakeSuccess();

if (_nestedAuth != 1)
if (_nestedAuth != (int)StreamUse.InUse)
{
if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(this, $"Ignoring unsolicited renegotiated certificate.");
// ignore certificates received outside of handshake or requested renegotiation.
Expand Down Expand Up @@ -763,13 +766,16 @@ private SecurityStatusPal DecryptData(int frameSize)
private async ValueTask<int> ReadAsyncInternal<TIOAdapter>(Memory<byte> buffer, CancellationToken cancellationToken)
where TIOAdapter : IReadWriteAdapter
{
if (Interlocked.Exchange(ref _nestedRead, 1) == 1)
// Throw first if we already have exception.
// Check for disposal is not atomic so we will check again below.
ThrowIfExceptionalOrNotAuthenticated();

if (Interlocked.CompareExchange(ref _nestedRead, (int)StreamUse.InUse, (int)StreamUse.NotInUse) != (int)StreamUse.NotInUse)
{
ObjectDisposedException.ThrowIf(_nestedRead == (int)StreamUse.Disposed, this);
throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "read"));
}

ThrowIfExceptionalOrNotAuthenticated();

try
{
int processedLength = 0;
Expand Down Expand Up @@ -904,7 +910,7 @@ private async ValueTask<int> ReadAsyncInternal<TIOAdapter>(Memory<byte> buffer,
finally
{
ReturnReadBufferIfEmpty();
_nestedRead = 0;
_nestedRead = (int)StreamUse.NotInUse;
}
}

Expand All @@ -919,7 +925,7 @@ private async ValueTask WriteAsyncInternal<TIOAdapter>(ReadOnlyMemory<byte> buff
return;
}

if (Interlocked.Exchange(ref _nestedWrite, 1) == 1)
if (Interlocked.Exchange(ref _nestedWrite, (int)StreamUse.InUse) == (int)StreamUse.InUse)
{
throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "write"));
}
Expand All @@ -942,7 +948,7 @@ private async ValueTask WriteAsyncInternal<TIOAdapter>(ReadOnlyMemory<byte> buff
}
finally
{
_nestedWrite = 0;
_nestedWrite = (int)StreamUse.NotInUse;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,13 @@ public void ReturnBuffer()
}
}

private enum StreamUse
{
NotInUse = 0,
InUse = 1,
Disposed = 2,
};

private int _nestedWrite;
private int _nestedRead;

Expand Down Expand Up @@ -703,7 +710,7 @@ public override async ValueTask DisposeAsync()
public override int ReadByte()
{
ThrowIfExceptionalOrNotAuthenticated();
if (Interlocked.Exchange(ref _nestedRead, 1) == 1)
if (Interlocked.Exchange(ref _nestedRead, (int)StreamUse.InUse) == (int)StreamUse.InUse)
{
throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "read"));
}
Expand All @@ -724,7 +731,7 @@ public override int ReadByte()
// Regardless of whether we were able to read a byte from the buffer,
// reset the read tracking. If we weren't able to read a byte, the
// subsequent call to Read will set the flag again.
_nestedRead = 0;
_nestedRead = (int)StreamUse.NotInUse;
}

// Otherwise, fall back to reading a byte via Read, the same way Stream.ReadByte does.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.IO;
using System.Net.Test.Common;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;

using Xunit;
Expand All @@ -12,13 +12,13 @@ namespace System.Net.Security.Tests
{
using Configuration = System.Net.Test.Common.Configuration;

public abstract class SslStreamDisposeTest
public class SslStreamDisposeTest
{
[Fact]
public async Task DisposeAsync_NotConnected_ClosesStream()
{
bool disposed = false;
var stream = new SslStream(new DelegateStream(disposeFunc: _ => disposed = true), false, delegate { return true; });
var stream = new SslStream(new DelegateStream(disposeFunc: _ => disposed = true, canReadFunc: () => true, canWriteFunc: () => true), false, delegate { return true; });

Assert.False(disposed);
await stream.DisposeAsync();
Expand Down Expand Up @@ -50,5 +50,57 @@ await TestConfiguration.WhenAllOrAnyFailedWithTimeout(
await serverStream.DisposeAsync();
Assert.NotEqual(0, trackingStream2.TimesCalled(nameof(Stream.DisposeAsync)));
}

[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task Dispose_PendingReadAsync_ThrowsODE(bool bufferedRead)
{
using CancellationTokenSource cts = new CancellationTokenSource();
cts.CancelAfter(TestConfiguration.PassingTestTimeout);

(SslStream client, SslStream server) = TestHelper.GetConnectedSslStreams(leaveInnerStreamOpen: true);
using (client)
using (server)
using (X509Certificate2 serverCertificate = Configuration.Certificates.GetServerCertificate())
using (X509Certificate2 clientCertificate = Configuration.Certificates.GetClientCertificate())
{
SslClientAuthenticationOptions clientOptions = new SslClientAuthenticationOptions()
{
TargetHost = Guid.NewGuid().ToString("N"),
};
clientOptions.RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true;

SslServerAuthenticationOptions serverOptions = new SslServerAuthenticationOptions()
{
ServerCertificate = serverCertificate,
};

await TestConfiguration.WhenAllOrAnyFailedWithTimeout(
client.AuthenticateAsClientAsync(clientOptions, default),
server.AuthenticateAsServerAsync(serverOptions, default));

await TestHelper.PingPong(client, server, cts.Token);

await server.WriteAsync("PINGPONG"u8.ToArray(), cts.Token);
var readBuffer = new byte[1024];

Task<int>? task = null;
if (bufferedRead)
{
// This will read everything into internal buffer. Following ReadAsync will not need IO.
task = client.ReadAsync(readBuffer, 0, 4, cts.Token);
client.Dispose();
int readLength = await task.ConfigureAwait(false);
Assert.Equal(4, readLength);
}
else
{
client.Dispose();
}

await Assert.ThrowsAnyAsync<ObjectDisposedException>(() => client.ReadAsync(readBuffer, cts.Token).AsTask());
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ public static bool AllowAnyServerCertificate(object sender, X509Certificate cert
return true;
}

public static (SslStream ClientStream, SslStream ServerStream) GetConnectedSslStreams()
public static (SslStream ClientStream, SslStream ServerStream) GetConnectedSslStreams(bool leaveInnerStreamOpen = false)
{
(Stream clientStream, Stream serverStream) = GetConnectedStreams();
return (new SslStream(clientStream), new SslStream(serverStream));
return (new SslStream(clientStream, leaveInnerStreamOpen), new SslStream(serverStream, leaveInnerStreamOpen));
}

public static (Stream ClientStream, Stream ServerStream) GetConnectedStreams()
Expand Down