Skip to content

MartinuzziFrancesco/RecurrentLayers.jl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

RecurrentLayers

Stable Dev Build Status Coverage Aqua

Caution

Currently still under HEAVY development. Use at own risk and beware.

Warning

Tests and benchmark are still missing. Layers may not work as intended yet.

Overview

RecurrentLayers.jl extends Flux.jl recurrent layers offering by providing implementations of bleeding edge recurrent layers not commonly available in base deep learning libraries. It is designed for a seamless integration with the larger Flux ecosystem, enabling researchers and practitioners to leverage the latest developments in recurrent neural networks.

Features

Currently available layers and work in progress in the short term:

  • Minimal gated unit (MGU) arxiv
  • Light gated recurrent unit (LiGRU) arxiv
  • Minimal gated recurrent unit (minGRU) and minimal long short term memory arxiv
  • Independently recurrent neural networks (IndRNN) arxiv

Installation

RecurrentLayers.jl is not yet registered. You can install it directly from the GitHub repository:

using Pkg
Pkg.add(url="https://github.com/MartinuzziFrancesco/RecurrentLayers.jl")

Getting started

The workflow is identical to any recurrent Flux layer:

using Flux
using RecurrentLayers

input_size = 2
hidden_size = 5
output_size = 3
sequence_length = 100
epochs = 10

model = Chain(
    MGU(input_size, hidden_size),
    Dense(hidden_size, output_size)
)

# dummy data
X = rand(Float32, input_size, sequence_length)
Y = rand(1:output_size)

# loss function
loss_fn(x, y) = Flux.mse(model(x), y)

# optimizer
opt = Adam()

# training 
for epoch in 1:epochs
    # gradients
    gs = gradient(Flux.params(model)) do
        loss = loss_fn(X, Y)
        return loss
    end
    # update parameters
    Flux.update!(opt, Flux.params(model), gs)
    # loss at epoch
    current_loss = loss_fn(X, Y)
    println("Epoch $epoch, Loss: $(current_loss)")
end