Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compile training loop with Reactant #673

Closed
wants to merge 9 commits into from
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Expand All @@ -60,14 +61,15 @@ LuxMPIExt = "MPI"
LuxMPINCCLExt = ["CUDA", "MPI", "NCCL"]
LuxReverseDiffExt = ["FunctionWrappers", "ReverseDiff"]
LuxSimpleChainsExt = "SimpleChains"
LuxReactantExt = ["Enzyme", "Reactant"]
LuxTrackerExt = "Tracker"
LuxZygoteExt = "Zygote"

[compat]
ADTypes = "1.5"
Adapt = "4"
ArgCheck = "2.3"
ArrayInterface = "7.9"
ArrayInterface = "7.10"
CUDA = "5.3.2"
ChainRulesCore = "1.24"
Compat = "4.15"
Expand Down Expand Up @@ -96,6 +98,7 @@ NNlib = "0.9.21"
Optimisers = "0.3.3"
Preferences = "1.4.3"
Random = "1.10"
Reactant = "0.2"
Reexport = "1.2.2"
ReverseDiff = "1.15"
SIMDTypes = "0.1"
Expand Down
14 changes: 14 additions & 0 deletions ext/LuxReactantExt/LuxReactantExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
module LuxReactantExt

using Enzyme: Enzyme, Active, Const, Duplicated
using Reactant: Reactant
using Static: Static, False
using Setfield: @set!

using Lux: Lux, ReactantBackend
using Lux.Training: TrainingBackendCache, TrainState
using LuxCore: LuxCore

include("training.jl")

end
77 changes: 77 additions & 0 deletions ext/LuxReactantExt/training.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
function Lux.Training.single_train_step!(
backend::ReactantBackend, obj_fn::F, data, ts::TrainState) where {F}
data = Reactant.to_rarray(data)
ps = Reactant.to_rarray(ts.parameters)
st = Reactant.to_rarray(ts.states)
st_opt = Reactant.to_rarray(ts.optimizer_state)

compiled_inference = if backend.input_prototype !== nothing
Reactant.compile(LuxCore.apply,
(ts.model, Reactant.to_rarray(backend.input_prototype),
ps, LuxCore.testmode(st)))
else
nothing
end

compiled_grad_and_step! = Reactant.compile(
internal_grad_and_step!, (obj_fn, ts.model, ps, st, st_opt, data, ts.optimizer))

loss, st_updated, stats = compiled_grad_and_step!(
obj_fn, ts.model, ps, st, st_opt, data, ts.optimizer)

cache = TrainingBackendCache(backend, False(), nothing, (; compiled_grad_and_step!,
compiled_inference))
avik-pal marked this conversation as resolved.
Show resolved Hide resolved
@set! ts.cache = cache
@set! ts.objective_function = obj_fn
@set! ts.parameters = ps
@set! ts.states = st_updated
@set! ts.optimizer_state = st_opt
@set! ts.step = ts.step + 1

return nothing, loss, stats, ts # TODO: Return the gradients
end

function Lux.Training.single_train_step!(::ReactantBackend, obj_fn::F, data,
ts::TrainState{<:TrainingBackendCache{<:ReactantBackend}, F}) where {F}
data = Reactant.to_rarray(data)

loss, st_updated, stats = ts.cache.extras.compiled_grad_and_step!(
obj_fn, ts.model, ts.parameters, ts.states, ts.optimizer_state, data, ts.optimizer)

@set! ts.objective_function = obj_fn
@set! ts.states = st_updated
@set! ts.step = ts.step + 1

return nothing, loss, stats, ts # TODO: Return the gradients
end

function internal_grad_and_step!(
obj_fn::F, model, ps, st, st_opt, data, optimizer) where {F}
dps = Lux.recursive_make_zero(ps)

_, (loss, st_updated, stats) = Enzyme.autodiff(
Enzyme.ReverseWithPrimal, obj_fn, Active, Const(model),
Duplicated(ps, dps), Const(st), Const(data))

Lux.simple_optimizers_apply!(optimizer, st_opt, ps, dps) # ps & st_opt are updated in-place

