Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix collection-optimized paths to be at iteration-time #946

Merged
merged 1 commit into from
Jan 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions MoreLinq.Test/AssertCountTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
namespace MoreLinq.Test
{
using System;
using System.Collections.Generic;
using NUnit.Framework;

[TestFixture]
Expand Down Expand Up @@ -113,13 +114,6 @@ public void AssertCountWithCollectionIsLazy()
_ = new BreakingCollection<int>(new int[5]).AssertCount(0);
}

[Test]
public void AssertCountWithMatchingCollectionCount()
{
var xs = new[] { 123, 456, 789 };
Assert.That(xs, Is.SameAs(xs.AssertCount(3)));
}

[TestCase(3, 2, "Sequence contains too many elements when exactly 2 were expected.")]
[TestCase(3, 4, "Sequence contains too few elements when exactly 4 were expected.")]
public void AssertCountWithMismatchingCollectionCount(int sourceCount, int count, string message)
Expand All @@ -135,5 +129,15 @@ public void AssertCountWithReadOnlyCollectionIsLazy()
{
_ = new BreakingReadOnlyCollection<object>(5).AssertCount(0);
}

[Test]
public void AssertCountUsesCollectionCountAtIterationTime()
{
var stack = new Stack<int>(Enumerable.Range(1, 3));
var result = stack.AssertCount(4);
stack.Push(4);
result.Consume();
Assert.Pass();
}
}
}
9 changes: 9 additions & 0 deletions MoreLinq.Test/CountDownTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -211,5 +211,14 @@ public ReadOnlyCollection(ICollection<T> collection,
protected override IEnumerable<T> Items => _collection;
}
}

[Test]
public void UsesCollectionCountAtIterationTime()
{
var stack = new Stack<int>(Enumerable.Range(1, 3));
var result = stack.CountDown(2, (_, cd) => cd);
stack.Push(4);
result.AssertSequenceEqual(null, null, 1, 0);
}
}
}
46 changes: 21 additions & 25 deletions MoreLinq.Test/FallbackIfEmptyTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

namespace MoreLinq.Test
{
using System.Collections.Generic;
using NUnit.Framework;

[TestFixture]
Expand All @@ -36,37 +37,32 @@ public void FallbackIfEmptyWithEmptySequence()
// ReSharper restore PossibleMultipleEnumeration
}

[TestCase(SourceKind.BreakingCollection)]
[TestCase(SourceKind.BreakingReadOnlyCollection)]
public void FallbackIfEmptyPreservesSourceCollectionIfPossible(SourceKind sourceKind)
{
var source = new[] { 1 }.ToSourceKind(sourceKind);
// ReSharper disable PossibleMultipleEnumeration
Assert.That(source.FallbackIfEmpty(12), Is.SameAs(source));
Assert.That(source.FallbackIfEmpty(12, 23), Is.SameAs(source));
Assert.That(source.FallbackIfEmpty(12, 23, 34), Is.SameAs(source));
Assert.That(source.FallbackIfEmpty(12, 23, 34, 45), Is.SameAs(source));
Assert.That(source.FallbackIfEmpty(12, 23, 34, 45, 56), Is.SameAs(source));
Assert.That(source.FallbackIfEmpty(12, 23, 34, 45, 56, 67), Is.SameAs(source));
// ReSharper restore PossibleMultipleEnumeration
}

[TestCase(SourceKind.BreakingCollection)]
[TestCase(SourceKind.BreakingReadOnlyCollection)]
public void FallbackIfEmptyPreservesFallbackCollectionIfPossible(SourceKind sourceKind)
{
var source = new int[0].ToSourceKind(sourceKind);
var fallback = new[] { 1 };
Assert.That(source.FallbackIfEmpty(fallback), Is.SameAs(fallback));
Assert.That(source.FallbackIfEmpty(fallback.AsEnumerable()), Is.SameAs(fallback));
}

[Test]
public void FallbackIfEmptyWithEmptyNullableSequence()
{
var source = Enumerable.Empty<int?>().Select(x => x);
var fallback = (int?)null;
source.FallbackIfEmpty(fallback).AssertSequenceEqual(fallback);
}

[Test]
public void FallbackUsesCollectionCountAtIterationTime()
{
var source = new List<int>();

var results = new[]
{
source.FallbackIfEmpty(-1),
source.FallbackIfEmpty(-1, -2),
source.FallbackIfEmpty(-1, -2, -3),
source.FallbackIfEmpty(-1, -2, -3, -4),
source.FallbackIfEmpty(-1, -2, -3, -4, -5),
};

source.Add(123);

foreach (var result in results)
result.AssertSequenceEqual(123);
}
}
}
8 changes: 8 additions & 0 deletions MoreLinq.Test/PadStartTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,14 @@ public void ReferenceTypeElements(ICollection<string> source, int width, IEnumer
}
}

