Skip to content

Commit

Permalink
Fixed typos; am now explicitly importing loss in these files.
Browse files Browse the repository at this point in the history
  • Loading branch information
benedict-96 committed Oct 20, 2023
1 parent 0f77270 commit 009a17c
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion test/data_loader/data_loader.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ function test_data_loader(sys_dim, n_time_steps, n_params, T=Float32)
dl = DataLoader(data)

# first argument is sys_dim, second is number of heads, third is number of units
model = Transformer(dl.sys_dim, 2, 1)
model = Transformer(dl.input_dim, 2, 1)
ps = initialparameters(CPU(), T, model)
dx = Zygote.gradient(ps -> GeometricMachineLearning.loss(model, ps, dl), ps)[1]
ps_copy = deepcopy(ps)
Expand Down
9 changes: 5 additions & 4 deletions test/data_loader/mnist_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ end
function test_onehotbatch(V::AbstractVector{T}) where {T<:Integer}
V_encoded = onehotbatch(V)
for i in length(V)
@test sum(V_encoded[:,i]) == 1
@test sum(V_encoded[:,1,i]) == 1
end
end

Expand All @@ -47,13 +47,14 @@ train_y = Int.(ceil.(10*rand(Float32, 100))) .- 1

dl = DataLoader(train_x, train_y)

model = Dense(49, 10, tanh)
activation_function(x) = tanh.(x)
model = Classification(49, 10, activation_function)
ps = initialparameters(CPU(), Float32, model)
loss₁ = loss(model, ps, dl)
loss₁ = GeometricMachineLearning.loss(model, ps, dl)

opt = Optimizer(GradientOptimizer(), ps)
dx = Zygote.gradient(ps -> GeometricMachineLearning.loss(model, ps, dl), ps)[1]
optimization_step!(opt, model, ps, dx)
loss₂ = loss(model, ps, dl)
loss₂ = GeometricMachineLearning.loss(model, ps, dl)

@test loss₂ < loss₁

0 comments on commit 009a17c

Please sign in to comment.