Skip to content

Commit

Permalink
Replaced two constructors with one. added hasseqlength for Batch.
Browse files Browse the repository at this point in the history
  • Loading branch information
benedict-96 committed Oct 25, 2023
1 parent 4070f22 commit 7c0dba1
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions src/data_loader/batch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,18 @@
The functor returns indices that are then used in the optimization step (always for an entire epoch).
"""
struct Batch{seq_length}
batch_size::Integer
seq_length::Union{Nothing, Integer}
struct Batch{seq_type <: Union{Nothing, Integer}}
batch_size::Int
seq_length::seq_type

function Batch(batch_size, seq_length)
new{true}(batch_size, seq_length)
end

function Batch(batch_size::Integer)
new{false}(batch_size, nothing)
function Batch(batch_size, seq_length = nothing)
new{typeof(seq_length)}(batch_size, seq_length)
end
end

hasseqlength(::Batch{<:Integer}) = true
hasseqlength(::Batch{<:Nothing}) = false


function (batch::Batch{false})(dl::DataLoader{T, AT}) where {T, AT<:AbstractArray{T}}
indices = shuffle(1:dl.n_params)
Expand Down

0 comments on commit 7c0dba1

Please sign in to comment.