[Test]
public void PadStartUsesCollectionCountAtIterationTime()
{
var queue = new Queue<int>(Enumerable.Range(1, 3));
var result = queue.PadStart(4, -1);
queue.Enqueue(4);
result.AssertSequenceEqual(1, 2, 3, 4);
}

static void AssertEqual<T>(ICollection<T> input, Func<IEnumerable<T>, IEnumerable<T>> op, IEnumerable<T> expected)
{
Expand Down
10 changes: 10 additions & 0 deletions MoreLinq.Test/SkipLastTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

namespace MoreLinq.Test
{
using System.Collections.Generic;
using NUnit.Framework;

[TestFixture]
Expand Down Expand Up @@ -56,5 +57,14 @@ public void SkipLastIsLazy()
{
_ = new BreakingSequence<object>().SkipLast(1);
}

[Test]
public void SkipLastUsesCollectionCountAtIterationTime()
{
var list = new List<int> { 1, 2, 3, 4 };
var result = list.SkipLast(2);
list.Add(5);
result.AssertSequenceEqual(1, 2, 3);
}
}
}
9 changes: 9 additions & 0 deletions MoreLinq.Test/TakeLastTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ public void TakeLastOptimizedForCollections(SourceKind sourceKind)
sequence.TakeLast(3).AssertSequenceEqual(8, 9, 10);
}

[Test]
public void TakeLastUsesCollectionCountAtIterationTime()
{
var list = new List<int> { 1, 2, 3, 4 };
var result = list.TakeLast(3);
list.Add(5);
result.AssertSequenceEqual(3, 4, 5);
}

