From 8e2b350f8f897c0c09999f8a5f89f034953863b8 Mon Sep 17 00:00:00 2001 From: Neven Sajko Date: Sat, 2 Nov 2024 21:41:04 +0100 Subject: [PATCH] support sorting tuples Uses merge sort, as an obvious choice for a stable sort of tuples. A recursive data structure of singleton type, representing Peano natural numbers, is used to help with splitting a tuple into two halves in the merge sort. An alternative design would use a reference tuple, but this would require relying on `tail`, which seems more harsh on the compiler. With the recursive datastructure the predecessor operation and the successor operation are both trivial. Allows inference to preserve inferred element type even when tuple length is not known. Follow-up PRs may add further improvements, such as the ability to select an unstable sorting algorithm. The added file, typedomainnumbers.jl is not specific to sorting, thus making it a separate file. Xref #55571. Fixes #54489 --- NEWS.md | 1 + base/Base.jl | 2 + base/sort.jl | 85 +++++++++++++++++++++ base/typedomainnumbers.jl | 156 ++++++++++++++++++++++++++++++++++++++ test/sorting.jl | 44 +++++++++++ 5 files changed, 288 insertions(+) create mode 100644 base/typedomainnumbers.jl diff --git a/NEWS.md b/NEWS.md index ba9ca1c521c55b..2c8bd77ac8278c 100644 --- a/NEWS.md +++ b/NEWS.md @@ -119,6 +119,7 @@ New library features * `Base.require_one_based_indexing` and `Base.has_offset_axes` are now public ([#56196]) * New `ltruncate`, `rtruncate` and `ctruncate` functions for truncating strings to text width, accounting for char widths ([#55351]) * `isless` (and thus `cmp`, sorting, etc.) is now supported for zero-dimensional `AbstractArray`s ([#55772]) +* `sort` now sorts tuples (#56425) Standard library changes ------------------------ diff --git a/base/Base.jl b/base/Base.jl index 39507b625660d6..c737bb2c69f969 100644 --- a/base/Base.jl +++ b/base/Base.jl @@ -106,6 +106,8 @@ include("cartesian.jl") using .Cartesian include("multidimensional.jl") +include("typedomainnumbers.jl") + include("broadcast.jl") using .Broadcast using .Broadcast: broadcasted, broadcasted_kwsyntax, materialize, materialize!, diff --git a/base/sort.jl b/base/sort.jl index 6991f12551ab4d..106abbc8d42a5d 100644 --- a/base/sort.jl +++ b/base/sort.jl @@ -1736,6 +1736,91 @@ julia> v """ sort(v::AbstractVector; kws...) = sort!(copymutable(v); kws...) +module _SortTupleStable + using + Base._TypeDomainNumbers.PositiveIntegers, Base._TypeDomainNumbers.IntegersGreaterThanOne, + Base._TypeDomainNumbers.Utils, Base._TypeDomainNumberTupleUtils, Base._TupleTypeByLength + using Base: tail + using Base.Order: Ordering, lt + export sort_tuple_stable + function merge_recursive((@nospecialize ord::Ordering), a::Tuple, b::Tuple) + if a isa Tuple1OrMore + a + else + b + end + end + function merge_recursive(ord::Ordering, a::Tuple1OrMore, b::Tuple1OrMore) + l = first(a) + r = first(b) + x = tail(a) + y = tail(b) + if lt(ord, r, l) + let rec = merge_recursive(ord, a, y) + (r, rec...) + end + else + let rec = merge_recursive(ord, x, b) + (l, rec...) + end + end + end + function merge_nontrivial(ord::Ordering, a::Tuple1OrMore, b::Tuple1OrMore) + merge_recursive(ord, a, b) + end + function split_tuple(@nospecialize tup::Tuple2OrMore) + len = tuple_type_domain_length(tup) + len_l = half_floor_nontrivial(len) + len_r = half_ceiling_nontrivial(len) + tup_l = skip_from_tail_nontrivial(tup, len_r) + tup_r = skip_from_front_nontrivial(tup, len_l) + (tup_l, tup_r) + end + function sort_recursive((@nospecialize ord::Ordering), @nospecialize tup::Tuple{Any}) + tup + end + function sort_recursive(ord::Ordering, tup::Tuple2OrMore) + (tup_l, tup_r) = split_tuple(tup) + sorted_l = sort_recursive(ord, tup_l) + sorted_r = sort_recursive(ord, tup_r) + merge_nontrivial(ord, sorted_l, sorted_r) + end + function sort_tuple_stable_2_or_more(ord::Ordering, tup::Tuple2OrMore) + sort_recursive(ord, tup) + end + function sort_tuple_array_fallback(ord::Ordering, tup::Tuple2OrMore) + vec = if tup isa NTuple + [tup...] + else + Any[tup...] + end + sort!(vec; order = ord) + (vec...,) + end + function sort_tuple_stable((@nospecialize ord::Ordering), @nospecialize tup::Tuple) + if tup isa Tuple2OrMore + if tup isa Tuple32OrMore + sort_tuple_array_fallback(ord, tup) + else + sort_tuple_stable_2_or_more(ord, tup) + end + else + tup + end + end +end + +function sort( + tup::Tuple; + lt = isless, + by = identity, + rev::Union{Nothing, Bool} = nothing, + order::Ordering = Forward, +) + o = ord(lt, by, rev, order) + _SortTupleStable.sort_tuple_stable(o, tup) +end + ## partialsortperm: the permutation to sort the first k elements of an array ## """ diff --git a/base/typedomainnumbers.jl b/base/typedomainnumbers.jl new file mode 100644 index 00000000000000..e9f6ff24bc8acd --- /dev/null +++ b/base/typedomainnumbers.jl @@ -0,0 +1,156 @@ +# This file is a part of Julia. License is MIT: https://julialang.org/license + +# Adapted from the TypeDomainNaturalNumbers.jl package. +module _TypeDomainNumbers + module Zeros + export Zero + struct Zero end + end + + module PositiveIntegers + module RecursiveStep + using ...Zeros + export recursive_step + function recursive_step(@nospecialize t::Type) + Union{Zero, t} + end + end + module UpperBounds + using ..RecursiveStep + abstract type A end + abstract type B{P <: recursive_step(A)} <: A end + abstract type C{P <: recursive_step(B)} <: B{P} end + abstract type D{P <: recursive_step(C)} <: C{P} end + end + using .RecursiveStep + const PositiveIntegerUpperBound = UpperBounds.A + const PositiveIntegerUpperBoundTighter = UpperBounds.D + export + natural_successor, natural_predecessor, + NonnegativeInteger, NonnegativeIntegerUpperBound, + PositiveInteger, PositiveIntegerUpperBound + struct PositiveInteger{ + Predecessor <: recursive_step(PositiveIntegerUpperBoundTighter), + } <: PositiveIntegerUpperBoundTighter{Predecessor} + predecessor::Predecessor + global const NonnegativeInteger = recursive_step(PositiveInteger) + global const NonnegativeIntegerUpperBound = recursive_step(PositiveIntegerUpperBound) + global function natural_successor(p::P) where {P <: NonnegativeInteger} + new{P}(p) + end + end + function natural_predecessor(@nospecialize o::PositiveInteger) + getfield(o, :predecessor) # avoid specializing `getproperty` for each number + end + end + + module IntegersGreaterThanOne + using ..PositiveIntegers + export + IntegerGreaterThanOne, IntegerGreaterThanOneUpperBound, + natural_predecessor_predecessor + const IntegerGreaterThanOne = let t = PositiveInteger + t{P} where {P <: t} + end + const IntegerGreaterThanOneUpperBound = let t = PositiveIntegerUpperBound + PositiveIntegers.UpperBounds.B{P} where {P <: t} + end + function natural_predecessor_predecessor(@nospecialize x::IntegerGreaterThanOne) + natural_predecessor(natural_predecessor(x)) + end + end + + module Constants + using ..Zeros, ..PositiveIntegers + export n0, n1 + const n0 = Zero() + const n1 = natural_successor(n0) + end + + module Utils + using ..PositiveIntegers, ..IntegersGreaterThanOne, ..Constants + using Base: @assume_effects + export half_floor, half_ceiling, half_floor_nontrivial, half_ceiling_nontrivial + @assume_effects :foldable :nothrow function half_floor(@nospecialize m::NonnegativeInteger) + if m isa IntegerGreaterThanOneUpperBound + let n = natural_predecessor_predecessor(m), rec = half_floor(n) + natural_successor(rec) + end + else + n0 + end + end + @assume_effects :foldable :nothrow function half_ceiling(@nospecialize m::NonnegativeInteger) + if m isa IntegerGreaterThanOneUpperBound + let n = natural_predecessor_predecessor(m), rec = half_ceiling(n) + natural_successor(rec) + end + else + if m isa PositiveIntegerUpperBound + n1 + else + n0 + end + end + end + function half_floor_nontrivial(@nospecialize m::IntegerGreaterThanOne) + half_floor(m) + end + function half_ceiling_nontrivial(@nospecialize m::IntegerGreaterThanOne) + half_ceiling(m) + end + end +end + +module _TupleTypeByLength + export Tuple1OrMore, Tuple2OrMore, Tuple3OrMore, Tuple4OrMore, Tuple32OrMore + const Tuple1OrMore = Tuple{Any, Vararg} + const Tuple2OrMore = Tuple{Any, Any, Vararg} + const Tuple3OrMore = Tuple{Any, Any, Any, Vararg} + const Tuple4OrMore = Tuple{Any, Any, Any, Any, Vararg} + const Tuple32OrMore = Base.Any32 +end + +module _TypeDomainNumberTupleUtils + using + .._TypeDomainNumbers.PositiveIntegers, .._TypeDomainNumbers.IntegersGreaterThanOne, + .._TypeDomainNumbers.Constants, .._TupleTypeByLength + using Base: @assume_effects, front, tail + export + tuple_type_domain_length, + skip_from_front, skip_from_tail, + skip_from_front_nontrivial, skip_from_tail_nontrivial + @assume_effects :foldable :nothrow function tuple_type_domain_length(@nospecialize tup::Tuple) + if tup isa Tuple1OrMore + let t = tail(tup), rec = tuple_type_domain_length(t) + natural_successor(rec) + end + else + n0 + end + end + @assume_effects :foldable function skip_from_front((@nospecialize tup::Tuple), @nospecialize skip_count::NonnegativeInteger) + if skip_count isa PositiveIntegerUpperBound + let cm1 = natural_predecessor(skip_count), t = tail(tup) + @inline skip_from_front(t, cm1) + end + else + tup + end + end + @assume_effects :foldable function skip_from_tail((@nospecialize tup::Tuple), @nospecialize skip_count::NonnegativeInteger) + if skip_count isa PositiveIntegerUpperBound + let cm1 = natural_predecessor(skip_count), t = front(tup) + @inline skip_from_tail(t, cm1) + end + else + tup + end + end + function skip_from_front_nontrivial((@nospecialize tup::Tuple2OrMore), @nospecialize skip_count::PositiveInteger) + skip_from_front(tup, skip_count) + end + function skip_from_tail_nontrivial((@nospecialize tup::Tuple2OrMore), @nospecialize skip_count::PositiveInteger) + skip_from_tail(tup, skip_count) + end +end diff --git a/test/sorting.jl b/test/sorting.jl index 93e0cdd7de5ba2..d6c5f36df992a7 100644 --- a/test/sorting.jl +++ b/test/sorting.jl @@ -92,6 +92,50 @@ end end @test sort(1:2000, by=x->x÷100, rev=true) == sort(1:2000, by=x->-x÷100) == vcat(2000, (x:x+99 for x in 1900:-100:100)..., 1:99) + @testset "tuples" begin + tup = Tuple(0:9) + @test tup === sort(tup; by = _ -> 0) + @test (0, 2, 4, 6, 8, 1, 3, 5, 7, 9) === sort(tup; by = x -> isodd(x)) + @test (1, 3, 5, 7, 9, 0, 2, 4, 6, 8) === sort(tup; by = x -> iseven(x)) + end +end + +@testset "tuple sorting" begin + max_unrolled_length = 31 + @testset "correctness" begin + tup = Tuple(0:9) + tup_rev = reverse(tup) + @test tup === @inferred sort(tup) + @test tup === sort(tup; rev = false) + @test tup_rev === sort(tup; rev = true) + @test tup_rev === sort(tup; lt = >) + end + @testset "inference" begin + known_length = (Tuple{Vararg{Int, max_unrolled_length}}, Tuple{Vararg{Float64, max_unrolled_length}}) + unknown_length = (Tuple{Vararg{Int}}, Tuple{Vararg{Float64}}) + for Tup ∈ (known_length..., unknown_length...) + @test Tup == Base.infer_return_type(sort, Tuple{Tup}) + end + for Tup ∈ (known_length...,) + @test Core.Compiler.is_foldable(Base.infer_effects(sort, Tuple{Tup})) + end + end + @testset "alloc" begin + function test_zero_allocated(tup::Tuple) + @test iszero(@allocated sort(tup)) + end + test_zero_allocated(ntuple(identity, max_unrolled_length)) + end + @testset "heterogeneous" begin + @testset "stability" begin + tup = (0, 0x0, 0x000) + @test tup === sort(tup) + end + tup = (1, 2, 3, missing, missing) + for t ∈ (tup, (1, missing, 2, missing, 3), (missing, missing, 1, 2, 3)) + @test tup === @inferred sort(t) + end + end end @testset "partialsort" begin