Skip to content

Commit

Permalink
Remove class constraint from Interlocked.{Compare}Exchange
Browse files Browse the repository at this point in the history
Today `Interlocked.CompareExchange<T>` and `Interlocked.Exchange<T>` support only reference type `T`s. Now that we have corresponding {Compare}Exchange methods that support types of size 1, 2, 4, and 8, we can remove the constraint and support any `T` that's either a reference type, a primitive type, or an enum type, making the generic overload more useful and avoiding consumers needing to choose less-than-ideal types just because of the need for atomicity with Interlocked.{Compare}Exchange.
  • Loading branch information
stephentoub committed Jul 10, 2024
1 parent 4a1a076 commit b23c285
Show file tree
Hide file tree
Showing 60 changed files with 919 additions and 554 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
Expand Down Expand Up @@ -102,21 +103,6 @@ public static long Exchange(ref long location1, long value)
[return: NotNullIfNotNull(nameof(location1))]
[MethodImpl(MethodImplOptions.InternalCall)]
private static extern object? ExchangeObject([NotNullIfNotNull(nameof(value))] ref object? location1, object? value);

// The below whole method reduces to a single call to Exchange(ref object, object) but
// the JIT thinks that it will generate more native code than it actually does.

/// <summary>Sets a variable of the specified type <typeparamref name="T"/> to a specified value and returns the original value, as an atomic operation.</summary>
/// <param name="location1">The variable to set to the specified value.</param>
/// <param name="value">The value to which the <paramref name="location1"/> parameter is set.</param>
/// <returns>The original value of <paramref name="location1"/>.</returns>
/// <exception cref="NullReferenceException">The address of location1 is a null pointer.</exception>
/// <typeparam name="T">The type to be used for <paramref name="location1"/> and <paramref name="value"/>. This type must be a reference type.</typeparam>
[Intrinsic]
[return: NotNullIfNotNull(nameof(location1))]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static T Exchange<T>([NotNullIfNotNull(nameof(value))] ref T location1, T value) where T : class? =>
Unsafe.As<T>(Exchange(ref Unsafe.As<T, object?>(ref location1), value));
#endregion

