diff --git a/src/DotVVM.Framework.Tests.Common/Binding/BindingCompilationTests.cs b/src/DotVVM.Framework.Tests.Common/Binding/BindingCompilationTests.cs index d8fd602c85..591522a6fc 100755 --- a/src/DotVVM.Framework.Tests.Common/Binding/BindingCompilationTests.cs +++ b/src/DotVVM.Framework.Tests.Common/Binding/BindingCompilationTests.cs @@ -28,7 +28,7 @@ public class BindingCompilationTests private BindingCompilationService bindingService; [TestInitialize] - public void INIT() + public void Init() { this.configuration = DotvvmTestHelper.DefaultConfig; this.bindingService = configuration.ServiceProvider.GetRequiredService(); @@ -200,6 +200,14 @@ public void BindingCompiler_Invalid_LambdaParameters(string expr) Assert.ThrowsException(() => ExecuteBinding(expr, viewModel)); } + [TestMethod] + public void BindingCompiler_Valid_ExtensionMethods() + { + var viewModel = new TestViewModel(); + var result = (long[])ExecuteBinding("LongArray.Where((long item) => item % 2 != 0).ToArray()", viewModel); + CollectionAssert.AreEqual(viewModel.LongArray.Where(item => item % 2 != 0).ToArray(), result); + } + class MoqComponent : DotvvmBindableObject { public object Property diff --git a/src/DotVVM.Framework.Tests.Common/Binding/CommandResolverTests.cs b/src/DotVVM.Framework.Tests.Common/Binding/CommandResolverTests.cs index b485631e3b..01abc3dc4c 100644 --- a/src/DotVVM.Framework.Tests.Common/Binding/CommandResolverTests.cs +++ b/src/DotVVM.Framework.Tests.Common/Binding/CommandResolverTests.cs @@ -24,7 +24,7 @@ public class CommandResolverTests private BindingCompilationService bindingService; [TestInitialize] - public void INIT() + public void Init() { this.configuration = DotvvmTestHelper.DefaultConfig; this.bindingService = configuration.ServiceProvider.GetRequiredService(); diff --git a/src/DotVVM.Framework.Tests.Common/Binding/CustomExtensionMethodTests.cs b/src/DotVVM.Framework.Tests.Common/Binding/CustomExtensionMethodTests.cs new file mode 100644 index 0000000000..9d9474e6f1 --- /dev/null +++ b/src/DotVVM.Framework.Tests.Common/Binding/CustomExtensionMethodTests.cs @@ -0,0 +1,53 @@ +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using System.Text; +using DotVVM.Framework.Compilation.Binding; +using DotVVM.Framework.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace DotVVM.Framework.Tests.Common.Binding +{ + [TestClass] + public class CustomExtensionMethodTests + { + private DotvvmConfiguration configuration; + private MemberExpressionFactory memberExpressionFactory; + + [TestInitialize] + public void Init() + { + this.configuration = DotvvmTestHelper.CreateConfiguration(services => services.AddScoped()); + this.memberExpressionFactory = configuration.ServiceProvider.GetRequiredService(); + } + + [TestMethod] + public void Call_FindCustomExtensionMethod() + { + var target = new MethodGroupExpression() + { + MethodName = nameof(TestExtensions.Increment), + Target = Expression.Constant(11) + }; + + var expression = memberExpressionFactory.Call(target, Array.Empty()); + var result = Expression.Lambda>(expression).Compile().Invoke(); + Assert.AreEqual(12, result); + } + } + + static class TestExtensions + { + public static int Increment(this int number) + => ++number; + } + + class TestExtensionsProvider : DefaultExtensionsProvider + { + public TestExtensionsProvider() + { + AddTypeForExtensionsLookup(typeof(TestExtensions)); + } + } +} diff --git a/src/DotVVM.Framework.Tests.Common/Binding/ExpressionHelperTests.cs b/src/DotVVM.Framework.Tests.Common/Binding/ExpressionHelperTests.cs index 050fcfaefd..2c27949dff 100644 --- a/src/DotVVM.Framework.Tests.Common/Binding/ExpressionHelperTests.cs +++ b/src/DotVVM.Framework.Tests.Common/Binding/ExpressionHelperTests.cs @@ -9,6 +9,7 @@ using DotVVM.Framework.Utils; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CSharp.RuntimeBinder; +using Microsoft.Extensions.DependencyInjection; using Microsoft.VisualStudio.TestTools.UnitTesting; namespace DotVVM.Framework.Tests.Common.Binding @@ -17,13 +18,21 @@ namespace DotVVM.Framework.Tests.Common.Binding public class ExpressionHelperTests { public TestContext TestContext { get; set; } + private MemberExpressionFactory memberExpressionFactory; + + [TestInitialize] + public void Init() + { + var configuration = DotvvmTestHelper.CreateConfiguration(); + memberExpressionFactory = configuration.ServiceProvider.GetRequiredService(); + } [TestMethod] public void UpdateMember_GetValue() { var cP = Expression.Parameter(typeof(DotvvmControl), "c"); var newValueP = Expression.Parameter(typeof(object), "newValue"); - var updateExpr = ExpressionHelper.UpdateMember(ExpressionUtils.Replace((DotvvmControl c) => c.GetValue(DotvvmBindableObject.DataContextProperty, true), cP), newValueP); + var updateExpr = memberExpressionFactory.UpdateMember(ExpressionUtils.Replace((DotvvmControl c) => c.GetValue(DotvvmBindableObject.DataContextProperty, true), cP), newValueP); Assert.IsNotNull(updateExpr); Assert.AreEqual("c.SetValue(DotvvmBindableObject.DataContextProperty, newValue)", updateExpr.ToString()); } @@ -33,7 +42,7 @@ public void UpdateMember_NormalProperty() { var vmP = Expression.Parameter(typeof(Tests.Binding.TestViewModel), "vm"); var newValueP = Expression.Parameter(typeof(DateTime), "newValue"); - var updateExpr = ExpressionHelper.UpdateMember(ExpressionUtils.Replace((Tests.Binding.TestViewModel c) => c.DateFrom, vmP), newValueP); + var updateExpr = memberExpressionFactory.UpdateMember(ExpressionUtils.Replace((Tests.Binding.TestViewModel c) => c.DateFrom, vmP), newValueP); Assert.IsNotNull(updateExpr); Assert.AreEqual("(vm.DateFrom = Convert(newValue, Nullable`1))", updateExpr.ToString()); } @@ -43,7 +52,7 @@ public void UpdateMember_ReadOnlyProperty() { var vmP = Expression.Parameter(typeof(Tests.Binding.TestViewModel), "vm"); var newValueP = Expression.Parameter(typeof(long[]), "newValue"); - var updateExpr = ExpressionHelper.UpdateMember(ExpressionUtils.Replace((Tests.Binding.TestViewModel c) => c.LongArray, vmP), newValueP); + var updateExpr = memberExpressionFactory.UpdateMember(ExpressionUtils.Replace((Tests.Binding.TestViewModel c) => c.LongArray, vmP), newValueP); Assert.IsNull(updateExpr); } @@ -57,7 +66,7 @@ public void Call_FindOverload_Generic_FirstLevel(Type resultIdentifierType, Type Call_FindOverload_Generic(typeof(MethodsGenericArgumentsResolvingSampleObject), MethodsGenericArgumentsResolvingSampleObject.MethodName, new[] { argType }, resultIdentifierType, expectedGenericArgs); } - private static void Call_FindOverload_Generic(Type targetType, string methodName, Type[] argTypes, Type resultIdentifierType, Type[] expectedGenericArgs) + private void Call_FindOverload_Generic(Type targetType, string methodName, Type[] argTypes, Type resultIdentifierType, Type[] expectedGenericArgs) { Expression target = new MethodGroupExpression() { MethodName = methodName, @@ -66,7 +75,7 @@ private static void Call_FindOverload_Generic(Type targetType, string methodName var j = 0; var arguments = argTypes.Select(s => Expression.Parameter(s, $"param_{j++}")).ToArray(); - var expression = ExpressionHelper.Call(target, arguments) as MethodCallExpression; + var expression = memberExpressionFactory.Call(target, arguments) as MethodCallExpression; Assert.IsNotNull(expression); Assert.AreEqual(resultIdentifierType, expression.Method.GetResultType()); diff --git a/src/DotVVM.Framework.Tests.Common/Binding/GenericPropertyResolverTests.cs b/src/DotVVM.Framework.Tests.Common/Binding/GenericPropertyResolverTests.cs index 8f5cc0a8d1..c1fe25b7c1 100644 --- a/src/DotVVM.Framework.Tests.Common/Binding/GenericPropertyResolverTests.cs +++ b/src/DotVVM.Framework.Tests.Common/Binding/GenericPropertyResolverTests.cs @@ -16,7 +16,7 @@ public class GenericPropertyResolverTests private BindingCompilationService bindingService; [TestInitialize] - public void INIT() + public void Init() { this.configuration = DotvvmTestHelper.DefaultConfig; this.bindingService = configuration.ServiceProvider.GetRequiredService(); diff --git a/src/DotVVM.Framework.Tests.Common/Binding/JavascriptCompilationTests.cs b/src/DotVVM.Framework.Tests.Common/Binding/JavascriptCompilationTests.cs index 8527f51322..af865fd74f 100644 --- a/src/DotVVM.Framework.Tests.Common/Binding/JavascriptCompilationTests.cs +++ b/src/DotVVM.Framework.Tests.Common/Binding/JavascriptCompilationTests.cs @@ -28,7 +28,7 @@ public class JavascriptCompilationTests private BindingCompilationService bindingService; [TestInitialize] - public void INIT() + public void Init() { this.configuration = DotvvmTestHelper.CreateConfiguration(); configuration.RegisterApiClient(typeof(TestApiClient), "http://server/api", "./apiscript.js", "_testApi"); @@ -47,7 +47,7 @@ public string CompileBinding(string expression, Type[] contexts, Type expectedTy { context = DataContextStack.Create(contexts[i], context); } - var parser = new BindingExpressionBuilder(configuration.ServiceProvider.GetRequiredService()); + var parser = new BindingExpressionBuilder(configuration.ServiceProvider.GetRequiredService(), configuration.ServiceProvider.GetRequiredService()); var parsedExpression = parser.ParseWithLambdaConversion(expression, context, BindingParserOptions.Create(), expectedType); var expressionTree = TypeConversion.MagicLambdaConversion(parsedExpression, expectedType) ?? @@ -368,9 +368,11 @@ public void JsTranslator_ArrayIndexer() } [TestMethod] - public void JsTranslator_EnumerableWhere() + [DataRow("Enumerable.Where(LongArray, (long item) => item % 2 == 0)", DisplayName = "Regular call of Enumerable.Where")] + [DataRow("LongArray.Where((long item) => item % 2 == 0)", DisplayName = "Syntax sugar - extension method")] + public void JsTranslator_EnumerableWhere(string binding) { - var result = CompileBinding("Enumerable.Where(LongArray, (long item) => item % 2 == 0)", new[] { typeof(TestViewModel) }); + var result = CompileBinding(binding, new[] { typeof(TestViewModel) }); Assert.AreEqual("LongArray().filter(function(item){return ko.unwrap(item)%2==0;})", result); } @@ -382,12 +384,21 @@ public void JsTranslator_NestedEnumerableMethods() } [TestMethod] - public void JsTranslator_EnumerableSelect() + [DataRow("Enumerable.Select(LongArray, (long item) => -item)", DisplayName = "Regular call of Enumerable.Select")] + [DataRow("LongArray.Select((long item) => -item)", DisplayName = "Syntax sugar - extension method")] + public void JsTranslator_EnumerableSelect(string binding) { - var result = CompileBinding("Enumerable.Select(LongArray, (long item) => -item)", new[] { typeof(TestViewModel) }); + var result = CompileBinding(binding, new[] { typeof(TestViewModel) }); Assert.AreEqual("LongArray().map(function(item){return -ko.unwrap(item);})", result); } + [TestMethod] + public void JsTranslator_ValidMethod_UnsupportedTranslation() + { + Assert.ThrowsException(() => + CompileBinding("Enumerable.Skip(LongArray, 2)", new[] { typeof(TestViewModel) })); + } + [TestMethod] public void JavascriptCompilation_GuidToString() { diff --git a/src/DotVVM.Framework.Tests.Common/Binding/StaticCommandCompilationTests.cs b/src/DotVVM.Framework.Tests.Common/Binding/StaticCommandCompilationTests.cs index f3ac30bc4c..ce4f7e0cc9 100644 --- a/src/DotVVM.Framework.Tests.Common/Binding/StaticCommandCompilationTests.cs +++ b/src/DotVVM.Framework.Tests.Common/Binding/StaticCommandCompilationTests.cs @@ -70,7 +70,7 @@ public string CompileBinding(string expression, bool niceMode, Type[] contexts, var options = BindingParserOptions.Create() .AddImports(configuration.Markup.ImportedNamespaces); - var parser = new BindingExpressionBuilder(configuration.ServiceProvider.GetRequiredService()); + var parser = new BindingExpressionBuilder(configuration.ServiceProvider.GetRequiredService(), configuration.ServiceProvider.GetRequiredService()); var expressionTree = parser.ParseWithLambdaConversion(expression, context, options, expectedType); var jsExpression = configuration.ServiceProvider.GetRequiredService().CompileToJavascript(context, expressionTree); diff --git a/src/DotVVM.Framework.Tests.Common/Runtime/ControlTree/DefaultControlTreeResolverTests.cs b/src/DotVVM.Framework.Tests.Common/Runtime/ControlTree/DefaultControlTreeResolverTests.cs index a8bbd2d0aa..a29a36f5dd 100755 --- a/src/DotVVM.Framework.Tests.Common/Runtime/ControlTree/DefaultControlTreeResolverTests.cs +++ b/src/DotVVM.Framework.Tests.Common/Runtime/ControlTree/DefaultControlTreeResolverTests.cs @@ -649,7 +649,7 @@ public void ResolvedTree_ViewModel_InvalidAssemblyQualified() { var root = ParseSource(@"@viewModel System.String, whatever"); Assert.IsTrue(root.Directives.Any(d => d.Value.Any(dd => dd.DothtmlNode.HasNodeErrors))); - Assert.AreEqual(typeof(ExpressionHelper.UnknownTypeSentinel), root.DataContextTypeStack.DataContextType); + Assert.AreEqual(typeof(UnknownTypeSentinel), root.DataContextTypeStack.DataContextType); } private ResolvedBinding[] GetLiteralBindings(ResolvedContentNode node) => diff --git a/src/DotVVM.Framework/Compilation/Binding/BindingExpressionBuilder.cs b/src/DotVVM.Framework/Compilation/Binding/BindingExpressionBuilder.cs index 1ff72eeed4..3473974eec 100644 --- a/src/DotVVM.Framework/Compilation/Binding/BindingExpressionBuilder.cs +++ b/src/DotVVM.Framework/Compilation/Binding/BindingExpressionBuilder.cs @@ -17,10 +17,12 @@ namespace DotVVM.Framework.Compilation.Binding public class BindingExpressionBuilder : IBindingExpressionBuilder { private readonly CompiledAssemblyCache compiledAssemblyCache; + private readonly MemberExpressionFactory memberExpressionFactory; - public BindingExpressionBuilder(CompiledAssemblyCache compiledAssemblyCache) + public BindingExpressionBuilder(CompiledAssemblyCache compiledAssemblyCache, MemberExpressionFactory memberExpressionFactory) { this.compiledAssemblyCache = compiledAssemblyCache; + this.memberExpressionFactory = memberExpressionFactory; } public Expression Parse(string expression, DataContextStack dataContexts, BindingParserOptions options, params KeyValuePair[] additionalSymbols) @@ -49,7 +51,7 @@ public Expression Parse(string expression, DataContextStack dataContexts, Bindin symbols = symbols.AddSymbols(options.ExtensionParameters.Select(p => CreateParameter(dataContexts, p.Identifier, p))); symbols = symbols.AddSymbols(additionalSymbols); - var visitor = new ExpressionBuildingVisitor(symbols); + var visitor = new ExpressionBuildingVisitor(symbols, memberExpressionFactory); visitor.Scope = symbols.Resolve(options.ScopeParameter); return visitor.Visit(node); } @@ -116,7 +118,7 @@ static ParameterExpression CreateParameter(DataContextStack stackItem, string na (extensionParameter == null ? stackItem.DataContextType : ResolvedTypeDescriptor.ToSystemType(extensionParameter.ParameterType)) - ?? typeof(ExpressionHelper.UnknownTypeSentinel) + ?? typeof(UnknownTypeSentinel) , name) .AddParameterAnnotation(new BindingParameterAnnotation(stackItem, extensionParameter)); } diff --git a/src/DotVVM.Framework/Compilation/Binding/DefaultExtensionsProvider.cs b/src/DotVVM.Framework/Compilation/Binding/DefaultExtensionsProvider.cs new file mode 100644 index 0000000000..155844a41d --- /dev/null +++ b/src/DotVVM.Framework/Compilation/Binding/DefaultExtensionsProvider.cs @@ -0,0 +1,36 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading.Tasks; + +namespace DotVVM.Framework.Compilation.Binding +{ + public class DefaultExtensionsProvider : IExtensionsProvider + { + private readonly List typesLookup; + private readonly List methodsLookup; + + public DefaultExtensionsProvider() + { + typesLookup = new List(); + methodsLookup = new List(); + AddTypeForExtensionsLookup(typeof(Enumerable)); + } + + protected void AddTypeForExtensionsLookup(Type type) + { + foreach (var method in type.GetMethods(BindingFlags.Public | BindingFlags.Static).Where(m => m.GetCustomAttribute(typeof(ExtensionAttribute)) != null)) + methodsLookup.Add(method); + + typesLookup.Add(type); + } + + public virtual IEnumerable GetExtensionMethods() + { + return methodsLookup; + } + } +} diff --git a/src/DotVVM.Framework/Compilation/Binding/ExpressionBuildingVisitor.cs b/src/DotVVM.Framework/Compilation/Binding/ExpressionBuildingVisitor.cs index 44b6b410a1..e655219247 100644 --- a/src/DotVVM.Framework/Compilation/Binding/ExpressionBuildingVisitor.cs +++ b/src/DotVVM.Framework/Compilation/Binding/ExpressionBuildingVisitor.cs @@ -16,10 +16,12 @@ public class ExpressionBuildingVisitor : BindingParserNodeVisitor public bool ResolveOnlyTypeName { get; set; } private List currentErrors; + private readonly MemberExpressionFactory memberExpressionFactory; - public ExpressionBuildingVisitor(TypeRegistry registry) + public ExpressionBuildingVisitor(TypeRegistry registry, MemberExpressionFactory memberExpressionFactory) { Registry = registry; + this.memberExpressionFactory = memberExpressionFactory; } protected T HandleErrors(TNode node, Func action, string defaultErrorMessage = "Binding compilation failed", bool allowResultNull = true) @@ -120,7 +122,7 @@ protected override Expression VisitUnaryOperator(UnaryOperatorBindingParserNode default: throw new NotSupportedException($"unary operator { node.Operator } is not supported"); } - return ExpressionHelper.GetUnaryOperator(operand, eop); + return memberExpressionFactory.GetUnaryOperator(operand, eop); } protected override Expression VisitBinaryOperator(BinaryOperatorBindingParserNode node) @@ -188,7 +190,7 @@ protected override Expression VisitBinaryOperator(BinaryOperatorBindingParserNod throw new NotSupportedException($"unary operator { node.Operator } is not supported"); } - return ExpressionHelper.GetBinaryOperator(left, right, eop); + return memberExpressionFactory.GetBinaryOperator(left, right, eop); } protected override Expression VisitArrayAccess(ArrayAccessBindingParserNode node) @@ -210,7 +212,7 @@ protected override Expression VisitFunctionCall(FunctionCallBindingParserNode no } ThrowOnErrors(); - return ExpressionHelper.Call(target, args); + return memberExpressionFactory.Call(target, args); } protected override Expression VisitSimpleName(SimpleNameBindingParserNode node) @@ -260,7 +262,7 @@ protected override Expression VisitMemberAccess(MemberAccessBindingParserNode no return resolvedTypeExpression; } - return ExpressionHelper.GetMember(target, nameNode.Name, typeParameters, onlyMemberTypes: ResolveOnlyTypeName); + return memberExpressionFactory.GetMember(target, nameNode.Name, typeParameters, onlyMemberTypes: ResolveOnlyTypeName); } protected override Expression VisitGenericName(GenericNameBindingParserNode node) @@ -336,11 +338,11 @@ private Expression GetMemberOrTypeExpression(IdentifierNameBindingParserNode nod var expr = Scope == null ? Registry.Resolve(node.Name, throwOnNotFound: false) - : (ExpressionHelper.GetMember(Scope, node.Name, typeParameters, throwExceptions: false, onlyMemberTypes: ResolveOnlyTypeName) + : (memberExpressionFactory.GetMember(Scope, node.Name, typeParameters, throwExceptions: false, onlyMemberTypes: ResolveOnlyTypeName) ?? Registry.Resolve(node.Name, throwOnNotFound: false)); if (expr == null) return new UnknownStaticClassIdentifierExpression(node.Name); - if (expr is ParameterExpression && expr.Type == typeof(ExpressionHelper.UnknownTypeSentinel)) throw new Exception($"Type of '{expr}' could not be resolved."); + if (expr is ParameterExpression && expr.Type == typeof(UnknownTypeSentinel)) throw new Exception($"Type of '{expr}' could not be resolved."); return expr; } diff --git a/src/DotVVM.Framework/Compilation/Binding/ExpressionHelper.cs b/src/DotVVM.Framework/Compilation/Binding/ExpressionHelper.cs index feb3ee6f1e..f8f7684368 100644 --- a/src/DotVVM.Framework/Compilation/Binding/ExpressionHelper.cs +++ b/src/DotVVM.Framework/Compilation/Binding/ExpressionHelper.cs @@ -20,365 +20,6 @@ namespace DotVVM.Framework.Compilation.Binding public static class ExpressionHelper { - public static Expression GetMember(Expression target, string name, Type[] typeArguments = null, bool throwExceptions = true, bool onlyMemberTypes = false) - { - if (target is MethodGroupExpression) - throw new Exception("Can not access member on method group."); - - var type = target.Type; - if (type == typeof(UnknownTypeSentinel)) if (throwExceptions) throw new Exception($"Type of '{target}' could not be resolved."); else return null; - - var isStatic = target is StaticClassIdentifierExpression; - - var isGeneric = typeArguments != null && typeArguments.Length != 0; - var genericName = isGeneric ? $"{name}`{typeArguments.Length}" : name; - - if (!isGeneric && !onlyMemberTypes && typeof(DotvvmBindableObject).IsAssignableFrom(target.Type) && - GetDotvvmPropertyMember(target, name) is Expression result) return result; - - var members = type.GetAllMembers(BindingFlags.Public | (isStatic ? BindingFlags.Static : BindingFlags.Instance)) - .Where(m => ((isGeneric && m is TypeInfo) ? genericName : name) == m.Name) - .ToArray(); - - if (members.Length == 0) - { - if (throwExceptions) throw new Exception($"Could not find { (isStatic ? "static" : "instance") } member { name } on type { type.FullName }."); - else return null; - } - if (members.Length == 1) - { - if (!(members[0] is TypeInfo) && onlyMemberTypes) { throw new Exception("Only type names are supported."); } - - var instance = isStatic ? null : target; - if (members[0] is PropertyInfo) - { - var property = members[0] as PropertyInfo; - return Expression.Property(instance, property); - } - else if (members[0] is FieldInfo) - { - var field = members[0] as FieldInfo; - return Expression.Field(instance, field); - } - else if (members[0] is TypeInfo) - { - var nonGenericType = (TypeInfo)members[0]; - return isGeneric - ? new StaticClassIdentifierExpression(nonGenericType.MakeGenericType(typeArguments)) - : new StaticClassIdentifierExpression(nonGenericType.UnderlyingSystemType); - } - } - return new MethodGroupExpression() { MethodName = name, Target = target, TypeArgs = typeArguments }; - } - - static Expression GetDotvvmPropertyMember(Expression target, string name) - { - var property = DotvvmProperty.ResolveProperty(target.Type, name); - if (property == null) return null; - - var field = property.DeclaringType.GetField(property.Name + "Property", BindingFlags.Static | BindingFlags.Public); - if (field == null) return null; - - return Expression.Convert( - Expression.Call(target, "GetValue", Type.EmptyTypes, - Expression.Field(null, field), - Expression.Constant(true) - ), - property.PropertyType - ); - } - - /// - /// Creates an expression that updates the member inside with a - /// new . - /// - /// - /// Should contain a call to the - /// method, it will be - /// replaced with a - /// call. - /// - public static Expression UpdateMember(Expression node, Expression value) - { - if ((node.NodeType == ExpressionType.MemberAccess - && node is MemberExpression member - && member.Member is PropertyInfo property - && property.CanWrite) - || node.NodeType == ExpressionType.Parameter - || node.NodeType == ExpressionType.Index) - { - return Expression.Assign(node, Expression.Convert(value, node.Type)); - } - - var current = node; - while (current.NodeType == ExpressionType.Convert - && current is UnaryExpression unary) - { - current = unary.Operand; - } - - if (current.NodeType == ExpressionType.Call - && current is MethodCallExpression call - && call.Method.DeclaringType == typeof(DotvvmBindableObject) - && call.Method.Name == nameof(DotvvmBindableObject.GetValue) - && call.Arguments.Count == 2 - && call.Arguments[0].Type == typeof(DotvvmProperty) - && call.Arguments[1].Type == typeof(bool)) - { - var propertyArgument = call.Arguments[0]; - var setValue = typeof(DotvvmBindableObject) - .GetMethod(nameof(DotvvmBindableObject.SetValue), - new[] { typeof(DotvvmProperty), typeof(object) }); - return Expression.Call(call.Object, setValue, propertyArgument, value); - } - - return null; - } - - public static Expression Call(Expression target, Expression[] arguments) - { - if (target is MethodGroupExpression) - { - return ((MethodGroupExpression)target).CreateMethodCall(arguments); - } - return Expression.Invoke(target, arguments); - } - - public static Expression CallMethod(Expression target, BindingFlags flags, string name, Type[] typeArguments, Expression[] arguments, IDictionary namedArgs = null) - { - // the following piece of code is nicer and more readable than method recognition done in roslyn, C# dynamic and also expression evaluator :) - var method = FindValidMethodOveloads(target.Type, name, flags, typeArguments, arguments, namedArgs); - return Expression.Call(target, method.Method, method.Arguments); - } - - public static Expression CallMethod(Type target, BindingFlags flags, string name, Type[] typeArguments, Expression[] arguments, IDictionary namedArgs = null) - { - // the following piece of code is nicer and more readable than method recognition done in roslyn, C# dynamic and also expression evaluator :) - var method = FindValidMethodOveloads(target, name, flags, typeArguments, arguments, namedArgs); - return Expression.Call(method.Method, method.Arguments); - } - - - private static MethodRecognitionResult FindValidMethodOveloads(Type type, string name, BindingFlags flags, Type[] typeArguments, Expression[] arguments, IDictionary namedArgs) - { - var methods = FindValidMethodOveloads(type.GetAllMembers(flags).OfType().Where(m => m.Name == name), typeArguments, arguments, namedArgs).ToList(); - - if (methods.Count == 1) return methods.FirstOrDefault(); - if (methods.Count == 0) throw new InvalidOperationException($"Could not find overload of method '{name}'."); - else - { - methods = methods.OrderBy(s => s.CastCount).ThenBy(s => s.AutomaticTypeArgCount).ToList(); - var method = methods.FirstOrDefault(); - var method2 = methods.Skip(1).FirstOrDefault(); - if (method.AutomaticTypeArgCount == method2.AutomaticTypeArgCount && method.CastCount == method2.CastCount) - { - // TODO: this behavior is not completed. Implement the same behavior as in roslyn. - throw new InvalidOperationException($"Found ambiguous overloads of method '{name}'."); - } - return method; - } - } - - private static IEnumerable FindValidMethodOveloads(IEnumerable methods, Type[] typeArguments, Expression[] arguments, IDictionary namedArgs) - => from m in methods - let r = TryCallMethod(m, typeArguments, arguments, namedArgs) - where r != null - orderby r.CastCount descending, r.AutomaticTypeArgCount - select r; - - class MethodRecognitionResult - { - public int AutomaticTypeArgCount { get; set; } - public int CastCount { get; set; } - public Expression[] Arguments { get; set; } - public MethodInfo Method { get; set; } - } - - private static MethodRecognitionResult TryCallMethod(MethodInfo method, Type[] typeArguments, Expression[] positionalArguments, IDictionary namedArguments) - { - var parameters = method.GetParameters(); - - int castCount = 0; - if (parameters.Length < positionalArguments.Length) return null; - var args = new Expression[parameters.Length]; - Array.Copy(positionalArguments, args, positionalArguments.Length); - int namedArgCount = 0; - for (int i = positionalArguments.Length; i < args.Length; i++) - { - if (namedArguments?.ContainsKey(parameters[i].Name) == true) - { - args[i] = namedArguments[parameters[i].Name]; - namedArgCount++; - } - else if (parameters[i].HasDefaultValue) - { - castCount++; - args[i] = Expression.Constant(parameters[i].DefaultValue, parameters[i].ParameterType); - } - else return null; - } - - // some named arguments were not used - if (namedArguments != null && namedArgCount != namedArguments.Count) return null; - - int automaticTypeArgs = 0; - // resolve generic parameters - if (method.ContainsGenericParameters) - { - var genericArguments = method.GetGenericArguments(); - var typeArgs = new Type[genericArguments.Length]; - if (typeArguments != null) - { - if (typeArguments.Length > typeArgs.Length) return null; - Array.Copy(typeArguments, typeArgs, typeArgs.Length); - } - for (int genericArgumentPosition = 0; genericArgumentPosition < typeArgs.Length; genericArgumentPosition++) - { - if (typeArgs[genericArgumentPosition] == null) - { - // try to resolve from arguments - var argType = GetGenericParameterType(genericArguments[genericArgumentPosition], parameters.Select(s => s.ParameterType).ToArray(), args.Select(s => s.Type).ToArray()); - automaticTypeArgs++; - if (argType != null) typeArgs[genericArgumentPosition] = argType; - else return null; - } - } - method = method.MakeGenericMethod(typeArgs); - parameters = method.GetParameters(); - } - else if (typeArguments != null) return null; - - // cast arguments - for (int i = 0; i < args.Length; i++) - { - var casted = TypeConversion.ImplicitConversion(args[i], parameters[i].ParameterType); - if (casted == null) return null; - if (casted != args[i]) - { - castCount++; - args[i] = casted; - } - } - - return new MethodRecognitionResult { - CastCount = castCount, - AutomaticTypeArgCount = automaticTypeArgs, - Method = method, - Arguments = args - }; - } - - private static Type GetGenericParameterType(Type genericArg, Type[] searchedGenericTypes, Type[] expressionTypes) - { - for (var i = 0; i < searchedGenericTypes.Length; i++) - { - if (expressionTypes.Length <= i) return null; - var sgt = searchedGenericTypes[i]; - if (sgt == genericArg) - { - return expressionTypes[i]; - } - if (sgt.IsArray) - { - var elementType = sgt.GetElementType(); - var expressionElementType = expressionTypes[i].GetElementType(); - if (elementType == genericArg) - return expressionElementType; - else - return GetGenericParameterType(genericArg, searchedGenericTypes[i].GetGenericArguments(), expressionTypes[i].GetGenericArguments()); - } - else if (sgt.IsGenericType) - { - Type[] genericArguments; - var expression = expressionTypes[i]; - - // Arrays need to be handled in a special way to obtain instantiation - if (expression.IsArray) - genericArguments = new[] { expression.GetElementType() }; - else - genericArguments = expression.GetGenericArguments(); - - var value = GetGenericParameterType(genericArg, sgt.GetGenericArguments(), genericArguments); - if (value is Type) return value; - } - } - return null; - } - - public static Expression EqualsMethod(Expression left, Expression right) - { - Expression equatable = null; - Expression theOther = null; - if (typeof(IEquatable<>).IsAssignableFrom(left.Type)) - { - equatable = left; - theOther = right; - } - else if (typeof(IEquatable<>).IsAssignableFrom(right.Type)) - { - equatable = right; - theOther = left; - } - - if (equatable != null) - { - var m = CallMethod(equatable, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.FlattenHierarchy, "Equals", null, new[] { theOther }); - if (m != null) return m; - } - - if (left.Type.GetTypeInfo().IsValueType) - { - equatable = left; - theOther = right; - } - else if (left.Type.GetTypeInfo().IsValueType) - { - equatable = right; - theOther = left; - } - - if (equatable != null) - { - theOther = TypeConversion.ImplicitConversion(theOther, equatable.Type); - if (theOther != null) return Expression.Equal(equatable, theOther); - } - - return CallMethod(left, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.FlattenHierarchy, "Equals", null, new[] { right }); - } - - public static Expression CompareMethod(Expression left, Expression right) - { - Type compareType = typeof(object); - Expression equatable = null; - Expression theOther = null; - if (typeof(IComparable<>).IsAssignableFrom(left.Type)) - { - equatable = left; - theOther = right; - } - else if (typeof(IComparable<>).IsAssignableFrom(right.Type)) - { - equatable = right; - theOther = left; - } - else if (typeof(IComparable).IsAssignableFrom(left.Type)) - { - equatable = left; - theOther = right; - } - else if (typeof(IComparable).IsAssignableFrom(right.Type)) - { - equatable = right; - theOther = left; - } - - if (equatable != null) - { - return CallMethod(equatable, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.FlattenHierarchy, "Compare", null, new[] { theOther }); - } - throw new NotSupportedException("IComparable is not implemented on any of specified types"); - } - public static Expression RewriteTaskSequence(Expression left, Expression right) { // if the left side is a task, make the right side also a task and join them @@ -412,44 +53,9 @@ public static Expression RewriteTaskSequence(Expression left, Expression right) } } - public static Expression UnwrapNullable(this Expression expression) => expression.Type.IsNullable() ? Expression.Property(expression, "Value") : expression; - public static Expression GetBinaryOperator(Expression left, Expression right, ExpressionType operation) - { - if (operation == ExpressionType.Coalesce) return Expression.Coalesce(left, right); - if (operation == ExpressionType.Assign) - { - return Expression.Assign(left, TypeConversion.ImplicitConversion(right, left.Type, true, true)); - } - - // TODO: type conversions - if (operation == ExpressionType.AndAlso) return Expression.AndAlso(left, right); - else if (operation == ExpressionType.OrElse) return Expression.OrElse(left, right); - - var binder = (DynamicMetaObjectBinder)Microsoft.CSharp.RuntimeBinder.Binder.BinaryOperation( - CSharpBinderFlags.None, operation, typeof(object), GetBinderArguments(2)); - var result = ApplyBinder(binder, false, left, right); - if (result != null) return result; - if (operation == ExpressionType.Equal) return EqualsMethod(left, right); - if (operation == ExpressionType.NotEqual) return Expression.Not(EqualsMethod(left, right)); - - // lift the operator - if (left.Type.IsNullable() || right.Type.IsNullable()) - return GetBinaryOperator(left.UnwrapNullable(), right.UnwrapNullable(), operation); - - throw new Exception($"could not apply { operation } binary operator to { left } and { right }"); - // TODO: comparison operators - } - - public static Expression GetUnaryOperator(Expression expr, ExpressionType operation) - { - var binder = (DynamicMetaObjectBinder)Microsoft.CSharp.RuntimeBinder.Binder.UnaryOperation( - CSharpBinderFlags.None, operation, typeof(object), GetBinderArguments(1)); - return ApplyBinder(binder, true, expr); - } - public static Expression GetIndexer(Expression expr, Expression index) { if (expr.Type.IsArray) return Expression.ArrayIndex(expr, index); @@ -498,7 +104,5 @@ public static Expression ApplyBinder(DynamicMetaObjectBinder binder, bool throwE } return result.Expression; } - - public sealed class UnknownTypeSentinel { } } } diff --git a/src/DotVVM.Framework/Compilation/Binding/GeneralBindingPropertyResolvers.cs b/src/DotVVM.Framework/Compilation/Binding/GeneralBindingPropertyResolvers.cs index 08fc216de9..67d4f0a5db 100644 --- a/src/DotVVM.Framework/Compilation/Binding/GeneralBindingPropertyResolvers.cs +++ b/src/DotVVM.Framework/Compilation/Binding/GeneralBindingPropertyResolvers.cs @@ -28,13 +28,15 @@ public class BindingPropertyResolvers private readonly IBindingExpressionBuilder bindingParser; private readonly StaticCommandBindingCompiler staticCommandBindingCompiler; private readonly JavascriptTranslator javascriptTranslator; + private readonly MemberExpressionFactory memberExpressionFactory; - public BindingPropertyResolvers(IBindingExpressionBuilder bindingParser, StaticCommandBindingCompiler staticCommandBindingCompiler, JavascriptTranslator javascriptTranslator, DotvvmConfiguration configuration) + public BindingPropertyResolvers(IBindingExpressionBuilder bindingParser, StaticCommandBindingCompiler staticCommandBindingCompiler, JavascriptTranslator javascriptTranslator, DotvvmConfiguration configuration, MemberExpressionFactory memberExpressionFactory) { this.configuration = configuration; this.bindingParser = bindingParser; this.staticCommandBindingCompiler = staticCommandBindingCompiler; this.javascriptTranslator = javascriptTranslator; + this.memberExpressionFactory = memberExpressionFactory; } public ActionFiltersBindingProperty GetActionFilters(ParsedExpressionBindingProperty parsedExpression) @@ -71,7 +73,7 @@ public Expression CompileToUpdateDelegate(ParsedExpressio { var valueParameter = Expression.Parameter(typeof(object), "value"); var body = BindingCompiler.ReplaceParameters(binding.Expression, dataContext); - body = ExpressionHelper.UpdateMember(body, valueParameter); + body = memberExpressionFactory.UpdateMember(body, valueParameter); if (body == null) { return null; diff --git a/src/DotVVM.Framework/Compilation/Binding/IExtensionsProvider.cs b/src/DotVVM.Framework/Compilation/Binding/IExtensionsProvider.cs new file mode 100644 index 0000000000..877ceb2a3a --- /dev/null +++ b/src/DotVVM.Framework/Compilation/Binding/IExtensionsProvider.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text; +using System.Threading.Tasks; + +namespace DotVVM.Framework.Compilation.Binding +{ + public interface IExtensionsProvider + { + IEnumerable GetExtensionMethods(); + } +} diff --git a/src/DotVVM.Framework/Compilation/Binding/MemberExpressionFactory.cs b/src/DotVVM.Framework/Compilation/Binding/MemberExpressionFactory.cs new file mode 100644 index 0000000000..77fc4944e6 --- /dev/null +++ b/src/DotVVM.Framework/Compilation/Binding/MemberExpressionFactory.cs @@ -0,0 +1,457 @@ +using System; +using System.Collections.Generic; +using System.Dynamic; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Text; +using System.Threading.Tasks; +using DotVVM.Framework.Binding; +using DotVVM.Framework.Controls; +using DotVVM.Framework.Utils; +using Microsoft.CSharp.RuntimeBinder; +using Microsoft.Extensions.DependencyInjection; + +namespace DotVVM.Framework.Compilation.Binding +{ + public class MemberExpressionFactory + { + private readonly IExtensionsProvider extensionsProvider; + + public MemberExpressionFactory(IServiceProvider serviceProvider) + { + extensionsProvider = serviceProvider.GetService(); + if (extensionsProvider == null) + extensionsProvider = new DefaultExtensionsProvider(); + } + + public Expression GetMember(Expression target, string name, Type[] typeArguments = null, bool throwExceptions = true, bool onlyMemberTypes = false) + { + if (target is MethodGroupExpression) + throw new Exception("Can not access member on method group."); + + var type = target.Type; + if (type == typeof(UnknownTypeSentinel)) if (throwExceptions) throw new Exception($"Type of '{target}' could not be resolved."); else return null; + + var isStatic = target is StaticClassIdentifierExpression; + + var isGeneric = typeArguments != null && typeArguments.Length != 0; + var genericName = isGeneric ? $"{name}`{typeArguments.Length}" : name; + + if (!isGeneric && !onlyMemberTypes && typeof(DotvvmBindableObject).IsAssignableFrom(target.Type) && + GetDotvvmPropertyMember(target, name) is Expression result) return result; + + var members = type.GetAllMembers(BindingFlags.Public | (isStatic ? BindingFlags.Static : BindingFlags.Instance)) + .Where(m => ((isGeneric && m is TypeInfo) ? genericName : name) == m.Name) + .ToArray(); + + if (members.Length == 0) + { + // We did not find any match in regular methods => try extension methods + var extensions = extensionsProvider.GetExtensionMethods() + .Where(m => m.Name == name).ToArray(); + members = extensions; + + if (members.Length == 0 && throwExceptions) + throw new Exception($"Could not find { (isStatic ? "static" : "instance") } member { name } on type { type.FullName }."); + else if (members.Length == 0 && !throwExceptions) + return null; + } + if (members.Length == 1) + { + if (!(members[0] is TypeInfo) && onlyMemberTypes) { throw new Exception("Only type names are supported."); } + + var instance = isStatic ? null : target; + if (members[0] is PropertyInfo) + { + var property = members[0] as PropertyInfo; + return Expression.Property(instance, property); + } + else if (members[0] is FieldInfo) + { + var field = members[0] as FieldInfo; + return Expression.Field(instance, field); + } + else if (members[0] is TypeInfo) + { + var nonGenericType = (TypeInfo)members[0]; + return isGeneric + ? new StaticClassIdentifierExpression(nonGenericType.MakeGenericType(typeArguments)) + : new StaticClassIdentifierExpression(nonGenericType.UnderlyingSystemType); + } + } + return new MethodGroupExpression() { MethodName = name, Target = target, TypeArgs = typeArguments }; + } + + private Expression GetDotvvmPropertyMember(Expression target, string name) + { + var property = DotvvmProperty.ResolveProperty(target.Type, name); + if (property == null) return null; + + var field = property.DeclaringType.GetField(property.Name + "Property", BindingFlags.Static | BindingFlags.Public); + if (field == null) return null; + + return Expression.Convert( + Expression.Call(target, "GetValue", Type.EmptyTypes, + Expression.Field(null, field), + Expression.Constant(true) + ), + property.PropertyType + ); + } + + /// + /// Creates an expression that updates the member inside with a + /// new . + /// + /// + /// Should contain a call to the + /// method, it will be + /// replaced with a + /// call. + /// + public Expression UpdateMember(Expression node, Expression value) + { + if ((node.NodeType == ExpressionType.MemberAccess + && node is MemberExpression member + && member.Member is PropertyInfo property + && property.CanWrite) + || node.NodeType == ExpressionType.Parameter + || node.NodeType == ExpressionType.Index) + { + return Expression.Assign(node, Expression.Convert(value, node.Type)); + } + + var current = node; + while (current.NodeType == ExpressionType.Convert + && current is UnaryExpression unary) + { + current = unary.Operand; + } + + if (current.NodeType == ExpressionType.Call + && current is MethodCallExpression call + && call.Method.DeclaringType == typeof(DotvvmBindableObject) + && call.Method.Name == nameof(DotvvmBindableObject.GetValue) + && call.Arguments.Count == 2 + && call.Arguments[0].Type == typeof(DotvvmProperty) + && call.Arguments[1].Type == typeof(bool)) + { + var propertyArgument = call.Arguments[0]; + var setValue = typeof(DotvvmBindableObject) + .GetMethod(nameof(DotvvmBindableObject.SetValue), + new[] { typeof(DotvvmProperty), typeof(object) }); + return Expression.Call(call.Object, setValue, propertyArgument, value); + } + + return null; + } + + public Expression Call(Expression target, Expression[] arguments) + { + if (target is MethodGroupExpression) + { + return ((MethodGroupExpression)target).CreateMethodCall(arguments, this); + } + return Expression.Invoke(target, arguments); + } + + public Expression CallMethod(Expression target, BindingFlags flags, string name, Type[] typeArguments, Expression[] arguments, IDictionary namedArgs = null) + { + // the following piece of code is nicer and more readable than method recognition done in roslyn, C# dynamic and also expression evaluator :) + var method = FindValidMethodOveloads(target, target.Type, name, flags, typeArguments, arguments, namedArgs); + + if (method.IsExtension) + { + // Change to a static call + var newArguments = new[] { target }.Concat(arguments); + return Expression.Call(method.Method, newArguments); + } + + return Expression.Call(target, method.Method, method.Arguments); + } + + public Expression CallMethod(Type target, BindingFlags flags, string name, Type[] typeArguments, Expression[] arguments, IDictionary namedArgs = null) + { + // the following piece of code is nicer and more readable than method recognition done in roslyn, C# dynamic and also expression evaluator :) + var method = FindValidMethodOveloads(null, target, name, flags, typeArguments, arguments, namedArgs); + return Expression.Call(method.Method, method.Arguments); + } + + + private MethodRecognitionResult FindValidMethodOveloads(Expression target, Type type, string name, BindingFlags flags, Type[] typeArguments, Expression[] arguments, IDictionary namedArgs) + { + var methods = FindValidMethodOveloads(type.GetAllMembers(flags).OfType().Where(m => m.Name == name), typeArguments, arguments, namedArgs).ToList(); + + if (methods.Count == 1) return methods.FirstOrDefault(); + if (methods.Count == 0) + { + // We did not find any match in regular methods => try extension methods + if (target != null) + { + // Change to a static call + var newArguments = new[] { target }.Concat(arguments).ToArray(); + var extensions = FindValidMethodOveloads(extensionsProvider.GetExtensionMethods().OfType().Where(m => m.Name == name), typeArguments, newArguments, namedArgs) + .Select(method => { method.IsExtension = true; return method; }).ToList(); + + // We found an extension method + if (extensions.Count == 1) + return extensions.FirstOrDefault(); + + target = null; + methods = extensions; + arguments = newArguments; + } + + if (methods.Count == 0) + throw new InvalidOperationException($"Could not find method overload nor extension method that matched '{name}'."); + } + + // There are multiple method candidates + methods = methods.OrderBy(s => s.CastCount).ThenBy(s => s.AutomaticTypeArgCount).ToList(); + var method = methods.FirstOrDefault(); + var method2 = methods.Skip(1).FirstOrDefault(); + if (method.AutomaticTypeArgCount == method2.AutomaticTypeArgCount && method.CastCount == method2.CastCount) + { + // TODO: this behavior is not completed. Implement the same behavior as in roslyn. + throw new InvalidOperationException($"Found ambiguous overloads of method '{name}'."); + } + return method; + } + + private IEnumerable FindValidMethodOveloads(IEnumerable methods, Type[] typeArguments, Expression[] arguments, IDictionary namedArgs) + => from m in methods + let r = TryCallMethod(m, typeArguments, arguments, namedArgs) + where r != null + orderby r.CastCount descending, r.AutomaticTypeArgCount + select r; + + class MethodRecognitionResult + { + public int AutomaticTypeArgCount { get; set; } + public int CastCount { get; set; } + public Expression[] Arguments { get; set; } + public MethodInfo Method { get; set; } + public bool IsExtension { get; set; } + } + + private MethodRecognitionResult TryCallMethod(MethodInfo method, Type[] typeArguments, Expression[] positionalArguments, IDictionary namedArguments) + { + var parameters = method.GetParameters(); + + int castCount = 0; + if (parameters.Length < positionalArguments.Length) return null; + var args = new Expression[parameters.Length]; + Array.Copy(positionalArguments, args, positionalArguments.Length); + int namedArgCount = 0; + for (int i = positionalArguments.Length; i < args.Length; i++) + { + if (namedArguments?.ContainsKey(parameters[i].Name) == true) + { + args[i] = namedArguments[parameters[i].Name]; + namedArgCount++; + } + else if (parameters[i].HasDefaultValue) + { + castCount++; + args[i] = Expression.Constant(parameters[i].DefaultValue, parameters[i].ParameterType); + } + else return null; + } + + // some named arguments were not used + if (namedArguments != null && namedArgCount != namedArguments.Count) return null; + + int automaticTypeArgs = 0; + // resolve generic parameters + if (method.ContainsGenericParameters) + { + var genericArguments = method.GetGenericArguments(); + var typeArgs = new Type[genericArguments.Length]; + if (typeArguments != null) + { + if (typeArguments.Length > typeArgs.Length) return null; + Array.Copy(typeArguments, typeArgs, typeArgs.Length); + } + for (int genericArgumentPosition = 0; genericArgumentPosition < typeArgs.Length; genericArgumentPosition++) + { + if (typeArgs[genericArgumentPosition] == null) + { + // try to resolve from arguments + var argType = GetGenericParameterType(genericArguments[genericArgumentPosition], parameters.Select(s => s.ParameterType).ToArray(), args.Select(s => s.Type).ToArray()); + automaticTypeArgs++; + if (argType != null) typeArgs[genericArgumentPosition] = argType; + else return null; + } + } + method = method.MakeGenericMethod(typeArgs); + parameters = method.GetParameters(); + } + else if (typeArguments != null) return null; + + // cast arguments + for (int i = 0; i < args.Length; i++) + { + var casted = TypeConversion.ImplicitConversion(args[i], parameters[i].ParameterType); + if (casted == null) return null; + if (casted != args[i]) + { + castCount++; + args[i] = casted; + } + } + + return new MethodRecognitionResult { + CastCount = castCount, + AutomaticTypeArgCount = automaticTypeArgs, + Method = method, + Arguments = args + }; + } + + private Type GetGenericParameterType(Type genericArg, Type[] searchedGenericTypes, Type[] expressionTypes) + { + for (var i = 0; i < searchedGenericTypes.Length; i++) + { + if (expressionTypes.Length <= i) return null; + var sgt = searchedGenericTypes[i]; + if (sgt == genericArg) + { + return expressionTypes[i]; + } + if (sgt.IsArray) + { + var elementType = sgt.GetElementType(); + var expressionElementType = expressionTypes[i].GetElementType(); + if (elementType == genericArg) + return expressionElementType; + else + return GetGenericParameterType(genericArg, searchedGenericTypes[i].GetGenericArguments(), expressionTypes[i].GetGenericArguments()); + } + else if (sgt.IsGenericType) + { + Type[] genericArguments; + var expression = expressionTypes[i]; + + // Arrays need to be handled in a special way to obtain instantiation + if (expression.IsArray) + genericArguments = new[] { expression.GetElementType() }; + else + genericArguments = expression.GetGenericArguments(); + + var value = GetGenericParameterType(genericArg, sgt.GetGenericArguments(), genericArguments); + if (value is Type) return value; + } + } + return null; + } + + public Expression EqualsMethod(Expression left, Expression right) + { + Expression equatable = null; + Expression theOther = null; + if (typeof(IEquatable<>).IsAssignableFrom(left.Type)) + { + equatable = left; + theOther = right; + } + else if (typeof(IEquatable<>).IsAssignableFrom(right.Type)) + { + equatable = right; + theOther = left; + } + + if (equatable != null) + { + var m = CallMethod(equatable, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.FlattenHierarchy, "Equals", null, new[] { theOther }); + if (m != null) return m; + } + + if (left.Type.GetTypeInfo().IsValueType) + { + equatable = left; + theOther = right; + } + else if (left.Type.GetTypeInfo().IsValueType) + { + equatable = right; + theOther = left; + } + + if (equatable != null) + { + theOther = TypeConversion.ImplicitConversion(theOther, equatable.Type); + if (theOther != null) return Expression.Equal(equatable, theOther); + } + + return CallMethod(left, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.FlattenHierarchy, "Equals", null, new[] { right }); + } + + public Expression CompareMethod(Expression left, Expression right) + { + Type compareType = typeof(object); + Expression equatable = null; + Expression theOther = null; + if (typeof(IComparable<>).IsAssignableFrom(left.Type)) + { + equatable = left; + theOther = right; + } + else if (typeof(IComparable<>).IsAssignableFrom(right.Type)) + { + equatable = right; + theOther = left; + } + else if (typeof(IComparable).IsAssignableFrom(left.Type)) + { + equatable = left; + theOther = right; + } + else if (typeof(IComparable).IsAssignableFrom(right.Type)) + { + equatable = right; + theOther = left; + } + + if (equatable != null) + { + return CallMethod(equatable, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.FlattenHierarchy, "Compare", null, new[] { theOther }); + } + throw new NotSupportedException("IComparable is not implemented on any of specified types"); + } + + public Expression GetUnaryOperator(Expression expr, ExpressionType operation) + { + var binder = (DynamicMetaObjectBinder)Microsoft.CSharp.RuntimeBinder.Binder.UnaryOperation( + CSharpBinderFlags.None, operation, typeof(object), ExpressionHelper.GetBinderArguments(1)); + return ExpressionHelper.ApplyBinder(binder, true, expr); + } + + public Expression GetBinaryOperator(Expression left, Expression right, ExpressionType operation) + { + if (operation == ExpressionType.Coalesce) return Expression.Coalesce(left, right); + if (operation == ExpressionType.Assign) + { + return Expression.Assign(left, TypeConversion.ImplicitConversion(right, left.Type, true, true)); + } + + // TODO: type conversions + if (operation == ExpressionType.AndAlso) return Expression.AndAlso(left, right); + else if (operation == ExpressionType.OrElse) return Expression.OrElse(left, right); + + var binder = (DynamicMetaObjectBinder)Microsoft.CSharp.RuntimeBinder.Binder.BinaryOperation( + CSharpBinderFlags.None, operation, typeof(object), ExpressionHelper.GetBinderArguments(2)); + var result = ExpressionHelper.ApplyBinder(binder, false, left, right); + if (result != null) return result; + if (operation == ExpressionType.Equal) return EqualsMethod(left, right); + if (operation == ExpressionType.NotEqual) return Expression.Not(EqualsMethod(left, right)); + + // lift the operator + if (left.Type.IsNullable() || right.Type.IsNullable()) + return GetBinaryOperator(left.UnwrapNullable(), right.UnwrapNullable(), operation); + + throw new Exception($"could not apply { operation } binary operator to { left } and { right }"); + // TODO: comparison operators + } + } +} diff --git a/src/DotVVM.Framework/Compilation/Binding/MethodGroupExpression.cs b/src/DotVVM.Framework/Compilation/Binding/MethodGroupExpression.cs index 8dca9c2793..9055ec7213 100644 --- a/src/DotVVM.Framework/Compilation/Binding/MethodGroupExpression.cs +++ b/src/DotVVM.Framework/Compilation/Binding/MethodGroupExpression.cs @@ -72,16 +72,16 @@ public Expression CreateDelegateExpression() return Expression.Call(CreateDelegateMethodInfo, Expression.Constant(delegateType), Target, Expression.Constant(methodInfo)) .Apply(e => Expression.Convert(e, delegateType)); } - public Expression CreateMethodCall(IEnumerable args) + public Expression CreateMethodCall(IEnumerable args, MemberExpressionFactory memberExpressionFactory) { var argsArray = args.ToArray(); if (IsStatic) { - return ExpressionHelper.CallMethod((Target as StaticClassIdentifierExpression).Type, BindingFlags.Static | BindingFlags.Public | BindingFlags.FlattenHierarchy, MethodName, TypeArgs, argsArray); + return memberExpressionFactory.CallMethod((Target as StaticClassIdentifierExpression).Type, BindingFlags.Static | BindingFlags.Public | BindingFlags.FlattenHierarchy, MethodName, TypeArgs, argsArray); } else { - return ExpressionHelper.CallMethod(Target, BindingFlags.Instance | BindingFlags.Public | BindingFlags.FlattenHierarchy, MethodName, TypeArgs, argsArray); + return memberExpressionFactory.CallMethod(Target, BindingFlags.Instance | BindingFlags.Public | BindingFlags.FlattenHierarchy, MethodName, TypeArgs, argsArray); } } diff --git a/src/DotVVM.Framework/Compilation/Binding/UnknownTypeSentinel.cs b/src/DotVVM.Framework/Compilation/Binding/UnknownTypeSentinel.cs new file mode 100644 index 0000000000..8a7af3a60c --- /dev/null +++ b/src/DotVVM.Framework/Compilation/Binding/UnknownTypeSentinel.cs @@ -0,0 +1,12 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace DotVVM.Framework.Compilation.Binding +{ + internal sealed class UnknownTypeSentinel + { + } +} diff --git a/src/DotVVM.Framework/Compilation/ControlTree/DefaultControlTreeResolver.cs b/src/DotVVM.Framework/Compilation/ControlTree/DefaultControlTreeResolver.cs index a257960cf7..3c2b880f39 100644 --- a/src/DotVVM.Framework/Compilation/ControlTree/DefaultControlTreeResolver.cs +++ b/src/DotVVM.Framework/Compilation/ControlTree/DefaultControlTreeResolver.cs @@ -43,7 +43,7 @@ protected override IDataContextStack CreateDataContextTypeStack(ITypeDescriptor? { return DataContextStack.Create( - ResolvedTypeDescriptor.ToSystemType(viewModelType) ?? typeof(ExpressionHelper.UnknownTypeSentinel), + ResolvedTypeDescriptor.ToSystemType(viewModelType) ?? typeof(UnknownTypeSentinel), parentDataContextStack as DataContextStack, namespaceImports, extensionParameters); } diff --git a/src/DotVVM.Framework/Compilation/ControlTree/Resolved/ResolvedTreeBuilder.cs b/src/DotVVM.Framework/Compilation/ControlTree/Resolved/ResolvedTreeBuilder.cs index 9329e5e6d1..89ce337982 100644 --- a/src/DotVVM.Framework/Compilation/ControlTree/Resolved/ResolvedTreeBuilder.cs +++ b/src/DotVVM.Framework/Compilation/ControlTree/Resolved/ResolvedTreeBuilder.cs @@ -17,11 +17,13 @@ public class ResolvedTreeBuilder : IAbstractTreeBuilder { private readonly BindingCompilationService bindingService; private readonly CompiledAssemblyCache compiledAssemblyCache; + private readonly MemberExpressionFactory memberExpressionFactory; - public ResolvedTreeBuilder(BindingCompilationService bindingService, CompiledAssemblyCache compiledAssemblyCache) + public ResolvedTreeBuilder(BindingCompilationService bindingService, CompiledAssemblyCache compiledAssemblyCache, MemberExpressionFactory memberExpressionFactory) { this.bindingService = bindingService; this.compiledAssemblyCache = compiledAssemblyCache; + this.memberExpressionFactory = memberExpressionFactory; } public IAbstractTreeRoot BuildTreeRoot(IControlTreeResolver controlTreeResolver, IControlResolverMetadata metadata, DothtmlRootNode node, IDataContextStack dataContext, IReadOnlyDictionary> directives) @@ -162,7 +164,7 @@ public IAbstractBaseTypeDirective BuildBaseTypeDirective(DothtmlDirectiveNode di registry = TypeRegistry.DirectivesDefault(compiledAssemblyCache); } - var visitor = new ExpressionBuildingVisitor(registry) { + var visitor = new ExpressionBuildingVisitor(registry, memberExpressionFactory) { ResolveOnlyTypeName = true, Scope = null }; diff --git a/src/DotVVM.Framework/DependencyInjection/DotVVMServiceCollectionExtensions.cs b/src/DotVVM.Framework/DependencyInjection/DotVVMServiceCollectionExtensions.cs index 1b2f3628ae..9c70fca0f2 100644 --- a/src/DotVVM.Framework/DependencyInjection/DotVVMServiceCollectionExtensions.cs +++ b/src/DotVVM.Framework/DependencyInjection/DotVVMServiceCollectionExtensions.cs @@ -75,6 +75,7 @@ public static IServiceCollection RegisterDotVVMServices(IServiceCollection servi services.TryAddScoped(); services.TryAddScoped(); services.TryAddScoped(); + services.TryAddSingleton(); services.TryAddSingleton(s => DotvvmConfiguration.CreateDefault(s)); services.TryAddSingleton(s => s.GetRequiredService().Markup); services.TryAddSingleton(s => s.GetRequiredService().Resources); diff --git a/src/DotVVM.Framework/Utils/ReflectionUtils.cs b/src/DotVVM.Framework/Utils/ReflectionUtils.cs index 50744b47a6..6a0af27375 100644 --- a/src/DotVVM.Framework/Utils/ReflectionUtils.cs +++ b/src/DotVVM.Framework/Utils/ReflectionUtils.cs @@ -13,6 +13,7 @@ using System.Text; using System.Globalization; using System.Collections.Concurrent; +using DotVVM.Framework.Compilation.Binding; namespace DotVVM.Framework.Utils { @@ -63,7 +64,6 @@ public static IEnumerable GetAllMembers(this Type type, BindingFlags return type.GetMembers(flags); } - /// /// Gets filesystem path of assembly CodeBase /// http://stackoverflow.com/questions/52797/how-do-i-get-the-path-of-the-assembly-the-code-is-in diff --git a/src/DotVVM.Samples.Common/Views/FeatureSamples/LambdaExpressions/LambdaExpressions.dothtml b/src/DotVVM.Samples.Common/Views/FeatureSamples/LambdaExpressions/LambdaExpressions.dothtml index 528e754483..24f7e394be 100644 --- a/src/DotVVM.Samples.Common/Views/FeatureSamples/LambdaExpressions/LambdaExpressions.dothtml +++ b/src/DotVVM.Samples.Common/Views/FeatureSamples/LambdaExpressions/LambdaExpressions.dothtml @@ -24,12 +24,21 @@

Operations

+

+ Showcasing LINQ and JsTranslator
+ Note: you can use either explicit expressions (example: Enumerable.Where(Collection, ...)), + but also extension-method-like calls (example: Collection.Where(...)) +

+ + <%--Click="{command: SetResult(Enumerable.Where(Array, (int item) => item % 2 == 0))}" />--%> + Click="{command: SetResult(Array.Where((int item) => item % 2 == 0))}" /> + <%--Click="{command: SetResult(Enumerable.Where(Array, (int item) => item % 2 == 1))}" />--%> + Click="{command: SetResult(Array.Where((int item) => item % 2 == 1))}" /> + <%--Click="{command: SetResult(Enumerable.Select(Array, (int item) => -item))}" />--%> + Click="{command: SetResult(Array.Select((int item) => -item))}" />