Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for using extension methods in bindings #931

Merged
merged 20 commits into from
Feb 23, 2021
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,14 @@ public void BindingCompiler_Invalid_LambdaParameters(string expr)
Assert.ThrowsException<AggregateException>(() => ExecuteBinding(expr, viewModel));
}

[TestMethod]
public void BindingCompiler_Valid_ExtensionMethods()
{
var viewModel = new TestViewModel();
var result = (long[])ExecuteBinding("LongArray.Where((long item) => item % 2 != 0).ToArray()", viewModel);
CollectionAssert.AreEqual(viewModel.LongArray.Where(item => item % 2 != 0).ToArray(), result);
}

class MoqComponent : DotvvmBindableObject
{
public object Property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -368,9 +368,11 @@ public void JsTranslator_ArrayIndexer()
}

[TestMethod]
public void JsTranslator_EnumerableWhere()
[DataRow("Enumerable.Where(LongArray, (long item) => item % 2 == 0)", DisplayName = "Regular call of Enumerable.Where")]
[DataRow("LongArray.Where((long item) => item % 2 == 0)", DisplayName = "Syntax sugar - extension method")]
public void JsTranslator_EnumerableWhere(string binding)
{
var result = CompileBinding("Enumerable.Where(LongArray, (long item) => item % 2 == 0)", new[] { typeof(TestViewModel) });
var result = CompileBinding(binding, new[] { typeof(TestViewModel) });
Assert.AreEqual("LongArray().filter(function(item){return ko.unwrap(item)%2==0;})", result);
}

Expand All @@ -382,12 +384,21 @@ public void JsTranslator_NestedEnumerableMethods()
}

[TestMethod]
public void JsTranslator_EnumerableSelect()
[DataRow("Enumerable.Select(LongArray, (long item) => -item)", DisplayName = "Regular call of Enumerable.Select")]
[DataRow("LongArray.Select((long item) => -item)", DisplayName = "Syntax sugar - extension method")]
public void JsTranslator_EnumerableSelect(string binding)
{
var result = CompileBinding("Enumerable.Select(LongArray, (long item) => -item)", new[] { typeof(TestViewModel) });
var result = CompileBinding(binding, new[] { typeof(TestViewModel) });
Assert.AreEqual("LongArray().map(function(item){return -ko.unwrap(item);})", result);
}

[TestMethod]
public void JsTranslator_ValidMethod_UnsupportedTranslation()
{
Assert.ThrowsException<NotSupportedException>(() =>
CompileBinding("Enumerable.Skip<long>(LongArray, 2)", new[] { typeof(TestViewModel) }));
}

