diff --git a/src/BitzArt.LinqExtensions/Enums/FilterOperation.cs b/src/BitzArt.LinqExtensions/Enums/FilterOperation.cs new file mode 100644 index 0000000..b59b662 --- /dev/null +++ b/src/BitzArt.LinqExtensions/Enums/FilterOperation.cs @@ -0,0 +1,11 @@ +namespace System.Linq; + +public enum FilterOperation : byte +{ + Equal = 0, + NotEqual = 1, + GreaterThan = 2, + GreaterThanOrEqual = 3, + LessThan = 4, + LessThanOrEqual = 5 +} diff --git a/src/BitzArt.LinqExtensions/Extensions/AddFilterExtension.cs b/src/BitzArt.LinqExtensions/Extensions/AddFilterExtension.cs index efe7260..efb15e3 100644 --- a/src/BitzArt.LinqExtensions/Extensions/AddFilterExtension.cs +++ b/src/BitzArt.LinqExtensions/Extensions/AddFilterExtension.cs @@ -5,15 +5,15 @@ namespace System.Linq; public static class AddFilterExtension { - public static IQueryable AddFilter(this IQueryable source, Expression> expression, TProperty? filter) + public static IQueryable AddFilter(this IQueryable source, Expression> expression, TProperty? filter, FilterOperation filterOperation = FilterOperation.Equal) where TProperty : class { if (filter is null) return source; - return BuildExpression(source, filter, expression); + return BuildExpression(source, filter, expression, filterOperation); } - public static IQueryable AddFilter(this IQueryable source, Expression> expression, TProperty? filter) + public static IQueryable AddFilter(this IQueryable source, Expression> expression, TProperty? filter, FilterOperation filterOperation = FilterOperation.Equal) where TProperty : struct { if (filter is null) return source; @@ -22,15 +22,25 @@ public static IQueryable AddFilter(this IQueryable< Expression> getValueExpression = x => x!.Value; var valueExpression = expression.Compose(getValueExpression); - return BuildExpression(source, filterValue, valueExpression); + return BuildExpression(source, filterValue, valueExpression, filterOperation); } - private static IQueryable BuildExpression(IQueryable source, TProperty filter, Expression> expression) + private static IQueryable BuildExpression(IQueryable source, TProperty filter, Expression> expression, FilterOperation filterOperation) { var argument = Expression.Parameter(typeof(TSource)); var left = Expression.Invoke(expression, argument); var right = Expression.Constant(filter); - var eq = Expression.Equal(left, right); + + var eq = filterOperation switch + { + FilterOperation.Equal => Expression.Equal(left, right), + FilterOperation.NotEqual => Expression.NotEqual(left, right), + FilterOperation.GreaterThan => Expression.GreaterThan(left, right), + FilterOperation.GreaterThanOrEqual => Expression.GreaterThanOrEqual(left, right), + FilterOperation.LessThan => Expression.LessThan(left, right), + FilterOperation.LessThanOrEqual => Expression.LessThanOrEqual(left, right), + _ => throw new NotImplementedException($"Unsupported Filter Operation: '{filterOperation}'") + }; var lambda = Expression .Lambda>(eq, new[] { argument }); diff --git a/tests/BitzArt.LinqExtensions.Tests/AddFilterTests.cs b/tests/BitzArt.LinqExtensions.Tests/AddFilterTests.cs index c2e5f5b..2cadf06 100644 --- a/tests/BitzArt.LinqExtensions.Tests/AddFilterTests.cs +++ b/tests/BitzArt.LinqExtensions.Tests/AddFilterTests.cs @@ -69,5 +69,101 @@ public void StructFilter_FilterNull_DoesNotFilter() Assert.Equal(models, filtered); } + + [Theory] + [InlineData(1)] + [InlineData(2)] + [InlineData(3)] + [InlineData(4)] + [InlineData(5)] + [InlineData(6)] + [InlineData(7)] + [InlineData(8)] + [InlineData(9)] + [InlineData(10)] + public void NotEqual_Filter_Filters(int notEqualTo) + { + var range = Enumerable.Range(1, 10); + var models = range.Select(x => new TestModelStruct { Id = x }).ToList(); + var queryable = models.AsQueryable(); + + int? filter = notEqualTo; + + var filtered = queryable.AddFilter(x => x.Id, filter, FilterOperation.NotEqual).ToList(); + + Assert.Equal(9, filtered.Count); + + if (notEqualTo != 1) Assert.True(filtered.First().Id == 1); + else Assert.True(filtered.First().Id == 2); + + if (notEqualTo != 10) Assert.True(filtered.Last().Id == 10); + else Assert.True(filtered.Last().Id == 9); + + Assert.DoesNotContain(filtered, x => x.Id == filter); + } + + [Fact] + public void GreaterThan_Filter_Filters() + { + var range = Enumerable.Range(1, 10); + var models = range.Select(x => new TestModelStruct { Id = x }).ToList(); + var queryable = models.AsQueryable(); + + int? filter = 5; + + var filtered = queryable.AddFilter(x => x.Id, filter, FilterOperation.GreaterThan).ToList(); + + Assert.Equal(5, filtered.Count); + Assert.True(filtered.First().Id == 6); + Assert.True(filtered.Last().Id == 10); + } + + [Fact] + public void GreaterThanOrEqual_Filter_Filters() + { + var range = Enumerable.Range(1, 10); + var models = range.Select(x => new TestModelStruct { Id = x }).ToList(); + var queryable = models.AsQueryable(); + + int? filter = 5; + + var filtered = queryable.AddFilter(x => x.Id, filter, FilterOperation.GreaterThanOrEqual).ToList(); + + Assert.Equal(6, filtered.Count); + Assert.True(filtered.First().Id == 5); + Assert.True(filtered.Last().Id == 10); + } + + [Fact] + public void LessThan_Filter_Filters() + { + var range = Enumerable.Range(1, 10); + var models = range.Select(x => new TestModelStruct { Id = x }).ToList(); + var queryable = models.AsQueryable(); + + int? filter = 5; + + var filtered = queryable.AddFilter(x => x.Id, filter, FilterOperation.LessThan).ToList(); + + Assert.Equal(4, filtered.Count); + Assert.True(filtered.First().Id == 1); + Assert.True(filtered.Last().Id == 4); + } + + [Fact] + public void LessThanOrEqual_Filter_Filters() + { + var range = Enumerable.Range(1, 10); + var models = range.Select(x => new TestModelStruct { Id = x }).ToList(); + var queryable = models.AsQueryable(); + + int? filter = 5; + + var filtered = queryable.AddFilter(x => x.Id, filter, FilterOperation.LessThanOrEqual).ToList(); + + Assert.Equal(5, filtered.Count); + Assert.True(filtered.First().Id == 1); + Assert.True(filtered.Last().Id == 5); + } } } \ No newline at end of file