Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements flatmap #44792

Merged
merged 14 commits into from
Apr 7, 2022
25 changes: 24 additions & 1 deletion base/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import .Base:
getindex, setindex!, get, iterate,
popfirst!, isdone, peek

export enumerate, zip, rest, countfrom, take, drop, takewhile, dropwhile, cycle, repeated, product, flatten, partition
export enumerate, zip, rest, countfrom, take, drop, takewhile, dropwhile, cycle, repeated, product, flatten, partition, flatmap

"""
Iterators.map(f, iterators...)
Expand Down Expand Up @@ -1162,6 +1162,29 @@ end
reverse(f::Flatten) = Flatten(reverse(itr) for itr in reverse(f.it))
last(f::Flatten) = last(last(f.it))

"""
Iterators.flatmap(f, iterators...)

Equivalent to `flatten(map(f, iterators...))`.

# Examples
```jldoctest
julia> Iterators.flatmap(n->-n:2:n, 1:3) |> collect
9-element Vector{Int64}:
-1
1
-2
0
2
-3
-1
1
3
```
"""
# flatmap = flatten ∘ map
nlw0 marked this conversation as resolved.
Show resolved Hide resolved
flatmap(f, c...) = flatten(map(f, c...))

"""
partition(collection, n)

Expand Down
23 changes: 23 additions & 0 deletions test/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,29 @@ end
# see #29112, #29464, #29548
@test Base.return_types(Base.IteratorEltype, Tuple{Array}) == [Base.HasEltype]

# flatmap
# -------
@test flatmap(1:3) do j flatmap(1:3) do k
j!=k ? ((j,k),) : ()
end end |> collect == [(j,k) for j in 1:3 for k in 1:3 if j!=k]
# Test inspired by the monad associativity law
fmf(x) = x<0 ? () : (x^2,)
fmg(x) = x<1 ? () : (x/2,)
fmdata = -2:0.75:2
fmv1 = flatmap(tuple.(fmdata)) do h
flatmap(h) do x
gx = fmg(x)
flatmap(gx) do x
fmf(x)
end
end
end
fmv2 = flatmap(tuple.(fmdata)) do h
gh = flatmap(h) do x fmg(x) end
flatmap(gh) do x fmf(x) end
end
@test all(fmv1 .== fmv2)

# partition(c, n)
let v = collect(partition([1,2,3,4,5], 1))
@test all(i->v[i][1] == i, v)
Expand Down