diff --git a/src/data_loader/data_loader.jl b/src/data_loader/data_loader.jl index a92373606..1989f12c9 100644 --- a/src/data_loader/data_loader.jl +++ b/src/data_loader/data_loader.jl @@ -100,11 +100,11 @@ function loss(model::Union{Chain, AbstractExplicitLayer}, ps::Union{Tuple, Named end function loss(model::Chain, ps::Tuple, dl::DataLoader{T, BT, Nothing}) where {T, BT<:AbstractArray{T, 3}} - loss(model, ps, dl) + loss(model, ps, dl.input) end function loss(model::Chain, ps::Tuple, dl::DataLoader{T, BT, Nothing}) where {T, BT<:AbstractArray{T, 2}} - loss(model, ps, dl) + loss(model, ps, dl.input) end @doc raw"""