Skip to content

Commit

Permalink
Merge pull request #931 from riganti/feature/lambdas-extension-methods
Browse files Browse the repository at this point in the history
Add support for using extension methods in bindings
  • Loading branch information
quigamdev authored Feb 23, 2021
2 parents d26226d + 1997ff5 commit 3280b1a
Show file tree
Hide file tree
Showing 22 changed files with 656 additions and 434 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public class BindingCompilationTests
private BindingCompilationService bindingService;

[TestInitialize]
public void INIT()
public void Init()
{
this.configuration = DotvvmTestHelper.DefaultConfig;
this.bindingService = configuration.ServiceProvider.GetRequiredService<BindingCompilationService>();
Expand Down Expand Up @@ -200,6 +200,14 @@ public void BindingCompiler_Invalid_LambdaParameters(string expr)
Assert.ThrowsException<AggregateException>(() => ExecuteBinding(expr, viewModel));
}

[TestMethod]
public void BindingCompiler_Valid_ExtensionMethods()
{
var viewModel = new TestViewModel();
var result = (long[])ExecuteBinding("LongArray.Where((long item) => item % 2 != 0).ToArray()", viewModel);
CollectionAssert.AreEqual(viewModel.LongArray.Where(item => item % 2 != 0).ToArray(), result);
}

