Skip to content

Commit

Permalink
Merge pull request #168 from JuliaGNI/new_data_loader_constructor
Browse files Browse the repository at this point in the history
Implemented new constructor to deal with input & output.
  • Loading branch information
michakraus authored Aug 21, 2024
2 parents 12f366e + dec3906 commit 5536a75
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 1 deletion.
20 changes: 20 additions & 0 deletions src/data_loader/batch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -281,4 +281,24 @@ function convert_input_and_batch_indices_to_array(dl::DataLoader{T, BT, OT}, bat
@views output_batch = dl.output[:, :, parameter_indices]

input_batch, output_batch
end

function convert_input_and_batch_indices_to_array(dl::DataLoader{T, BT, BT}, batch::Batch, batch_indices_tuple::Vector{Tuple{Int, Int}}) where {T, BT<:AbstractArray{T, 3}}
backend = KernelAbstractions.get_backend(dl.input)

# the batch size is smaller for the last batch
_batch_size = length(batch_indices_tuple)

batch_indices = convert_vector_of_tuples_to_matrix(backend, batch_indices_tuple)

input = KernelAbstractions.allocate(backend, T, dl.input_dim, batch.seq_length, _batch_size)

assign_input_from_vector_of_tuples! = assign_input_from_vector_of_tuples_kernel!(backend)
assign_input_from_vector_of_tuples!(input, dl.input, batch_indices, ndrange=(dl.input_dim, batch.seq_length, _batch_size))

output = KernelAbstractions.allocate(backend, T, dl.output_dim, batch.prediction_window, _batch_size)

assign_input_from_vector_of_tuples!(output, dl.output, batch_indices, ndrange=(dl.output_dim, batch.prediction_window, _batch_size))

input, output
end
17 changes: 17 additions & 0 deletions src/data_loader/data_loader.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,23 @@ function DataLoader(data::AbstractVector; autoencoder=true, suppress_info = fals
DataLoader(reshape(data, 1, length(data)); autoencoder = autoencoder, suppress_info = suppress_info)
end

function DataLoader(input::AbstractArray{T, 3}, output::AbstractArray{T, 3}; suppress_info = false) where T
@assert size(input, 3) == size(output, 3)
if !suppress_info
@info "You have provided an input and an output."
end

DataLoader{T, typeof(input), typeof(output), :TimeSeries}(input, output, size(input, 1), size(input, 2), size(input, 3), size(output, 1), size(output, 2))
end

function DataLoader(input::AbstractMatrix{T}, output::AbstractMatrix{T}; suppress_info = false) where T
DataLoader(reshape(input, size(input)..., 1), reshape(output, size(output)..., 1); suppress_info = suppress_info)
end

function DataLoader(input::AbstractVector{T}, output::AbstractVector{T}; suppress_info = false) where T
DataLoader(reshape(input, 1, length(input)), reshape(output, 1, length(output)); suppress_info = suppress_info)
end

@doc raw"""
DataLoader(data::AbstractArray{T, 3}, target::AbstractVector)
Expand Down
25 changes: 25 additions & 0 deletions test/data_loader/data_loader_for_input_and_output.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
using GeometricMachineLearning
using Test
import Random
Random.seed!(123)

f1(x) = 1cos(x) - 2sin(x) + 1.5cos(2x) - 2.5sin(2x) + 2cos(3x) - 1sin(4x)

nsamples = 1000
xsamples = Float32.(collect(range(0, 5, nsamples)))
ysamples = Float32.(f1.(xsamples))

nwidth = 64
nbatch = 10
nepochs = 2000

model = Chain(Dense(1, nwidth), Dense(nwidth, 1))
nn = NeuralNetwork(model, Float32)
dl = DataLoader(xsamples, ysamples)
o = Optimizer(AdamOptimizer(), nn)
batch = Batch(nbatch, 1, 1)

loss = FeedForwardLoss()

loss_array = o(nn, dl, batch, nepochs, loss)
@test loss_array[end] < 0.9
4 changes: 3 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,6 @@ using Documenter: doctest
@safetestset "Volume-Preserving Transformer (cayley-transform tests) " begin include("volume_preserving_attention/test_cayley_transforms.jl") end

@safetestset "Linear Symplectic Attention " begin include("linear_symplectic_attention.jl") end
@safetestset "Linear Symplectic Transformer " begin include("linear_symplectic_transformer.jl") end
@safetestset "Linear Symplectic Transformer " begin include("linear_symplectic_transformer.jl") end

@safetestset "DataLoader for input and output " begin include("data_loader/data_loader_for_input_and_output.jl") end

0 comments on commit 5536a75

Please sign in to comment.