From 009a17ca28553ec5743170f240d404ba8a0a7f96 Mon Sep 17 00:00:00 2001 From: benedict-96 Date: Fri, 20 Oct 2023 13:12:10 +0800 Subject: [PATCH] Fixed typos; am now explicitly importing loss in these files. --- test/data_loader/data_loader.jl | 2 +- test/data_loader/mnist_utils.jl | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/test/data_loader/data_loader.jl b/test/data_loader/data_loader.jl index c28d887d7..880212f5b 100644 --- a/test/data_loader/data_loader.jl +++ b/test/data_loader/data_loader.jl @@ -5,7 +5,7 @@ function test_data_loader(sys_dim, n_time_steps, n_params, T=Float32) dl = DataLoader(data) # first argument is sys_dim, second is number of heads, third is number of units - model = Transformer(dl.sys_dim, 2, 1) + model = Transformer(dl.input_dim, 2, 1) ps = initialparameters(CPU(), T, model) dx = Zygote.gradient(ps -> GeometricMachineLearning.loss(model, ps, dl), ps)[1] ps_copy = deepcopy(ps) diff --git a/test/data_loader/mnist_utils.jl b/test/data_loader/mnist_utils.jl index 171523e3a..fe685c647 100644 --- a/test/data_loader/mnist_utils.jl +++ b/test/data_loader/mnist_utils.jl @@ -34,7 +34,7 @@ end function test_onehotbatch(V::AbstractVector{T}) where {T<:Integer} V_encoded = onehotbatch(V) for i in length(V) - @test sum(V_encoded[:,i]) == 1 + @test sum(V_encoded[:,1,i]) == 1 end end @@ -47,13 +47,14 @@ train_y = Int.(ceil.(10*rand(Float32, 100))) .- 1 dl = DataLoader(train_x, train_y) -model = Dense(49, 10, tanh) +activation_function(x) = tanh.(x) +model = Classification(49, 10, activation_function) ps = initialparameters(CPU(), Float32, model) -loss₁ = loss(model, ps, dl) +loss₁ = GeometricMachineLearning.loss(model, ps, dl) opt = Optimizer(GradientOptimizer(), ps) dx = Zygote.gradient(ps -> GeometricMachineLearning.loss(model, ps, dl), ps)[1] optimization_step!(opt, model, ps, dx) -loss₂ = loss(model, ps, dl) +loss₂ = GeometricMachineLearning.loss(model, ps, dl) @test loss₂ < loss₁