class MoqComponent : DotvvmBindableObject
{
public object Property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public class CommandResolverTests
private BindingCompilationService bindingService;

[TestInitialize]
public void INIT()
public void Init()
{
this.configuration = DotvvmTestHelper.DefaultConfig;
this.bindingService = configuration.ServiceProvider.GetRequiredService<BindingCompilationService>();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
using System;
using System.Collections.Generic;
using System.Linq.Expressions;
using System.Text;
using DotVVM.Framework.Compilation.Binding;
using DotVVM.Framework.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.VisualStudio.TestTools.UnitTesting;

namespace DotVVM.Framework.Tests.Common.Binding
{
[TestClass]
public class CustomExtensionMethodTests
{
private DotvvmConfiguration configuration;
private MemberExpressionFactory memberExpressionFactory;

[TestInitialize]
public void Init()
{
this.configuration = DotvvmTestHelper.CreateConfiguration(services => services.AddScoped<IExtensionsProvider, TestExtensionsProvider>());
this.memberExpressionFactory = configuration.ServiceProvider.GetRequiredService<MemberExpressionFactory>();
}

[TestMethod]
public void Call_FindCustomExtensionMethod()
{
var target = new MethodGroupExpression()
{
MethodName = nameof(TestExtensions.Increment),
Target = Expression.Constant(11)
};

var expression = memberExpressionFactory.Call(target, Array.Empty<Expression>());
var result = Expression.Lambda<Func<int>>(expression).Compile().Invoke();
Assert.AreEqual(12, result);
}
}

static class TestExtensions
{
public static int Increment(this int number)
=> ++number;
}

class TestExtensionsProvider : DefaultExtensionsProvider
{
public TestExtensionsProvider()
{
AddTypeForExtensionsLookup(typeof(TestExtensions));
}
}
}
19 changes: 14 additions & 5 deletions src/DotVVM.Framework.Tests.Common/Binding/ExpressionHelperTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using DotVVM.Framework.Utils;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CSharp.RuntimeBinder;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.VisualStudio.TestTools.UnitTesting;

namespace DotVVM.Framework.Tests.Common.Binding
Expand All @@ -17,13 +18,21 @@ namespace DotVVM.Framework.Tests.Common.Binding
public class ExpressionHelperTests
{
public TestContext TestContext { get; set; }
private MemberExpressionFactory memberExpressionFactory;

[TestInitialize]
public void Init()
{
var configuration = DotvvmTestHelper.CreateConfiguration();
memberExpressionFactory = configuration.ServiceProvider.GetRequiredService<MemberExpressionFactory>();
}

[TestMethod]
public void UpdateMember_GetValue()
{
var cP = Expression.Parameter(typeof(DotvvmControl), "c");
var newValueP = Expression.Parameter(typeof(object), "newValue");
var updateExpr = ExpressionHelper.UpdateMember(ExpressionUtils.Replace((DotvvmControl c) => c.GetValue(DotvvmBindableObject.DataContextProperty, true), cP), newValueP);
var updateExpr = memberExpressionFactory.UpdateMember(ExpressionUtils.Replace((DotvvmControl c) => c.GetValue(DotvvmBindableObject.DataContextProperty, true), cP), newValueP);
Assert.IsNotNull(updateExpr);
Assert.AreEqual("c.SetValue(DotvvmBindableObject.DataContextProperty, newValue)", updateExpr.ToString());
}
Expand All @@ -33,7 +42,7 @@ public void UpdateMember_NormalProperty()
{
var vmP = Expression.Parameter(typeof(Tests.Binding.TestViewModel), "vm");
var newValueP = Expression.Parameter(typeof(DateTime), "newValue");
var updateExpr = ExpressionHelper.UpdateMember(ExpressionUtils.Replace((Tests.Binding.TestViewModel c) => c.DateFrom, vmP), newValueP);
var updateExpr = memberExpressionFactory.UpdateMember(ExpressionUtils.Replace((Tests.Binding.TestViewModel c) => c.DateFrom, vmP), newValueP);
Assert.IsNotNull(updateExpr);
Assert.AreEqual("(vm.DateFrom = Convert(newValue, Nullable`1))", updateExpr.ToString());
}
Expand All @@ -43,7 +52,7 @@ public void UpdateMember_ReadOnlyProperty()
{
var vmP = Expression.Parameter(typeof(Tests.Binding.TestViewModel), "vm");
var newValueP = Expression.Parameter(typeof(long[]), "newValue");
var updateExpr = ExpressionHelper.UpdateMember(ExpressionUtils.Replace((Tests.Binding.TestViewModel c) => c.LongArray, vmP), newValueP);
var updateExpr = memberExpressionFactory.UpdateMember(ExpressionUtils.Replace((Tests.Binding.TestViewModel c) => c.LongArray, vmP), newValueP);
Assert.IsNull(updateExpr);
}

Expand All @@ -57,7 +66,7 @@ public void Call_FindOverload_Generic_FirstLevel(Type resultIdentifierType, Type
Call_FindOverload_Generic(typeof(MethodsGenericArgumentsResolvingSampleObject), MethodsGenericArgumentsResolvingSampleObject.MethodName, new[] { argType }, resultIdentifierType, expectedGenericArgs);
}

private static void Call_FindOverload_Generic(Type targetType, string methodName, Type[] argTypes, Type resultIdentifierType, Type[] expectedGenericArgs)
private void Call_FindOverload_Generic(Type targetType, string methodName, Type[] argTypes, Type resultIdentifierType, Type[] expectedGenericArgs)
{
Expression target = new MethodGroupExpression() {
MethodName = methodName,
Expand All @@ -66,7 +75,7 @@ private static void Call_FindOverload_Generic(Type targetType, string methodName

var j = 0;
var arguments = argTypes.Select(s => Expression.Parameter(s, $"param_{j++}")).ToArray();
var expression = ExpressionHelper.Call(target, arguments) as MethodCallExpression;
var expression = memberExpressionFactory.Call(target, arguments) as MethodCallExpression;
Assert.IsNotNull(expression);
Assert.AreEqual(resultIdentifierType, expression.Method.GetResultType());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public class GenericPropertyResolverTests
private BindingCompilationService bindingService;

[TestInitialize]
public void INIT()
public void Init()
{
this.configuration = DotvvmTestHelper.DefaultConfig;
this.bindingService = configuration.ServiceProvider.GetRequiredService<BindingCompilationService>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public class JavascriptCompilationTests
private BindingCompilationService bindingService;

[TestInitialize]
public void INIT()
public void Init()
{
this.configuration = DotvvmTestHelper.CreateConfiguration();
configuration.RegisterApiClient(typeof(TestApiClient), "http://server/api", "./apiscript.js", "_testApi");
Expand All @@ -47,7 +47,7 @@ public string CompileBinding(string expression, Type[] contexts, Type expectedTy
{
context = DataContextStack.Create(contexts[i], context);
}
var parser = new BindingExpressionBuilder(configuration.ServiceProvider.GetRequiredService<CompiledAssemblyCache>());
var parser = new BindingExpressionBuilder(configuration.ServiceProvider.GetRequiredService<CompiledAssemblyCache>(), configuration.ServiceProvider.GetRequiredService<MemberExpressionFactory>());
var parsedExpression = parser.ParseWithLambdaConversion(expression, context, BindingParserOptions.Create<ValueBindingExpression>(), expectedType);
var expressionTree =
TypeConversion.MagicLambdaConversion(parsedExpression, expectedType) ??
Expand Down Expand Up @@ -368,9 +368,11 @@ public void JsTranslator_ArrayIndexer()
}

[TestMethod]
public void JsTranslator_EnumerableWhere()
[DataRow("Enumerable.Where(LongArray, (long item) => item % 2 == 0)", DisplayName = "Regular call of Enumerable.Where")]
[DataRow("LongArray.Where((long item) => item % 2 == 0)", DisplayName = "Syntax sugar - extension method")]
public void JsTranslator_EnumerableWhere(string binding)
{
var result = CompileBinding("Enumerable.Where(LongArray, (long item) => item % 2 == 0)", new[] { typeof(TestViewModel) });
var result = CompileBinding(binding, new[] { typeof(TestViewModel) });
Assert.AreEqual("LongArray().filter(function(item){return ko.unwrap(item)%2==0;})", result);
}

Expand All @@ -382,12 +384,21 @@ public void JsTranslator_NestedEnumerableMethods()
}

[TestMethod]
public void JsTranslator_EnumerableSelect()
[DataRow("Enumerable.Select(LongArray, (long item) => -item)", DisplayName = "Regular call of Enumerable.Select")]
[DataRow("LongArray.Select((long item) => -item)", DisplayName = "Syntax sugar - extension method")]
public void JsTranslator_EnumerableSelect(string binding)
{
var result = CompileBinding("Enumerable.Select(LongArray, (long item) => -item)", new[] { typeof(TestViewModel) });
var result = CompileBinding(binding, new[] { typeof(TestViewModel) });
Assert.AreEqual("LongArray().map(function(item){return -ko.unwrap(item);})", result);
}

[TestMethod]
public void JsTranslator_ValidMethod_UnsupportedTranslation()
{
Assert.ThrowsException<NotSupportedException>(() =>
CompileBinding("Enumerable.Skip<long>(LongArray, 2)", new[] { typeof(TestViewModel) }));
}

[TestMethod]
public void JavascriptCompilation_GuidToString()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public string CompileBinding(string expression, bool niceMode, Type[] contexts,
var options = BindingParserOptions.Create<ValueBindingExpression>()
.AddImports(configuration.Markup.ImportedNamespaces);

var parser = new BindingExpressionBuilder(configuration.ServiceProvider.GetRequiredService<CompiledAssemblyCache>());
var parser = new BindingExpressionBuilder(configuration.ServiceProvider.GetRequiredService<CompiledAssemblyCache>(), configuration.ServiceProvider.GetRequiredService<MemberExpressionFactory>());
var expressionTree = parser.ParseWithLambdaConversion(expression, context, options, expectedType);
var jsExpression =
configuration.ServiceProvider.GetRequiredService<StaticCommandBindingCompiler>().CompileToJavascript(context, expressionTree);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ public void ResolvedTree_ViewModel_InvalidAssemblyQualified()
{
var root = ParseSource(@"@viewModel System.String, whatever");
Assert.IsTrue(root.Directives.Any(d => d.Value.Any(dd => dd.DothtmlNode.HasNodeErrors)));
Assert.AreEqual(typeof(ExpressionHelper.UnknownTypeSentinel), root.DataContextTypeStack.DataContextType);
Assert.AreEqual(typeof(UnknownTypeSentinel), root.DataContextTypeStack.DataContextType);
}

private ResolvedBinding[] GetLiteralBindings(ResolvedContentNode node) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ namespace DotVVM.Framework.Compilation.Binding
public class BindingExpressionBuilder : IBindingExpressionBuilder
{
private readonly CompiledAssemblyCache compiledAssemblyCache;
private readonly MemberExpressionFactory memberExpressionFactory;

public BindingExpressionBuilder(CompiledAssemblyCache compiledAssemblyCache)
public BindingExpressionBuilder(CompiledAssemblyCache compiledAssemblyCache, MemberExpressionFactory memberExpressionFactory)
{
this.compiledAssemblyCache = compiledAssemblyCache;
this.memberExpressionFactory = memberExpressionFactory;
}

public Expression Parse(string expression, DataContextStack dataContexts, BindingParserOptions options, params KeyValuePair<string, Expression>[] additionalSymbols)
Expand Down Expand Up @@ -49,7 +51,7 @@ public Expression Parse(string expression, DataContextStack dataContexts, Bindin
symbols = symbols.AddSymbols(options.ExtensionParameters.Select(p => CreateParameter(dataContexts, p.Identifier, p)));
symbols = symbols.AddSymbols(additionalSymbols);

var visitor = new ExpressionBuildingVisitor(symbols);
var visitor = new ExpressionBuildingVisitor(symbols, memberExpressionFactory);
visitor.Scope = symbols.Resolve(options.ScopeParameter);
return visitor.Visit(node);
}
Expand Down Expand Up @@ -116,7 +118,7 @@ static ParameterExpression CreateParameter(DataContextStack stackItem, string na
(extensionParameter == null
? stackItem.DataContextType
: ResolvedTypeDescriptor.ToSystemType(extensionParameter.ParameterType))
?? typeof(ExpressionHelper.UnknownTypeSentinel)
?? typeof(UnknownTypeSentinel)
, name)
.AddParameterAnnotation(new BindingParameterAnnotation(stackItem, extensionParameter));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading.Tasks;

namespace DotVVM.Framework.Compilation.Binding
{
public class DefaultExtensionsProvider : IExtensionsProvider
{
private readonly List<Type> typesLookup;
private readonly List<MethodInfo> methodsLookup;

public DefaultExtensionsProvider()
{
typesLookup = new List<Type>();
methodsLookup = new List<MethodInfo>();
AddTypeForExtensionsLookup(typeof(Enumerable));
}

protected void AddTypeForExtensionsLookup(Type type)
{
foreach (var method in type.GetMethods(BindingFlags.Public | BindingFlags.Static).Where(m => m.GetCustomAttribute(typeof(ExtensionAttribute)) != null))
methodsLookup.Add(method);

typesLookup.Add(type);
}

public virtual IEnumerable<MethodInfo> GetExtensionMethods()
{
return methodsLookup;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ public class ExpressionBuildingVisitor : BindingParserNodeVisitor<Expression>
public bool ResolveOnlyTypeName { get; set; }

private List<Exception> currentErrors;
private readonly MemberExpressionFactory memberExpressionFactory;

public ExpressionBuildingVisitor(TypeRegistry registry)
public ExpressionBuildingVisitor(TypeRegistry registry, MemberExpressionFactory memberExpressionFactory)
{
Registry = registry;
this.memberExpressionFactory = memberExpressionFactory;
}

protected T HandleErrors<T, TNode>(TNode node, Func<TNode, T> action, string defaultErrorMessage = "Binding compilation failed", bool allowResultNull = true)
Expand Down Expand Up @@ -120,7 +122,7 @@ protected override Expression VisitUnaryOperator(UnaryOperatorBindingParserNode
default:
throw new NotSupportedException($"unary operator { node.Operator } is not supported");
}
return ExpressionHelper.GetUnaryOperator(operand, eop);
return memberExpressionFactory.GetUnaryOperator(operand, eop);
}

protected override Expression VisitBinaryOperator(BinaryOperatorBindingParserNode node)
Expand Down Expand Up @@ -188,7 +190,7 @@ protected override Expression VisitBinaryOperator(BinaryOperatorBindingParserNod
throw new NotSupportedException($"unary operator { node.Operator } is not supported");
}

return ExpressionHelper.GetBinaryOperator(left, right, eop);
return memberExpressionFactory.GetBinaryOperator(left, right, eop);
}

protected override Expression VisitArrayAccess(ArrayAccessBindingParserNode node)
Expand All @@ -210,7 +212,7 @@ protected override Expression VisitFunctionCall(FunctionCallBindingParserNode no
}
ThrowOnErrors();

return ExpressionHelper.Call(target, args);
return memberExpressionFactory.Call(target, args);
}

protected override Expression VisitSimpleName(SimpleNameBindingParserNode node)
Expand Down Expand Up @@ -260,7 +262,7 @@ protected override Expression VisitMemberAccess(MemberAccessBindingParserNode no
return resolvedTypeExpression;
}

return ExpressionHelper.GetMember(target, nameNode.Name, typeParameters, onlyMemberTypes: ResolveOnlyTypeName);
return memberExpressionFactory.GetMember(target, nameNode.Name, typeParameters, onlyMemberTypes: ResolveOnlyTypeName);
}

protected override Expression VisitGenericName(GenericNameBindingParserNode node)
Expand Down Expand Up @@ -336,11 +338,11 @@ private Expression GetMemberOrTypeExpression(IdentifierNameBindingParserNode nod
var expr =
Scope == null
? Registry.Resolve(node.Name, throwOnNotFound: false)
: (ExpressionHelper.GetMember(Scope, node.Name, typeParameters, throwExceptions: false, onlyMemberTypes: ResolveOnlyTypeName)
: (memberExpressionFactory.GetMember(Scope, node.Name, typeParameters, throwExceptions: false, onlyMemberTypes: ResolveOnlyTypeName)
?? Registry.Resolve(node.Name, throwOnNotFound: false));

if (expr == null) return new UnknownStaticClassIdentifierExpression(node.Name);
if (expr is ParameterExpression && expr.Type == typeof(ExpressionHelper.UnknownTypeSentinel)) throw new Exception($"Type of '{expr}' could not be resolved.");
if (expr is ParameterExpression && expr.Type == typeof(UnknownTypeSentinel)) throw new Exception($"Type of '{expr}' could not be resolved.");
return expr;
}

Expand Down
Loading

0 comments on commit 3280b1a

Please sign in to comment.