Skip to content

Commit

Permalink
Pretty printing for DataLoader (#122)
Browse files Browse the repository at this point in the history
* pretty printing for DataLoader

* tidy, tests
  • Loading branch information
mcabbott authored Oct 2, 2022
1 parent a73692e commit e247fb5
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 0 deletions.
38 changes: 38 additions & 0 deletions src/eachobs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,41 @@ end
e.parallel && throw(ArgumentError("Transducer fold protocol not supported on parallel data loads"))
_dataloader_foldl1(rf, val, e, ObsView(e.data))
end

# Base uses this function for composable array printing, e.g. adjoint(view(::Matrix)))
function Base.showarg(io::IO, e::DataLoader, toplevel)
print(io, "DataLoader(")
Base.showarg(io, e.data, false)
e.buffer == false || print(io, ", buffer=", e.buffer)
e.parallel == false || print(io, ", parallel=", e.parallel)
e.shuffle == false || print(io, ", shuffle=", e.shuffle)
e.batchsize == 1 || print(io, ", batchsize=", e.batchsize)
e.partial == true || print(io, ", partial=", e.partial)
e.collate == Val(nothing) || print(io, ", collate=", e.collate)
e.rng == Random.GLOBAL_RNG || print(io, ", rng=", e.rng)
print(io, ")")
end

Base.show(io::IO, e::DataLoader) = Base.showarg(io, e, false)

function Base.show(io::IO, m::MIME"text/plain", e::DataLoader)
if Base.haslength(e)
print(io, length(e), "-element ")
else
print(io, "Unknown-length ")
end
Base.showarg(io, e, false)
print(io, "\n with first element:")
print(io, "\n ", _expanded_summary(first(e)))
end

_expanded_summary(x) = summary(x)
function _expanded_summary(xs::Tuple)
parts = [_expanded_summary(x) for x in xs]
"(" * join(parts, ", ") * ",)"
end
function _expanded_summary(xs::NamedTuple)
parts = ["$k = "*_expanded_summary(x) for (k,x) in zip(keys(xs), xs)]
"(; " * join(parts, ", ") * ")"
end

21 changes: 21 additions & 0 deletions test/dataloader.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,4 +214,25 @@
dloader = DataLoader(1:1000; batchsize = 2, shuffle = true)
@test copy(Map(x -> x[1]), Vector{Int}, dloader) != collect(1:2:1000)
end

@testset "printing" begin
X2 = reshape(Float32[1:10;], (2, 5))
Y2 = [1:5;]

d = DataLoader((X2, Y2), batchsize=3)

@test contains(repr(d), "DataLoader(::Tuple{Matrix")
@test contains(repr(d), "batchsize=3")

@test contains(repr(MIME"text/plain"(), d), "2-element DataLoader")
@test contains(repr(MIME"text/plain"(), d), "2×3 Matrix{Float32}, 3-element Vector")

d2 = DataLoader((x = X2, y = Y2), batchsize=2, partial=false)

@test contains(repr(d2), "DataLoader(::NamedTuple")
@test contains(repr(d2), "partial=false")

@test contains(repr(MIME"text/plain"(), d2), "2-element DataLoader(::NamedTuple")
@test contains(repr(MIME"text/plain"(), d2), "x = 2×2 Matrix{Float32}, y = 2-element Vector")
end
end

0 comments on commit e247fb5

Please sign in to comment.