Skip to content

Commit

Permalink
Fix CallFromNativeContract (#2051)
Browse files Browse the repository at this point in the history
  • Loading branch information
erikzhang authored Nov 8, 2020
1 parent 402e9b1 commit f26db56
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 17 deletions.
16 changes: 8 additions & 8 deletions src/neo/SmartContract/ApplicationEngine.Contract.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ protected internal void CreateContract(byte[] script, byte[] manifest)

ContractMethodDescriptor md = contract.Manifest.Abi.GetMethod("_deploy");
if (md != null)
CallContractInternal(contract, md, new Array(ReferenceCounter) { false }, CallFlags.All, CheckReturnType.EnsureIsEmpty);
CallContractInternal(contract, md, new Array(ReferenceCounter) { false }, CallFlags.All, ReturnTypeConvention.EnsureIsEmpty);
}

protected internal void UpdateContract(byte[] script, byte[] manifest)
Expand Down Expand Up @@ -103,7 +103,7 @@ protected internal void UpdateContract(byte[] script, byte[] manifest)
{
ContractMethodDescriptor md = contract.Manifest.Abi.GetMethod("_deploy");
if (md != null)
CallContractInternal(contract, md, new Array(ReferenceCounter) { true }, CallFlags.All, CheckReturnType.EnsureIsEmpty);
CallContractInternal(contract, md, new Array(ReferenceCounter) { true }, CallFlags.All, ReturnTypeConvention.EnsureIsEmpty);
}
}

Expand All @@ -120,17 +120,17 @@ protected internal void DestroyContract()

protected internal void CallContract(UInt160 contractHash, string method, Array args)
{
CallContractInternal(contractHash, method, args, CallFlags.All);
CallContractEx(contractHash, method, args, CallFlags.All);
}

protected internal void CallContractEx(UInt160 contractHash, string method, Array args, CallFlags callFlags)
{
if ((callFlags & ~CallFlags.All) != 0)
throw new ArgumentOutOfRangeException(nameof(callFlags));
CallContractInternal(contractHash, method, args, callFlags);
CallContractInternal(contractHash, method, args, callFlags, ReturnTypeConvention.EnsureNotEmpty);
}

private void CallContractInternal(UInt160 contractHash, string method, Array args, CallFlags flags)
private void CallContractInternal(UInt160 contractHash, string method, Array args, CallFlags flags, ReturnTypeConvention convention)
{
if (method.StartsWith('_')) throw new ArgumentException($"Invalid Method Name: {method}");

Expand All @@ -143,10 +143,10 @@ private void CallContractInternal(UInt160 contractHash, string method, Array arg
if (currentManifest != null && !currentManifest.CanCall(contract.Manifest, method))
throw new InvalidOperationException($"Cannot Call Method {method} Of Contract {contractHash} From Contract {CurrentScriptHash}");

CallContractInternal(contract, md, args, flags, CheckReturnType.EnsureNotEmpty);
CallContractInternal(contract, md, args, flags, convention);
}

private void CallContractInternal(ContractState contract, ContractMethodDescriptor method, Array args, CallFlags flags, CheckReturnType checkReturnValue)
private void CallContractInternal(ContractState contract, ContractMethodDescriptor method, Array args, CallFlags flags, ReturnTypeConvention convention)
{
if (invocationCounter.TryGetValue(contract.ScriptHash, out var counter))
{
Expand All @@ -157,7 +157,7 @@ private void CallContractInternal(ContractState contract, ContractMethodDescript
invocationCounter[contract.ScriptHash] = 1;
}

GetInvocationState(CurrentContext).NeedCheckReturnValue = checkReturnValue;
GetInvocationState(CurrentContext).Convention = convention;

ExecutionContextState state = CurrentContext.GetState<ExecutionContextState>();
UInt160 callingScriptHash = state.ScriptHash;
Expand Down
16 changes: 8 additions & 8 deletions src/neo/SmartContract/ApplicationEngine.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace Neo.SmartContract
{
public partial class ApplicationEngine : ExecutionEngine
{
private enum CheckReturnType : byte
private enum ReturnTypeConvention : byte
{
None = 0,
EnsureIsEmpty = 1,
Expand All @@ -29,7 +29,7 @@ private class InvocationState
{
public Type ReturnType;
public Delegate Callback;
public CheckReturnType NeedCheckReturnValue;
public ReturnTypeConvention Convention;
}

/// <summary>
Expand Down Expand Up @@ -87,15 +87,15 @@ internal void CallFromNativeContract(Action onComplete, UInt160 hash, string met
InvocationState state = GetInvocationState(CurrentContext);
state.ReturnType = typeof(void);
state.Callback = onComplete;
CallContract(hash, method, new VMArray(ReferenceCounter, args));
CallContractInternal(hash, method, new VMArray(ReferenceCounter, args), CallFlags.All, ReturnTypeConvention.EnsureIsEmpty);
}

internal void CallFromNativeContract<T>(Action<T> onComplete, UInt160 hash, string method, params StackItem[] args)
{
InvocationState state = GetInvocationState(CurrentContext);
state.ReturnType = typeof(T);
state.Callback = onComplete;
CallContract(hash, method, new VMArray(ReferenceCounter, args));
CallContractInternal(hash, method, new VMArray(ReferenceCounter, args), CallFlags.All, ReturnTypeConvention.EnsureNotEmpty);
}

protected override void ContextUnloaded(ExecutionContext context)
Expand All @@ -104,15 +104,15 @@ protected override void ContextUnloaded(ExecutionContext context)
if (!(UncaughtException is null)) return;
if (invocationStates.Count == 0) return;
if (!invocationStates.Remove(CurrentContext, out InvocationState state)) return;
switch (state.NeedCheckReturnValue)
switch (state.Convention)
{
case CheckReturnType.EnsureIsEmpty:
case ReturnTypeConvention.EnsureIsEmpty:
{
if (context.EvaluationStack.Count != 0)
throw new InvalidOperationException();
break;
}
case CheckReturnType.EnsureNotEmpty:
case ReturnTypeConvention.EnsureNotEmpty:
{
if (context.EvaluationStack.Count == 0)
Push(StackItem.Null);
Expand Down Expand Up @@ -164,7 +164,7 @@ protected override void LoadContext(ExecutionContext context)
internal void LoadContext(ExecutionContext context, bool checkReturnValue)
{
if (checkReturnValue)
GetInvocationState(CurrentContext).NeedCheckReturnValue = CheckReturnType.EnsureNotEmpty;
GetInvocationState(CurrentContext).Convention = ReturnTypeConvention.EnsureNotEmpty;
LoadContext(context);
}

Expand Down
2 changes: 1 addition & 1 deletion src/neo/SmartContract/Native/Oracle/OracleContract.cs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ private void Finish(ApplicationEngine engine)
if (request == null) throw new ArgumentException("Oracle request was not found");
engine.SendNotification(Hash, "OracleResponse", new VM.Types.Array { response.Id, request.OriginalTxid.ToArray() });
StackItem userData = BinarySerializer.Deserialize(request.UserData, engine.Limits.MaxStackSize, engine.Limits.MaxItemSize, engine.ReferenceCounter);
engine.CallFromNativeContract(null, request.CallbackContract, request.CallbackMethod, request.Url, userData, (int)response.Code, response.Result);
engine.CallFromNativeContract(() => { }, request.CallbackContract, request.CallbackMethod, request.Url, userData, (int)response.Code, response.Result);
}

private UInt256 GetOriginalTxid(ApplicationEngine engine)
Expand Down

0 comments on commit f26db56

Please sign in to comment.