From 7c0dba1afd2894b1b697084e399909d215c7f3cc Mon Sep 17 00:00:00 2001 From: benedict-96 Date: Wed, 25 Oct 2023 14:40:07 +0800 Subject: [PATCH] Replaced two constructors with one. added hasseqlength for Batch. --- src/data_loader/batch.jl | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/data_loader/batch.jl b/src/data_loader/batch.jl index 4ce942fb6..c82bbb0a2 100644 --- a/src/data_loader/batch.jl +++ b/src/data_loader/batch.jl @@ -3,19 +3,18 @@ The functor returns indices that are then used in the optimization step (always for an entire epoch). """ -struct Batch{seq_length} - batch_size::Integer - seq_length::Union{Nothing, Integer} +struct Batch{seq_type <: Union{Nothing, Integer}} + batch_size::Int + seq_length::seq_type - function Batch(batch_size, seq_length) - new{true}(batch_size, seq_length) - end - - function Batch(batch_size::Integer) - new{false}(batch_size, nothing) + function Batch(batch_size, seq_length = nothing) + new{typeof(seq_length)}(batch_size, seq_length) end end +hasseqlength(::Batch{<:Integer}) = true +hasseqlength(::Batch{<:Nothing}) = false + function (batch::Batch{false})(dl::DataLoader{T, AT}) where {T, AT<:AbstractArray{T}} indices = shuffle(1:dl.n_params)