From e247fb56fdb0256cfa760b788e5198fc54e0405f Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 1 Oct 2022 22:49:11 -0400 Subject: [PATCH] Pretty printing for `DataLoader` (#122) * pretty printing for DataLoader * tidy, tests --- src/eachobs.jl | 38 ++++++++++++++++++++++++++++++++++++++ test/dataloader.jl | 21 +++++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/src/eachobs.jl b/src/eachobs.jl index 7f49306..99c96c4 100644 --- a/src/eachobs.jl +++ b/src/eachobs.jl @@ -255,3 +255,41 @@ end e.parallel && throw(ArgumentError("Transducer fold protocol not supported on parallel data loads")) _dataloader_foldl1(rf, val, e, ObsView(e.data)) end + +# Base uses this function for composable array printing, e.g. adjoint(view(::Matrix))) +function Base.showarg(io::IO, e::DataLoader, toplevel) + print(io, "DataLoader(") + Base.showarg(io, e.data, false) + e.buffer == false || print(io, ", buffer=", e.buffer) + e.parallel == false || print(io, ", parallel=", e.parallel) + e.shuffle == false || print(io, ", shuffle=", e.shuffle) + e.batchsize == 1 || print(io, ", batchsize=", e.batchsize) + e.partial == true || print(io, ", partial=", e.partial) + e.collate == Val(nothing) || print(io, ", collate=", e.collate) + e.rng == Random.GLOBAL_RNG || print(io, ", rng=", e.rng) + print(io, ")") +end + +Base.show(io::IO, e::DataLoader) = Base.showarg(io, e, false) + +function Base.show(io::IO, m::MIME"text/plain", e::DataLoader) + if Base.haslength(e) + print(io, length(e), "-element ") + else + print(io, "Unknown-length ") + end + Base.showarg(io, e, false) + print(io, "\n with first element:") + print(io, "\n ", _expanded_summary(first(e))) +end + +_expanded_summary(x) = summary(x) +function _expanded_summary(xs::Tuple) + parts = [_expanded_summary(x) for x in xs] + "(" * join(parts, ", ") * ",)" +end +function _expanded_summary(xs::NamedTuple) + parts = ["$k = "*_expanded_summary(x) for (k,x) in zip(keys(xs), xs)] + "(; " * join(parts, ", ") * ")" +end + diff --git a/test/dataloader.jl b/test/dataloader.jl index 45b2a2c..dc569d7 100644 --- a/test/dataloader.jl +++ b/test/dataloader.jl @@ -214,4 +214,25 @@ dloader = DataLoader(1:1000; batchsize = 2, shuffle = true) @test copy(Map(x -> x[1]), Vector{Int}, dloader) != collect(1:2:1000) end + + @testset "printing" begin + X2 = reshape(Float32[1:10;], (2, 5)) + Y2 = [1:5;] + + d = DataLoader((X2, Y2), batchsize=3) + + @test contains(repr(d), "DataLoader(::Tuple{Matrix") + @test contains(repr(d), "batchsize=3") + + @test contains(repr(MIME"text/plain"(), d), "2-element DataLoader") + @test contains(repr(MIME"text/plain"(), d), "2×3 Matrix{Float32}, 3-element Vector") + + d2 = DataLoader((x = X2, y = Y2), batchsize=2, partial=false) + + @test contains(repr(d2), "DataLoader(::NamedTuple") + @test contains(repr(d2), "partial=false") + + @test contains(repr(MIME"text/plain"(), d2), "2-element DataLoader(::NamedTuple") + @test contains(repr(MIME"text/plain"(), d2), "x = 2×2 Matrix{Float32}, y = 2-element Vector") + end end