diff --git a/src/observation.jl b/src/observation.jl index d686012..ff4d7fd 100644 --- a/src/observation.jl +++ b/src/observation.jl @@ -2,17 +2,24 @@ numobs(data) Return the total number of observations contained in `data`. +If `data` does not have `numobs` defined, then this function +falls back to `length(data)`. See also [`getobs`](@ref) """ function numobs end +# Generic Fallbacks +numobs(data) = length(data) + """ getobs(data, [idx]) Return the observations corresponding to the observation-index `idx`. Note that `idx` can be any type as long as `data` has defined `getobs` for that type. +If `data` does not have `getobs` defined, then this function +falls back to `data[idx]`. The returned observation(s) should be in the form intended to be passed as-is to some learning algorithm. There is no strict @@ -29,7 +36,7 @@ function getobs end # Generic Fallbacks getobs(data) = data -# getobs(data, idx) = data[idx] +getobs(data, idx) = data[idx] """ getobs!(buffer, data, idx) diff --git a/test/observation.jl b/test/observation.jl index d61a107..a0369bc 100644 --- a/test/observation.jl +++ b/test/observation.jl @@ -1,3 +1,9 @@ +@testset "fallbacks" begin + x = FallbackType() + @test getobs(x, 3) == 1234 + @test numobs(x) == 5678 +end + @testset "array" begin a = rand(2,3) @test numobs(a) == 3 diff --git a/test/runtests.jl b/test/runtests.jl index 297e777..4c0cf2f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -31,6 +31,10 @@ const Y1 = collect(1:15) struct EmptyType end +struct FallbackType end +Base.getindex(::FallbackType, i) = 1234 +Base.length(::FallbackType) = 5678 + struct CustomType end MLUtils.numobs(::CustomType) = 15 MLUtils.getobs(::CustomType, i::Int) = i