diff --git a/Harmony/Documentation/articles/patching-injections.md b/Harmony/Documentation/articles/patching-injections.md index db860873..b0a26e14 100644 --- a/Harmony/Documentation/articles/patching-injections.md +++ b/Harmony/Documentation/articles/patching-injections.md @@ -14,6 +14,10 @@ Patches can use an argument called **`__instance`** to access the instance value Patches can use an argument called **`__result`** to access the returned value. The type must match the return type of the original or be assignable from it. For prefixes, as the original method hasn't run yet, the value of `__result` is the default for that type. For most reference types, that would be `null`. If you wish to **alter** the `__result`, you need to define it **by reference** like `ref string name`. +### __resultRef + +Patches can use an argument called **`__resultRef`** to alter the "**ref return**" reference itself. The type must be `RefResult` by reference, where `T` must match the return type of the original, without `ref` modifier. For example `ref RefResult __resultRef`. + ### __state Patches can use an argument called **`__state`** to store information in the prefix method that can be accessed again in the postfix method. Think of it as a local variable. It can be any type and you are responsible to initialize its value in the prefix. **Note:** It only works if both patches are defined in the same class. diff --git a/Harmony/Extras/RefResult.cs b/Harmony/Extras/RefResult.cs new file mode 100644 index 00000000..c0cbe395 --- /dev/null +++ b/Harmony/Extras/RefResult.cs @@ -0,0 +1,5 @@ +namespace HarmonyLib; + +/// Delegate type for "ref return" injections +/// Return type of the original method, without ref modifier +public delegate ref T RefResult(); diff --git a/Harmony/Internal/MethodPatcher.cs b/Harmony/Internal/MethodPatcher.cs index 448eabce..c22a64ef 100644 --- a/Harmony/Internal/MethodPatcher.cs +++ b/Harmony/Internal/MethodPatcher.cs @@ -15,6 +15,7 @@ internal class MethodPatcher const string ORIGINAL_METHOD_PARAM = "__originalMethod"; const string ARGS_ARRAY_VAR = "__args"; const string RESULT_VAR = "__result"; + const string RESULT_REF_VAR = "__resultRef"; const string STATE_VAR = "__state"; const string EXCEPTION_VAR = "__exception"; const string RUN_ORIGINAL_VAR = "__runOriginal"; @@ -76,6 +77,19 @@ internal MethodInfo CreateReplacement(out Dictionary final privateVars[RESULT_VAR] = resultVariable; } + if (fixes.Any(fix => fix.GetParameters().Any(p => p.Name == RESULT_REF_VAR))) + { + if(returnType.IsByRef) + { + var resultRefVariable = il.DeclareLocal( + typeof(RefResult<>).MakeGenericType(returnType.GetElementType()) + ); + emitter.Emit(OpCodes.Ldnull); + emitter.Emit(OpCodes.Stloc, resultRefVariable); + privateVars[RESULT_REF_VAR] = resultRefVariable; + } + } + LocalBuilder argsArrayVariable = null; if (fixes.Any(fix => fix.GetParameters().Any(p => p.Name == ARGS_ARRAY_VAR))) { @@ -432,10 +446,11 @@ bool EmitOriginalBaseMethod() return true; } - void EmitCallParameter(MethodInfo patch, Dictionary variables, LocalBuilder runOriginalVariable, bool allowFirsParamPassthrough, out LocalBuilder tmpInstanceBoxingVar, out LocalBuilder tmpObjectVar, List> tmpBoxVars) + void EmitCallParameter(MethodInfo patch, Dictionary variables, LocalBuilder runOriginalVariable, bool allowFirsParamPassthrough, out LocalBuilder tmpInstanceBoxingVar, out LocalBuilder tmpObjectVar, out bool refResultUsed, List> tmpBoxVars) { tmpInstanceBoxingVar = null; tmpObjectVar = null; + refResultUsed = false; var isInstance = original.IsStatic is false; var originalParameters = original.GetParameters(); @@ -474,10 +489,10 @@ void EmitCallParameter(MethodInfo patch, Dictionary variab else { var paramType = patchParam.ParameterType; - + var parameterIsRef = paramType.IsByRef; var parameterIsObject = paramType == typeof(object) || paramType == typeof(object).MakeByRefType(); - + if (AccessTools.IsStruct(originalType)) { if (parameterIsObject) @@ -571,7 +586,6 @@ void EmitCallParameter(MethodInfo patch, Dictionary variab // treat __result var special if (patchParam.Name == RESULT_VAR) { - var returnType = AccessTools.GetReturnedType(original); if (returnType == typeof(void)) throw new Exception($"Cannot get result from void method {original.FullDescription()}"); var resultType = patchParam.ParameterType; @@ -597,6 +611,25 @@ void EmitCallParameter(MethodInfo patch, Dictionary variab continue; } + // treat __resultRef delegate special + if (patchParam.Name == RESULT_REF_VAR) + { + if (!returnType.IsByRef) + throw new Exception( + $"Cannot use {RESULT_REF_VAR} with non-ref return type {returnType.FullName} of method {original.FullDescription()}"); + + var resultType = patchParam.ParameterType; + var expectedTypeRef = typeof(RefResult<>).MakeGenericType(returnType.GetElementType()).MakeByRefType(); + if (resultType != expectedTypeRef) + throw new Exception( + $"Wrong type of {RESULT_REF_VAR} for method {original.FullDescription()}. Expected {expectedTypeRef.FullName}, got {resultType.FullName}"); + + emitter.Emit(OpCodes.Ldloca, variables[RESULT_REF_VAR]); + + refResultUsed = true; + continue; + } + // any other declared variables if (variables.TryGetValue(patchParam.Name, out var localBuilder)) { @@ -763,7 +796,7 @@ void AddPrefixes(Dictionary variables, LocalBuilder runOri } var tmpBoxVars = new List>(); - EmitCallParameter(fix, variables, runOriginalVariable, false, out var tmpInstanceBoxingVar, out var tmpObjectVar, tmpBoxVars); + EmitCallParameter(fix, variables, runOriginalVariable, false, out var tmpInstanceBoxingVar, out var tmpObjectVar, out var refResultUsed, tmpBoxVars); emitter.Emit(OpCodes.Call, fix); if (fix.GetParameters().Any(p => p.Name == ARGS_ARRAY_VAR)) RestoreArgumentArray(variables); @@ -774,7 +807,22 @@ void AddPrefixes(Dictionary variables, LocalBuilder runOri emitter.Emit(OpCodes.Unbox_Any, original.DeclaringType); emitter.Emit(OpCodes.Stobj, original.DeclaringType); } - if (tmpObjectVar != null) + if (refResultUsed) + { + var label = il.DefineLabel(); + emitter.Emit(OpCodes.Ldloc, variables[RESULT_REF_VAR]); + emitter.Emit(OpCodes.Brfalse_S, label); + + emitter.Emit(OpCodes.Ldloc, variables[RESULT_REF_VAR]); + emitter.Emit(OpCodes.Callvirt, AccessTools.Method(variables[RESULT_REF_VAR].LocalType, "Invoke")); + emitter.Emit(OpCodes.Stloc, variables[RESULT_VAR]); + emitter.Emit(OpCodes.Ldnull); + emitter.Emit(OpCodes.Stloc, variables[RESULT_REF_VAR]); + + emitter.MarkLabel(label); + emitter.Emit(OpCodes.Nop); + } + else if (tmpObjectVar != null) { emitter.Emit(OpCodes.Ldloc, tmpObjectVar); emitter.Emit(OpCodes.Unbox_Any, AccessTools.GetReturnedType(original)); @@ -815,7 +863,7 @@ bool AddPostfixes(Dictionary variables, LocalBuilder runOr // throw new Exception("Methods without body cannot have postfixes. Use a transpiler instead."); var tmpBoxVars = new List>(); - EmitCallParameter(fix, variables, runOriginalVariable, true, out var tmpInstanceBoxingVar, out var tmpObjectVar, tmpBoxVars); + EmitCallParameter(fix, variables, runOriginalVariable, true, out var tmpInstanceBoxingVar, out var tmpObjectVar, out var refResultUsed, tmpBoxVars); emitter.Emit(OpCodes.Call, fix); if (fix.GetParameters().Any(p => p.Name == ARGS_ARRAY_VAR)) RestoreArgumentArray(variables); @@ -826,7 +874,22 @@ bool AddPostfixes(Dictionary variables, LocalBuilder runOr emitter.Emit(OpCodes.Unbox_Any, original.DeclaringType); emitter.Emit(OpCodes.Stobj, original.DeclaringType); } - if (tmpObjectVar != null) + if (refResultUsed) + { + var label = il.DefineLabel(); + emitter.Emit(OpCodes.Ldloc, variables[RESULT_REF_VAR]); + emitter.Emit(OpCodes.Brfalse_S, label); + + emitter.Emit(OpCodes.Ldloc, variables[RESULT_REF_VAR]); + emitter.Emit(OpCodes.Callvirt, AccessTools.Method(variables[RESULT_REF_VAR].LocalType, "Invoke")); + emitter.Emit(OpCodes.Stloc, variables[RESULT_VAR]); + emitter.Emit(OpCodes.Ldnull); + emitter.Emit(OpCodes.Stloc, variables[RESULT_REF_VAR]); + + emitter.MarkLabel(label); + emitter.Emit(OpCodes.Nop); + } + else if (tmpObjectVar != null) { emitter.Emit(OpCodes.Ldloc, tmpObjectVar); emitter.Emit(OpCodes.Unbox_Any, AccessTools.GetReturnedType(original)); @@ -871,7 +934,7 @@ bool AddFinalizers(Dictionary variables, LocalBuilder runO emitter.MarkBlockBefore(new ExceptionBlock(ExceptionBlockType.BeginExceptionBlock), out var label); var tmpBoxVars = new List>(); - EmitCallParameter(fix, variables, runOriginalVariable, false, out var tmpInstanceBoxingVar, out var tmpObjectVar, tmpBoxVars); + EmitCallParameter(fix, variables, runOriginalVariable, false, out var tmpInstanceBoxingVar, out var tmpObjectVar, out var refResultUsed, tmpBoxVars); emitter.Emit(OpCodes.Call, fix); if (fix.GetParameters().Any(p => p.Name == ARGS_ARRAY_VAR)) RestoreArgumentArray(variables); @@ -882,7 +945,22 @@ bool AddFinalizers(Dictionary variables, LocalBuilder runO emitter.Emit(OpCodes.Unbox_Any, original.DeclaringType); emitter.Emit(OpCodes.Stobj, original.DeclaringType); } - if (tmpObjectVar != null) + if (refResultUsed) + { + var label = il.DefineLabel(); + emitter.Emit(OpCodes.Ldloc, variables[RESULT_REF_VAR]); + emitter.Emit(OpCodes.Brfalse_S, label); + + emitter.Emit(OpCodes.Ldloc, variables[RESULT_REF_VAR]); + emitter.Emit(OpCodes.Callvirt, AccessTools.Method(variables[RESULT_REF_VAR].LocalType, "Invoke")); + emitter.Emit(OpCodes.Stloc, variables[RESULT_VAR]); + emitter.Emit(OpCodes.Ldnull); + emitter.Emit(OpCodes.Stloc, variables[RESULT_REF_VAR]); + + emitter.MarkLabel(label); + emitter.Emit(OpCodes.Nop); + } + else if (tmpObjectVar != null) { emitter.Emit(OpCodes.Ldloc, tmpObjectVar); emitter.Emit(OpCodes.Unbox_Any, AccessTools.GetReturnedType(original)); diff --git a/HarmonyTests/Patching/Assets/Specials.cs b/HarmonyTests/Patching/Assets/Specials.cs index a78386e7..067ea2f6 100644 --- a/HarmonyTests/Patching/Assets/Specials.cs +++ b/HarmonyTests/Patching/Assets/Specials.cs @@ -25,6 +25,70 @@ public static void ResetTest() // ----------------------------------------------------- + public class ResultRefStruct + { + // ReSharper disable FieldCanBeMadeReadOnly.Global + public static int[] numbersPrefix = [0, 0]; + public static int[] numbersPostfix = [0, 0]; + public static int[] numbersPostfixWithNull = [0]; + public static int[] numbersFinalizer = [0]; + public static int[] numbersMixed = [0, 0]; + // ReSharper restore FieldCanBeMadeReadOnly.Global + + [MethodImpl(MethodImplOptions.NoInlining)] + public ref int ToPrefix() => ref numbersPrefix[0]; + + [MethodImpl(MethodImplOptions.NoInlining)] + public ref int ToPostfix() => ref numbersPostfix[0]; + + [MethodImpl(MethodImplOptions.NoInlining)] + public ref int ToPostfixWithNull() => ref numbersPostfixWithNull[0]; + + [MethodImpl(MethodImplOptions.NoInlining)] + public ref int ToFinalizer() => throw new Exception(); + + [MethodImpl(MethodImplOptions.NoInlining)] + public ref int ToMixed() => ref numbersMixed[0]; + } + + [HarmonyPatch(typeof(ResultRefStruct))] + public class ResultRefStruct_Patch + { + [HarmonyPatch(nameof(ResultRefStruct.ToPrefix))] + [HarmonyPrefix] + public static bool Prefix(ref RefResult __resultRef) + { + __resultRef = () => ref ResultRefStruct.numbersPrefix[1]; + return false; + } + + [HarmonyPatch(nameof(ResultRefStruct.ToPostfix))] + [HarmonyPostfix] + public static void Postfix(ref RefResult __resultRef) => __resultRef = () => ref ResultRefStruct.numbersPostfix[1]; + + [HarmonyPatch(nameof(ResultRefStruct.ToPostfixWithNull))] + [HarmonyPostfix] + public static void PostfixWithNull(ref RefResult __resultRef) => __resultRef = null; + + [HarmonyPatch(nameof(ResultRefStruct.ToFinalizer))] + [HarmonyFinalizer] + public static Exception Finalizer(ref RefResult __resultRef) + { + __resultRef = () => ref ResultRefStruct.numbersFinalizer[0]; + return null; + } + + [HarmonyPatch(nameof(ResultRefStruct.ToMixed))] + [HarmonyPostfix] + public static void PostfixMixed(ref int __result, ref RefResult __resultRef) + { + __result = 42; + __resultRef = () => ref ResultRefStruct.numbersMixed[1]; + } + } + + // ----------------------------------------------------- + public class DeadEndCode { [MethodImpl(MethodImplOptions.NoInlining)] diff --git a/HarmonyTests/Patching/Specials.cs b/HarmonyTests/Patching/Specials.cs index 8079b39c..5daa2214 100644 --- a/HarmonyTests/Patching/Specials.cs +++ b/HarmonyTests/Patching/Specials.cs @@ -50,6 +50,49 @@ public void Test_HttpWebRequestGetResponse() Assert.True(HttpWebRequestPatches.postfixCalled, "Postfix not called"); } + [Test] + public void Test_PatchResultRef() + { + ResultRefStruct.numbersPrefix = [0, 0]; + ResultRefStruct.numbersPostfix = [0, 0]; + ResultRefStruct.numbersPostfixWithNull = [0]; + ResultRefStruct.numbersFinalizer = [0]; + ResultRefStruct.numbersMixed = [0, 0]; + + var test = new ResultRefStruct(); + + var instance = new Harmony("result-ref-test"); + Assert.NotNull(instance); + var processor = instance.CreateClassProcessor(typeof(ResultRefStruct_Patch)); + Assert.NotNull(processor, "processor"); + + test.ToPrefix() = 1; + test.ToPostfix() = 2; + test.ToPostfixWithNull() = 3; + test.ToMixed() = 5; + + Assert.AreEqual(new[] { 1, 0 }, ResultRefStruct.numbersPrefix); + Assert.AreEqual(new[] { 2, 0 }, ResultRefStruct.numbersPostfix); + Assert.AreEqual(new[] { 3 }, ResultRefStruct.numbersPostfixWithNull); + Assert.Throws(() => test.ToFinalizer(), "ToFinalizer method does not throw"); + Assert.AreEqual(new[] { 5, 0 }, ResultRefStruct.numbersMixed); + + var replacements = processor.Patch(); + Assert.NotNull(replacements, "replacements"); + + test.ToPrefix() = -1; + test.ToPostfix() = -2; + test.ToPostfixWithNull() = -3; + test.ToFinalizer() = -4; + test.ToMixed() = -5; + + Assert.AreEqual(new[] { 1, -1 }, ResultRefStruct.numbersPrefix); + Assert.AreEqual(new[] { 2, -2 }, ResultRefStruct.numbersPostfix); + Assert.AreEqual(new[] { -3 }, ResultRefStruct.numbersPostfixWithNull); + Assert.AreEqual(new[] { -4 }, ResultRefStruct.numbersFinalizer); + Assert.AreEqual(new[] { 42, -5 }, ResultRefStruct.numbersMixed); + } + [Test] public void Test_Patch_ConcreteClass() { @@ -327,7 +370,7 @@ public void Test_PatchExternalMethod() Assert.NotNull(patcher, "Patch processor"); _ = patcher.Patch(); } - + [Test] public void Test_PatchEventHandler() { @@ -348,7 +391,7 @@ public void Test_PatchEventHandler() new EventHandlerTestClass().Run(); Console.WriteLine($"### EventHandlerTestClass AFTER"); } - + [Test] public void Test_PatchMarshalledClass() { @@ -369,7 +412,7 @@ public void Test_PatchMarshalledClass() new MarshalledTestClass().Run(); Console.WriteLine($"### MarshalledTestClass AFTER"); } - + [Test] public void Test_MarshalledWithEventHandler1() { @@ -390,7 +433,7 @@ public void Test_MarshalledWithEventHandler1() new MarshalledWithEventHandlerTest1Class().Run(); Console.WriteLine($"### MarshalledWithEventHandlerTest1 AFTER"); } - + [Test] public void Test_MarshalledWithEventHandler2() {