diff --git a/MoreLinq.Test/BatchTest.cs b/MoreLinq.Test/BatchTest.cs index c5c9186e7..a459ee2dd 100644 --- a/MoreLinq.Test/BatchTest.cs +++ b/MoreLinq.Test/BatchTest.cs @@ -23,18 +23,12 @@ namespace MoreLinq.Test [TestFixture] public class BatchTest { - [Test] - public void BatchZeroSize() - { - AssertThrowsArgument.OutOfRangeException("size",() => - new object[0].Batch(0)); - } - - [Test] - public void BatchNegativeSize() + [TestCase(0)] + [TestCase(-1)] + public void BatchBadSize(int size) { - AssertThrowsArgument.OutOfRangeException("size",() => - new object[0].Batch(-1)); + AssertThrowsArgument.OutOfRangeException("size", () => + new object[0].Batch(size)); } [Test] @@ -141,3 +135,259 @@ public void BatchEmptySource(SourceKind kind) } } } + +#if NETCOREAPP3_1_OR_GREATER + +namespace MoreLinq.Test +{ + using System; + using System.Buffers; + using System.Collections.Generic; + using MoreLinq.Experimental; + using NUnit.Framework; + + [TestFixture] + public class BatchPoolTest + { + [TestCase(0)] + [TestCase(-1)] + public void BatchBadSize(int size) + { + AssertThrowsArgument.OutOfRangeException("size", () => + new object[0].Batch(size, ArrayPool.Shared, + BreakingFunc.Of, IEnumerable>(), + BreakingFunc.Of, object>())); + } + + [Test] + public void BatchEvenlyDivisibleSequence() + { + using var input = TestingSequence.Of(1, 2, 3, 4, 5, 6, 7, 8, 9); + using var pool = new TestArrayPool(); + + var result = input.Batch(3, pool, Enumerable.ToArray); + + using var reader = result.Read(); + reader.Read().AssertSequenceEqual(1, 2, 3); + reader.Read().AssertSequenceEqual(4, 5, 6); + reader.Read().AssertSequenceEqual(7, 8, 9); + reader.ReadEnd(); + } + + [Test] + public void BatchUnevenlyDivisibleSequence() + { + using var input = TestingSequence.Of(1, 2, 3, 4, 5, 6, 7, 8, 9); + using var pool = new TestArrayPool(); + + var result = input.Batch(4, pool, Enumerable.ToArray); + + using var reader = result.Read(); + reader.Read().AssertSequenceEqual(1, 2, 3, 4); + reader.Read().AssertSequenceEqual(5, 6, 7, 8); + reader.Read().AssertSequenceEqual(9); + reader.ReadEnd(); + } + + [Test] + public void BatchIsLazy() + { + var input = new BreakingSequence(); + _ = input.Batch(1, ArrayPool.Shared, + BreakingFunc.Of, IEnumerable>(), + BreakingFunc.Of, object>()); + } + + [TestCase(SourceKind.BreakingList , 0)] + [TestCase(SourceKind.BreakingReadOnlyList, 0)] + [TestCase(SourceKind.BreakingList , 1)] + [TestCase(SourceKind.BreakingReadOnlyList, 1)] + [TestCase(SourceKind.BreakingList , 2)] + [TestCase(SourceKind.BreakingReadOnlyList, 2)] + public void BatchCollectionSmallerThanSize(SourceKind kind, int oversize) + { + var xs = new[] { 1, 2, 3, 4, 5 }; + using var pool = new TestArrayPool(); + + var result = xs.ToSourceKind(kind) + .Batch(xs.Length + oversize, pool, Enumerable.ToArray); + + using var reader = result.Read(); + reader.Read().AssertSequenceEqual(1, 2, 3, 4, 5); + reader.ReadEnd(); + } + + [Test] + public void BatchReadOnlyCollectionSmallerThanSize() + { + var collection = ReadOnlyCollection.From(1, 2, 3, 4, 5); + using var pool = new TestArrayPool(); + + var result = collection.Batch(collection.Count * 2, pool, + Enumerable.ToArray); + + using var reader = result.Read(); + reader.Read().AssertSequenceEqual(1, 2, 3, 4, 5); + reader.ReadEnd(); + } + + [TestCase(SourceKind.Sequence)] + [TestCase(SourceKind.BreakingList)] + [TestCase(SourceKind.BreakingReadOnlyList)] + [TestCase(SourceKind.BreakingReadOnlyCollection)] + [TestCase(SourceKind.BreakingCollection)] + public void BatchEmptySource(SourceKind kind) + { + using var pool = new TestArrayPool(); + + var result = Enumerable.Empty() + .ToSourceKind(kind) + .Batch(100, pool, Enumerable.ToArray); + + Assert.That(result, Is.Empty); + } + + [Test] + public void BatchFilterBucket() + { + const int scale = 2; + var input = TestingSequence.Of(1, 2, 3, 4, 5, 6, 7, 8, 9); + using var pool = new TestArrayPool(); + + var result = input.Batch(3, pool, + current => from n in current + where n % 2 == 0 + select n * scale, + Enumerable.ToArray); + + using var reader = result.Read(); + reader.Read().AssertSequenceEqual(2 * scale); + reader.Read().AssertSequenceEqual(4 * scale, 6 * scale); + reader.Read().AssertSequenceEqual(8 * scale); + reader.ReadEnd(); + } + + [Test] + public void BatchSumBucket() + { + var input = TestingSequence.Of(1, 2, 3, 4, 5, 6, 7, 8, 9); + using var pool = new TestArrayPool(); + + var result = input.Batch(3, pool, Enumerable.Sum); + + using var reader = result.Read(); + Assert.That(reader.Read(), Is.EqualTo(1 + 2 + 3)); + Assert.That(reader.Read(), Is.EqualTo(4 + 5 + 6)); + Assert.That(reader.Read(), Is.EqualTo(7 + 8 + 9)); + reader.ReadEnd(); + } + + /// + /// This test does not exercise the intended usage! + /// + + [Test] + public void BatchUpdatesCurrentListInPlace() + { + var input = TestingSequence.Of(1, 2, 3, 4, 5, 6, 7, 8, 9); + using var pool = new TestArrayPool(); + + var result = input.Batch(4, pool, current => current, current => (ICurrentBuffer)current); + + using var reader = result.Read(); + var current = reader.Read(); + current.AssertSequenceEqual(1, 2, 3, 4); + _ = reader.Read(); + current.AssertSequenceEqual(5, 6, 7, 8); + _ = reader.Read(); + current.AssertSequenceEqual(9); + + reader.ReadEnd(); + + Assert.That(current, Is.Empty); + } + + [Test] + public void BatchCallsBucketSelectorBeforeIteratingSource() + { + var iterations = 0; + IEnumerable Source() + { + iterations++; + yield break; + } + + var input = Source(); + using var pool = new TestArrayPool(); + var initIterations = -1; + + var result = input.Batch(4, pool, + current => + { + initIterations = iterations; + return current; + }, + _ => 0); + + using var enumerator = result.GetEnumerator(); + Assert.That(enumerator.MoveNext(), Is.False); + Assert.That(initIterations, Is.Zero); + } + + [Test] + public void BatchBucketSelectorCurrentList() + { + var input = TestingSequence.Of(1, 2, 3, 4, 5, 6, 7, 8, 9); + using var pool = new TestArrayPool(); + int[] bucketSelectorItems = null; + + var result = input.Batch(4, pool, current => bucketSelectorItems = current.ToArray(), _ => 0); + + using var reader = result.Read(); + _ = reader.Read(); + Assert.That(bucketSelectorItems, Is.Not.Null); + Assert.That(bucketSelectorItems, Is.Empty); + } + + /// + /// An implementation for testing purposes that holds only + /// one array in the pool and ensures that it is returned when the pool is disposed. + /// + + sealed class TestArrayPool : ArrayPool, IDisposable + { + T[] _pooledArray; + T[] _rentedArray; + + public override T[] Rent(int minimumLength) + { + if (_pooledArray is null && _rentedArray is null) + _pooledArray = new T[minimumLength * 2]; + + if (_pooledArray is null) + throw new InvalidOperationException("The pool is exhausted."); + + (_pooledArray, _rentedArray) = (null, _pooledArray); + + return _rentedArray; + } + + public override void Return(T[] array, bool clearArray = false) + { + if (_rentedArray is null) + throw new InvalidOperationException("Cannot return when nothing has been rented from this pool."); + + if (array != _rentedArray) + throw new InvalidOperationException("Cannot return what has not been rented from this pool."); + + _pooledArray = array; + _rentedArray = null; + } + + public void Dispose() => + Assert.That(_rentedArray, Is.Null); + } + } +} + +#endif // NETCOREAPP3_1_OR_GREATER diff --git a/MoreLinq/Experimental/Batch.cs b/MoreLinq/Experimental/Batch.cs new file mode 100644 index 000000000..ae40954ff --- /dev/null +++ b/MoreLinq/Experimental/Batch.cs @@ -0,0 +1,294 @@ +#region License and Terms +// MoreLINQ - Extensions to LINQ to Objects +// Copyright (c) 2022 Atif Aziz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#endregion + +#if !NO_BUFFERS + +namespace MoreLinq.Experimental +{ + using System; + using System.Buffers; + using System.Collections.Generic; + using System.Diagnostics; + using System.Linq; + + static partial class ExperimentalEnumerable + { + /// + /// Batches the source sequence into sized buckets using an array pool + /// to rent arrays to back each bucket and returns a sequence of + /// elements projected from each bucket. + /// + /// + /// Type of elements in sequence. + /// + /// Type of elements of the resulting sequence. + /// + /// The source sequence. + /// Size of buckets. + /// The pool used to rent the array for each bucket. + /// A function that projects a result from + /// the current bucket. + /// + /// A sequence whose elements are projected from each bucket (returned by + /// ). + /// + /// + /// + /// This operator uses deferred execution and streams its results + /// (however, each bucket provided to is + /// buffered). + /// + /// + /// Each bucket is backed by a rented array that may be at least + /// in length. + /// + /// + /// When more than one bucket is produced, all buckets except the last + /// is guaranteed to have elements. The last + /// bucket may be smaller depending on the remaining elements in the + /// sequence. + /// Each bucket is pre-allocated to elements. + /// If is set to a very large value, e.g. + /// to effectively disable batching by just + /// hoping for a single bucket, then it can lead to memory exhaustion + /// (). + /// + /// + + public static IEnumerable + Batch(this IEnumerable source, int size, + ArrayPool pool, + Func, TResult> resultSelector) + { + if (source == null) throw new ArgumentNullException(nameof(source)); + if (pool == null) throw new ArgumentNullException(nameof(pool)); + if (size <= 0) throw new ArgumentOutOfRangeException(nameof(size)); + if (resultSelector == null) throw new ArgumentNullException(nameof(resultSelector)); + + return source.Batch(size, pool, current => current, + current => resultSelector((ICurrentBuffer)current)); + } + + /// + /// Batches the source sequence into sized buckets using an array pool + /// to rent arrays to back each bucket and returns a sequence of + /// elements projected from each bucket. + /// + /// + /// Type of elements in sequence. + /// + /// Type of elements in the sequence returned by . + /// + /// Type of elements of the resulting sequence. + /// + /// The source sequence. + /// Size of buckets. + /// The pool used to rent the array for each bucket. + /// A function that returns a + /// sequence projection to use for each bucket. It is called initially + /// before iterating over , but the resulting + /// projection is evaluated for each bucket. This has the same effect as + /// calling for each bucket, + /// but allows initialization of the transformation to happen only once. + /// + /// A function that projects a result from + /// the input sequence produced over a bucket. + /// + /// A sequence whose elements are projected from each bucket (returned by + /// ). + /// + /// + /// + /// This operator uses deferred execution and streams its results + /// (however, each bucket is buffered). + /// + /// + /// Each bucket is backed by a rented array that may be at least + /// in length. + /// + /// + /// When more than one bucket is produced, all buckets except the last + /// is guaranteed to have elements. The last + /// bucket may be smaller depending on the remaining elements in the + /// sequence. + /// Each bucket is pre-allocated to elements. + /// If is set to a very large value, e.g. + /// to effectively disable batching by just + /// hoping for a single bucket, then it can lead to memory exhaustion + /// (). + /// + /// + + public static IEnumerable + Batch( + this IEnumerable source, int size, ArrayPool pool, + Func, IEnumerable> bucketProjectionSelector, + Func, TResult> resultSelector) + { + if (source == null) throw new ArgumentNullException(nameof(source)); + if (pool == null) throw new ArgumentNullException(nameof(pool)); + if (size <= 0) throw new ArgumentOutOfRangeException(nameof(size)); + if (bucketProjectionSelector == null) throw new ArgumentNullException(nameof(bucketProjectionSelector)); + if (resultSelector == null) throw new ArgumentNullException(nameof(resultSelector)); + + return _(); IEnumerable _() + { + using var batch = source.Batch(size, pool); + var bucket = bucketProjectionSelector(batch.CurrentBuffer); + while (batch.UpdateWithNext()) + yield return resultSelector(bucket); + } + } + + static ICurrentBufferProvider + Batch(this IEnumerable source, int size, ArrayPool pool) + { + if (source == null) throw new ArgumentNullException(nameof(source)); + if (pool == null) throw new ArgumentNullException(nameof(pool)); + if (size <= 0) throw new ArgumentOutOfRangeException(nameof(size)); + + ICurrentBufferProvider Cursor(IEnumerator<(T[], int)> source) => + new CurrentPoolArrayProvider(source, pool); + + switch (source) + { + case ICollection { Count: 0 }: + { + return Cursor(Enumerable.Empty <(T[], int)>().GetEnumerator()); + } + case ICollection collection when collection.Count <= size: + { + var bucket = pool.Rent(collection.Count); + collection.CopyTo(bucket, 0); + return Cursor(MoreEnumerable.Return((bucket, collection.Count)).GetEnumerator()); + } + case IReadOnlyCollection { Count: 0 }: + { + return Cursor(Enumerable.Empty <(T[], int)>().GetEnumerator()); + } + case IReadOnlyList list when list.Count <= size: + { + return Cursor(_()); IEnumerator<(T[], int)> _() + { + var bucket = pool.Rent(list.Count); + for (var i = 0; i < list.Count; i++) + bucket[i] = list[i]; + yield return (bucket, list.Count); + } + } + case IReadOnlyCollection collection when collection.Count <= size: + { + return Cursor(Batch(collection.Count)); + } + default: + { + return Cursor(Batch(size)); + } + } + + IEnumerator<(T[], int)> Batch(int size) + { + T[]? bucket = null; + var count = 0; + + foreach (var item in source) + { + bucket ??= pool.Rent(size); + bucket[count++] = item; + + // The bucket is fully buffered before it's yielded + if (count != size) + continue; + + yield return (bucket, size); + + bucket = null; + count = 0; + } + + // Return the last bucket with all remaining elements + if (bucket is { } someBucket && count > 0) + yield return (someBucket, count); + } + } + + sealed class CurrentPoolArrayProvider : CurrentBuffer, ICurrentBufferProvider + { + bool _rented; + T[] _array = Array.Empty(); + int _count; + IEnumerator<(T[], int)>? _rental; + ArrayPool? _pool; + + public CurrentPoolArrayProvider(IEnumerator<(T[], int)> rental, ArrayPool pool) => + (_rental, _pool) = (rental, pool); + + ICurrentBuffer ICurrentBufferProvider.CurrentBuffer => this; + + public bool UpdateWithNext() + { + if (_rental is { Current: var (array, _) } rental) + { + Debug.Assert(_pool is not null); + if (_rented) + { + _pool.Return(array); + _rented = false; + } + + if (!rental.MoveNext()) + { + Dispose(); + return false; + } + + _rented = true; + (_array, _count) = rental.Current; + return true; + } + + return false; + } + + public override int Count => _count; + + public override T this[int index] + { + get => index >= 0 && index < Count ? _array[index] : throw new IndexOutOfRangeException(); + set => throw new NotSupportedException(); + + } + + public void Dispose() + { + if (_rental is { Current: var (array, _) } enumerator) + { + Debug.Assert(_pool is not null); + if (_rented) + _pool.Return(array); + enumerator.Dispose(); + _array = Array.Empty(); + _count = 0; + _rental = null; + _pool = null; + } + } + } + } +} + +#endif // !NO_BUFFERS diff --git a/MoreLinq/Experimental/CurrentBuffer.cs b/MoreLinq/Experimental/CurrentBuffer.cs new file mode 100644 index 000000000..e6cb526a1 --- /dev/null +++ b/MoreLinq/Experimental/CurrentBuffer.cs @@ -0,0 +1,107 @@ +#region License and Terms +// MoreLINQ - Extensions to LINQ to Objects +// Copyright (c) 2022 Atif Aziz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#endregion + +#if !NO_BUFFERS + +namespace MoreLinq.Experimental +{ + using System; + using System.Collections; + using System.Collections.Generic; + using System.Linq; + + /// + /// Represents a current buffered view of a larger result and which + /// is updated in-place (thus current) as it is moved through the overall + /// result. + /// + /// Type of elements in the list. + + public interface ICurrentBuffer : IList { } + + /// + /// A provider of current buffer that updates it in-place. + /// + /// Type of elements in the list. + + interface ICurrentBufferProvider : IDisposable + { + /// + /// Gets the current items of the list. + /// + /// + /// The returned list is updated in-place when + /// is called. + /// + + ICurrentBuffer CurrentBuffer { get; } + + /// + /// Update this instance with the next set of elements from the source. + /// + /// + /// A Boolean that is true if this instance was updated with + /// new elements; otherwise false to indicate that the end of + /// the bucket source has been reached. + /// + + bool UpdateWithNext(); + } + + abstract class CurrentBuffer : ICurrentBuffer + { + public abstract int Count { get; } + public abstract T this[int index] { get; set; } + + public virtual bool IsReadOnly => false; + + public virtual int IndexOf(T item) + { + var comparer = EqualityComparer.Default; + + for (var i = 0; i < Count; i++) + { + if (comparer.Equals(this[i], item)) + return i; + } + + return -1; + } + + public virtual bool Contains(T item) => IndexOf(item) >= 0; + + public virtual void CopyTo(T[] array, int arrayIndex) + { + if (arrayIndex < 0) throw new ArgumentOutOfRangeException(nameof(arrayIndex), arrayIndex, null); + if (arrayIndex + Count > array.Length) throw new ArgumentException(null, nameof(arrayIndex)); + + for (int i = 0, j = arrayIndex; i < Count; i++, j++) + array[j] = this[i]; + } + + public virtual IEnumerator GetEnumerator() => this.Take(Count).GetEnumerator(); + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + void IList.Insert(int index, T item) => throw new NotSupportedException(); + void IList.RemoveAt(int index) => throw new NotSupportedException(); + void ICollection.Add(T item) => throw new NotSupportedException(); + void ICollection.Clear() => throw new NotSupportedException(); + bool ICollection.Remove(T item) => throw new NotSupportedException(); + } +} + +#endif // !NO_BUFFERS diff --git a/MoreLinq/MoreLinq.csproj b/MoreLinq/MoreLinq.csproj index 35e2e7c3f..f428cf3e3 100644 --- a/MoreLinq/MoreLinq.csproj +++ b/MoreLinq/MoreLinq.csproj @@ -119,7 +119,7 @@ en-US 3.3.2 MoreLINQ Developers. - net451;netstandard1.0;netstandard2.0;net6.0 + net451;netstandard1.0;netstandard2.0;netstandard2.1;net6.0 enable @@ -192,8 +192,12 @@ $(DefineConstants);MORELINQ + + $(DefineConstants);NO_BUFFERS + + - $(DefineConstants);NO_SERIALIZATION_ATTRIBUTES;NO_EXCEPTION_SERIALIZATION;NO_TRACING;NO_COM;NO_ASYNC + $(DefineConstants);NO_BUFFERS;NO_SERIALIZATION_ATTRIBUTES;NO_EXCEPTION_SERIALIZATION;NO_TRACING;NO_COM;NO_ASYNC diff --git a/README.md b/README.md index 420dc324d..0104175ee 100644 --- a/README.md +++ b/README.md @@ -141,7 +141,7 @@ the third-last element and so on. Batches the source sequence into sized buckets. -This method has 2 overloads. +This method has 4 overloads, 2 of which are experimental. ### Cartesian