-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #129 from JuliaGNI/loss_routines
Loss routines
- Loading branch information
Showing
7 changed files
with
189 additions
and
53 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
@doc raw""" | ||
Optimize for an entire epoch. For this you have to supply: | ||
- an instance of the optimizer. | ||
- the neural network model | ||
- the parameters of the model | ||
- the data (in form of `DataLoader`) | ||
- in instance of `Batch` that contains `batch_size` (and optionally `seq_length`) | ||
With the optional argument: | ||
- the loss, which takes the `model`, the parameters `ps` and an instance of `DataLoader` as input. | ||
The output of `optimize_for_one_epoch!` is the average loss over all batches of the epoch: | ||
```math | ||
output = \frac{1}{\mathtt{steps\_per\_epoch}}\sum_{t=1}^\mathtt{steps\_per\_epoch}loss(\theta^{(t-1)}). | ||
``` | ||
This is done because any **reverse differentiation** routine always has two outputs: a pullback and the value of the function it is differentiating. In the case of zygote: `loss_value, pullback = Zygote.pullback(ps -> loss(ps), ps)` (if the loss only depends on the parameters). | ||
""" | ||
function optimize_for_one_epoch!(opt::Optimizer, model, ps::Union{Tuple, NamedTuple}, dl::DataLoader{T, AT, BT}, batch::Batch, loss) where {T, T1, AT<:AbstractArray{T, 3}, BT<:AbstractArray{T1, 3}} | ||
count = 0 | ||
total_error = T(0) | ||
batches = batch(dl) | ||
@views for batch_indices in batches | ||
count += 1 | ||
# these `copy`s should not be necessary! coming from a Zygote problem! | ||
input_batch = copy(dl.input[:, :, batch_indices]) | ||
output_batch = copy(dl.output[:, :, batch_indices]) | ||
loss_value, pullback = Zygote.pullback(ps -> loss(model, ps, input_batch, output_batch), ps) | ||
total_error += loss_value | ||
dp = pullback(one(loss_value))[1] | ||
optimization_step!(opt, model, ps, dp) | ||
end | ||
total_error / count | ||
end | ||
|
||
function optimize_for_one_epoch!(opt::Optimizer, model, ps::Union{Tuple, NamedTuple}, dl::DataLoader{T, CT, Nothing}, batch::Batch, loss::NetworkLoss) where {T, AT<:AbstractArray{T, 3}, BT<:NamedTuple{(:q, :p), Tuple{AT, AT}}, CT<:Union{AT, BT}} | ||
count = 0 | ||
total_error = T(0) | ||
batches = batch(dl) | ||
@views for batch_indices in batches | ||
count += 1 | ||
# these `copy`s should not be necessary! coming from a Zygote problem! | ||
input_nt, output_nt = convert_input_and_batch_indices_to_array(dl, batch, batch_indices) | ||
loss_value, pullback = Zygote.pullback(ps -> loss(model, ps, input_nt, output_nt), ps) | ||
total_error += loss_value | ||
dp = pullback(one(loss_value))[1] | ||
optimization_step!(opt, model, ps, dp) | ||
end | ||
total_error / count | ||
end | ||
|
||
@doc raw""" | ||
A functor for `Optimizer`. It is called with: | ||
- `nn::NeuralNetwork` | ||
- `dl::DataLoader` | ||
- `batch::Batch` | ||
- `n_epochs::Int` | ||
- `loss` | ||
The last argument is a function through which `Zygote` differentiates. This argument is optional; if it is not supplied `GeometricMachineLearning` defaults to an appropriate loss for the `DataLoader`. | ||
""" | ||
function (o::Optimizer)(nn::NeuralNetwork, dl::DataLoader, batch::Batch, n_epochs::Int, loss::NetworkLoss) | ||
progress_object = ProgressMeter.Progress(n_epochs; enabled=true) | ||
loss_array = zeros(n_epochs) | ||
for i in 1:n_epochs | ||
loss_array[i] = optimize_for_one_epoch!(o, nn.model, nn.params, dl, batch, loss) | ||
ProgressMeter.next!(progress_object; showvalues = [(:TrainingLoss, loss_array[i])]) | ||
end | ||
loss_array | ||
end | ||
|
||
#= | ||
function (o::Optimizer)(nn::NeuralNetwork{<:TransformerIntegrator}, dl::DataLoader, batch::Batch{Int}, n_epochs::Int=1) | ||
loss = TransformerLoss(batch) | ||
o(nn, dl, batch, n_epochs, loss) | ||
end | ||
function (o::Optimizer)(nn::NeuralNetwork{<:NeuralNetworkIntegrator}, dl::DataLoader, batch::Batch{Int}, n_epochs::Int=1) | ||
loss = FeedForwardLoss() | ||
o(nn, dl, batch, n_epochs, loss) | ||
end | ||
(o::Optimizer)(nn::NeuralNetwork{<:NeuralNetworkIntegrator}, dl::DataLoader, batch::Batch{Nothing}, n_epochs::Int=1) = o(nn, dl, Batch(batch.batch_size, 1), n_epochs) | ||
=# |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
abstract type NetworkLoss end | ||
|
||
function (loss::NetworkLoss)(nn::NeuralNetwork, input::AT, output::AT) where {AT <: AbstractArray} | ||
loss(nn.model, nn.params, input, output) | ||
end | ||
|
||
@doc raw""" | ||
The loss for a transformer network (especially a transformer integrator). The constructor is called with: | ||
- `seq_length::Int` | ||
- `prediction_window::Int` (default is 1). | ||
""" | ||
struct TransformerLoss <: NetworkLoss | ||
seq_length::Int | ||
prediction_window::Int | ||
end | ||
|
||
TransformerLoss(seq_length::Int) = TransformerLoss(seq_length, 1) | ||
|
||
@doc raw""" | ||
This crops the output array of the neural network so that it conforms with the output it should be compared to. This is needed for the transformer loss. | ||
""" | ||
function crop_array_for_transformer_loss(nn_output::AT, output::AT) where {T, AT<:AbstractArray{T, 3}} | ||
@view nn_output[axes(output, 1), axes(output, 2) .+ size(nn_output, 2) .- size(output, 2), axes(output, 3)] | ||
end | ||
|
||
function (loss::TransformerLoss)(model::Chain, ps::Union{Tuple, NamedTuple}, input::AT, output::AT) where {T, AT <: Union{AbstractArray{T, 2}, AbstractArray{T, 3}}} | ||
input_dim, input_seq_length = size(input) | ||
output_dim, output_prediction_window = size(output) | ||
@assert input_dim == output_dim | ||
@assert input_seq_length == loss.seq_length | ||
@assert output_prediction_window == loss.prediction_window | ||
|
||
predicted_output_uncropped = model(input, ps) | ||
predicted_output_cropped = crop_array_for_transformer_loss(predicted_output_uncropped, output) | ||
norm(predicted_output_cropped - output) / norm(output) | ||
end | ||
|
||
struct FeedForwardLoss <: NetworkLoss end | ||
|
||
function (loss::FeedForwardLoss)(model::Chain, ps::Union{Tuple, NamedTuple}, input::AT, output::AT) where {AT <: AbstractArray} | ||
norm(model(input, ps) - output) / norm(output) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
using GeometricMachineLearning | ||
using GeometricMachineLearning: FeedForwardLoss | ||
using Test | ||
import Random | ||
|
||
Random.seed!(123) | ||
|
||
const sin_vector = sin.(0:0.01:2π) | ||
const dl = DataLoader(reshape(sin_vector, 1, length(sin_vector), 1)) | ||
|
||
function setup_network(dl::DataLoader{T}) where T | ||
arch = Chain(Dense(1, 5, tanh), ResNet(5, tanh), Dense(5, 1, identity)) | ||
NeuralNetwork(arch, CPU(), T) | ||
end | ||
|
||
function train_network(; n_epochs=5) | ||
nn = setup_network(dl) | ||
loss = FeedForwardLoss() | ||
|
||
o = Optimizer(AdamOptimizer(), nn) | ||
batch = Batch(5, 1) | ||
loss_array = o(nn, dl, batch, n_epochs, loss) | ||
T = eltype(dl) | ||
@test loss_array[end] / loss_array[1] < T(0.1) | ||
end | ||
|
||
train_network() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters