Skip to content

Commit

Permalink
Merge pull request #56 from JuliaML/darsnack/fallbacks
Browse files Browse the repository at this point in the history
Add generic fallbacks for getobs and numobs
  • Loading branch information
CarloLucibello authored Feb 17, 2022
2 parents 791fb77 + aad5c72 commit 6a593d4
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 1 deletion.
9 changes: 8 additions & 1 deletion src/observation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,24 @@
numobs(data)
Return the total number of observations contained in `data`.
If `data` does not have `numobs` defined, then this function
falls back to `length(data)`.
See also [`getobs`](@ref)
"""
function numobs end

# Generic Fallbacks
numobs(data) = length(data)

"""
getobs(data, [idx])
Return the observations corresponding to the observation-index `idx`.
Note that `idx` can be any type as long as `data` has defined
`getobs` for that type.
If `data` does not have `getobs` defined, then this function
falls back to `data[idx]`.
The returned observation(s) should be in the form intended to
be passed as-is to some learning algorithm. There is no strict
Expand All @@ -29,7 +36,7 @@ function getobs end

# Generic Fallbacks
getobs(data) = data
# getobs(data, idx) = data[idx]
getobs(data, idx) = data[idx]

"""
getobs!(buffer, data, idx)
Expand Down
6 changes: 6 additions & 0 deletions test/observation.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
@testset "fallbacks" begin
x = FallbackType()
@test getobs(x, 3) == 1234
@test numobs(x) == 5678
end

@testset "array" begin
a = rand(2,3)
@test numobs(a) == 3
Expand Down
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ const Y1 = collect(1:15)

struct EmptyType end

struct FallbackType end
Base.getindex(::FallbackType, i) = 1234
Base.length(::FallbackType) = 5678

struct CustomType end
MLUtils.numobs(::CustomType) = 15
MLUtils.getobs(::CustomType, i::Int) = i
Expand Down

0 comments on commit 6a593d4

Please sign in to comment.