static void AssertTakeLast<T>(ICollection<T> input, int count, Action<IEnumerable<T>> action)
{
// Test that the behaviour does not change whether a collection
Expand Down
13 changes: 7 additions & 6 deletions MoreLinq/AssertCount.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,14 @@ static IEnumerable<TSource> AssertCountImpl<TSource>(IEnumerable<TSource> source
if (count < 0) throw new ArgumentOutOfRangeException(nameof(count));
if (errorSelector == null) throw new ArgumentNullException(nameof(errorSelector));

return
source.TryGetCollectionCount() is { } collectionCount
? collectionCount == count
? source
: From<TSource>(() => throw errorSelector(collectionCount.CompareTo(count), count))
: _(); IEnumerable<TSource> _()
return _(); IEnumerable<TSource> _()
{
if (source.TryAsCollectionLike() is { Count: var collectionCount }
&& collectionCount.CompareTo(count) is var comparison && comparison != 0)
{
throw errorSelector(comparison, count);
}

var iterations = 0;
foreach (var element in source)
{
Expand Down
52 changes: 52 additions & 0 deletions MoreLinq/CollectionLike.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#region License and Terms
// MoreLINQ - Extensions to LINQ to Objects
// Copyright (c) 2023 Atif Aziz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#endregion

namespace MoreLinq
{
using System;
using System.Collections.Generic;
using System.Linq;

/// <summary>
/// Represents a union over list types implementing either <see
/// cref="ICollection{T}"/> or <see cref="IReadOnlyCollection{T}"/>,
/// allowing both to be treated the same.
/// </summary>

readonly struct CollectionLike<T>
{
readonly ICollection<T>? _rw;
readonly IReadOnlyCollection<T>? _rx;

public CollectionLike(ICollection<T> collection)
{
_rw = collection ?? throw new ArgumentNullException(nameof(collection));
_rx = null;
}

public CollectionLike(IReadOnlyCollection<T> collection)
{
_rw = null;
_rx = collection ?? throw new ArgumentNullException(nameof(collection));
}

public int Count => _rw?.Count ?? _rx?.Count ?? 0;

public IEnumerator<T> GetEnumerator() =>
_rw?.GetEnumerator() ?? _rx?.GetEnumerator() ?? Enumerable.Empty<T>().GetEnumerator();
}
}
9 changes: 5 additions & 4 deletions MoreLinq/CountDown.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ public static IEnumerable<TResult> CountDown<T, TResult>(this IEnumerable<T> sou

return source.TryAsListLike() is { } listLike
? IterateList(listLike)
: source.TryGetCollectionCount() is { } collectionCount
? IterateCollection(collectionCount)
: source.TryAsCollectionLike() is { } collectionLike
? IterateCollection(collectionLike)
: IterateSequence();

IEnumerable<TResult> IterateList(IListLike<T> list)
Expand All @@ -72,9 +72,10 @@ IEnumerable<TResult> IterateList(IListLike<T> list)
yield return resultSelector(list[i], listCount - i <= count ? --countdown : null);
}

IEnumerable<TResult> IterateCollection(int i)
IEnumerable<TResult> IterateCollection(CollectionLike<T> collection)
{
foreach (var item in source)
var i = collection.Count;
foreach (var item in collection)
yield return resultSelector(item, i-- <= count ? i : null);
}

Expand Down
8 changes: 4 additions & 4 deletions MoreLinq/CountMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ static bool QuantityIterator<T>(IEnumerable<T> source, int limit, int min, int m
{
if (source == null) throw new ArgumentNullException(nameof(source));

var count = source.TryGetCollectionCount() ?? source.CountUpTo(limit);
var count = source.TryAsCollectionLike()?.Count ?? source.CountUpTo(limit);

return count >= min && count <= max;
}
Expand Down Expand Up @@ -167,11 +167,11 @@ public static int CompareCount<TFirst, TSecond>(this IEnumerable<TFirst> first,
if (first == null) throw new ArgumentNullException(nameof(first));
if (second == null) throw new ArgumentNullException(nameof(second));

if (first.TryGetCollectionCount() is { } firstCount)
if (first.TryAsCollectionLike() is { Count: var firstCount })
{
return firstCount.CompareTo(second.TryGetCollectionCount() ?? second.CountUpTo(firstCount + 1));
return firstCount.CompareTo(second.TryAsCollectionLike()?.Count ?? second.CountUpTo(firstCount + 1));
}
else if (second.TryGetCollectionCount() is { } secondCount)
else if (second.TryAsCollectionLike() is { Count: var secondCount })
{
return first.CountUpTo(secondCount + 1).CompareTo(secondCount);
}
Expand Down
4 changes: 2 additions & 2 deletions MoreLinq/EndsWith.cs
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ public static bool EndsWith<T>(this IEnumerable<T> first, IEnumerable<T> second,

List<T> secondList;
#pragma warning disable IDE0075 // Simplify conditional expression (makes it worse)
return second.TryGetCollectionCount() is { } secondCount
? first.TryGetCollectionCount() is { } firstCount && secondCount > firstCount
return second.TryAsCollectionLike() is { Count: var secondCount }
? first.TryAsCollectionLike() is { Count: var firstCount } && secondCount > firstCount
? false
: Impl(second, secondCount)
: Impl(secondList = second.ToList(), secondList.Count);
Expand Down
6 changes: 3 additions & 3 deletions MoreLinq/Experimental/TrySingle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,11 @@ public static TResult TrySingle<T, TCardinality, TResult>(this IEnumerable<T> so
if (source == null) throw new ArgumentNullException(nameof(source));
if (resultSelector == null) throw new ArgumentNullException(nameof(resultSelector));

switch (source.TryGetCollectionCount())
switch (source.TryAsCollectionLike())
{
case 0:
case { Count: 0 }:
return resultSelector(zero, default);
case 1:
case { Count: 1 }:
{
var item = source switch
{
Expand Down
24 changes: 10 additions & 14 deletions MoreLinq/FallbackIfEmpty.cs
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,11 @@ static IEnumerable<T> FallbackIfEmptyImpl<T>(IEnumerable<T> source,
int? count, T? fallback1, T? fallback2, T? fallback3, T? fallback4,
IEnumerable<T>? fallback)
{
return source.TryGetCollectionCount() is { } collectionCount
? collectionCount == 0 ? Fallback() : source
: _();

IEnumerable<T> _()
return _(); IEnumerable<T> _()
{
using (var e = source.GetEnumerator())
if (source.TryAsCollectionLike() is null or { Count: > 0 })
{
using var e = source.GetEnumerator();
if (e.MoveNext())
{
do { yield return e.Current; }
Expand All @@ -177,15 +174,14 @@ IEnumerable<T> _()
}
}

foreach (var item in Fallback())
yield return item;
}

IEnumerable<T> Fallback()
{
return fallback ?? FallbackOnArgs();
if (fallback is { } someFallback)
{
Debug.Assert(count is null);

IEnumerable<T> FallbackOnArgs()
foreach (var item in someFallback)
yield return item;
}
else
{
Debug.Assert(count is >= 1 and <= 4);

Expand Down
6 changes: 3 additions & 3 deletions MoreLinq/MoreEnumerable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ namespace MoreLinq

public static partial class MoreEnumerable
{
internal static int? TryGetCollectionCount<T>(this IEnumerable<T> source) =>
internal static CollectionLike<T>? TryAsCollectionLike<T>(this IEnumerable<T> source) =>
source switch
{
null => throw new ArgumentNullException(nameof(source)),
ICollection<T> collection => collection.Count,
IReadOnlyCollection<T> collection => collection.Count,
ICollection<T> collection => new CollectionLike<T>(collection),
IReadOnlyCollection<T> collection => new CollectionLike<T>(collection),
_ => null
};

Expand Down
Loading