Skip to content

Commit

Permalink
Implements a flatmap functor, and iteration methods for Some and Nothing
Browse files Browse the repository at this point in the history
  • Loading branch information
nlw0 committed Apr 2, 2022
1 parent dbe41d4 commit 0c57e66
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 0 deletions.
53 changes: 53 additions & 0 deletions base/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1162,6 +1162,59 @@ end
reverse(f::Flatten) = Flatten(reverse(itr) for itr in reverse(f.it))
last(f::Flatten) = last(last(f.it))

"""
Iterators.flatmap(f, iterators...)
Equivalent to flatten(map(f, iterators...)).
# Examples
```jldoctest
julia> flatmap(n->-n:2:n, 1:3) |> collect
9-element Vector{Int64}:
-1
1
-2
0
2
-3
-1
1
3
julia> flatmap(x -> (x+1)*x % 4 == 0 ? x*x : nothing, 1:11) |> collect
5-element Vector{Int64}:
9
16
49
64
121
julia> [(j,k) for j in 1:3 for k in 1:3 if j>k]
3-element Vector{Tuple{Int64, Int64}}:
(2, 1)
(3, 1)
(3, 2)
julia> flatmap(1:3) do j
flatmap(1:3) do k
j>k ? Some((j,k)) : nothing
end
end |> collect
3-element Vector{Tuple{Int64, Int64}}:
(2, 1)
(3, 1)
(3, 2)
```
"""
# flatmap(f, c...) = flatten(map(f, c...))
flatmap = flatten map

# Allows filtering through `flatten` (or `flatmap`) by removing `nothing` values
iterate(_::Nothing) = nothing
iterate(x::Some{T}) where T = (something(x), nothing)
iterate(x::Some{T}, state::Nothing) where T = nothing
length(x::Some{T}) where T = 1

"""
partition(collection, n)
Expand Down
10 changes: 10 additions & 0 deletions test/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -469,13 +469,23 @@ end
@test length(flatten(1:6)) == 6
@test collect(flatten(Any[])) == Any[]
@test collect(flatten(())) == Union{}[]
@test collect(flatten([Some(1)])) == [1]
@test collect(flatten([nothing])) == Any[]
@test_throws ArgumentError length(flatten(NTuple[(1,), ()])) # #16680
@test_throws ArgumentError length(flatten([[1], [1]]))

@test Base.IteratorEltype(Base.Flatten((i for i=1:2) for j=1:1)) == Base.EltypeUnknown()
# see #29112, #29464, #29548
@test Base.return_types(Base.IteratorEltype, Tuple{Array}) == [Base.HasEltype]

# flatmap
# -------
@test flatmap(1:3) do j
flatmap(1:3) do k
j>k ? Some((j,k)) : nothing
end
end |> collect == [(j,k) for j in 1:3 for k in 1:3 if j>k]

# partition(c, n)
let v = collect(partition([1,2,3,4,5], 1))
@test all(i->v[i][1] == i, v)
Expand Down

0 comments on commit 0c57e66

Please sign in to comment.