Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for using extension methods in bindings #931

Merged
merged 20 commits into from
Feb 23, 2021
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,14 @@ public void BindingCompiler_Invalid_LambdaParameters(string expr)
Assert.ThrowsException<AggregateException>(() => ExecuteBinding(expr, viewModel));
}

[TestMethod]
public void BindingCompiler_Valid_ExtensionMethods()
{
var viewModel = new TestViewModel();
var result = (long[])ExecuteBinding("LongArray.Where((long item) => item % 2 != 0).ToArray()", viewModel);
CollectionAssert.AreEqual(viewModel.LongArray.Where(item => item % 2 != 0).ToArray(), result);
}

class MoqComponent : DotvvmBindableObject
{
public object Property
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
using System;
using System.Collections.Generic;
using System.Linq.Expressions;
using System.Text;
using DotVVM.Framework.Compilation.Binding;
using DotVVM.Framework.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.VisualStudio.TestTools.UnitTesting;

namespace DotVVM.Framework.Tests.Common.Binding
{
[TestClass]
public class CustomExtensionMethodTests
{
private DotvvmConfiguration configuration;
private MemberExpressionFactory memberExpressionFactory;

[TestInitialize]
public void INIT()
acizmarik marked this conversation as resolved.
Show resolved Hide resolved
{
this.configuration = DotvvmTestHelper.CreateConfiguration(services => services.AddScoped<IExtensionsProvider, TestExtensionsProvider>());
this.memberExpressionFactory = configuration.ServiceProvider.GetRequiredService<MemberExpressionFactory>();
}

[TestMethod]
public void Call_FindCustomExtensionMethod()
{
var target = new MethodGroupExpression()
{
MethodName = nameof(TestExtensions.Increment),
Target = Expression.Constant(11)
};

var expression = memberExpressionFactory.Call(target, Array.Empty<Expression>());
var result = Expression.Lambda<Func<int>>(expression).Compile().Invoke();
Assert.AreEqual(12, result);
}
}

static class TestExtensions
{
public static int Increment(this int number)
=> ++number;
}

class TestExtensionsProvider : DefaultExtensionsProvider
{
public TestExtensionsProvider()
{
AddTypeForExtensionsLookup(typeof(TestExtensions));
}
}
}
19 changes: 14 additions & 5 deletions src/DotVVM.Framework.Tests.Common/Binding/ExpressionHelperTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using DotVVM.Framework.Utils;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CSharp.RuntimeBinder;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.VisualStudio.TestTools.UnitTesting;

namespace DotVVM.Framework.Tests.Common.Binding
Expand All @@ -17,13 +18,21 @@ namespace DotVVM.Framework.Tests.Common.Binding
public class ExpressionHelperTests
{
public TestContext TestContext { get; set; }
private MemberExpressionFactory memberExpressionFactory;

[TestInitialize]
public void INIT()
{
var configuration = DotvvmTestHelper.CreateConfiguration();
memberExpressionFactory = configuration.ServiceProvider.GetRequiredService<MemberExpressionFactory>();
}

[TestMethod]
public void UpdateMember_GetValue()
{
var cP = Expression.Parameter(typeof(DotvvmControl), "c");
var newValueP = Expression.Parameter(typeof(object), "newValue");
var updateExpr = ExpressionHelper.UpdateMember(ExpressionUtils.Replace((DotvvmControl c) => c.GetValue(DotvvmBindableObject.DataContextProperty, true), cP), newValueP);
var updateExpr = memberExpressionFactory.UpdateMember(ExpressionUtils.Replace((DotvvmControl c) => c.GetValue(DotvvmBindableObject.DataContextProperty, true), cP), newValueP);
Assert.IsNotNull(updateExpr);
Assert.AreEqual("c.SetValue(DotvvmBindableObject.DataContextProperty, newValue)", updateExpr.ToString());
}
Expand All @@ -33,7 +42,7 @@ public void UpdateMember_NormalProperty()
{
var vmP = Expression.Parameter(typeof(Tests.Binding.TestViewModel), "vm");
var newValueP = Expression.Parameter(typeof(DateTime), "newValue");
var updateExpr = ExpressionHelper.UpdateMember(ExpressionUtils.Replace((Tests.Binding.TestViewModel c) => c.DateFrom, vmP), newValueP);
var updateExpr = memberExpressionFactory.UpdateMember(ExpressionUtils.Replace((Tests.Binding.TestViewModel c) => c.DateFrom, vmP), newValueP);
Assert.IsNotNull(updateExpr);
Assert.AreEqual("(vm.DateFrom = Convert(newValue, Nullable`1))", updateExpr.ToString());
}
Expand All @@ -43,7 +52,7 @@ public void UpdateMember_ReadOnlyProperty()
{
var vmP = Expression.Parameter(typeof(Tests.Binding.TestViewModel), "vm");
var newValueP = Expression.Parameter(typeof(long[]), "newValue");
var updateExpr = ExpressionHelper.UpdateMember(ExpressionUtils.Replace((Tests.Binding.TestViewModel c) => c.LongArray, vmP), newValueP);
var updateExpr = memberExpressionFactory.UpdateMember(ExpressionUtils.Replace((Tests.Binding.TestViewModel c) => c.LongArray, vmP), newValueP);
Assert.IsNull(updateExpr);
}

Expand All @@ -57,7 +66,7 @@ public void Call_FindOverload_Generic_FirstLevel(Type resultIdentifierType, Type
Call_FindOverload_Generic(typeof(MethodsGenericArgumentsResolvingSampleObject), MethodsGenericArgumentsResolvingSampleObject.MethodName, new[] { argType }, resultIdentifierType, expectedGenericArgs);
}

private static void Call_FindOverload_Generic(Type targetType, string methodName, Type[] argTypes, Type resultIdentifierType, Type[] expectedGenericArgs)
private void Call_FindOverload_Generic(Type targetType, string methodName, Type[] argTypes, Type resultIdentifierType, Type[] expectedGenericArgs)
{
Expression target = new MethodGroupExpression() {
MethodName = methodName,
Expand All @@ -66,7 +75,7 @@ private static void Call_FindOverload_Generic(Type targetType, string methodName

var j = 0;
var arguments = argTypes.Select(s => Expression.Parameter(s, $"param_{j++}")).ToArray();
var expression = ExpressionHelper.Call(target, arguments) as MethodCallExpression;
var expression = memberExpressionFactory.Call(target, arguments) as MethodCallExpression;
Assert.IsNotNull(expression);
Assert.AreEqual(resultIdentifierType, expression.Method.GetResultType());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public string CompileBinding(string expression, Type[] contexts, Type expectedTy
{
context = DataContextStack.Create(contexts[i], context);
}
var parser = new BindingExpressionBuilder(configuration.ServiceProvider.GetRequiredService<CompiledAssemblyCache>());
var parser = new BindingExpressionBuilder(configuration.ServiceProvider.GetRequiredService<CompiledAssemblyCache>(), configuration.ServiceProvider.GetRequiredService<MemberExpressionFactory>());
var parsedExpression = parser.ParseWithLambdaConversion(expression, context, BindingParserOptions.Create<ValueBindingExpression>(), expectedType);
var expressionTree =
TypeConversion.MagicLambdaConversion(parsedExpression, expectedType) ??
Expand Down Expand Up @@ -368,9 +368,11 @@ public void JsTranslator_ArrayIndexer()
}

[TestMethod]
public void JsTranslator_EnumerableWhere()
[DataRow("Enumerable.Where(LongArray, (long item) => item % 2 == 0)", DisplayName = "Regular call of Enumerable.Where")]
[DataRow("LongArray.Where((long item) => item % 2 == 0)", DisplayName = "Syntax sugar - extension method")]
public void JsTranslator_EnumerableWhere(string binding)
{
var result = CompileBinding("Enumerable.Where(LongArray, (long item) => item % 2 == 0)", new[] { typeof(TestViewModel) });
var result = CompileBinding(binding, new[] { typeof(TestViewModel) });
Assert.AreEqual("LongArray().filter(function(item){return ko.unwrap(item)%2==0;})", result);
}

Expand All @@ -382,12 +384,21 @@ public void JsTranslator_NestedEnumerableMethods()
}

[TestMethod]
public void JsTranslator_EnumerableSelect()
[DataRow("Enumerable.Select(LongArray, (long item) => -item)", DisplayName = "Regular call of Enumerable.Select")]
[DataRow("LongArray.Select((long item) => -item)", DisplayName = "Syntax sugar - extension method")]
public void JsTranslator_EnumerableSelect(string binding)
{
var result = CompileBinding("Enumerable.Select(LongArray, (long item) => -item)", new[] { typeof(TestViewModel) });
var result = CompileBinding(binding, new[] { typeof(TestViewModel) });
Assert.AreEqual("LongArray().map(function(item){return -ko.unwrap(item);})", result);
}

[TestMethod]
public void JsTranslator_ValidMethod_UnsupportedTranslation()
{
Assert.ThrowsException<NotSupportedException>(() =>
CompileBinding("Enumerable.Skip<long>(LongArray, 2)", new[] { typeof(TestViewModel) }));
}

[TestMethod]
public void JavascriptCompilation_GuidToString()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public string CompileBinding(string expression, bool niceMode, Type[] contexts,
var options = BindingParserOptions.Create<ValueBindingExpression>()
.AddImports(configuration.Markup.ImportedNamespaces);

var parser = new BindingExpressionBuilder(configuration.ServiceProvider.GetRequiredService<CompiledAssemblyCache>());
var parser = new BindingExpressionBuilder(configuration.ServiceProvider.GetRequiredService<CompiledAssemblyCache>(), configuration.ServiceProvider.GetRequiredService<MemberExpressionFactory>());
var expressionTree = parser.ParseWithLambdaConversion(expression, context, options, expectedType);
var jsExpression =
configuration.ServiceProvider.GetRequiredService<StaticCommandBindingCompiler>().CompileToJavascript(context, expressionTree);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ public void ResolvedTree_ViewModel_InvalidAssemblyQualified()
{
var root = ParseSource(@"@viewModel System.String, whatever");
Assert.IsTrue(root.Directives.Any(d => d.Value.Any(dd => dd.DothtmlNode.HasNodeErrors)));
Assert.AreEqual(typeof(ExpressionHelper.UnknownTypeSentinel), root.DataContextTypeStack.DataContextType);
Assert.AreEqual(typeof(UnknownTypeSentinel), root.DataContextTypeStack.DataContextType);
}

private ResolvedBinding[] GetLiteralBindings(ResolvedContentNode node) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ namespace DotVVM.Framework.Compilation.Binding
public class BindingExpressionBuilder : IBindingExpressionBuilder
{
private readonly CompiledAssemblyCache compiledAssemblyCache;
private readonly MemberExpressionFactory memberExpressionFactory;

public BindingExpressionBuilder(CompiledAssemblyCache compiledAssemblyCache)
public BindingExpressionBuilder(CompiledAssemblyCache compiledAssemblyCache, MemberExpressionFactory memberExpressionFactory)
{
this.compiledAssemblyCache = compiledAssemblyCache;
this.memberExpressionFactory = memberExpressionFactory;
}

public Expression Parse(string expression, DataContextStack dataContexts, BindingParserOptions options, params KeyValuePair<string, Expression>[] additionalSymbols)
Expand Down Expand Up @@ -49,7 +51,7 @@ public Expression Parse(string expression, DataContextStack dataContexts, Bindin
symbols = symbols.AddSymbols(options.ExtensionParameters.Select(p => CreateParameter(dataContexts, p.Identifier, p)));
symbols = symbols.AddSymbols(additionalSymbols);

var visitor = new ExpressionBuildingVisitor(symbols);
var visitor = new ExpressionBuildingVisitor(symbols, memberExpressionFactory);
visitor.Scope = symbols.Resolve(options.ScopeParameter);
return visitor.Visit(node);
}
Expand Down Expand Up @@ -116,7 +118,7 @@ static ParameterExpression CreateParameter(DataContextStack stackItem, string na
(extensionParameter == null
? stackItem.DataContextType
: ResolvedTypeDescriptor.ToSystemType(extensionParameter.ParameterType))
?? typeof(ExpressionHelper.UnknownTypeSentinel)
?? typeof(UnknownTypeSentinel)
, name)
.AddParameterAnnotation(new BindingParameterAnnotation(stackItem, extensionParameter));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Text;
using System.Threading.Tasks;

namespace DotVVM.Framework.Compilation.Binding
{
public class DefaultExtensionsProvider : IExtensionsProvider
{
private readonly List<Type> typesLookup;

public DefaultExtensionsProvider()
{
typesLookup = new List<Type>();
AddTypeForExtensionsLookup(typeof(Enumerable));
}

protected void AddTypeForExtensionsLookup(Type type)
{
typesLookup.Add(type);
}

public virtual IEnumerable<MethodInfo> GetExtensionMethods()
{
foreach (var registeredType in typesLookup)
foreach (var method in registeredType.GetMethods(BindingFlags.Public | BindingFlags.Static))
yield return method;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ public class ExpressionBuildingVisitor : BindingParserNodeVisitor<Expression>
public bool ResolveOnlyTypeName { get; set; }

private List<Exception> currentErrors;
private readonly MemberExpressionFactory memberExpressionFactory;

public ExpressionBuildingVisitor(TypeRegistry registry)
public ExpressionBuildingVisitor(TypeRegistry registry, MemberExpressionFactory memberExpressionFactory)
{
Registry = registry;
this.memberExpressionFactory = memberExpressionFactory;
}

protected T HandleErrors<T, TNode>(TNode node, Func<TNode, T> action, string defaultErrorMessage = "Binding compilation failed", bool allowResultNull = true)
Expand Down Expand Up @@ -120,7 +122,7 @@ protected override Expression VisitUnaryOperator(UnaryOperatorBindingParserNode
default:
throw new NotSupportedException($"unary operator { node.Operator } is not supported");
}
return ExpressionHelper.GetUnaryOperator(operand, eop);
return memberExpressionFactory.GetUnaryOperator(operand, eop);
}

protected override Expression VisitBinaryOperator(BinaryOperatorBindingParserNode node)
Expand Down Expand Up @@ -188,7 +190,7 @@ protected override Expression VisitBinaryOperator(BinaryOperatorBindingParserNod
throw new NotSupportedException($"unary operator { node.Operator } is not supported");
}

return ExpressionHelper.GetBinaryOperator(left, right, eop);
return memberExpressionFactory.GetBinaryOperator(left, right, eop);
}

protected override Expression VisitArrayAccess(ArrayAccessBindingParserNode node)
Expand All @@ -210,7 +212,7 @@ protected override Expression VisitFunctionCall(FunctionCallBindingParserNode no
}
ThrowOnErrors();

return ExpressionHelper.Call(target, args);
return memberExpressionFactory.Call(target, args);
}

protected override Expression VisitSimpleName(SimpleNameBindingParserNode node)
Expand Down Expand Up @@ -260,7 +262,7 @@ protected override Expression VisitMemberAccess(MemberAccessBindingParserNode no
return resolvedTypeExpression;
}

return ExpressionHelper.GetMember(target, nameNode.Name, typeParameters, onlyMemberTypes: ResolveOnlyTypeName);
return memberExpressionFactory.GetMember(target, nameNode.Name, typeParameters, onlyMemberTypes: ResolveOnlyTypeName);
}

protected override Expression VisitGenericName(GenericNameBindingParserNode node)
Expand Down Expand Up @@ -336,11 +338,11 @@ private Expression GetMemberOrTypeExpression(IdentifierNameBindingParserNode nod
var expr =
Scope == null
? Registry.Resolve(node.Name, throwOnNotFound: false)
: (ExpressionHelper.GetMember(Scope, node.Name, typeParameters, throwExceptions: false, onlyMemberTypes: ResolveOnlyTypeName)
: (memberExpressionFactory.GetMember(Scope, node.Name, typeParameters, throwExceptions: false, onlyMemberTypes: ResolveOnlyTypeName)
?? Registry.Resolve(node.Name, throwOnNotFound: false));

if (expr == null) return new UnknownStaticClassIdentifierExpression(node.Name);
if (expr is ParameterExpression && expr.Type == typeof(ExpressionHelper.UnknownTypeSentinel)) throw new Exception($"Type of '{expr}' could not be resolved.");
if (expr is ParameterExpression && expr.Type == typeof(UnknownTypeSentinel)) throw new Exception($"Type of '{expr}' could not be resolved.");
return expr;
}

Expand Down
Loading