diff --git a/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/TypePreinit.cs b/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/TypePreinit.cs index ba9a156aad4d3..2b37f94b307c6 100644 --- a/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/TypePreinit.cs +++ b/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/TypePreinit.cs @@ -539,7 +539,8 @@ private Status TryScanMethod(MethodIL methodIL, Value[] parameters, Stack::.ctor(void*, int32) + StackEntry entry = stack.Pop(); + long size = entry.ValueKind switch + { + StackValueKind.Int32 => entry.Value.AsInt32(), + StackValueKind.NativeInt => (context.Target.PointerSize == 4) + ? entry.Value.AsInt32() : entry.Value.AsInt64(), + _ => long.MaxValue + }; + + // Arbitrary limit for allocation size to prevent compiler OOM + if (size < 0 || size > 8192) + return Status.Fail(methodIL.OwningMethod, ILOpcode.localloc); + + opcode = reader.ReadILOpcode(); + if (opcode < ILOpcode.ldc_i4_0 || opcode > ILOpcode.ldc_i4) + return Status.Fail(methodIL.OwningMethod, ILOpcode.localloc); + + int maybeSpanLength = opcode switch + { + ILOpcode.ldc_i4_s => (sbyte)reader.ReadILByte(), + ILOpcode.ldc_i4 => (int)reader.ReadILUInt32(), + _ => opcode - ILOpcode.ldc_i4_0, + }; + + opcode = reader.ReadILOpcode(); + if (opcode != ILOpcode.newobj) + return Status.Fail(methodIL.OwningMethod, ILOpcode.localloc); + + var ctorMethod = (MethodDesc)methodIL.GetObject(reader.ReadILToken()); + if (!TryGetSpanElementType(ctorMethod.OwningType, isReadOnlySpan: false, out MetadataType elementType) + || ctorMethod.Signature.Length != 2 + || !ctorMethod.Signature[0].IsPointer + || !ctorMethod.Signature[1].IsWellKnownType(WellKnownType.Int32) + || maybeSpanLength * elementType.InstanceFieldSize.AsInt != size) + return Status.Fail(methodIL.OwningMethod, ILOpcode.localloc); + + var instance = new ReadOnlySpanValue(elementType, new byte[size], index: 0, (int)size); + stack.PushFromLocation(ctorMethod.OwningType, instance); + } + break; + case ILOpcode.stfld: { FieldDesc field = (FieldDesc)methodIL.GetObject(reader.ReadILToken()); @@ -703,14 +755,17 @@ private Status TryScanMethod(MethodIL methodIL, Value[] parameters, Stack throw new NotImplementedException() // unreachable }; } - else if (value1.ValueKind == StackValueKind.Int64 && value2.ValueKind == StackValueKind.Int64) + else if (value1.ValueKind.WithNormalizedNativeInt(context) == StackValueKind.Int64 && value2.ValueKind.WithNormalizedNativeInt(context) == StackValueKind.Int64) { branchTaken = normalizedOpcode switch { @@ -1136,7 +1215,7 @@ private Status TryScanMethod(MethodIL methodIL, Value[] parameters, Stack throw new NotImplementedException(), // unreachable }; - stack.Push(StackValueKind.Int32, ValueTypeValue.FromInt32(result)); + stack.Push(isNint ? StackValueKind.NativeInt : StackValueKind.Int32, ValueTypeValue.FromInt32(result)); } - else if (value1.ValueKind == StackValueKind.Int64 && value2.ValueKind == StackValueKind.Int64) + else if (value1.ValueKind.WithNormalizedNativeInt(context) == StackValueKind.Int64 && value2.ValueKind.WithNormalizedNativeInt(context) == StackValueKind.Int64) { if (isDivRem && value2.Value.AsInt64() == 0) return Status.Fail(methodIL.OwningMethod, opcode, "Division by zero"); @@ -1279,7 +1361,7 @@ private Status TryScanMethod(MethodIL methodIL, Value[] parameters, Stack throw new NotImplementedException(), // unreachable }; - stack.Push(StackValueKind.Int64, ValueTypeValue.FromInt64(result)); + stack.Push(isNint ? StackValueKind.NativeInt : StackValueKind.Int64, ValueTypeValue.FromInt64(result)); } else if (value1.ValueKind == StackValueKind.Float && value2.ValueKind == StackValueKind.Float) { @@ -1305,7 +1387,32 @@ private Status TryScanMethod(MethodIL methodIL, Value[] parameters, Stack addend.Value.AsInt32(), + _ => context.Target.PointerSize == 8 ? addend.Value.AsInt64() : addend.Value.AsInt32() + }; + + var previousByRef = (ByRefValue)reference.Value; + if (addition > previousByRef.PointedToBytes.Length - previousByRef.PointedToOffset + || addition + previousByRef.PointedToOffset < 0) + return Status.Fail(methodIL.OwningMethod, "Out of range byref access"); + + stack.Push(StackValueKind.ByRef, new ByRefValue(previousByRef.PointedToBytes, (int)(previousByRef.PointedToOffset + addition))); } else { @@ -1599,6 +1706,32 @@ private Status TryScanMethod(MethodIL methodIL, Value[] parameters, Stack addressValue.PointedToBytes.Length - addressValue.PointedToOffset + || sizeBytes > int.MaxValue /* paranoid check that cast to int is legit */) + return Status.Fail(methodIL.OwningMethod, opcode); + + Array.Fill(addressValue.PointedToBytes, (byte)value.Value.AsInt32(), addressValue.PointedToOffset, (int)sizeBytes); + } + break; + default: return Status.Fail(methodIL.OwningMethod, opcode); } @@ -1608,14 +1741,14 @@ private Status TryScanMethod(MethodIL methodIL, Value[] parameters, Stack(), 0, 0); } + else if (TryGetSpanElementType(locationType, isReadOnlySpan: false, out MetadataType spanElementType)) + { + return new ReadOnlySpanValue(spanElementType, Array.Empty(), 0, 0); + } else { Debug.Assert(locationType.IsValueType || locationType.IsPointer || locationType.IsFunctionPointer); @@ -1948,18 +2085,6 @@ public Value PopIntoLocation(TypeDesc locationType) } } - private enum StackValueKind - { - Unknown, - Int32, - Int64, - NativeInt, - Float, - ByRef, - ObjRef, - ValueType, - } - /// /// Represents a field value that can be serialized into a preinitialized blob. /// @@ -2340,10 +2465,19 @@ public bool TryAccessElement(int index, out Value value) public Value GetField(FieldDesc field) { - if (field.Name != "_length") + MetadataType elementType; + if (!TryGetSpanElementType(field.OwningType, isReadOnlySpan: true, out elementType) + && !TryGetSpanElementType(field.OwningType, isReadOnlySpan: false, out elementType)) + ThrowHelper.ThrowInvalidProgramException(); + + if (elementType != _elementType) ThrowHelper.ThrowInvalidProgramException(); - return ValueTypeValue.FromInt32(_length / _elementType.InstanceFieldSize.AsInt); + if (field.Name == "_length") + return ValueTypeValue.FromInt32(_length / _elementType.InstanceFieldSize.AsInt); + + Debug.Assert(field.Name == "_reference"); + return new ByRefValue(_bytes, _index); } public ByRefValue GetFieldAddress(FieldDesc field) @@ -3016,4 +3150,15 @@ public sealed class TypeLoaderAwarePreinitializationPolicy : TypePreinitializati public override bool CanPreinitializeAllConcreteFormsForCanonForm(DefType type) => false; } } + +#pragma warning disable SA1400 // Element 'Extensions' should declare an access modifier + file static class Extensions + { + public static StackValueKind WithNormalizedNativeInt(this StackValueKind kind, TypeSystemContext context) + => kind switch + { + StackValueKind.NativeInt => context.Target.PointerSize == 8 ? StackValueKind.Int64 : StackValueKind.Int32, + _ => kind + }; + } } diff --git a/src/tests/nativeaot/SmokeTests/Preinitialization/Preinitialization.cs b/src/tests/nativeaot/SmokeTests/Preinitialization/Preinitialization.cs index 4d905240fa697..25018eed9eb67 100644 --- a/src/tests/nativeaot/SmokeTests/Preinitialization/Preinitialization.cs +++ b/src/tests/nativeaot/SmokeTests/Preinitialization/Preinitialization.cs @@ -51,11 +51,13 @@ private static int Main() TestInstanceDelegate.Run(); TestStringFields.Run(); TestSharedCode.Run(); + TestSpan.Run(); TestReadOnlySpan.Run(); TestStaticInterfaceMethod.Run(); TestConstrainedCall.Run(); TestTypeHandles.Run(); TestIndirectLoads.Run(); + TestInitBlock.Run(); #else Console.WriteLine("Preinitialization is disabled in multimodule builds for now. Skipping test."); #endif @@ -1054,6 +1056,41 @@ public static void Run() } } +class TestSpan +{ + class StackAlloc + { + public static byte FirstByte; + public static byte LastByte; + public static char FirstChar; + public static char LastChar; + + static StackAlloc() + { + Span s1 = stackalloc byte[8]; + s1.Slice(0, 1)[0] = 42; + s1.Slice(s1.Length - 1, 1)[0] = 100; + FirstByte = s1[0]; + LastByte = s1[7]; + + Span s2 = stackalloc char[8]; + s2.Slice(0, 1)[0] = 'H'; + s2.Slice(s2.Length - 1, 1)[0] = '!'; + FirstChar = s2[0]; + LastChar = s2[7]; + } + } + + public static void Run() + { + Assert.IsPreinitialized(typeof(StackAlloc)); + Assert.AreEqual(42, StackAlloc.FirstByte); + Assert.AreEqual(100, StackAlloc.LastByte); + Assert.AreEqual('H', StackAlloc.FirstChar); + Assert.AreEqual('!', StackAlloc.LastChar); + } +} + class TestReadOnlySpan { class SimpleReadOnlySpanAccess @@ -1259,6 +1296,42 @@ public static void Run() } } +class TestInitBlock +{ + class Simple + { + public static byte Value; + + static Simple() + { + Value = 123; + Unsafe.InitBlockUnaligned(ref Value, 42, 1); + } + } + + class Overrun + { + public static byte Value; + public static byte Pad; + + static Overrun() + { + Value = 123; + Unsafe.InitBlockUnaligned(ref Value, 42, 2); + } + } + + public static void Run() + { + Assert.IsPreinitialized(typeof(Simple)); + Assert.AreEqual(42, Simple.Value); + + Assert.IsLazyInitialized(typeof(Overrun)); + Assert.AreEqual(42, Overrun.Value); + Assert.AreEqual(42, Overrun.Pad); + } +} + static class Assert { [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2070:UnrecognizedReflectionPattern",