-
Notifications
You must be signed in to change notification settings - Fork 21
/
batchview.jl
182 lines (148 loc) · 6.11 KB
/
batchview.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
"""
BatchView(data, batchsize; partial=true, collate=nothing)
BatchView(data; batchsize=1, partial=true, collate=nothing)
Create a view of the given `data` that represents it as a vector
of batches. Each batch will contain an equal amount of
observations in them. The batch-size
can be specified using the parameter `batchsize`.
In the case that the size of the dataset is not dividable by the
specified `batchsize`, the remaining observations will
be ignored if `partial=false`. If `partial=true` instead
the last batch-size can be slightly smaller.
Note that any data access is delayed until `getindex` is called.
If used as an iterator, the object will iterate over the dataset
once, effectively denoting an epoch.
For `BatchView` to work on some data structure, the type of the
given variable `data` must implement the data container
interface. See [`ObsView`](@ref) for more info.
# Arguments
- **`data`** : The object describing the dataset. Can be of any
type as long as it implements [`getobs`](@ref) and
[`numobs`](@ref) (see Details for more information).
- **`batchsize`** : The batch-size of each batch.
It is the number of observations that each batch must contain
(except possibly for the last one).
- **`partial`** : If `partial=false` and the number of observations is
not divisible by the batch-size, then the last mini-batch is dropped.
- **`collate`**: Batching behavior. If `nothing` (default), a batch
is `getobs(data, indices)`. If `false`, each batch is
`[getobs(data, i) for i in indices]`. When `true`, applies [`batch`](@ref)
to the vector of observations in a batch, recursively collating
arrays in the last dimensions. See [`batch`](@ref) for more information
and examples.
# Examples
```julia
using MLUtils
X, Y = MLUtils.load_iris()
A = BatchView(X, batchsize=30)
@assert typeof(A) <: BatchView <: AbstractVector
@assert eltype(A) <: SubArray{Float64,2}
@assert length(A) == 5 # Iris has 150 observations
@assert size(A[1]) == (4,30) # Iris has 4 features
# 5 batches of size 30 observations
for x in BatchView(X, batchsize=30)
@assert typeof(x) <: SubArray{Float64,2}
@assert numobs(x) === 30
end
# 7 batches of size 20 observations
# Note that the iris dataset has 150 observations,
# which means that with a batchsize of 20, the last
# 10 observations will be ignored
for (x, y) in BatchView((X, Y), batchsize=20, partial=false)
@assert typeof(x) <: SubArray{Float64,2}
@assert typeof(y) <: SubArray{String,1}
@assert numobs(x) == numobs(y) == 20
end
# collate tuple observations
for (x, y) in BatchView((rand(10, 3), ["a", "b", "c"]), batchsize=2, collate=true, partial=false)
@assert size(x) == (10, 2)
@assert size(y) == (2,)
end
# randomly assign observations to one and only one batch.
for (x, y) in BatchView(shuffleobs((X, Y)), batchsize=20)
@assert typeof(x) <: SubArray{Float64,2}
@assert typeof(y) <: SubArray{String,1}
end
```
"""
struct BatchView{TElem,TData,TCollate} <: AbstractDataContainer
data::TData
batchsize::Int
count::Int
partial::Bool
end
function BatchView(data::T; batchsize::Int=1, partial::Bool=true, collate=Val(nothing)) where {T}
n = numobs(data)
if n < batchsize
@warn "Number of observations less than batch-size, decreasing the batch-size to $n"
batchsize = n
end
collate = collate isa Val ? collate : Val(collate)
if !(collate ∈ (Val(nothing), Val(true), Val(false)))
throw(ArgumentError("`collate` must be one of `nothing`, `true` or `false`."))
end
E = _batchviewelemtype(data, collate)
count = partial ? cld(n, batchsize) : fld(n, batchsize)
BatchView{E,T,typeof(collate)}(data, batchsize, count, partial)
end
_batchviewelemtype(::TData, ::Val{nothing}) where TData =
Core.Compiler.return_type(getobs, Tuple{TData, UnitRange{Int}})
_batchviewelemtype(::TData, ::Val{false}) where TData =
Vector{Core.Compiler.return_type(getobs, Tuple{TData, Int})}
_batchviewelemtype(data, ::Val{true}) =
Core.Compiler.return_type(batch, Tuple{_batchviewelemtype(data, Val(false))})
"""
batchsize(data::BatchView) -> Int
Return the fixed size of each batch in `data`.
# Examples
```julia
using MLUtils
X, Y = MLUtils.load_iris()
A = BatchView(X, batchsize=30)
@assert batchsize(A) == 30
```
"""
batchsize(A::BatchView) = A.batchsize
Base.length(A::BatchView) = A.count
Base.@propagate_inbounds function getobs(A::BatchView)
return _getbatch(A, 1:numobs(A.data))
end
Base.@propagate_inbounds function Base.getindex(A::BatchView, i::Int)
obsindices = _batchrange(A, i)
_getbatch(A, obsindices)
end
Base.@propagate_inbounds function Base.getindex(A::BatchView, is::AbstractVector)
obsindices = union((_batchrange(A, i) for i in is)...)::Vector{Int}
_getbatch(A, obsindices)
end
function _getbatch(A::BatchView{TElem, TData, Val{true}}, obsindices) where {TElem, TData}
batch([getobs(A.data, i) for i in obsindices])
end
function _getbatch(A::BatchView{TElem, TData, Val{false}}, obsindices) where {TElem, TData}
return [getobs(A.data, i) for i in obsindices]
end
function _getbatch(A::BatchView{TElem, TData, Val{nothing}}, obsindices) where {TElem, TData}
getobs(A.data, obsindices)
end
Base.parent(A::BatchView) = A.data
Base.eltype(::BatchView{Tel}) where Tel = Tel
# override AbstractDataContainer default
Base.iterate(A::BatchView, state = 1) =
(state > numobs(A)) ? nothing : (A[state], state + 1)
# Helper function to translate a batch-index into a range of observations.
@inline function _batchrange(A::BatchView, batchindex::Int)
@boundscheck (batchindex > A.count || batchindex < 0) && throw(BoundsError())
startidx = (batchindex - 1) * A.batchsize + 1
endidx = min(numobs(parent(A)), startidx + A.batchsize -1)
return startidx:endidx
end
function Base.showarg(io::IO, A::BatchView, toplevel)
print(io, "BatchView(")
Base.showarg(io, parent(A), false)
print(io, ", ")
print(io, "batchsize=$(A.batchsize), ")
print(io, "partial=$(A.partial)")
print(io, ')')
toplevel && print(io, " with eltype ", nameof(eltype(A))) # simplify
end
# --------------------------------------------------------------------