Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CallFromNativeContract #2051

Merged
merged 4 commits into from
Nov 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/neo/SmartContract/ApplicationEngine.Contract.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ protected internal void CreateContract(byte[] script, byte[] manifest)

ContractMethodDescriptor md = contract.Manifest.Abi.GetMethod("_deploy");
if (md != null)
CallContractInternal(contract, md, new Array(ReferenceCounter) { false }, CallFlags.All, CheckReturnType.EnsureIsEmpty);
CallContractInternal(contract, md, new Array(ReferenceCounter) { false }, CallFlags.All, ReturnTypeConvention.EnsureIsEmpty);
}

protected internal void UpdateContract(byte[] script, byte[] manifest)
Expand Down Expand Up @@ -103,7 +103,7 @@ protected internal void UpdateContract(byte[] script, byte[] manifest)
{
ContractMethodDescriptor md = contract.Manifest.Abi.GetMethod("_deploy");
if (md != null)
CallContractInternal(contract, md, new Array(ReferenceCounter) { true }, CallFlags.All, CheckReturnType.EnsureIsEmpty);
CallContractInternal(contract, md, new Array(ReferenceCounter) { true }, CallFlags.All, ReturnTypeConvention.EnsureIsEmpty);
}
}

Expand All @@ -120,17 +120,17 @@ protected internal void DestroyContract()

protected internal void CallContract(UInt160 contractHash, string method, Array args)
{
CallContractInternal(contractHash, method, args, CallFlags.All);
CallContractEx(contractHash, method, args, CallFlags.All);
}

protected internal void CallContractEx(UInt160 contractHash, string method, Array args, CallFlags callFlags)
{
if ((callFlags & ~CallFlags.All) != 0)
throw new ArgumentOutOfRangeException(nameof(callFlags));
CallContractInternal(contractHash, method, args, callFlags);
CallContractInternal(contractHash, method, args, callFlags, ReturnTypeConvention.EnsureNotEmpty);
}

private void CallContractInternal(UInt160 contractHash, string method, Array args, CallFlags flags)
private void CallContractInternal(UInt160 contractHash, string method, Array args, CallFlags flags, ReturnTypeConvention convention)
{
if (method.StartsWith('_')) throw new ArgumentException($"Invalid Method Name: {method}");

Expand All @@ -143,10 +143,10 @@ private void CallContractInternal(UInt160 contractHash, string method, Array arg
if (currentManifest != null && !currentManifest.CanCall(contract.Manifest, method))
throw new InvalidOperationException($"Cannot Call Method {method} Of Contract {contractHash} From Contract {CurrentScriptHash}");

CallContractInternal(contract, md, args, flags, CheckReturnType.EnsureNotEmpty);
CallContractInternal(contract, md, args, flags, convention);
}

private void CallContractInternal(ContractState contract, ContractMethodDescriptor method, Array args, CallFlags flags, CheckReturnType checkReturnValue)
private void CallContractInternal(ContractState contract, ContractMethodDescriptor method, Array args, CallFlags flags, ReturnTypeConvention convention)
{
if (invocationCounter.TryGetValue(contract.ScriptHash, out var counter))
{
Expand All @@ -157,7 +157,7 @@ private void CallContractInternal(ContractState contract, ContractMethodDescript
invocationCounter[contract.ScriptHash] = 1;
}

GetInvocationState(CurrentContext).NeedCheckReturnValue = checkReturnValue;
GetInvocationState(CurrentContext).Convention = convention;

ExecutionContextState state = CurrentContext.GetState<ExecutionContextState>();
UInt160 callingScriptHash = state.ScriptHash;
Expand Down
16 changes: 8 additions & 8 deletions src/neo/SmartContract/ApplicationEngine.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace Neo.SmartContract
{
public partial class ApplicationEngine : ExecutionEngine
{
private enum CheckReturnType : byte
private enum ReturnTypeConvention : byte
{
None = 0,
EnsureIsEmpty = 1,
Expand All @@ -29,7 +29,7 @@ private class InvocationState
{
public Type ReturnType;
public Delegate Callback;
public CheckReturnType NeedCheckReturnValue;
public ReturnTypeConvention Convention;
}

/// <summary>
Expand Down Expand Up @@ -87,15 +87,15 @@ internal void CallFromNativeContract(Action onComplete, UInt160 hash, string met
InvocationState state = GetInvocationState(CurrentContext);
state.ReturnType = typeof(void);
state.Callback = onComplete;
CallContract(hash, method, new VMArray(ReferenceCounter, args));
CallContractInternal(hash, method, new VMArray(ReferenceCounter, args), CallFlags.All, ReturnTypeConvention.EnsureIsEmpty);
}

internal void CallFromNativeContract<T>(Action<T> onComplete, UInt160 hash, string method, params StackItem[] args)
{
InvocationState state = GetInvocationState(CurrentContext);
state.ReturnType = typeof(T);
state.Callback = onComplete;
CallContract(hash, method, new VMArray(ReferenceCounter, args));
CallContractInternal(hash, method, new VMArray(ReferenceCounter, args), CallFlags.All, ReturnTypeConvention.EnsureNotEmpty);
}

protected override void ContextUnloaded(ExecutionContext context)
Expand All @@ -104,15 +104,15 @@ protected override void ContextUnloaded(ExecutionContext context)
if (!(UncaughtException is null)) return;
if (invocationStates.Count == 0) return;
if (!invocationStates.Remove(CurrentContext, out InvocationState state)) return;
switch (state.NeedCheckReturnValue)
switch (state.Convention)
{
case CheckReturnType.EnsureIsEmpty:
case ReturnTypeConvention.EnsureIsEmpty:
{
if (context.EvaluationStack.Count != 0)
throw new InvalidOperationException();
break;
}
case CheckReturnType.EnsureNotEmpty:
case ReturnTypeConvention.EnsureNotEmpty:
{
if (context.EvaluationStack.Count == 0)
Push(StackItem.Null);
Expand Down Expand Up @@ -164,7 +164,7 @@ protected override void LoadContext(ExecutionContext context)
internal void LoadContext(ExecutionContext context, bool checkReturnValue)
{
if (checkReturnValue)
GetInvocationState(CurrentContext).NeedCheckReturnValue = CheckReturnType.EnsureNotEmpty;
GetInvocationState(CurrentContext).Convention = ReturnTypeConvention.EnsureNotEmpty;
LoadContext(context);
}

Expand Down
2 changes: 1 addition & 1 deletion src/neo/SmartContract/Native/Oracle/OracleContract.cs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ private void Finish(ApplicationEngine engine)
if (request == null) throw new ArgumentException("Oracle request was not found");
engine.SendNotification(Hash, "OracleResponse", new VM.Types.Array { response.Id, request.OriginalTxid.ToArray() });
StackItem userData = BinarySerializer.Deserialize(request.UserData, engine.Limits.MaxStackSize, engine.Limits.MaxItemSize, engine.ReferenceCounter);
engine.CallFromNativeContract(null, request.CallbackContract, request.CallbackMethod, request.Url, userData, (int)response.Code, response.Result);
engine.CallFromNativeContract(() => { }, request.CallbackContract, request.CallbackMethod, request.Url, userData, (int)response.Code, response.Result);
}

private UInt256 GetOriginalTxid(ApplicationEngine engine)
Expand Down