Skip to content

Commit

Permalink
Refactor NativeContract (neo-project#1693)
Browse files Browse the repository at this point in the history
  • Loading branch information
erikzhang authored and KickSeason committed Jun 15, 2020
1 parent c10e74a commit af07180
Show file tree
Hide file tree
Showing 11 changed files with 202 additions and 268 deletions.
76 changes: 40 additions & 36 deletions src/neo/SmartContract/ApplicationEngine.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public ExecutionContext LoadScript(Script script, CallFlags callFlags, int rvcou
return context;
}

private StackItem ConvertReturnValue(object value)
internal StackItem Convert(object value)
{
return value switch
{
Expand All @@ -85,18 +85,51 @@ private StackItem ConvertReturnValue(object value)
uint i => i,
long i => i,
ulong i => i,
Enum e => ConvertReturnValue(Convert.ChangeType(e, e.GetTypeCode())),
Enum e => Convert(System.Convert.ChangeType(e, e.GetTypeCode())),
byte[] data => data,
string s => s,
UInt160 i => i.ToArray(),
UInt256 i => i.ToArray(),
BigInteger i => i,
IInteroperable interoperable => interoperable.ToStackItem(ReferenceCounter),
IInteroperable[] array => new VMArray(ReferenceCounter, array.Select(p => p.ToStackItem(ReferenceCounter))),
ISerializable i => i.ToArray(),
StackItem item => item,
(object a, object b) => new Struct(ReferenceCounter) { Convert(a), Convert(b) },
Array array => new VMArray(ReferenceCounter, array.OfType<object>().Select(p => Convert(p))),
_ => StackItem.FromInterface(value)
};
}

internal object Convert(StackItem item, InteropParameterDescriptor descriptor)
{
if (descriptor.IsArray)
{
Array av;
if (item is VMArray array)
{
av = Array.CreateInstance(descriptor.Type.GetElementType(), array.Count);
for (int i = 0; i < av.Length; i++)
av.SetValue(descriptor.Converter(array[i]), i);
}
else
{
int count = (int)item.GetBigInteger();
if (count > MaxStackSize) throw new InvalidOperationException();
av = Array.CreateInstance(descriptor.Type.GetElementType(), count);
for (int i = 0; i < av.Length; i++)
av.SetValue(descriptor.Converter(Pop()), i);
}
return av;
}
else
{
object value = descriptor.Converter(item);
if (descriptor.IsEnum)
value = System.Convert.ChangeType(value, descriptor.Type);
else if (descriptor.IsInterface)
value = ((InteropInterface)value).GetInterface<object>();
return value;
}
}

public override void Dispose()
{
foreach (IDisposable disposable in disposables)
Expand All @@ -120,39 +153,10 @@ protected override bool OnSysCall(uint method)
? new List<object>()
: null;
foreach (var pd in descriptor.Parameters)
{
StackItem item = Pop();
object value;
if (pd.IsArray)
{
Array av;
if (item is VMArray array)
{
av = Array.CreateInstance(pd.Type.GetElementType(), array.Count);
for (int i = 0; i < av.Length; i++)
av.SetValue(pd.Converter(array[i]), i);
}
else
{
av = Array.CreateInstance(pd.Type.GetElementType(), (int)item.GetBigInteger());
for (int i = 0; i < av.Length; i++)
av.SetValue(pd.Converter(Pop()), i);
}
value = av;
}
else
{
value = pd.Converter(item);
if (pd.IsEnum)
value = Convert.ChangeType(value, pd.Type);
else if (pd.IsInterface)
value = ((InteropInterface)value).GetInterface<object>();
}
parameters.Add(value);
}
parameters.Add(Convert(Pop(), pd));
object returnValue = descriptor.Handler.Invoke(this, parameters?.ToArray());
if (descriptor.Handler.ReturnType != typeof(void))
Push(ConvertReturnValue(returnValue));
Push(Convert(returnValue));
return true;
}

Expand Down
4 changes: 4 additions & 0 deletions src/neo/SmartContract/InteropParameterDescriptor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
using Neo.VM.Types;
using System;
using System.Collections.Generic;
using System.Numerics;
using System.Reflection;

namespace Neo.SmartContract
{
internal class InteropParameterDescriptor
{
public string Name { get; }
public Type Type { get; }
public Func<StackItem, object> Converter { get; }
public bool IsEnum => Type.IsEnum;
Expand All @@ -28,6 +30,7 @@ internal class InteropParameterDescriptor
[typeof(uint)] = p => (uint)p.GetBigInteger(),
[typeof(long)] = p => (long)p.GetBigInteger(),
[typeof(ulong)] = p => (ulong)p.GetBigInteger(),
[typeof(BigInteger)] = p => p.GetBigInteger(),
[typeof(byte[])] = p => p.IsNull ? null : p.GetSpan().ToArray(),
[typeof(string)] = p => p.IsNull ? null : p.GetString(),
[typeof(UInt160)] = p => p.IsNull ? null : new UInt160(p.GetSpan()),
Expand All @@ -37,6 +40,7 @@ internal class InteropParameterDescriptor

public InteropParameterDescriptor(ParameterInfo parameterInfo)
{
Name = parameterInfo.Name;
Type = parameterInfo.ParameterType;
if (IsEnum)
{
Expand Down
8 changes: 2 additions & 6 deletions src/neo/SmartContract/Native/ContractMethodAttribute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,16 @@

namespace Neo.SmartContract.Native
{
[AttributeUsage(AttributeTargets.Method, AllowMultiple = false)]
[AttributeUsage(AttributeTargets.Method | AttributeTargets.Property, AllowMultiple = false)]
internal class ContractMethodAttribute : Attribute
{
public string Name { get; set; }
public long Price { get; }
public ContractParameterType ReturnType { get; }
public ContractParameterType[] ParameterTypes { get; set; } = Array.Empty<ContractParameterType>();
public string[] ParameterNames { get; set; } = Array.Empty<string>();
public CallFlags RequiredCallFlags { get; }

public ContractMethodAttribute(long price, ContractParameterType returnType, CallFlags requiredCallFlags)
public ContractMethodAttribute(long price, CallFlags requiredCallFlags)
{
this.Price = price;
this.ReturnType = returnType;
this.RequiredCallFlags = requiredCallFlags;
}
}
Expand Down
38 changes: 33 additions & 5 deletions src/neo/SmartContract/Native/ContractMethodMetadata.cs
Original file line number Diff line number Diff line change
@@ -1,13 +1,41 @@
using Neo.VM.Types;
using Neo.Persistence;
using System;
using VMArray = Neo.VM.Types.Array;
using System.Linq;
using System.Reflection;

namespace Neo.SmartContract.Native
{
internal class ContractMethodMetadata
{
public Func<ApplicationEngine, VMArray, StackItem> Delegate;
public long Price;
public CallFlags RequiredCallFlags;
public string Name { get; }
public MethodInfo Handler { get; }
public InteropParameterDescriptor[] Parameters { get; }
public bool NeedApplicationEngine { get; }
public bool NeedSnapshot { get; }
public long Price { get; }
public CallFlags RequiredCallFlags { get; }

public ContractMethodMetadata(MemberInfo member, ContractMethodAttribute attribute)
{
this.Name = attribute.Name ?? member.Name.ToLower()[0] + member.Name[1..];
this.Handler = member switch
{
MethodInfo m => m,
PropertyInfo p => p.GetMethod,
_ => throw new ArgumentException(nameof(member))
};
ParameterInfo[] parameterInfos = this.Handler.GetParameters();
if (parameterInfos.Length > 0)
{
NeedApplicationEngine = parameterInfos[0].ParameterType.IsAssignableFrom(typeof(ApplicationEngine));
NeedSnapshot = parameterInfos[0].ParameterType.IsAssignableFrom(typeof(StoreView));
}
if (NeedApplicationEngine || NeedSnapshot)
this.Parameters = parameterInfos.Skip(1).Select(p => new InteropParameterDescriptor(p)).ToArray();
else
this.Parameters = parameterInfos.Select(p => new InteropParameterDescriptor(p)).ToArray();
this.Price = attribute.Price;
this.RequiredCallFlags = attribute.RequiredCallFlags;
}
}
}
108 changes: 64 additions & 44 deletions src/neo/SmartContract/Native/NativeContract.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
#pragma warning disable IDE0060

using Neo.IO;
using Neo.Ledger;
using Neo.SmartContract.Manifest;
Expand All @@ -9,6 +7,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Numerics;
using System.Reflection;
using Array = Neo.VM.Types.Array;

Expand All @@ -26,11 +25,13 @@ public abstract class NativeContract
public static GasToken GAS { get; } = new GasToken();
public static PolicyContract Policy { get; } = new PolicyContract();

[ContractMethod(0, CallFlags.None)]
public abstract string Name { get; }
public byte[] Script { get; }
public UInt160 Hash { get; }
public abstract int Id { get; }
public ContractManifest Manifest { get; }
[ContractMethod(0, CallFlags.None)]
public virtual string[] SupportedStandards { get; } = { "NEP-10" };

protected NativeContract()
Expand All @@ -44,36 +45,31 @@ protected NativeContract()
this.Hash = Script.ToScriptHash();
List<ContractMethodDescriptor> descriptors = new List<ContractMethodDescriptor>();
List<string> safeMethods = new List<string>();
foreach (MethodInfo method in GetType().GetMethods(BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public))
foreach (MemberInfo member in GetType().GetMembers(BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public))
{
ContractMethodAttribute attribute = method.GetCustomAttribute<ContractMethodAttribute>();
ContractMethodAttribute attribute = member.GetCustomAttribute<ContractMethodAttribute>();
if (attribute is null) continue;
string name = attribute.Name ?? (method.Name.ToLower()[0] + method.Name.Substring(1));
ContractMethodMetadata metadata = new ContractMethodMetadata(member, attribute);
descriptors.Add(new ContractMethodDescriptor
{
Name = name,
ReturnType = attribute.ReturnType,
Parameters = attribute.ParameterTypes.Zip(attribute.ParameterNames, (t, n) => new ContractParameterDefinition { Type = t, Name = n }).ToArray()
});
if (!attribute.RequiredCallFlags.HasFlag(CallFlags.AllowModifyStates)) safeMethods.Add(name);
methods.Add(name, new ContractMethodMetadata
{
Delegate = (Func<ApplicationEngine, Array, StackItem>)method.CreateDelegate(typeof(Func<ApplicationEngine, Array, StackItem>), this),
Price = attribute.Price,
RequiredCallFlags = attribute.RequiredCallFlags
Name = metadata.Name,
ReturnType = ToParameterType(metadata.Handler.ReturnType),
Parameters = metadata.Parameters.Select(p => new ContractParameterDefinition { Type = ToParameterType(p.Type), Name = p.Name }).ToArray()
});
if (!attribute.RequiredCallFlags.HasFlag(CallFlags.AllowModifyStates)) safeMethods.Add(metadata.Name);
methods.Add(metadata.Name, metadata);
}
this.Manifest = new ContractManifest
{
Permissions = new[] { ContractPermission.DefaultPermission },
Abi = new ContractAbi()
{
Hash = Hash,
Events = new ContractEventDescriptor[0],
Events = System.Array.Empty<ContractEventDescriptor>(),
Methods = descriptors.ToArray()
},
Features = ContractFeatures.NoProperty,
Groups = new ContractGroup[0],
Groups = System.Array.Empty<ContractGroup>(),
SafeMethods = WildcardContainer<string>.Create(safeMethods.ToArray()),
Trusts = WildcardContainer<UInt160>.Create(),
Extra = null,
Expand Down Expand Up @@ -114,19 +110,24 @@ public static NativeContract GetContract(string name)

internal bool Invoke(ApplicationEngine engine)
{
if (!engine.CurrentScriptHash.Equals(Hash))
return false;
string operation = engine.CurrentContext.EvaluationStack.Pop().GetString();
Array args = (Array)engine.CurrentContext.EvaluationStack.Pop();
if (!methods.TryGetValue(operation, out ContractMethodMetadata method))
return false;
if (!engine.CurrentScriptHash.Equals(Hash)) return false;
if (!engine.TryPop(out string operation)) return false;
if (!engine.TryPop(out Array args)) return false;
if (!methods.TryGetValue(operation, out ContractMethodMetadata method)) return false;
ExecutionContextState state = engine.CurrentContext.GetState<ExecutionContextState>();
if (!state.CallFlags.HasFlag(method.RequiredCallFlags))
return false;
if (!engine.AddGas(method.Price))
return false;
StackItem result = method.Delegate(engine, args);
engine.CurrentContext.EvaluationStack.Push(result);
if (!state.CallFlags.HasFlag(method.RequiredCallFlags)) return false;
if (!engine.AddGas(method.Price)) return false;
List<object> parameters = new List<object>();
if (method.NeedApplicationEngine) parameters.Add(engine);
if (method.NeedSnapshot) parameters.Add(engine.Snapshot);
for (int i = 0; i < method.Parameters.Length; i++)
{
StackItem item = i < args.Count ? args[i] : StackItem.Null;
parameters.Add(engine.Convert(item, method.Parameters[i]));
}
object returnValue = method.Handler.Invoke(this, parameters.ToArray());
if (method.Handler.ReturnType != typeof(void))
engine.Push(engine.Convert(returnValue));
return true;
}

Expand All @@ -139,22 +140,11 @@ internal virtual void Initialize(ApplicationEngine engine)
{
}

[ContractMethod(0, ContractParameterType.Boolean, CallFlags.AllowModifyStates)]
protected StackItem OnPersist(ApplicationEngine engine, Array args)
[ContractMethod(0, CallFlags.AllowModifyStates)]
protected virtual void OnPersist(ApplicationEngine engine)
{
if (engine.Trigger != TriggerType.System) return false;
return OnPersist(engine);
}

protected virtual bool OnPersist(ApplicationEngine engine)
{
return true;
}

[ContractMethod(0, ContractParameterType.Array, CallFlags.None, Name = "supportedStandards")]
protected StackItem SupportedStandardsMethod(ApplicationEngine engine, Array args)
{
return new Array(engine.ReferenceCounter, SupportedStandards.Select(p => (StackItem)p));
if (engine.Trigger != TriggerType.System)
throw new InvalidOperationException();
}

public ApplicationEngine TestCall(string operation, params object[] args)
Expand All @@ -165,5 +155,35 @@ public ApplicationEngine TestCall(string operation, params object[] args)
return ApplicationEngine.Run(sb.ToArray(), testMode: true);
}
}

private static ContractParameterType ToParameterType(Type type)
{
if (type == typeof(void)) return ContractParameterType.Void;
if (type == typeof(bool)) return ContractParameterType.Boolean;
if (type == typeof(sbyte)) return ContractParameterType.Integer;
if (type == typeof(byte)) return ContractParameterType.Integer;
if (type == typeof(short)) return ContractParameterType.Integer;
if (type == typeof(ushort)) return ContractParameterType.Integer;
if (type == typeof(int)) return ContractParameterType.Integer;
if (type == typeof(uint)) return ContractParameterType.Integer;
if (type == typeof(long)) return ContractParameterType.Integer;
if (type == typeof(ulong)) return ContractParameterType.Integer;
if (type == typeof(BigInteger)) return ContractParameterType.Integer;
if (type == typeof(byte[])) return ContractParameterType.ByteArray;
if (type == typeof(string)) return ContractParameterType.ByteArray;
if (type == typeof(VM.Types.Boolean)) return ContractParameterType.Boolean;
if (type == typeof(Integer)) return ContractParameterType.Integer;
if (type == typeof(ByteString)) return ContractParameterType.ByteArray;
if (type == typeof(VM.Types.Buffer)) return ContractParameterType.ByteArray;
if (type == typeof(Array)) return ContractParameterType.Array;
if (type == typeof(Struct)) return ContractParameterType.Array;
if (type == typeof(Map)) return ContractParameterType.Map;
if (type == typeof(StackItem)) return ContractParameterType.Any;
if (typeof(IInteroperable).IsAssignableFrom(type)) return ContractParameterType.Array;
if (typeof(ISerializable).IsAssignableFrom(type)) return ContractParameterType.ByteArray;
if (type.IsArray) return ContractParameterType.Array;
if (type.IsEnum) return ContractParameterType.Integer;
return ContractParameterType.Any;
}
}
}
Loading

0 comments on commit af07180

Please sign in to comment.