return loss, st_updated, stats
end

function (tstate::TrainState{<:TrainingBackendCache{<:ReactantBackend}})(data)
data_reactant = Reactant.to_rarray(data)
compiled_inference = if tstate.cache.extras.compiled_inference !== nothing
tstate.cache.extras.compiled_inference
else
@warn "Inference function not compiled before. This will trigger compilation on \
every inference call to `(::TrainState)(data)`. Please use \
`ReactantBackend(; input_prototype = data)` to compile the inference \
function on the first call to `single_train_step!` or \
`single_train_step`." maxlog=1
Reactant.compile(LuxCore.apply,
(tstate.model, data_reactant, tstate.parameters,
LuxCore.testmode(tstate.states)))
end
return compiled_inference(
tstate.model, data_reactant, tstate.parameters, LuxCore.testmode(tstate.states))
end
21 changes: 16 additions & 5 deletions src/Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ using GPUArraysCore: @allowscalar
using LossFunctions: LossFunctions
using Markdown: @doc_str
using NNlib: NNlib
<<<<<<< HEAD
using Optimisers: Optimisers
=======
using Optimisers: Optimisers, Leaf, Descent
using Preferences: load_preference, has_preference
>>>>>>> f68ad624 (refactor(reactant): move optimisers into main pkg)
using Random: Random, AbstractRNG
using Static: StaticBool, StaticInt, StaticSymbol, True, False, static, known, dynamic
using Reexport: Reexport, @reexport
Expand Down Expand Up @@ -46,12 +51,20 @@ include("extended_ops.jl")
# Training Helpers
include("helpers/training.jl")

# Compilers
include("compilers.jl")

# Experimental
include("contrib/contrib.jl")

# Pretty Printing
include("layers/display.jl")

# Transform to and from other frameworks
include("transform/types.jl")
include("transform/flux.jl")
include("transform/simplechains.jl")

# Layer Implementations
include("layers/basic.jl")
include("layers/containers.jl")
Expand All @@ -70,16 +83,12 @@ include("helpers/losses.jl")
include("helpers/recursive_ops.jl")
include("helpers/match_eltype.jl")
include("helpers/size_propagator.jl")
include("helpers/simple_optimizers.jl")

# AutoDiff
include("autodiff/api.jl")
include("autodiff/autodiff.jl")

# Transform to and from other frameworks
include("transform/types.jl")
include("transform/flux.jl")
include("transform/simplechains.jl")

# Distributed Training
include("distributed/backend.jl")
include("distributed/public_api.jl")
Expand All @@ -106,6 +115,8 @@ export jacobian_vector_product, vector_jacobian_product
export batched_jacobian
export AutoEnzyme, AutoForwardDiff, AutoReverseDiff, AutoTracker, AutoZygote

export ReactantBackend

export BinaryCrossEntropyLoss, BinaryFocalLoss, CrossEntropyLoss, DiceCoeffLoss, FocalLoss,
HingeLoss, HuberLoss, KLDivergenceLoss, L1Loss, L2Loss, MAELoss, MSELoss, MSLELoss,
PoissonLoss, SiameseContrastiveLoss, SquaredHingeLoss
Expand Down
29 changes: 29 additions & 0 deletions src/compilers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
abstract type AbstractCompilerBackend end

"""
ReactantBackend(; input_prototype = nothing)

Compile Lux model and gradient computation to MLIR/XLA via `Reactant.jl`.

!!! tip "Newly Added Feature!"

This has been added to Lux very recently and is under-going rapid development.
Currently, only a limited subset of Lux models can be compiled via `Reactant.jl`. If you
encounter any issues, please report them on the `Lux.jl` or `Reactant.jl` GitHub
repository.

## Keyword Arguments

- `input_prototype`: Input data representative of the data that will be used for
inference. If this is provided, we will compile the inference function with
`Reactant.jl` on the first call to [`Lux.Experimental.single_train_step!`](@ref) or
[`Lux.Experimental.single_train_step`](@ref). If this is not provided, we will have to
recompile the inference function on every call to `(::TrainState)(data)` and this will
be prohibitively expensive.

See [`Lux.Experimental.single_train_step!`](@ref) or
[`Lux.Experimental.single_train_step`](@ref) for information on how to use this backend.
"""
@kwdef @concrete struct ReactantBackend <: AbstractCompilerBackend
input_prototype = nothing
end
14 changes: 14 additions & 0 deletions src/helpers/simple_optimizers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# These are meant to be used internally for compiling certain lux optiomization
function simple_optimizers_apply!(ps, gs, leaf::Leaf{<:Descent})
@. ps -= leaf.rule.eta * gs
end