#region CompareExchange
Expand Down Expand Up @@ -183,29 +169,6 @@ public static long CompareExchange(ref long location1, long value, long comparan
[MethodImpl(MethodImplOptions.InternalCall)]
[return: NotNullIfNotNull(nameof(location1))]
private static extern object? CompareExchangeObject(ref object? location1, object? value, object? comparand);

// Note that getILIntrinsicImplementationForInterlocked() in vm\jitinterface.cpp replaces
// the body of the following method with the following IL:
// ldarg.0
// ldarg.1
// ldarg.2
// call System.Threading.Interlocked::CompareExchange(ref Object, Object, Object)
// ret
// The workaround is no longer strictly necessary now that we have Unsafe.As but it does
// have the advantage of being less sensitive to JIT's inliner decisions.

/// <summary>Compares two instances of the specified reference type <typeparamref name="T"/> for reference equality and, if they are equal, replaces the first one.</summary>
/// <param name="location1">The destination, whose value is compared by reference with <paramref name="comparand"/> and possibly replaced.</param>
/// <param name="value">The value that replaces the destination value if the comparison by reference results in equality.</param>
/// <param name="comparand">The object that is compared by reference to the value at <paramref name="location1"/>.</param>
/// <returns>The original value in <paramref name="location1"/>.</returns>
/// <exception cref="NullReferenceException">The address of <paramref name="location1"/> is a null pointer.</exception>
/// <typeparam name="T">The type to be used for <paramref name="location1"/>, <paramref name="value"/>, and <paramref name="comparand"/>. This type must be a reference type.</typeparam>
[Intrinsic]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
[return: NotNullIfNotNull(nameof(location1))]
public static T CompareExchange<T>(ref T location1, T value, T comparand) where T : class? =>
Unsafe.As<T>(CompareExchange(ref Unsafe.As<T, object?>(ref location1), value, comparand));
#endregion

#region Add
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ private static unsafe void ClrReportEvent(string eventSource, short type, ushort
Interop.Advapi32.DeregisterEventSource(handle);
}

private static byte s_once;
private static bool s_once;

public static bool ShouldLogInEventLog
{
Expand All @@ -180,7 +180,7 @@ public static bool ShouldLogInEventLog
if (Interop.Kernel32.IsDebuggerPresent())
return false;

if (s_once == 1 || Interlocked.Exchange(ref s_once, 1) == 1)
if (s_once || Interlocked.Exchange(ref s_once, true))
return false;

return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,6 @@ public static long CompareExchange(ref long location1, long value, long comparan
#endif
}

[Intrinsic]
[return: NotNullIfNotNull(nameof(location1))]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static T CompareExchange<T>(ref T location1, T value, T comparand) where T : class?
{
return Unsafe.As<T>(CompareExchange(ref Unsafe.As<T, object?>(ref location1), value, comparand));
}

[Intrinsic]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
[return: NotNullIfNotNull(nameof(location1))]
Expand Down Expand Up @@ -92,16 +84,6 @@ public static long Exchange(ref long location1, long value)
#endif
}

[Intrinsic]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
[return: NotNullIfNotNull(nameof(location1))]
public static T Exchange<T>([NotNullIfNotNull(nameof(value))] ref T location1, T value) where T : class?
{
if (Unsafe.IsNullRef(ref location1))
ThrowHelper.ThrowNullReferenceException();
return Unsafe.As<T>(RuntimeImports.InterlockedExchange(ref Unsafe.As<T, object?>(ref location1), value));
}

[Intrinsic]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
[return: NotNullIfNotNull(nameof(location1))]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,45 @@ public static MethodIL EmitIL(
if (compilationModuleGroup.ContainsType(method.OwningType))
#endif // READYTORUN
{
TypeDesc objectType = method.Context.GetWellKnownType(WellKnownType.Object);
MethodDesc compareExchangeObject = method.OwningType.GetKnownMethod("CompareExchange",
new MethodSignature(
MethodSignatureFlags.Static,
genericParameterCount: 0,
returnType: objectType,
parameters: new TypeDesc[] { objectType.MakeByRefType(), objectType, objectType }));

ILEmitter emit = new ILEmitter();
ILCodeStream codeStream = emit.NewCodeStream();
codeStream.EmitLdArg(0);
codeStream.EmitLdArg(1);
codeStream.EmitLdArg(2);
codeStream.Emit(ILOpcode.call, emit.NewToken(compareExchangeObject));
codeStream.Emit(ILOpcode.ret);
return emit.Link(method);
// Rewrite the generic Interlocked.CompareExchange<T> to be a call to one of the non-generic overloads.
TypeDesc ceArgType = null;

TypeDesc tType = method.Instantiation[0];
if (!tType.IsValueType)
{
ceArgType = method.Context.GetWellKnownType(WellKnownType.Object);
}
else if (tType.IsPrimitive || tType.IsEnum)
{
int size = tType.GetElementSize().AsInt;
Debug.Assert(size is 1 or 2 or 4 or 8);
ceArgType = size switch
{
1 => method.Context.GetWellKnownType(WellKnownType.Byte),
2 => method.Context.GetWellKnownType(WellKnownType.UInt16),
4 => method.Context.GetWellKnownType(WellKnownType.UInt32),
_ => method.Context.GetWellKnownType(WellKnownType.UInt64),
};
}

if (ceArgType is not null)
{
MethodDesc compareExchangeNonGeneric = method.OwningType.GetKnownMethod("CompareExchange",
new MethodSignature(
MethodSignatureFlags.Static,
genericParameterCount: 0,
returnType: ceArgType,
parameters: [ceArgType.MakeByRefType(), ceArgType, ceArgType]));

ILEmitter emit = new ILEmitter();
ILCodeStream codeStream = emit.NewCodeStream();
codeStream.EmitLdArg(0);
codeStream.EmitLdArg(1);
codeStream.EmitLdArg(2);
codeStream.Emit(ILOpcode.call, emit.NewToken(compareExchangeNonGeneric));
codeStream.Emit(ILOpcode.ret);
return emit.Link(method);
}
}
}

Expand Down
4 changes: 4 additions & 0 deletions src/coreclr/vm/corelib.h
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,10 @@ DEFINE_METHOD(MEMORY_MARSHAL, GET_ARRAY_DATA_REFERENCE_MDARRAY, GetArrayDa
DEFINE_CLASS(INTERLOCKED, Threading, Interlocked)
DEFINE_METHOD(INTERLOCKED, COMPARE_EXCHANGE_T, CompareExchange, GM_RefT_T_T_RetT)
DEFINE_METHOD(INTERLOCKED, COMPARE_EXCHANGE_OBJECT,CompareExchange, SM_RefObject_Object_Object_RetObject)
DEFINE_METHOD(INTERLOCKED, COMPARE_EXCHANGE_BYTE, CompareExchange, SM_RefByte_Byte_Byte_RetByte)
DEFINE_METHOD(INTERLOCKED, COMPARE_EXCHANGE_USHRT, CompareExchange, SM_RefUShrt_UShrt_UShrt_RetUShrt)
DEFINE_METHOD(INTERLOCKED, COMPARE_EXCHANGE_UINT, CompareExchange, SM_RefUInt_UInt_UInt_RetUInt)
DEFINE_METHOD(INTERLOCKED, COMPARE_EXCHANGE_ULONG, CompareExchange, SM_RefULong_ULong_ULong_RetULong)

DEFINE_CLASS(RAW_DATA, CompilerServices, RawData)
DEFINE_FIELD(RAW_DATA, DATA, Data)
Expand Down
42 changes: 34 additions & 8 deletions src/coreclr/vm/jitinterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7250,8 +7250,34 @@ bool getILIntrinsicImplementationForInterlocked(MethodDesc * ftn,
if (ftn->GetMemberDef() != CoreLibBinder::GetMethod(METHOD__INTERLOCKED__COMPARE_EXCHANGE_T)->GetMemberDef())
return false;

// Get MethodDesc for non-generic System.Threading.Interlocked.CompareExchange()
MethodDesc* cmpxchgObject = CoreLibBinder::GetMethod(METHOD__INTERLOCKED__COMPARE_EXCHANGE_OBJECT);
// Determine the type of the generic T method parameter
_ASSERTE(ftn->HasMethodInstantiation());
_ASSERTE(ftn->GetNumGenericMethodArgs() == 1);
TypeHandle typeHandle = ftn->GetMethodInstantiation()[0];
MethodTable* methodTable = typeHandle.GetMethodTable();

MethodDesc* cmpxchg;

// Based on the generic method parameter, determine which overload of CompareExchange
// to delegate to, or if we can't handle the type at all.
if (!typeHandle.IsValueType())
{
cmpxchg = CoreLibBinder::GetMethod(METHOD__INTERLOCKED__COMPARE_EXCHANGE_OBJECT);
}
else if (CorTypeInfo::IsPrimitiveType(typeHandle.GetVerifierCorElementType()))
{
UINT size = typeHandle.GetSize();
_ASSERTE(size == 1 || size == 2 || size == 4 || size == 8);
cmpxchg =
size == 1 ? CoreLibBinder::GetMethod(METHOD__INTERLOCKED__COMPARE_EXCHANGE_BYTE) :
size == 2 ? CoreLibBinder::GetMethod(METHOD__INTERLOCKED__COMPARE_EXCHANGE_USHRT) :
size == 4 ? CoreLibBinder::GetMethod(METHOD__INTERLOCKED__COMPARE_EXCHANGE_UINT) :
CoreLibBinder::GetMethod(METHOD__INTERLOCKED__COMPARE_EXCHANGE_ULONG);
}
else
{
return false;
}

// Setup up the body of the method
static BYTE il[] = {
Expand All @@ -7262,12 +7288,12 @@ bool getILIntrinsicImplementationForInterlocked(MethodDesc * ftn,
CEE_RET
};

// Get the token for non-generic System.Threading.Interlocked.CompareExchange(), and patch [target]
mdMethodDef cmpxchgObjectToken = cmpxchgObject->GetMemberDef();
il[4] = (BYTE)((int)cmpxchgObjectToken >> 0);
il[5] = (BYTE)((int)cmpxchgObjectToken >> 8);
il[6] = (BYTE)((int)cmpxchgObjectToken >> 16);
il[7] = (BYTE)((int)cmpxchgObjectToken >> 24);
// Get the token for the relevant System.Threading.Interlocked.CompareExchange overload, and patch [target]
mdMethodDef cmpxchgToken = cmpxchg->GetMemberDef();
il[4] = (BYTE)((int)cmpxchgToken >> 0);
il[5] = (BYTE)((int)cmpxchgToken >> 8);
il[6] = (BYTE)((int)cmpxchgToken >> 16);
il[7] = (BYTE)((int)cmpxchgToken >> 24);

// Initialize methInfo
methInfo->ILCode = const_cast<BYTE*>(il);
Expand Down
4 changes: 4 additions & 0 deletions src/coreclr/vm/metasig.h
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,10 @@ DEFINE_METASIG_T(SM(RefDec_RetVoid, r(g(DECIMAL)), v))

DEFINE_METASIG(GM(RefT_T_T_RetT, IMAGE_CEE_CS_CALLCONV_DEFAULT, 1, r(M(0)) M(0) M(0), M(0)))
DEFINE_METASIG(SM(RefObject_Object_Object_RetObject, r(j) j j, j))
DEFINE_METASIG(SM(RefByte_Byte_Byte_RetByte, r(b) b b, b))
DEFINE_METASIG(SM(RefUShrt_UShrt_UShrt_RetUShrt, r(H) H H, H))
DEFINE_METASIG(SM(RefUInt_UInt_UInt_RetUInt, r(K) K K, K))
DEFINE_METASIG(SM(RefULong_ULong_ULong_RetULong, r(L) L L, L))

DEFINE_METASIG_T(SM(RefCleanupWorkListElement_RetVoid, r(C(CLEANUP_WORK_LIST_ELEMENT)), v))
DEFINE_METASIG_T(SM(RefCleanupWorkListElement_SafeHandle_RetIntPtr, r(C(CLEANUP_WORK_LIST_ELEMENT)) C(SAFE_HANDLE), I))
Expand Down
12 changes: 6 additions & 6 deletions src/libraries/Common/src/System/Net/StreamBuffer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -292,15 +292,15 @@ private sealed class ResettableValueTaskSource : IValueTaskSource

private ManualResetValueTaskSourceCore<bool> _waitSource; // mutable struct, do not make this readonly
private CancellationTokenRegistration _waitSourceCancellation;
private int _hasWaiter;
private bool _hasWaiter;

ValueTaskSourceStatus IValueTaskSource.GetStatus(short token) => _waitSource.GetStatus(token);

void IValueTaskSource.OnCompleted(Action<object?> continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags) => _waitSource.OnCompleted(continuation, state, token, flags);

void IValueTaskSource.GetResult(short token)
{
Debug.Assert(_hasWaiter == 0);
Debug.Assert(!_hasWaiter);

// Clean up the registration. This will wait for any in-flight cancellation to complete.
_waitSourceCancellation.Dispose();
Expand All @@ -312,7 +312,7 @@ void IValueTaskSource.GetResult(short token)

public void SignalWaiter()
{
if (Interlocked.Exchange(ref _hasWaiter, 0) == 1)
if (Interlocked.Exchange(ref _hasWaiter, false))
{
_waitSource.SetResult(true);
}
Expand All @@ -322,21 +322,21 @@ private void CancelWaiter(CancellationToken cancellationToken)
{
Debug.Assert(cancellationToken.IsCancellationRequested);

if (Interlocked.Exchange(ref _hasWaiter, 0) == 1)
if (Interlocked.Exchange(ref _hasWaiter, false))
{
_waitSource.SetException(ExceptionDispatchInfo.SetCurrentStackTrace(new OperationCanceledException(cancellationToken)));
}
}

public void Reset()
{
if (_hasWaiter != 0)
if (_hasWaiter)
{
throw new InvalidOperationException("Concurrent use is not supported");
}

_waitSource.Reset();
Volatile.Write(ref _hasWaiter, 1);
Volatile.Write(ref _hasWaiter, true);
}

public void Wait()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ public class WebSocketStream : Stream
// Used by the class to indicate that the stream is writable.
private bool _writeable;

// Whether Dispose has been called. 0 == false, 1 == true
private int _disposed;
// Whether Dispose has been called.
private bool _disposed;

public WebSocketStream(WebSocket socket)
: this(socket, FileAccess.ReadWrite, ownsSocket: false)
Expand Down Expand Up @@ -140,7 +140,7 @@ public void Close(int timeout)

protected override void Dispose(bool disposing)
{
if (Interlocked.Exchange(ref _disposed, 1) != 0)
if (Interlocked.Exchange(ref _disposed, true))
{
return;
}
Expand Down Expand Up @@ -269,7 +269,7 @@ public override void SetLength(long value)

private void ThrowIfDisposed()
{
ObjectDisposedException.ThrowIf(_disposed != 0, this);
ObjectDisposedException.ThrowIf(_disposed, this);
}

private static IOException WrapException(string resourceFormatString, Exception innerException)
Expand Down
Loading

0 comments on commit b23c285

Please sign in to comment.