Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor NativeContract #1693

Merged
merged 5 commits into from
Jun 12, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 38 additions & 36 deletions src/neo/SmartContract/ApplicationEngine.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public ExecutionContext LoadScript(Script script, CallFlags callFlags, int rvcou
return context;
}

private StackItem ConvertReturnValue(object value)
internal StackItem Convert(object value)
{
return value switch
{
Expand All @@ -85,18 +85,49 @@ private StackItem ConvertReturnValue(object value)
uint i => i,
long i => i,
ulong i => i,
Enum e => ConvertReturnValue(Convert.ChangeType(e, e.GetTypeCode())),
Enum e => Convert(System.Convert.ChangeType(e, e.GetTypeCode())),
byte[] data => data,
string s => s,
UInt160 i => i.ToArray(),
UInt256 i => i.ToArray(),
BigInteger i => i,
IInteroperable interoperable => interoperable.ToStackItem(ReferenceCounter),
IInteroperable[] array => new VMArray(ReferenceCounter, array.Select(p => p.ToStackItem(ReferenceCounter))),
ISerializable i => i.ToArray(),
StackItem item => item,
(object a, object b) => new Struct(ReferenceCounter) { Convert(a), Convert(b) },
Array array => new VMArray(ReferenceCounter, array.OfType<object>().Select(p => Convert(p))),
_ => StackItem.FromInterface(value)
};
}

internal object Convert(StackItem item, InteropParameterDescriptor descriptor)
{
if (descriptor.IsArray)
{
Array av;
if (item is VMArray array)
{
av = Array.CreateInstance(descriptor.Type.GetElementType(), array.Count);
for (int i = 0; i < av.Length; i++)
av.SetValue(descriptor.Converter(array[i]), i);
}
else
{
av = Array.CreateInstance(descriptor.Type.GetElementType(), (int)item.GetBigInteger());
shargon marked this conversation as resolved.
Show resolved Hide resolved
for (int i = 0; i < av.Length; i++)
av.SetValue(descriptor.Converter(Pop()), i);
erikzhang marked this conversation as resolved.
Show resolved Hide resolved
}
return av;
}
else
{
object value = descriptor.Converter(item);
if (descriptor.IsEnum)
value = System.Convert.ChangeType(value, descriptor.Type);
else if (descriptor.IsInterface)
value = ((InteropInterface)value).GetInterface<object>();
return value;
}
}

public override void Dispose()
{
foreach (IDisposable disposable in disposables)
Expand All @@ -120,39 +151,10 @@ protected override bool OnSysCall(uint method)
? new List<object>()
: null;
foreach (var pd in descriptor.Parameters)
{
StackItem item = Pop();
object value;
if (pd.IsArray)
{
Array av;
if (item is VMArray array)
{
av = Array.CreateInstance(pd.Type.GetElementType(), array.Count);
for (int i = 0; i < av.Length; i++)
av.SetValue(pd.Converter(array[i]), i);
}
else
{
av = Array.CreateInstance(pd.Type.GetElementType(), (int)item.GetBigInteger());
for (int i = 0; i < av.Length; i++)
av.SetValue(pd.Converter(Pop()), i);
}
value = av;
}
else
{
value = pd.Converter(item);
if (pd.IsEnum)
value = Convert.ChangeType(value, pd.Type);
else if (pd.IsInterface)
value = ((InteropInterface)value).GetInterface<object>();
}
parameters.Add(value);
}
parameters.Add(Convert(Pop(), pd));
object returnValue = descriptor.Handler.Invoke(this, parameters?.ToArray());
if (descriptor.Handler.ReturnType != typeof(void))
Push(ConvertReturnValue(returnValue));
Push(Convert(returnValue));
return true;
}

Expand Down
4 changes: 4 additions & 0 deletions src/neo/SmartContract/InteropParameterDescriptor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
using Neo.VM.Types;
using System;
using System.Collections.Generic;
using System.Numerics;
using System.Reflection;

