From 4910703e7e51ffc817b308168d2ec2efbedfbe75 Mon Sep 17 00:00:00 2001 From: Atif Aziz Date: Thu, 26 Jan 2023 17:58:06 +0100 Subject: [PATCH] Fix collection-optimized paths to be at iteration-time --- MoreLinq.Test/AssertCountTest.cs | 18 +++++---- MoreLinq.Test/CountDownTest.cs | 9 +++++ MoreLinq.Test/FallbackIfEmptyTest.cs | 46 +++++++++++------------ MoreLinq.Test/PadStartTest.cs | 8 ++++ MoreLinq.Test/SkipLastTest.cs | 10 +++++ MoreLinq.Test/TakeLastTest.cs | 9 +++++ MoreLinq/AssertCount.cs | 13 ++++--- MoreLinq/CollectionLike.cs | 52 ++++++++++++++++++++++++++ MoreLinq/CountDown.cs | 9 +++-- MoreLinq/CountMethods.cs | 8 ++-- MoreLinq/EndsWith.cs | 4 +- MoreLinq/Experimental/TrySingle.cs | 6 +-- MoreLinq/FallbackIfEmpty.cs | 24 +++++------- MoreLinq/MoreEnumerable.cs | 6 +-- MoreLinq/PadStart.cs | 56 +++++++++++++++------------- MoreLinq/SkipLast.cs | 1 - MoreLinq/StartsWith.cs | 4 +- MoreLinq/TakeLast.cs | 8 ++-- 18 files changed, 189 insertions(+), 102 deletions(-) create mode 100644 MoreLinq/CollectionLike.cs diff --git a/MoreLinq.Test/AssertCountTest.cs b/MoreLinq.Test/AssertCountTest.cs index 73aee530f..b8dfe839d 100644 --- a/MoreLinq.Test/AssertCountTest.cs +++ b/MoreLinq.Test/AssertCountTest.cs @@ -18,6 +18,7 @@ namespace MoreLinq.Test { using System; + using System.Collections.Generic; using NUnit.Framework; [TestFixture] @@ -113,13 +114,6 @@ public void AssertCountWithCollectionIsLazy() _ = new BreakingCollection(new int[5]).AssertCount(0); } - [Test] - public void AssertCountWithMatchingCollectionCount() - { - var xs = new[] { 123, 456, 789 }; - Assert.That(xs, Is.SameAs(xs.AssertCount(3))); - } - [TestCase(3, 2, "Sequence contains too many elements when exactly 2 were expected.")] [TestCase(3, 4, "Sequence contains too few elements when exactly 4 were expected.")] public void AssertCountWithMismatchingCollectionCount(int sourceCount, int count, string message) @@ -135,5 +129,15 @@ public void AssertCountWithReadOnlyCollectionIsLazy() { _ = new BreakingReadOnlyCollection(5).AssertCount(0); } + + [Test] + public void AssertCountUsesCollectionCountAtIterationTime() + { + var stack = new Stack(Enumerable.Range(1, 3)); + var result = stack.AssertCount(4); + stack.Push(4); + result.Consume(); + Assert.Pass(); + } } } diff --git a/MoreLinq.Test/CountDownTest.cs b/MoreLinq.Test/CountDownTest.cs index 5d047e177..b7c635667 100644 --- a/MoreLinq.Test/CountDownTest.cs +++ b/MoreLinq.Test/CountDownTest.cs @@ -211,5 +211,14 @@ public ReadOnlyCollection(ICollection collection, protected override IEnumerable Items => _collection; } } + + [Test] + public void UsesCollectionCountAtIterationTime() + { + var stack = new Stack(Enumerable.Range(1, 3)); + var result = stack.CountDown(2, (_, cd) => cd); + stack.Push(4); + result.AssertSequenceEqual(null, null, 1, 0); + } } } diff --git a/MoreLinq.Test/FallbackIfEmptyTest.cs b/MoreLinq.Test/FallbackIfEmptyTest.cs index 65ccfc788..451484606 100644 --- a/MoreLinq.Test/FallbackIfEmptyTest.cs +++ b/MoreLinq.Test/FallbackIfEmptyTest.cs @@ -17,6 +17,7 @@ namespace MoreLinq.Test { + using System.Collections.Generic; using NUnit.Framework; [TestFixture] @@ -36,31 +37,6 @@ public void FallbackIfEmptyWithEmptySequence() // ReSharper restore PossibleMultipleEnumeration } - [TestCase(SourceKind.BreakingCollection)] - [TestCase(SourceKind.BreakingReadOnlyCollection)] - public void FallbackIfEmptyPreservesSourceCollectionIfPossible(SourceKind sourceKind) - { - var source = new[] { 1 }.ToSourceKind(sourceKind); - // ReSharper disable PossibleMultipleEnumeration - Assert.That(source.FallbackIfEmpty(12), Is.SameAs(source)); - Assert.That(source.FallbackIfEmpty(12, 23), Is.SameAs(source)); - Assert.That(source.FallbackIfEmpty(12, 23, 34), Is.SameAs(source)); - Assert.That(source.FallbackIfEmpty(12, 23, 34, 45), Is.SameAs(source)); - Assert.That(source.FallbackIfEmpty(12, 23, 34, 45, 56), Is.SameAs(source)); - Assert.That(source.FallbackIfEmpty(12, 23, 34, 45, 56, 67), Is.SameAs(source)); - // ReSharper restore PossibleMultipleEnumeration - } - - [TestCase(SourceKind.BreakingCollection)] - [TestCase(SourceKind.BreakingReadOnlyCollection)] - public void FallbackIfEmptyPreservesFallbackCollectionIfPossible(SourceKind sourceKind) - { - var source = new int[0].ToSourceKind(sourceKind); - var fallback = new[] { 1 }; - Assert.That(source.FallbackIfEmpty(fallback), Is.SameAs(fallback)); - Assert.That(source.FallbackIfEmpty(fallback.AsEnumerable()), Is.SameAs(fallback)); - } - [Test] public void FallbackIfEmptyWithEmptyNullableSequence() { @@ -68,5 +44,25 @@ public void FallbackIfEmptyWithEmptyNullableSequence() var fallback = (int?)null; source.FallbackIfEmpty(fallback).AssertSequenceEqual(fallback); } + + [Test] + public void FallbackUsesCollectionCountAtIterationTime() + { + var source = new List(); + + var results = new[] + { + source.FallbackIfEmpty(-1), + source.FallbackIfEmpty(-1, -2), + source.FallbackIfEmpty(-1, -2, -3), + source.FallbackIfEmpty(-1, -2, -3, -4), + source.FallbackIfEmpty(-1, -2, -3, -4, -5), + }; + + source.Add(123); + + foreach (var result in results) + result.AssertSequenceEqual(123); + } } } diff --git a/MoreLinq.Test/PadStartTest.cs b/MoreLinq.Test/PadStartTest.cs index 7e504f5f4..78d45ef4d 100644 --- a/MoreLinq.Test/PadStartTest.cs +++ b/MoreLinq.Test/PadStartTest.cs @@ -133,6 +133,14 @@ public void ReferenceTypeElements(ICollection source, int width, IEnumer } } + [Test] + public void PadStartUsesCollectionCountAtIterationTime() + { + var queue = new Queue(Enumerable.Range(1, 3)); + var result = queue.PadStart(4, -1); + queue.Enqueue(4); + result.AssertSequenceEqual(1, 2, 3, 4); + } static void AssertEqual(ICollection input, Func, IEnumerable> op, IEnumerable expected) { diff --git a/MoreLinq.Test/SkipLastTest.cs b/MoreLinq.Test/SkipLastTest.cs index dfdbad739..df7b19680 100644 --- a/MoreLinq.Test/SkipLastTest.cs +++ b/MoreLinq.Test/SkipLastTest.cs @@ -17,6 +17,7 @@ namespace MoreLinq.Test { + using System.Collections.Generic; using NUnit.Framework; [TestFixture] @@ -56,5 +57,14 @@ public void SkipLastIsLazy() { _ = new BreakingSequence().SkipLast(1); } + + [Test] + public void SkipLastUsesCollectionCountAtIterationTime() + { + var list = new List { 1, 2, 3, 4 }; + var result = list.SkipLast(2); + list.Add(5); + result.AssertSequenceEqual(1, 2, 3); + } } } diff --git a/MoreLinq.Test/TakeLastTest.cs b/MoreLinq.Test/TakeLastTest.cs index fdbd8f923..5779f4cfc 100644 --- a/MoreLinq.Test/TakeLastTest.cs +++ b/MoreLinq.Test/TakeLastTest.cs @@ -70,6 +70,15 @@ public void TakeLastOptimizedForCollections(SourceKind sourceKind) sequence.TakeLast(3).AssertSequenceEqual(8, 9, 10); } + [Test] + public void TakeLastUsesCollectionCountAtIterationTime() + { + var list = new List { 1, 2, 3, 4 }; + var result = list.TakeLast(3); + list.Add(5); + result.AssertSequenceEqual(3, 4, 5); + } + static void AssertTakeLast(ICollection input, int count, Action> action) { // Test that the behaviour does not change whether a collection diff --git a/MoreLinq/AssertCount.cs b/MoreLinq/AssertCount.cs index 916400abf..0273bf060 100644 --- a/MoreLinq/AssertCount.cs +++ b/MoreLinq/AssertCount.cs @@ -86,13 +86,14 @@ static IEnumerable AssertCountImpl(IEnumerable source if (count < 0) throw new ArgumentOutOfRangeException(nameof(count)); if (errorSelector == null) throw new ArgumentNullException(nameof(errorSelector)); - return - source.TryGetCollectionCount() is { } collectionCount - ? collectionCount == count - ? source - : From(() => throw errorSelector(collectionCount.CompareTo(count), count)) - : _(); IEnumerable _() + return _(); IEnumerable _() { + if (source.TryAsCollectionLike() is { Count: var collectionCount } + && collectionCount.CompareTo(count) is var comparison && comparison != 0) + { + throw errorSelector(comparison, count); + } + var iterations = 0; foreach (var element in source) { diff --git a/MoreLinq/CollectionLike.cs b/MoreLinq/CollectionLike.cs new file mode 100644 index 000000000..16286cdb3 --- /dev/null +++ b/MoreLinq/CollectionLike.cs @@ -0,0 +1,52 @@ +#region License and Terms +// MoreLINQ - Extensions to LINQ to Objects +// Copyright (c) 2023 Atif Aziz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#endregion + +namespace MoreLinq +{ + using System; + using System.Collections.Generic; + using System.Linq; + + /// + /// Represents a union over list types implementing either or , + /// allowing both to be treated the same. + /// + + readonly struct CollectionLike + { + readonly ICollection? _rw; + readonly IReadOnlyCollection? _rx; + + public CollectionLike(ICollection collection) + { + _rw = collection ?? throw new ArgumentNullException(nameof(collection)); + _rx = null; + } + + public CollectionLike(IReadOnlyCollection collection) + { + _rw = null; + _rx = collection ?? throw new ArgumentNullException(nameof(collection)); + } + + public int Count => _rw?.Count ?? _rx?.Count ?? 0; + + public IEnumerator GetEnumerator() => + _rw?.GetEnumerator() ?? _rx?.GetEnumerator() ?? Enumerable.Empty().GetEnumerator(); + } +} diff --git a/MoreLinq/CountDown.cs b/MoreLinq/CountDown.cs index f5cb867f0..f1467cc0f 100644 --- a/MoreLinq/CountDown.cs +++ b/MoreLinq/CountDown.cs @@ -59,8 +59,8 @@ public static IEnumerable CountDown(this IEnumerable sou return source.TryAsListLike() is { } listLike ? IterateList(listLike) - : source.TryGetCollectionCount() is { } collectionCount - ? IterateCollection(collectionCount) + : source.TryAsCollectionLike() is { } collectionLike + ? IterateCollection(collectionLike) : IterateSequence(); IEnumerable IterateList(IListLike list) @@ -72,9 +72,10 @@ IEnumerable IterateList(IListLike list) yield return resultSelector(list[i], listCount - i <= count ? --countdown : null); } - IEnumerable IterateCollection(int i) + IEnumerable IterateCollection(CollectionLike collection) { - foreach (var item in source) + var i = collection.Count; + foreach (var item in collection) yield return resultSelector(item, i-- <= count ? i : null); } diff --git a/MoreLinq/CountMethods.cs b/MoreLinq/CountMethods.cs index c2eef6abb..9e9818ff2 100644 --- a/MoreLinq/CountMethods.cs +++ b/MoreLinq/CountMethods.cs @@ -136,7 +136,7 @@ static bool QuantityIterator(IEnumerable source, int limit, int min, int m { if (source == null) throw new ArgumentNullException(nameof(source)); - var count = source.TryGetCollectionCount() ?? source.CountUpTo(limit); + var count = source.TryAsCollectionLike()?.Count ?? source.CountUpTo(limit); return count >= min && count <= max; } @@ -167,11 +167,11 @@ public static int CompareCount(this IEnumerable first, if (first == null) throw new ArgumentNullException(nameof(first)); if (second == null) throw new ArgumentNullException(nameof(second)); - if (first.TryGetCollectionCount() is { } firstCount) + if (first.TryAsCollectionLike() is { Count: var firstCount }) { - return firstCount.CompareTo(second.TryGetCollectionCount() ?? second.CountUpTo(firstCount + 1)); + return firstCount.CompareTo(second.TryAsCollectionLike()?.Count ?? second.CountUpTo(firstCount + 1)); } - else if (second.TryGetCollectionCount() is { } secondCount) + else if (second.TryAsCollectionLike() is { Count: var secondCount }) { return first.CountUpTo(secondCount + 1).CompareTo(secondCount); } diff --git a/MoreLinq/EndsWith.cs b/MoreLinq/EndsWith.cs index 3ff560bf5..d414a3dc0 100644 --- a/MoreLinq/EndsWith.cs +++ b/MoreLinq/EndsWith.cs @@ -75,8 +75,8 @@ public static bool EndsWith(this IEnumerable first, IEnumerable second, List secondList; #pragma warning disable IDE0075 // Simplify conditional expression (makes it worse) - return second.TryGetCollectionCount() is { } secondCount - ? first.TryGetCollectionCount() is { } firstCount && secondCount > firstCount + return second.TryAsCollectionLike() is { Count: var secondCount } + ? first.TryAsCollectionLike() is { Count: var firstCount } && secondCount > firstCount ? false : Impl(second, secondCount) : Impl(secondList = second.ToList(), secondList.Count); diff --git a/MoreLinq/Experimental/TrySingle.cs b/MoreLinq/Experimental/TrySingle.cs index 881e4ee75..06f2115c9 100644 --- a/MoreLinq/Experimental/TrySingle.cs +++ b/MoreLinq/Experimental/TrySingle.cs @@ -101,11 +101,11 @@ public static TResult TrySingle(this IEnumerable so if (source == null) throw new ArgumentNullException(nameof(source)); if (resultSelector == null) throw new ArgumentNullException(nameof(resultSelector)); - switch (source.TryGetCollectionCount()) + switch (source.TryAsCollectionLike()) { - case 0: + case { Count: 0 }: return resultSelector(zero, default); - case 1: + case { Count: 1 }: { var item = source switch { diff --git a/MoreLinq/FallbackIfEmpty.cs b/MoreLinq/FallbackIfEmpty.cs index 2acbeae26..159a7de4f 100644 --- a/MoreLinq/FallbackIfEmpty.cs +++ b/MoreLinq/FallbackIfEmpty.cs @@ -161,14 +161,11 @@ static IEnumerable FallbackIfEmptyImpl(IEnumerable source, int? count, T? fallback1, T? fallback2, T? fallback3, T? fallback4, IEnumerable? fallback) { - return source.TryGetCollectionCount() is { } collectionCount - ? collectionCount == 0 ? Fallback() : source - : _(); - - IEnumerable _() + return _(); IEnumerable _() { - using (var e = source.GetEnumerator()) + if (source.TryAsCollectionLike() is null or { Count: > 0 }) { + using var e = source.GetEnumerator(); if (e.MoveNext()) { do { yield return e.Current; } @@ -177,15 +174,14 @@ IEnumerable _() } } - foreach (var item in Fallback()) - yield return item; - } - - IEnumerable Fallback() - { - return fallback ?? FallbackOnArgs(); + if (fallback is { } someFallback) + { + Debug.Assert(count is null); - IEnumerable FallbackOnArgs() + foreach (var item in someFallback) + yield return item; + } + else { Debug.Assert(count is >= 1 and <= 4); diff --git a/MoreLinq/MoreEnumerable.cs b/MoreLinq/MoreEnumerable.cs index b80983c3e..18aa05dd4 100644 --- a/MoreLinq/MoreEnumerable.cs +++ b/MoreLinq/MoreEnumerable.cs @@ -27,12 +27,12 @@ namespace MoreLinq public static partial class MoreEnumerable { - internal static int? TryGetCollectionCount(this IEnumerable source) => + internal static CollectionLike? TryAsCollectionLike(this IEnumerable source) => source switch { null => throw new ArgumentNullException(nameof(source)), - ICollection collection => collection.Count, - IReadOnlyCollection collection => collection.Count, + ICollection collection => new CollectionLike(collection), + IReadOnlyCollection collection => new CollectionLike(collection), _ => null }; diff --git a/MoreLinq/PadStart.cs b/MoreLinq/PadStart.cs index 5fdfd7a55..2979da6b3 100644 --- a/MoreLinq/PadStart.cs +++ b/MoreLinq/PadStart.cs @@ -19,7 +19,6 @@ namespace MoreLinq { using System; using System.Collections.Generic; - using System.Linq; static partial class MoreEnumerable { @@ -118,42 +117,47 @@ public static IEnumerable PadStart(this IEnumerable s static IEnumerable PadStartImpl(IEnumerable source, int width, T? padding, Func? paddingSelector) { - return - source.TryGetCollectionCount() is { } collectionCount - ? collectionCount >= width - ? source - : Enumerable.Range(0, width - collectionCount) - .Select(i => paddingSelector != null ? paddingSelector(i) : padding!) - .Concat(source) - : _(); IEnumerable _() + return _(); IEnumerable _() { - var array = new T[width]; - var count = 0; + if (source.TryAsCollectionLike() is { Count: var collectionCount } && collectionCount < width) + { + var paddingCount = width - collectionCount; + for (var i = 0; i < paddingCount; i++) + yield return paddingSelector is { } selector ? selector(i) : padding!; - using (var e = source.GetEnumerator()) + foreach (var item in source) + yield return item; + } + else { - for (; count < width && e.MoveNext(); count++) - array[count] = e.Current; + var array = new T[width]; + var count = 0; - if (count == width) + using (var e = source.GetEnumerator()) { - for (var i = 0; i < count; i++) - yield return array[i]; + for (; count < width && e.MoveNext(); count++) + array[count] = e.Current; + + if (count == width) + { + for (var i = 0; i < count; i++) + yield return array[i]; - while (e.MoveNext()) - yield return e.Current; + while (e.MoveNext()) + yield return e.Current; - yield break; + yield break; + } } - } - var len = width - count; + var len = width - count; - for (var i = 0; i < len; i++) - yield return paddingSelector != null ? paddingSelector(i) : padding!; + for (var i = 0; i < len; i++) + yield return paddingSelector != null ? paddingSelector(i) : padding!; - for (var i = 0; i < count; i++) - yield return array[i]; + for (var i = 0; i < count; i++) + yield return array[i]; + } } } } diff --git a/MoreLinq/SkipLast.cs b/MoreLinq/SkipLast.cs index e97a29c37..0483f3e74 100644 --- a/MoreLinq/SkipLast.cs +++ b/MoreLinq/SkipLast.cs @@ -38,7 +38,6 @@ public static IEnumerable SkipLast(this IEnumerable source, int count) if (source == null) throw new ArgumentNullException(nameof(source)); return count < 1 ? source - : source.TryGetCollectionCount() is { } collectionCount ? source.Take(collectionCount - count) : source.CountDown(count, (e, cd) => (Element: e, Countdown: cd)) .TakeWhile(e => e.Countdown == null) .Select(e => e.Element); diff --git a/MoreLinq/StartsWith.cs b/MoreLinq/StartsWith.cs index bdfcc17e7..f3b8ee5f0 100644 --- a/MoreLinq/StartsWith.cs +++ b/MoreLinq/StartsWith.cs @@ -73,8 +73,8 @@ public static bool StartsWith(this IEnumerable first, IEnumerable secon if (first == null) throw new ArgumentNullException(nameof(first)); if (second == null) throw new ArgumentNullException(nameof(second)); - if (first.TryGetCollectionCount() is { } firstCount && - second.TryGetCollectionCount() is { } secondCount && + if (first.TryAsCollectionLike() is { Count: var firstCount } && + second.TryAsCollectionLike() is { Count: var secondCount } && secondCount > firstCount) { return false; diff --git a/MoreLinq/TakeLast.cs b/MoreLinq/TakeLast.cs index 0512342dd..6619ff196 100644 --- a/MoreLinq/TakeLast.cs +++ b/MoreLinq/TakeLast.cs @@ -51,11 +51,9 @@ public static IEnumerable TakeLast(this IEnumerable s if (source == null) throw new ArgumentNullException(nameof(source)); return count < 1 ? Enumerable.Empty() - : source.TryGetCollectionCount() is { } collectionCount - ? source.Slice(Math.Max(0, collectionCount - count), int.MaxValue) - : source.CountDown(count, (e, cd) => (Element: e, Countdown: cd)) - .SkipWhile(e => e.Countdown == null) - .Select(e => e.Element); + : source.CountDown(count, (e, cd) => (Element: e, Countdown: cd)) + .SkipWhile(e => e.Countdown == null) + .Select(e => e.Element); } } }