Skip to content

Commit

Permalink
improvements to stack (#125)
Browse files Browse the repository at this point in the history
* improvements to stack

* cleanup

* use Base definition of stack

* v0.4

* use stack in batch

* Compat bound
  • Loading branch information
CarloLucibello authored Nov 12, 2022
1 parent 1c50c62 commit 08ad0b7
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 63 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name = "MLUtils"
uuid = "f1d291b0-491e-4a28-83b9-f70985020b54"
authors = ["Carlo Lucibello <[email protected]> and contributors"]
version = "0.3.1"
version = "0.4.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
FLoops = "cc61a311-1640-44b5-9fba-1b764f453329"
Expand All @@ -20,6 +21,7 @@ Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"

[compat]
ChainRulesCore = "1.0"
Compat = "4.2"
DataAPI = "1.0"
DelimitedFiles = "1.0"
FLoops = "0.2"
Expand Down
5 changes: 3 additions & 2 deletions src/MLUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ import NNlib

@traitdef IsTable{X}
@traitimpl IsTable{X} <- Tables.istable(X)


using Compat: stack

include("observation.jl")
export numobs,
Expand Down Expand Up @@ -75,7 +76,7 @@ export batch,
rand_like,
randn_like,
rpad_constant,
stack,
stack, # in Base since julia v1.9
unbatch,
unsqueeze,
unstack,
Expand Down
1 change: 0 additions & 1 deletion src/deprecations.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Deprecated in v0.2
@deprecate stack(x, dims) stack(x; dims=dims)
@deprecate unstack(x, dims) unstack(x; dims=dims)
@deprecate unsqueeze(x::AbstractArray, dims::Int) unsqueeze(x; dims=dims)
@deprecate unsqueeze(dims::Int) unsqueeze(dims=dims)
Expand Down
63 changes: 5 additions & 58 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Return `x` reshaped into an array one dimensionality higher than `x`,
where `dims` indicates in which dimension `x` is extended.
`dims` can be an integer between 1 and `ndims(x)+1`.
See also [`flatten`](@ref), [`stack`](@ref).
Expand Down Expand Up @@ -33,8 +34,9 @@ julia> unsqueeze(xs, dims=1)
[1, 2] [3, 4] [5, 6]
```
"""
function unsqueeze(x::AbstractArray; dims::Int)
sz = ntuple(i -> i < dims ? size(x, i) : i == dims ? 1 : size(x, i - 1), ndims(x) + 1)
function unsqueeze(x::AbstractArray{T,N}; dims::Int) where {T, N}
@assert 1 <= dims <= N + 1
sz = ntuple(i -> i < dims ? size(x, i) : i == dims ? 1 : size(x, i - 1), N + 1)
return reshape(x, sz)
end

Expand All @@ -55,51 +57,6 @@ _unsqueeze(x, dims) = unsqueeze(x; dims)

Base.show_function(io::IO, u::Base.Fix2{typeof(_unsqueeze)}, ::Bool) = print(io, "unsqueeze(dims=", u.x, ")")

"""
stack(xs; dims)
Concatenate the given array of arrays `xs` into a single array along the
given dimension `dims`.
See also [`stack`](@ref) and [`batch`](@ref).
# Examples
```jldoctest
julia> xs = [[1, 2], [3, 4], [5, 6]]
3-element Vector{Vector{Int64}}:
[1, 2]
[3, 4]
[5, 6]
julia> stack(xs, dims=1)
3×2 Matrix{Int64}:
1 2
3 4
5 6
julia> stack(xs, dims=2)
2×3 Matrix{Int64}:
1 3 5
2 4 6
julia> stack(xs, dims=3)
2×1×3 Array{Int64, 3}:
[:, :, 1] =
1
2
[:, :, 2] =
3
4
[:, :, 3] =
5
6
```
"""
stack(xs; dims::Int) = cat(unsqueeze.(xs; dims)...; dims)

"""
unstack(xs; dims)
Expand Down Expand Up @@ -329,17 +286,7 @@ end

batchindex(xs, i) = (reverse(Base.tail(reverse(axes(xs))))..., i)

function batch(xs::AbstractArray{<:AbstractArray})
# Don't use stack(xs, dims=N+1), it is much slower.
# Here we do reduce(vcat, xs) along with some reshapes.
szxs = size(xs)
@assert length(xs) > 0 "Minimum batch size is 1."
szx = size(xs[1])
@assert all(x -> size(x) == szx, xs) "All arrays must be of the same size."
vxs = vec(vec.(xs))
y = reduce(vcat, vxs)
return reshape(y, szx..., szxs...)
end
batch(xs::AbstractArray{<:AbstractArray}) = stack(xs)

function batch(xs::Vector{<:Tuple})
@assert length(xs) > 0 "Input should be non-empty"
Expand Down
13 changes: 13 additions & 0 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,17 @@

"""
Test gradients through zygote.
# Arguments
- `f`: function to test
- `xs`: inputs to `f`
# Keyword Arguments
Keyword arguments are passed to `rrule`.
- `fkwargs`: keyword arguments to `f`
"""
function test_zygote(f, xs...; kws...)
config = ZygoteRuleConfig()
test_rrule(config, f, xs...; kws..., rrule_f = rrule_via_ad)
Expand Down
13 changes: 12 additions & 1 deletion test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,30 @@
@test @inferred(unsqueeze(x; dims=4)) == reshape(x, 2, 3, 2, 1)

@test unsqueeze(dims=2)(x) == unsqueeze(x, dims=2)

@test_throws AssertionError unsqueeze(rand(2,2), dims=4)
end

@testset "stack and unstack" begin
x = randn(3,3)
stacked = stack([x, x], dims=2)
@test size(stacked) == (3,2,3)
@test_broken @inferred(stack([x, x], dims=2)) == stacked
@test @inferred(stack([x, x], dims=2)) == stacked

stacked_array=[ 8 9 3 5; 9 6 6 9; 9 1 7 2; 7 4 10 6 ]
unstacked_array=[[8, 9, 9, 7], [9, 6, 1, 4], [3, 6, 7, 10], [5, 9, 2, 6]]
@test unstack(stacked_array, dims=2) == unstacked_array
@test stack(unstacked_array, dims=2) == stacked_array
@test stack(unstack(stacked_array, dims=1), dims=1) == stacked_array

for d in (1,2,3)
test_zygote(stack, [x,2x], fkwargs=(; dims=d), check_inferred=false)
end

# Issue #121
a = [[1] for i in 1:10000]
@test size(stack(a, dims=1)) == (10000, 1)
@test size(stack(a, dims=2)) == (1, 10000)
end

@testset "batch and unbatch" begin
Expand Down

0 comments on commit 08ad0b7

Please sign in to comment.