[TestMethod]
public void JavascriptCompilation_GuidToString()
{
Expand Down
67 changes: 52 additions & 15 deletions src/DotVVM.Framework/Compilation/Binding/ExpressionHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,16 @@ public static Expression GetMember(Expression target, string name, Type[] typeAr

if (members.Length == 0)
{
if (throwExceptions) throw new Exception($"Could not find { (isStatic ? "static" : "instance") } member { name } on type { type.FullName }.");
else return null;
// We did not find any match in regular methods => try extension methods
var extensions = type.GetAllExtensions()
.Where(m => ((isGeneric && m is TypeInfo) ? genericName : name) == m.Name)
.ToArray();
members = extensions;

if (members.Length == 0 && throwExceptions)
throw new Exception($"Could not find { (isStatic ? "static" : "instance") } member { name } on type { type.FullName }.");
else if (members.Length == 0 && !throwExceptions)
return null;
}
if (members.Length == 1)
{
Expand Down Expand Up @@ -147,36 +155,64 @@ public static Expression Call(Expression target, Expression[] arguments)
public static Expression CallMethod(Expression target, BindingFlags flags, string name, Type[] typeArguments, Expression[] arguments, IDictionary<string, Expression> namedArgs = null)
{
// the following piece of code is nicer and more readable than method recognition done in roslyn, C# dynamic and also expression evaluator :)
var method = FindValidMethodOveloads(target.Type, name, flags, typeArguments, arguments, namedArgs);
var method = FindValidMethodOveloads(target, target.Type, name, flags, typeArguments, arguments, namedArgs);

if (method.IsExtension)
{
// Change to a static call
var newArguments = new[] { target }.Concat(arguments);
return Expression.Call(method.Method, newArguments);
}

return Expression.Call(target, method.Method, method.Arguments);
}

public static Expression CallMethod(Type target, BindingFlags flags, string name, Type[] typeArguments, Expression[] arguments, IDictionary<string, Expression> namedArgs = null)
{
// the following piece of code is nicer and more readable than method recognition done in roslyn, C# dynamic and also expression evaluator :)
var method = FindValidMethodOveloads(target, name, flags, typeArguments, arguments, namedArgs);
var method = FindValidMethodOveloads(null, target, name, flags, typeArguments, arguments, namedArgs);
return Expression.Call(method.Method, method.Arguments);
}


private static MethodRecognitionResult FindValidMethodOveloads(Type type, string name, BindingFlags flags, Type[] typeArguments, Expression[] arguments, IDictionary<string, Expression> namedArgs)
private static MethodRecognitionResult FindValidMethodOveloads(Expression target, Type type, string name, BindingFlags flags, Type[] typeArguments, Expression[] arguments, IDictionary<string, Expression> namedArgs)
{
var methods = FindValidMethodOveloads(type.GetAllMembers(flags).OfType<MethodInfo>().Where(m => m.Name == name), typeArguments, arguments, namedArgs).ToList();
var methods = FindValidMethodOveloads(type.GetAllMembers(flags).OfType<MethodInfo>().Where(m => m.Name == name), typeArguments, arguments, namedArgs).ToList();

if (methods.Count == 1) return methods.FirstOrDefault();
if (methods.Count == 0) throw new InvalidOperationException($"Could not find overload of method '{name}'.");
else
if (methods.Count == 0)
{
methods = methods.OrderBy(s => s.CastCount).ThenBy(s => s.AutomaticTypeArgCount).ToList();
var method = methods.FirstOrDefault();
var method2 = methods.Skip(1).FirstOrDefault();
if (method.AutomaticTypeArgCount == method2.AutomaticTypeArgCount && method.CastCount == method2.CastCount)
// We did not find any match in regular methods => try extension methods
if (target != null)
{
// TODO: this behavior is not completed. Implement the same behavior as in roslyn.
throw new InvalidOperationException($"Found ambiguous overloads of method '{name}'.");
// Change to a static call
var newArguments = new[] { target }.Concat(arguments).ToArray();
var extensions = FindValidMethodOveloads(type.GetAllExtensions().OfType<MethodInfo>().Where(m => m.Name == name), typeArguments, newArguments, namedArgs)
.Select(method => { method.IsExtension = true; return method; }).ToList();

// We found an extension method
if (extensions.Count == 1)
return extensions.FirstOrDefault();

target = null;
methods = extensions;
arguments = newArguments;
}
return method;

if (methods.Count == 0)
throw new InvalidOperationException($"Could not find method overload nor extension method that matched '{name}'.");
}

// There are multiple method candidates
methods = methods.OrderBy(s => s.CastCount).ThenBy(s => s.AutomaticTypeArgCount).ToList();
var method = methods.FirstOrDefault();
var method2 = methods.Skip(1).FirstOrDefault();
if (method.AutomaticTypeArgCount == method2.AutomaticTypeArgCount && method.CastCount == method2.CastCount)
{
// TODO: this behavior is not completed. Implement the same behavior as in roslyn.
throw new InvalidOperationException($"Found ambiguous overloads of method '{name}'.");
}
return method;
}

private static IEnumerable<MethodRecognitionResult> FindValidMethodOveloads(IEnumerable<MethodInfo> methods, Type[] typeArguments, Expression[] arguments, IDictionary<string, Expression> namedArgs)
Expand All @@ -192,6 +228,7 @@ class MethodRecognitionResult
public int CastCount { get; set; }
public Expression[] Arguments { get; set; }
public MethodInfo Method { get; set; }
public bool IsExtension { get; set; }
}

private static MethodRecognitionResult TryCallMethod(MethodInfo method, Type[] typeArguments, Expression[] positionalArguments, IDictionary<string, Expression> namedArguments)
Expand Down
5 changes: 5 additions & 0 deletions src/DotVVM.Framework/Compilation/Binding/TypeRegistry.cs
Original file line number Diff line number Diff line change
Expand Up @@ -127,5 +127,10 @@ public static Expression CreateStatic(Type type)
ImmutableList<Func<string, Expression>>.Empty
.Add(type => CreateStatic(compiledAssemblyCache.FindType(type + (assemblyName != null ? $", {assemblyName}" : ""))))
);

public static IEnumerable<Type> GetRegisteredTypesForExtensionMethodsLookup()
{
yield return typeof(Enumerable);
}
}
}
7 changes: 7 additions & 0 deletions src/DotVVM.Framework/Utils/ReflectionUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
using System.Text;
using System.Globalization;
using System.Collections.Concurrent;
using DotVVM.Framework.Compilation.Binding;

namespace DotVVM.Framework.Utils
{
Expand Down Expand Up @@ -63,6 +64,12 @@ public static IEnumerable<MemberInfo> GetAllMembers(this Type type, BindingFlags
return type.GetMembers(flags);
}

public static IEnumerable<MethodInfo> GetAllExtensions(this Type type, BindingFlags flags = BindingFlags.Public | BindingFlags.Static)
{
foreach (var registeredType in TypeRegistry.GetRegisteredTypesForExtensionMethodsLookup())
foreach (var method in registeredType.GetMethods(flags))
yield return method;
}

/// <summary>
/// Gets filesystem path of assembly CodeBase
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,21 @@

<div>
<h2>Operations</h2>
<p>
Showcasing LINQ and JsTranslator <br/>
Note: you can use either explicit expressions (example: Enumerable.Where(Collection, ...)),
but also extension-method-like calls (example: Collection.Where(...))
</p>

<dot:Button Text="Get all even numbers" Validation.Enabled="false"
Click="{command: SetResult(Enumerable.Where(Array, (int item) => item % 2 == 0))}" />
<%--Click="{command: SetResult(Enumerable.Where(Array, (int item) => item % 2 == 0))}" />--%>
Click="{command: SetResult(Array.Where((int item) => item % 2 == 0))}" />
<dot:Button Text="Get all odd numbers" Validation.Enabled="false"
Click="{command: SetResult(Enumerable.Where(Array, (int item) => item % 2 == 1))}" />
<%--Click="{command: SetResult(Enumerable.Where(Array, (int item) => item % 2 == 1))}" />--%>
Click="{command: SetResult(Array.Where((int item) => item % 2 == 1))}" />
<dot:Button Text="Negate numbers" Validation.Enabled="false"
Click="{command: SetResult(Enumerable.Select(Array, (int item) => -item))}" />
<%--Click="{command: SetResult(Enumerable.Select(Array, (int item) => -item))}" />--%>
Click="{command: SetResult(Array.Select((int item) => -item))}" />
</div>

<div>
Expand Down