Skip to content

Commit

Permalink
Merge pull request #59 from JuliaML/darsnack/getindex-fix
Browse files Browse the repository at this point in the history
Make `getindex`/`length` the default interface
  • Loading branch information
CarloLucibello authored Feb 23, 2022
2 parents e1f4c6b + c0a19e3 commit 3e6f259
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 22 deletions.
2 changes: 1 addition & 1 deletion src/MLUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ export mapobs,
groupobs,
joinobs,
shuffleobs

include("batchview.jl")
export batchsize,
BatchView
Expand Down
6 changes: 5 additions & 1 deletion src/batchview.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ Return the fixed size of each batch in `data`.
"""
batchsize(A::BatchView) = A.batchsize

numobs(A::BatchView) = A.count
Base.length(A::BatchView) = A.count
getobs(A::BatchView) = getobs(A.data)
getobs(A::BatchView, i::Int) = getobs(A.data, _batchrange(A, i))

Expand All @@ -119,6 +119,10 @@ function Base.getindex(A::BatchView, is::AbstractVector)
obsview(A.data, obsindices)
end

# override AbstractDataContainer default
Base.iterate(A::BatchView, state = 1) =
(state > numobs(A)) ? nothing : (A[state], state + 1)

obsview(A::BatchView) = A
obsview(A::BatchView, i) = A[i]

Expand Down
23 changes: 15 additions & 8 deletions src/observation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,14 @@
numobs(data)
Return the total number of observations contained in `data`.
If `data` does not have `numobs` defined, then this function
falls back to `length(data)`.
Authors of custom data containers should implement
`Base.length` for their type instead of `numobs`.
`numobs` should only be implemented for types where there is a
difference between `numobs` and `Base.length`
(such as multi-dimensional arrays).
See also [`getobs`](@ref)
"""
Expand All @@ -18,16 +24,20 @@ numobs(data) = length(data)
Return the observations corresponding to the observation-index `idx`.
Note that `idx` can be any type as long as `data` has defined
`getobs` for that type.
If `data` does not have `getobs` defined, then this function
falls back to `data[idx]`.
Authors of custom data containers should implement
`Base.getindex` for their type instead of `getobs`.
`getobs` should only be implemented for types where there is a
difference between `getobs` and `Base.getindex`
(such as multi-dimensional arrays).
The returned observation(s) should be in the form intended to
be passed as-is to some learning algorithm. There is no strict
interface requirement on how this "actual data" must look like.
Every author behind some custom data container can make this
decision themselves.
The output should be consistent when `idx` is a scalar vs vector.
See also [`getobs!`](@ref) and [`numobs`](@ref)
Expand Down Expand Up @@ -64,13 +74,10 @@ getobs!(buffer, data, idx) = getobs(data, idx)

abstract type AbstractDataContainer end

Base.getindex(x::AbstractDataContainer, i) = getobs(x, i)
Base.length(x::AbstractDataContainer) = numobs(x)
Base.size(x::AbstractDataContainer) = (length(x),)

Base.size(x::AbstractDataContainer) = (numobs(x),)
Base.iterate(x::AbstractDataContainer, state = 1) =
(state > length(x)) ? nothing : (x[state], state + 1)
Base.lastindex(x::AbstractDataContainer) = length(x)
(state > numobs(x)) ? nothing : (getobs(x, state), state + 1)
Base.lastindex(x::AbstractDataContainer) = numobs(x)

# --------------------------------------------------------------------
# Arrays
Expand Down
20 changes: 10 additions & 10 deletions src/obstransform.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@

# mapobs

struct MappedData{F,D}
struct MappedData{F,D} <: AbstractDataContainer
f::F
data::D
end

Base.show(io::IO, data::MappedData) = print(io, "mapobs($(data.f), $(summary(data.data)))")
Base.show(io::IO, data::MappedData{F,<:AbstractArray}) where {F} =
print(io, "mapobs($(data.f), $(ShowLimit(data.data, limit=80)))")
numobs(data::MappedData) = numobs(data.data)
getobs(data::MappedData, idx::Int) = data.f(getobs(data.data, idx))
getobs(data::MappedData, idxs::AbstractVector) = data.f.(getobs(data.data, idxs))
Base.length(data::MappedData) = numobs(data.data)
Base.getindex(data::MappedData, idx::Int) = data.f(getobs(data.data, idx))
Base.getindex(data::MappedData, idxs::AbstractVector) = data.f.(getobs(data.data, idxs))


"""
Expand All @@ -38,14 +38,14 @@ Returns a tuple of transformed data containers.
mapobs(fs::Tuple, data) = Tuple(mapobs(f, data) for f in fs)


struct NamedTupleData{TData,F}
struct NamedTupleData{TData,F} <: AbstractDataContainer
data::TData
namedfs::NamedTuple{F}
end

numobs(data::NamedTupleData) = numobs(getfield(data, :data))
Base.length(data::NamedTupleData) = numobs(getfield(data, :data))

function getobs(data::NamedTupleData{TData,F}, idx::Int) where {TData,F}
function Base.getindex(data::NamedTupleData{TData,F}, idx::Int) where {TData,F}
obs = getobs(getfield(data, :data), idx)
namedfs = getfield(data, :namedfs)
return NamedTuple{F}(f(obs) for f in namedfs)
Expand Down Expand Up @@ -126,16 +126,16 @@ end

# joinumobs

struct JoinedData{T,N}
struct JoinedData{T,N} <: AbstractDataContainer
datas::NTuple{N,T}
ns::NTuple{N,Int}
end

JoinedData(datas) = JoinedData(datas, numobs.(datas))

numobs(data::JoinedData) = sum(data.ns)
Base.length(data::JoinedData) = sum(data.ns)

function getobs(data::JoinedData, idx)
function Base.getindex(data::JoinedData, idx)
for (i, n) in enumerate(data.ns)
if idx <= n
return getobs(data.datas[i], idx)
Expand Down
3 changes: 1 addition & 2 deletions src/obsview.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,10 @@ end

Base.IteratorEltype(::Type{<:ObsView}) = Base.EltypeUnknown()

# override AbstractDataContainer defaults
Base.getindex(subset::ObsView, idx) =
obsview(subset.data, subset.indices[idx])

numobs(subset::ObsView) = length(subset.indices)
Base.length(subset::ObsView) = length(subset.indices)

getobs(subset::ObsView) = getobs(subset.data, subset.indices)
getobs(subset::ObsView, idx) = getobs(subset.data, subset.indices[idx])
Expand Down

0 comments on commit 3e6f259

Please sign in to comment.