-
-
Notifications
You must be signed in to change notification settings - Fork 608
/
dataloader.jl
121 lines (96 loc) · 3.83 KB
/
dataloader.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# Adapted from Knet's src/data.jl (author: Deniz Yuret)
using Random: AbstractRNG, shuffle!, GLOBAL_RNG
struct DataLoader{D,R<:AbstractRNG}
data::D
batchsize::Int
nobs::Int
partial::Bool
imax::Int
indices::Vector{Int}
shuffle::Bool
rng::R
end
"""
Flux.DataLoader(data; batchsize=1, shuffle=false, partial=true, rng=GLOBAL_RNG)
An object that iterates over mini-batches of `data`,
each mini-batch containing `batchsize` observations
(except possibly the last one).
Takes as input a single data tensor, or a tuple (or a named tuple) of tensors.
The last dimension in each tensor is the observation dimension, i.e. the one
divided into mini-batches.
If `shuffle=true`, it shuffles the observations each time iterations are re-started.
If `partial=false` and the number of observations is not divisible by the batchsize,
then the last mini-batch is dropped.
The original data is preserved in the `data` field of the DataLoader.
# Examples
```jldoctest
julia> Xtrain = rand(10, 100);
julia> array_loader = Flux.DataLoader(Xtrain, batchsize=2);
julia> for x in array_loader
@assert size(x) == (10, 2)
# do something with x, 50 times
end
julia> array_loader.data === Xtrain
true
julia> tuple_loader = Flux.DataLoader((Xtrain,), batchsize=2); # similar, but yielding 1-element tuples
julia> for x in tuple_loader
@assert x isa Tuple{Matrix}
@assert size(x[1]) == (10, 2)
end
julia> Ytrain = rand('a':'z', 100); # now make a DataLoader yielding 2-element named tuples
julia> train_loader = Flux.DataLoader((data=Xtrain, label=Ytrain), batchsize=5, shuffle=true);
julia> for epoch in 1:100
for (x, y) in train_loader # access via tuple destructuring
@assert size(x) == (10, 5)
@assert size(y) == (5,)
# loss += f(x, y) # etc, runs 100 * 20 times
end
end
julia> first(train_loader).label isa Vector{Char} # access via property name
true
julia> first(train_loader).label == Ytrain[1:5] # because of shuffle=true
false
julia> foreach(println∘summary, Flux.DataLoader(rand(Int8, 10, 64), batchsize=30)) # partial=false would omit last
10×30 Matrix{Int8}
10×30 Matrix{Int8}
10×4 Matrix{Int8}
```
"""
function DataLoader(data; batchsize=1, shuffle=false, partial=true, rng=GLOBAL_RNG)
batchsize > 0 || throw(ArgumentError("Need positive batchsize"))
n = _nobs(data)
if n < batchsize
@warn "Number of observations less than batchsize, decreasing the batchsize to $n"
batchsize = n
end
imax = partial ? n : n - batchsize + 1
DataLoader(data, batchsize, n, partial, imax, [1:n;], shuffle, rng)
end
@propagate_inbounds function Base.iterate(d::DataLoader, i=0) # returns data in d.indices[i+1:i+batchsize]
i >= d.imax && return nothing
if d.shuffle && i == 0
shuffle!(d.rng, d.indices)
end
nexti = min(i + d.batchsize, d.nobs)
ids = d.indices[i+1:nexti]
batch = _getobs(d.data, ids)
return (batch, nexti)
end
function Base.length(d::DataLoader)
n = d.nobs / d.batchsize
d.partial ? ceil(Int,n) : floor(Int,n)
end
_nobs(data::AbstractArray) = size(data)[end]
function _nobs(data::Union{Tuple, NamedTuple})
length(data) > 0 || throw(ArgumentError("Need at least one data input"))
n = _nobs(data[1])
for i in keys(data)
ni = _nobs(data[i])
n == ni || throw(DimensionMismatch("All data inputs should have the same number of observations, i.e. size in the last dimension. " *
"But data[$(repr(first(keys(data))))] ($(summary(data[1]))) has $n, while data[$(repr(i))] ($(summary(data[i]))) has $ni."))
end
return n
end
_getobs(data::AbstractArray, i) = data[ntuple(i -> Colon(), Val(ndims(data) - 1))..., i]
_getobs(data::Union{Tuple, NamedTuple}, i) = map(Base.Fix2(_getobs, i), data)
Base.eltype(::DataLoader{D}) where D = D