From 42dbe2013b4105db8353c1b1fcd2fa69b8982d73 Mon Sep 17 00:00:00 2001 From: Atif Aziz Date: Fri, 4 Nov 2022 22:21:20 +0100 Subject: [PATCH] Rewrite switch statements as expressions (#865) --- MoreLinq.Test/FlattenTest.cs | 51 +++++++++--------------------- MoreLinq.Test/FromTest.cs | 15 ++++----- MoreLinq.Test/FullGroupJoinTest.cs | 17 ++++------ MoreLinq.Test/NullArgumentTest.cs | 10 +++--- MoreLinq.Test/TestExtensions.cs | 22 +++++-------- MoreLinq/Backsert.cs | 11 ++++--- MoreLinq/Exclude.cs | 11 ++++--- MoreLinq/Fold.cs | 38 +++++++++++----------- MoreLinq/MaxBy.cs | 18 +++++------ MoreLinq/OrderedMerge.cs | 30 ++++++++---------- 10 files changed, 98 insertions(+), 125 deletions(-) diff --git a/MoreLinq.Test/FlattenTest.cs b/MoreLinq.Test/FlattenTest.cs index a416df347..f7d0ea28a 100644 --- a/MoreLinq.Test/FlattenTest.cs +++ b/MoreLinq.Test/FlattenTest.cs @@ -325,21 +325,13 @@ public void FlattenSelector() } }; - var result = source.Flatten(obj => + var result = source.Flatten(obj => obj switch { - switch (obj) - { - case string: - return null; - case IEnumerable inner: - return inner; - case Series s: - return new object[] { s.Name, s.Attributes }; - case Attribute a: - return a.Values; - default: - return null; - } + string => null, + IEnumerable inner => inner, + Series s => new object[] { s.Name, s.Attributes }, + Attribute a => a.Values, + _ => null }); var expectations = new object[] { "series1", 1, 2, 3, 4, "series2", 5, 6 }; @@ -368,17 +360,11 @@ public void FlattenSelectorFilteringOnlyIntegers() 4, }; - var result = source.Flatten(obj => + var result = source.Flatten(obj => obj switch { - switch (obj) - { - case int: - return null; - case IEnumerable inner: - return inner; - default: - return Enumerable.Empty(); - } + int => null, + IEnumerable inner => inner, + _ => Enumerable.Empty() }); var expectations = new object[] { 1, 2, 3, 4 }; @@ -406,19 +392,12 @@ public void FlattenSelectorWithTree() ) ); - var result = new [] { source }.Flatten(obj => + var result = new[] { source }.Flatten(obj => obj switch { - switch (obj) - { - case int: - return null; - case Tree tree: - return new object[] { tree.Left, tree.Value, tree.Right }; - case IEnumerable inner: - return inner; - default: - return Enumerable.Empty(); - } + int => null, + Tree tree => new object[] { tree.Left, tree.Value, tree.Right }, + IEnumerable inner => inner, + _ => Enumerable.Empty() }); var expectations = Enumerable.Range(1, 7); diff --git a/MoreLinq.Test/FromTest.cs b/MoreLinq.Test/FromTest.cs index 9577fe231..8dbe01375 100644 --- a/MoreLinq.Test/FromTest.cs +++ b/MoreLinq.Test/FromTest.cs @@ -66,15 +66,14 @@ public void TestFromInvokesMethodsMultipleTimes(int numArgs) int F3() { evals[2]++; return -2; } int F4() { evals[3]++; return -2; } - IEnumerable results; - switch (numArgs) + var results = numArgs switch { - case 1: results = MoreEnumerable.From(F1); break; - case 2: results = MoreEnumerable.From(F1, F2); break; - case 3: results = MoreEnumerable.From(F1, F2, F3); break; - case 4: results = MoreEnumerable.From(F1, F2, F3, F4); break; - default: throw new ArgumentOutOfRangeException(nameof(numArgs)); - } + 1 => MoreEnumerable.From(F1), + 2 => MoreEnumerable.From(F1, F2), + 3 => MoreEnumerable.From(F1, F2, F3), + 4 => MoreEnumerable.From(F1, F2, F3, F4), + _ => throw new ArgumentOutOfRangeException(nameof(numArgs)) + }; results.Consume(); results.Consume(); diff --git a/MoreLinq.Test/FullGroupJoinTest.cs b/MoreLinq.Test/FullGroupJoinTest.cs index 40d8778c5..1b7838ee3 100644 --- a/MoreLinq.Test/FullGroupJoinTest.cs +++ b/MoreLinq.Test/FullGroupJoinTest.cs @@ -133,17 +133,12 @@ public void FullGroupPreservesOrder(OverloadCase overloadCase) } } - static IEnumerable<(int Key, IEnumerable First, IEnumerable Second)> FullGroupJoin(OverloadCase overloadCase, IEnumerable listA, IEnumerable listB, Func getKey) - { - switch (overloadCase) + static IEnumerable<(int Key, IEnumerable First, IEnumerable Second)> FullGroupJoin(OverloadCase overloadCase, IEnumerable listA, IEnumerable listB, Func getKey) => + overloadCase switch { - case CustomResult: - return listA.FullGroupJoin(listB, getKey, getKey, ValueTuple.Create, comparer: null); - case TupleResult: - return listA.FullGroupJoin(listB, getKey, getKey); - default: - throw new ArgumentOutOfRangeException(nameof(overloadCase)); - } - } + CustomResult => listA.FullGroupJoin(listB, getKey, getKey, ValueTuple.Create, comparer: null), + TupleResult => listA.FullGroupJoin(listB, getKey, getKey), + _ => throw new ArgumentOutOfRangeException(nameof(overloadCase)) + }; } } diff --git a/MoreLinq.Test/NullArgumentTest.cs b/MoreLinq.Test/NullArgumentTest.cs index 56f3eb8fc..8e8d0f36d 100644 --- a/MoreLinq.Test/NullArgumentTest.cs +++ b/MoreLinq.Test/NullArgumentTest.cs @@ -98,10 +98,12 @@ static Type InstantiateType(TypeInfo typeParameter) { var constraints = typeParameter.GetGenericParameterConstraints(); - if (constraints.Length == 0) return typeof (int); - if (constraints.Length == 1) return constraints.Single(); - - throw new NotImplementedException("NullArgumentTest.InstantiateType"); + return constraints.Length switch + { + 0 => typeof(int), + 1 => constraints.Single(), + _ => throw new NotImplementedException("NullArgumentTest.InstantiateType") + }; } static bool IsReferenceType(ParameterInfo parameter) => diff --git a/MoreLinq.Test/TestExtensions.cs b/MoreLinq.Test/TestExtensions.cs index 5331447dd..b6e12aee2 100644 --- a/MoreLinq.Test/TestExtensions.cs +++ b/MoreLinq.Test/TestExtensions.cs @@ -76,21 +76,15 @@ internal static IEnumerable> ArrangeCollectionTestCases(this I internal static IEnumerable ToSourceKind(this IEnumerable input, SourceKind sourceKind) { - switch (sourceKind) + return sourceKind switch { - case SourceKind.Sequence: - return input.Select(x => x); - case SourceKind.BreakingList: - return new BreakingList(input.ToList()); - case SourceKind.BreakingReadOnlyList: - return new BreakingReadOnlyList(input.ToList()); - case SourceKind.BreakingCollection: - return new BreakingCollection(input.ToList()); - case SourceKind.BreakingReadOnlyCollection: - return new BreakingReadOnlyCollection(input.ToList()); - default: - throw new ArgumentException(null, nameof(sourceKind)); - } + SourceKind.Sequence => input.Select(x => x), + SourceKind.BreakingList => new BreakingList(input.ToList()), + SourceKind.BreakingReadOnlyList => new BreakingReadOnlyList(input.ToList()), + SourceKind.BreakingCollection => new BreakingCollection(input.ToList()), + SourceKind.BreakingReadOnlyCollection => new BreakingReadOnlyCollection(input.ToList()), + _ => throw new ArgumentException(null, nameof(sourceKind)) + }; } } } diff --git a/MoreLinq/Backsert.cs b/MoreLinq/Backsert.cs index 1ac6a9d4c..3b0a0b4bf 100644 --- a/MoreLinq/Backsert.cs +++ b/MoreLinq/Backsert.cs @@ -60,12 +60,15 @@ public static IEnumerable Backsert(this IEnumerable first, IEnumerable< { if (first == null) throw new ArgumentNullException(nameof(first)); if (second == null) throw new ArgumentNullException(nameof(second)); - if (index < 0) throw new ArgumentOutOfRangeException(nameof(index), "Index cannot be negative."); - if (index == 0) - return first.Concat(second); + return index switch + { + < 0 => throw new ArgumentOutOfRangeException(nameof(index), "Index cannot be negative."), + 0 => first.Concat(second), + _ => _() + }; - return _(); IEnumerable _() + IEnumerable _() { using var e = first.CountDown(index, ValueTuple.Create).GetEnumerator(); diff --git a/MoreLinq/Exclude.cs b/MoreLinq/Exclude.cs index 6572753ca..4e3d7edc9 100644 --- a/MoreLinq/Exclude.cs +++ b/MoreLinq/Exclude.cs @@ -36,12 +36,15 @@ public static IEnumerable Exclude(this IEnumerable sequence, int startI { if (sequence == null) throw new ArgumentNullException(nameof(sequence)); if (startIndex < 0) throw new ArgumentOutOfRangeException(nameof(startIndex)); - if (count < 0) throw new ArgumentOutOfRangeException(nameof(count)); - if (count == 0) - return sequence; + return count switch + { + < 0 => throw new ArgumentOutOfRangeException(nameof(count)), + 0 => sequence, + _ => _() + }; - return _(); IEnumerable _() + IEnumerable _() { var index = -1; var endIndex = startIndex + count; diff --git a/MoreLinq/Fold.cs b/MoreLinq/Fold.cs index ae4337292..6698cec9a 100644 --- a/MoreLinq/Fold.cs +++ b/MoreLinq/Fold.cs @@ -67,26 +67,26 @@ static TResult FoldImpl(IEnumerable source, int count, foreach (var e in AssertCountImpl(source.Index(), count, OnFolderSourceSizeErrorSelector)) elements[e.Key] = e.Value; - switch (count) + return count switch { - case 1: return folder1 !(elements[0]); - case 2: return folder2 !(elements[0], elements[1]); - case 3: return folder3 !(elements[0], elements[1], elements[2]); - case 4: return folder4 !(elements[0], elements[1], elements[2], elements[3]); - case 5: return folder5 !(elements[0], elements[1], elements[2], elements[3], elements[4]); - case 6: return folder6 !(elements[0], elements[1], elements[2], elements[3], elements[4], elements[5]); - case 7: return folder7 !(elements[0], elements[1], elements[2], elements[3], elements[4], elements[5], elements[6]); - case 8: return folder8 !(elements[0], elements[1], elements[2], elements[3], elements[4], elements[5], elements[6], elements[7]); - case 9: return folder9 !(elements[0], elements[1], elements[2], elements[3], elements[4], elements[5], elements[6], elements[7], elements[8]); - case 10: return folder10!(elements[0], elements[1], elements[2], elements[3], elements[4], elements[5], elements[6], elements[7], elements[8], elements[9]); - case 11: return folder11!(elements[0], elements[1], elements[2], elements[3], elements[4], elements[5], elements[6], elements[7], elements[8], elements[9], elements[10]); - case 12: return folder12!(elements[0], elements[1], elements[2], elements[3], elements[4], elements[5], elements[6], elements[7], elements[8], elements[9], elements[10], elements[11]); - case 13: return folder13!(elements[0], elements[1], elements[2], elements[3], elements[4], elements[5], elements[6], elements[7], elements[8], elements[9], elements[10], elements[11], elements[12]); - case 14: return folder14!(elements[0], elements[1], elements[2], elements[3], elements[4], elements[5], elements[6], elements[7], elements[8], elements[9], elements[10], elements[11], elements[12], elements[13]); - case 15: return folder15!(elements[0], elements[1], elements[2], elements[3], elements[4], elements[5], elements[6], elements[7], elements[8], elements[9], elements[10], elements[11], elements[12], elements[13], elements[14]); - case 16: return folder16!(elements[0], elements[1], elements[2], elements[3], elements[4], elements[5], elements[6], elements[7], elements[8], elements[9], elements[10], elements[11], elements[12], elements[13], elements[14], elements[15]); - default: throw new NotSupportedException(); - } + 1 => folder1 !(elements[0]), + 2 => folder2 !(elements[0], elements[1]), + 3 => folder3 !(elements[0], elements[1], elements[2]), + 4 => folder4 !(elements[0], elements[1], elements[2], elements[3]), + 5 => folder5 !(elements[0], elements[1], elements[2], elements[3], elements[4]), + 6 => folder6 !(elements[0], elements[1], elements[2], elements[3], elements[4], elements[5]), + 7 => folder7 !(elements[0], elements[1], elements[2], elements[3], elements[4], elements[5], elements[6]), + 8 => folder8 !(elements[0], elements[1], elements[2], elements[3], elements[4], elements[5], elements[6], elements[7]), + 9 => folder9 !(elements[0], elements[1], elements[2], elements[3], elements[4], elements[5], elements[6], elements[7], elements[8]), + 10 => folder10!(elements[0], elements[1], elements[2], elements[3], elements[4], elements[5], elements[6], elements[7], elements[8], elements[9]), + 11 => folder11!(elements[0], elements[1], elements[2], elements[3], elements[4], elements[5], elements[6], elements[7], elements[8], elements[9], elements[10]), + 12 => folder12!(elements[0], elements[1], elements[2], elements[3], elements[4], elements[5], elements[6], elements[7], elements[8], elements[9], elements[10], elements[11]), + 13 => folder13!(elements[0], elements[1], elements[2], elements[3], elements[4], elements[5], elements[6], elements[7], elements[8], elements[9], elements[10], elements[11], elements[12]), + 14 => folder14!(elements[0], elements[1], elements[2], elements[3], elements[4], elements[5], elements[6], elements[7], elements[8], elements[9], elements[10], elements[11], elements[12], elements[13]), + 15 => folder15!(elements[0], elements[1], elements[2], elements[3], elements[4], elements[5], elements[6], elements[7], elements[8], elements[9], elements[10], elements[11], elements[12], elements[13], elements[14]), + 16 => folder16!(elements[0], elements[1], elements[2], elements[3], elements[4], elements[5], elements[6], elements[7], elements[8], elements[9], elements[10], elements[11], elements[12], elements[13], elements[14], elements[15]), + _ => throw new NotSupportedException() + }; } static readonly Func OnFolderSourceSizeErrorSelector = OnFolderSourceSizeError; diff --git a/MoreLinq/MaxBy.cs b/MoreLinq/MaxBy.cs index 098c15f00..3752c1645 100644 --- a/MoreLinq/MaxBy.cs +++ b/MoreLinq/MaxBy.cs @@ -340,16 +340,16 @@ IEnumerable Extrema() { var item = e.Current; var key = selector(item); - var comparison = comparer(key, extremaKey); - if (comparison > 0) + switch (comparer(key, extremaKey)) { - extrema.Restart(ref store); - extrema.Add(ref store, limit, item); - extremaKey = key; - } - else if (comparison == 0) - { - extrema.Add(ref store, limit, item); + case > 0: + extrema.Restart(ref store); + extrema.Add(ref store, limit, item); + extremaKey = key; + break; + case 0: + extrema.Add(ref store, limit, item); + break; } } diff --git a/MoreLinq/OrderedMerge.cs b/MoreLinq/OrderedMerge.cs index bb4551ec7..a42e08057 100644 --- a/MoreLinq/OrderedMerge.cs +++ b/MoreLinq/OrderedMerge.cs @@ -300,23 +300,21 @@ IEnumerable _(IComparer comparer) var key1 = firstKeySelector(element1); var element2 = e2.Current; var key2 = secondKeySelector(element2); - var comparison = comparer.Compare(key1, key2); - - if (comparison < 0) - { - yield return firstSelector(element1); - gotFirst = e1.MoveNext(); - } - else if (comparison > 0) - { - yield return secondSelector(element2); - gotSecond = e2.MoveNext(); - } - else + switch (comparer.Compare(key1, key2)) { - yield return bothSelector(element1, element2); - gotFirst = e1.MoveNext(); - gotSecond = e2.MoveNext(); + case < 0: + yield return firstSelector(element1); + gotFirst = e1.MoveNext(); + break; + case > 0: + yield return secondSelector(element2); + gotSecond = e2.MoveNext(); + break; + default: + yield return bothSelector(element1, element2); + gotFirst = e1.MoveNext(); + gotSecond = e2.MoveNext(); + break; } } else if (gotSecond)