namespace Neo.SmartContract
{
internal class InteropParameterDescriptor
{
public string Name { get; }
public Type Type { get; }
public Func<StackItem, object> Converter { get; }
public bool IsEnum => Type.IsEnum;
Expand All @@ -28,6 +30,7 @@ internal class InteropParameterDescriptor
[typeof(uint)] = p => (uint)p.GetBigInteger(),
[typeof(long)] = p => (long)p.GetBigInteger(),
[typeof(ulong)] = p => (ulong)p.GetBigInteger(),
[typeof(BigInteger)] = p => p.GetBigInteger(),
[typeof(byte[])] = p => p.IsNull ? null : p.GetSpan().ToArray(),
[typeof(string)] = p => p.IsNull ? null : p.GetString(),
[typeof(UInt160)] = p => p.IsNull ? null : new UInt160(p.GetSpan()),
Expand All @@ -37,6 +40,7 @@ internal class InteropParameterDescriptor

public InteropParameterDescriptor(ParameterInfo parameterInfo)
{
Name = parameterInfo.Name;
Type = parameterInfo.ParameterType;
if (IsEnum)
{
Expand Down
8 changes: 2 additions & 6 deletions src/neo/SmartContract/Native/ContractMethodAttribute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,16 @@

namespace Neo.SmartContract.Native
{
[AttributeUsage(AttributeTargets.Method, AllowMultiple = false)]
[AttributeUsage(AttributeTargets.Method | AttributeTargets.Property, AllowMultiple = false)]
internal class ContractMethodAttribute : Attribute
{
public string Name { get; set; }
public long Price { get; }
public ContractParameterType ReturnType { get; }
public ContractParameterType[] ParameterTypes { get; set; } = Array.Empty<ContractParameterType>();
public string[] ParameterNames { get; set; } = Array.Empty<string>();
public CallFlags RequiredCallFlags { get; }

public ContractMethodAttribute(long price, ContractParameterType returnType, CallFlags requiredCallFlags)
public ContractMethodAttribute(long price, CallFlags requiredCallFlags)
{
this.Price = price;
this.ReturnType = returnType;
this.RequiredCallFlags = requiredCallFlags;
}
}
Expand Down
40 changes: 34 additions & 6 deletions src/neo/SmartContract/Native/ContractMethodMetadata.cs
Original file line number Diff line number Diff line change
@@ -1,13 +1,41 @@
using Neo.VM.Types;
using System;
using VMArray = Neo.VM.Types.Array;
using Neo.Persistence;
using System.Linq;
using System.Reflection;

namespace Neo.SmartContract.Native
{
internal class ContractMethodMetadata
{
public Func<ApplicationEngine, VMArray, StackItem> Delegate;
public long Price;
public CallFlags RequiredCallFlags;
public string Name { get; }
public MethodInfo Handler { get; }
public InteropParameterDescriptor[] Parameters { get; }
public bool NeedApplicationEngine { get; }
public bool NeedSnapshot { get; }
public long Price { get; }
public CallFlags RequiredCallFlags { get; }

public ContractMethodMetadata(MethodInfo handler, ContractMethodAttribute attribute)
{
this.Name = attribute.Name ?? GetDefaultMethodName(handler.Name);
this.Handler = handler;
ParameterInfo[] parameterInfos = handler.GetParameters();
if (parameterInfos.Length > 0)
{
NeedApplicationEngine = parameterInfos[0].ParameterType.IsAssignableFrom(typeof(ApplicationEngine));
shargon marked this conversation as resolved.
Show resolved Hide resolved
NeedSnapshot = parameterInfos[0].ParameterType.IsAssignableFrom(typeof(StoreView));
}
if (NeedApplicationEngine || NeedSnapshot)
this.Parameters = parameterInfos.Skip(1).Select(p => new InteropParameterDescriptor(p)).ToArray();
else
this.Parameters = parameterInfos.Select(p => new InteropParameterDescriptor(p)).ToArray();
this.Price = attribute.Price;
this.RequiredCallFlags = attribute.RequiredCallFlags;
}

private static string GetDefaultMethodName(string name)
{
if (name.StartsWith("get_")) name = name[4..];
return name.ToLower()[0] + name[1..];
}
}
}
110 changes: 68 additions & 42 deletions src/neo/SmartContract/Native/NativeContract.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
#pragma warning disable IDE0060

using Neo.IO;
using Neo.Ledger;
using Neo.SmartContract.Manifest;
Expand All @@ -9,6 +7,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Numerics;
using System.Reflection;
using Array = Neo.VM.Types.Array;

Expand All @@ -26,11 +25,13 @@ public abstract class NativeContract
public static GasToken GAS { get; } = new GasToken();
public static PolicyContract Policy { get; } = new PolicyContract();

[ContractMethod(0, CallFlags.None)]
public abstract string Name { get; }
public byte[] Script { get; }
public UInt160 Hash { get; }
public abstract int Id { get; }
public ContractManifest Manifest { get; }
[ContractMethod(0, CallFlags.None)]
public virtual string[] SupportedStandards { get; } = { "NEP-10" };

protected NativeContract()
Expand All @@ -44,24 +45,25 @@ protected NativeContract()
this.Hash = Script.ToScriptHash();
List<ContractMethodDescriptor> descriptors = new List<ContractMethodDescriptor>();
List<string> safeMethods = new List<string>();
foreach (MethodInfo method in GetType().GetMethods(BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public))
foreach (MemberInfo member in GetType().GetMembers(BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public))
{
ContractMethodAttribute attribute = method.GetCustomAttribute<ContractMethodAttribute>();
ContractMethodAttribute attribute = member.GetCustomAttribute<ContractMethodAttribute>();
if (attribute is null) continue;
string name = attribute.Name ?? (method.Name.ToLower()[0] + method.Name.Substring(1));
descriptors.Add(new ContractMethodDescriptor
MethodInfo method = member switch
{
Name = name,
ReturnType = attribute.ReturnType,
Parameters = attribute.ParameterTypes.Zip(attribute.ParameterNames, (t, n) => new ContractParameterDefinition { Type = t, Name = n }).ToArray()
});
if (!attribute.RequiredCallFlags.HasFlag(CallFlags.AllowModifyStates)) safeMethods.Add(name);
methods.Add(name, new ContractMethodMetadata
MethodInfo m => m,
PropertyInfo p => p.GetMethod,
_ => throw new InvalidOperationException()
};
ContractMethodMetadata metadata = new ContractMethodMetadata(method, attribute);
descriptors.Add(new ContractMethodDescriptor
{
Delegate = (Func<ApplicationEngine, Array, StackItem>)method.CreateDelegate(typeof(Func<ApplicationEngine, Array, StackItem>), this),
Price = attribute.Price,
RequiredCallFlags = attribute.RequiredCallFlags
Name = metadata.Name,
ReturnType = ToParameterType(method.ReturnType),
Parameters = metadata.Parameters.Select(p => new ContractParameterDefinition { Type = ToParameterType(p.Type), Name = p.Name }).ToArray()
});
if (!attribute.RequiredCallFlags.HasFlag(CallFlags.AllowModifyStates)) safeMethods.Add(metadata.Name);
methods.Add(metadata.Name, metadata);
}
this.Manifest = new ContractManifest
{
Expand Down Expand Up @@ -114,19 +116,24 @@ public static NativeContract GetContract(string name)

internal bool Invoke(ApplicationEngine engine)
{
if (!engine.CurrentScriptHash.Equals(Hash))
return false;
string operation = engine.CurrentContext.EvaluationStack.Pop().GetString();
Array args = (Array)engine.CurrentContext.EvaluationStack.Pop();
if (!methods.TryGetValue(operation, out ContractMethodMetadata method))
return false;
if (!engine.CurrentScriptHash.Equals(Hash)) return false;
if (!engine.TryPop(out string operation)) return false;
if (!engine.TryPop(out Array args)) return false;
if (!methods.TryGetValue(operation, out ContractMethodMetadata method)) return false;
ExecutionContextState state = engine.CurrentContext.GetState<ExecutionContextState>();
if (!state.CallFlags.HasFlag(method.RequiredCallFlags))
return false;
if (!engine.AddGas(method.Price))
return false;
StackItem result = method.Delegate(engine, args);
engine.CurrentContext.EvaluationStack.Push(result);
if (!state.CallFlags.HasFlag(method.RequiredCallFlags)) return false;
if (!engine.AddGas(method.Price)) return false;
List<object> parameters = new List<object>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can use an array because we know the size and it's faster than a list

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we also need to consider NeedApplicationEngine and NeedSnapshot.

if (method.NeedApplicationEngine) parameters.Add(engine);
if (method.NeedSnapshot) parameters.Add(engine.Snapshot);
for (int i = 0; i < method.Parameters.Length; i++)
{
StackItem item = i < args.Count ? args[i] : StackItem.Null;
parameters.Add(engine.Convert(item, method.Parameters[i]));
}
object returnValue = method.Handler.Invoke(this, parameters.ToArray());
if (method.Handler.ReturnType != typeof(void))
engine.Push(engine.Convert(returnValue));
return true;
}

Expand All @@ -139,22 +146,11 @@ internal virtual void Initialize(ApplicationEngine engine)
{
}

[ContractMethod(0, ContractParameterType.Boolean, CallFlags.AllowModifyStates)]
protected StackItem OnPersist(ApplicationEngine engine, Array args)
{
if (engine.Trigger != TriggerType.System) return false;
return OnPersist(engine);
}

protected virtual bool OnPersist(ApplicationEngine engine)
{
return true;
}

[ContractMethod(0, ContractParameterType.Array, CallFlags.None, Name = "supportedStandards")]
protected StackItem SupportedStandardsMethod(ApplicationEngine engine, Array args)
[ContractMethod(0, CallFlags.AllowModifyStates)]
protected virtual void OnPersist(ApplicationEngine engine)
{
return new Array(engine.ReferenceCounter, SupportedStandards.Select(p => (StackItem)p));
if (engine.Trigger != TriggerType.System)
throw new InvalidOperationException();
}

public ApplicationEngine TestCall(string operation, params object[] args)
Expand All @@ -165,5 +161,35 @@ public ApplicationEngine TestCall(string operation, params object[] args)
return ApplicationEngine.Run(sb.ToArray(), testMode: true);
}
}

private static ContractParameterType ToParameterType(Type type)
{
if (type == typeof(void)) return ContractParameterType.Void;
if (type == typeof(bool)) return ContractParameterType.Boolean;
if (type == typeof(sbyte)) return ContractParameterType.Integer;
if (type == typeof(byte)) return ContractParameterType.Integer;
if (type == typeof(short)) return ContractParameterType.Integer;
if (type == typeof(ushort)) return ContractParameterType.Integer;
if (type == typeof(int)) return ContractParameterType.Integer;
if (type == typeof(uint)) return ContractParameterType.Integer;
if (type == typeof(long)) return ContractParameterType.Integer;
if (type == typeof(ulong)) return ContractParameterType.Integer;
if (type == typeof(BigInteger)) return ContractParameterType.Integer;
if (type == typeof(byte[])) return ContractParameterType.ByteArray;
if (type == typeof(string)) return ContractParameterType.ByteArray;
if (type == typeof(VM.Types.Boolean)) return ContractParameterType.Boolean;
if (type == typeof(Integer)) return ContractParameterType.Integer;
if (type == typeof(ByteString)) return ContractParameterType.ByteArray;
if (type == typeof(VM.Types.Buffer)) return ContractParameterType.ByteArray;
if (type == typeof(Array)) return ContractParameterType.Array;
if (type == typeof(Struct)) return ContractParameterType.Array;
if (type == typeof(Map)) return ContractParameterType.Map;
if (type == typeof(StackItem)) return ContractParameterType.Any;
if (typeof(IInteroperable).IsAssignableFrom(type)) return ContractParameterType.Array;
if (typeof(ISerializable).IsAssignableFrom(type)) return ContractParameterType.ByteArray;
if (type.IsArray) return ContractParameterType.Array;
if (type.IsEnum) return ContractParameterType.Integer;
return ContractParameterType.Any;
}
}
}
Loading