for opt in (Descent,)
@eval function simple_optimizers_apply!(::$(opt), st_opt, ps, gs)
recursive_map(simple_optimizers_apply!, ps, gs, st_opt)
end
end

function simple_optimizers_apply!(opt, st_opt, ps, gs)
throw(ArgumentError("Optimizer $(typeof(opt)) not yet supported."))
end
22 changes: 19 additions & 3 deletions src/helpers/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ using LuxCore: LuxCore, AbstractLuxLayer
"""
TrainState

Training State containing:
## Training State containing:

- `model`: `Lux` model.
- `parameters`: Trainable Variables of the `model`.
Expand All @@ -23,7 +23,7 @@ Training State containing:
- `optimizer_state`: Optimizer State.
- `step`: Number of updates of the parameters made.

Internal fields:
## Internal fields:

- `cache`: Cached values. Implementations are free to use this for whatever they want.
- `objective_function`: Objective function might be cached.
Expand All @@ -32,6 +32,12 @@ Internal fields:

Constructing this object directly shouldn't be considered a stable API. Use the
version with the Optimisers API.

## Special Features

To run inference using the current parameters and states simply call the TrainState with
the input data as `tstate(data)`. This will automatically set `Lux.testmode`. However, note
that `tstate.states` will not be updated with the new state.
"""
@concrete struct TrainState
cache
Expand Down Expand Up @@ -65,6 +71,8 @@ function TrainState(model::AbstractLuxLayer, ps, st, optimizer::Optimisers.Abstr
return TrainState(nothing, nothing, model, ps, st, optimizer, st_opt, 0)
end

(ts::TrainState)(data) = ts.model(data, ts.parameters, Lux.testmode(ts.states))

@concrete struct TrainingBackendCache
backend
first_try <: StaticBool
Expand Down Expand Up @@ -237,12 +245,16 @@ Perform a single training step. Computes the gradients using [`compute_gradients
updates the parameters using [`apply_gradients!`](@ref). All backends supported via
[`compute_gradients`](@ref) are supported here.

## Additional Backends

- [`ReactantBackend`](@ref): Compiles the training loop to MLIR/XLA via `Reactant.jl`.

## Return

Returned values are the same as [`compute_gradients`](@ref). Note that despite the `!`,
only the parameters in `ts` are updated inplace. Users should be using the returned `ts`
object for further training steps, else there is no caching and performance will be
suboptimal (and absolutely terrible for backends like `AutoReactant`).
suboptimal (and absolutely terrible for backends like `ReactantBackend`).
"""
function single_train_step! end

Expand All @@ -253,6 +265,10 @@ Perform a single training step. Computes the gradients using [`compute_gradients
updates the parameters using [`apply_gradients`](@ref). All backends supported via
[`compute_gradients`](@ref) are supported here.

## Additional Backends

- [`ReactantBackend`](@ref): Compiles the training loop to MLIR/XLA via `Reactant.jl`.

In most cases you should use [`single_train_step!`](@ref) instead of this function.

## Return
Expand Down
3 changes: 2 additions & 1 deletion test/qa_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ end

@testitem "Explicit Imports: Quality Assurance" setup=[SharedTestSetup] tags=[:others] begin
# Load all trigger packages
import Lux, ComponentArrays, ReverseDiff, SimpleChains, Tracker, Zygote, Enzyme
import Lux, ComponentArrays, ReverseDiff, Flux, SimpleChains, Tracker, Zygote, Enzyme,
Reactant
using ExplicitImports

# Skip our own packages
Expand Down
Loading