diff --git a/base/iterators.jl b/base/iterators.jl index 1b96a24a9c16f8..d2101c5e0bc30e 100644 --- a/base/iterators.jl +++ b/base/iterators.jl @@ -1162,6 +1162,59 @@ end reverse(f::Flatten) = Flatten(reverse(itr) for itr in reverse(f.it)) last(f::Flatten) = last(last(f.it)) +""" + Iterators.flatmap(f, iterators...) + +Equivalent to flatten(map(f, iterators...)). + +# Examples +```jldoctest +julia> flatmap(n->-n:2:n, 1:3) |> collect +9-element Vector{Int64}: + -1 + 1 + -2 + 0 + 2 + -3 + -1 + 1 + 3 + +julia> flatmap(x -> (x+1)*x % 4 == 0 ? x*x : nothing, 1:11) |> collect +5-element Vector{Int64}: + 9 + 16 + 49 + 64 + 121 + +julia> [(j,k) for j in 1:3 for k in 1:3 if j>k] +3-element Vector{Tuple{Int64, Int64}}: + (2, 1) + (3, 1) + (3, 2) + +julia> flatmap(1:3) do j + flatmap(1:3) do k + j>k ? Some((j,k)) : nothing + end + end |> collect +3-element Vector{Tuple{Int64, Int64}}: + (2, 1) + (3, 1) + (3, 2) +``` +""" +# flatmap(f, c...) = flatten(map(f, c...)) +flatmap = flatten ∘ map + +# Allows filtering through `flatten` (or `flatmap`) by removing `nothing` values +iterate(_::Nothing) = nothing +iterate(x::Some{T}) where T = (something(x), nothing) +iterate(x::Some{T}, state::Nothing) where T = nothing +length(x::Some{T}) where T = 1 + """ partition(collection, n) diff --git a/test/iterators.jl b/test/iterators.jl index 70ce6866f4be35..ff705d32fed2e5 100644 --- a/test/iterators.jl +++ b/test/iterators.jl @@ -469,6 +469,8 @@ end @test length(flatten(1:6)) == 6 @test collect(flatten(Any[])) == Any[] @test collect(flatten(())) == Union{}[] +@test collect(flatten([Some(1)])) == [1] +@test collect(flatten([nothing])) == Any[] @test_throws ArgumentError length(flatten(NTuple[(1,), ()])) # #16680 @test_throws ArgumentError length(flatten([[1], [1]])) @@ -476,6 +478,14 @@ end # see #29112, #29464, #29548 @test Base.return_types(Base.IteratorEltype, Tuple{Array}) == [Base.HasEltype] +# flatmap +# ------- +@test flatmap(1:3) do j + flatmap(1:3) do k + j>k ? Some((j,k)) : nothing + end + end |> collect == [(j,k) for j in 1:3 for k in 1:3 if j>k] + # partition(c, n) let v = collect(partition([1,2,3,4,5], 1)) @test all(i->v[i][1] == i, v)