Skip to content

Commit

Permalink
Add single_train_step
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 1, 2024
1 parent 60c595e commit 0c35c52
Show file tree
Hide file tree
Showing 8 changed files with 145 additions and 25 deletions.
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Expand All @@ -57,6 +58,7 @@ LuxMLUtilsExt = "MLUtils"
LuxMPIExt = "MPI"
LuxMPINCCLExt = ["CUDA", "MPI", "NCCL"]
LuxOptimisersExt = "Optimisers"
LuxReactantExt = ["Enzyme", "Reactant"]
LuxReverseDiffExt = "ReverseDiff"
LuxSimpleChainsExt = "SimpleChains"
LuxTrackerExt = "Tracker"
Expand Down Expand Up @@ -102,6 +104,7 @@ Pkg = "1.10"
PrecompileTools = "1.2"
Preferences = "1.4.3"
Random = "1.10"
Reactant = "0.1.1"
ReTestItems = "1.23.1"
Reexport = "1.2.2"
ReverseDiff = "1.15"
Expand Down Expand Up @@ -134,6 +137,7 @@ MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
Expand All @@ -144,4 +148,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "ComponentArrays", "DynamicExpressions", "Enzyme", "ExplicitImports", "FiniteDifferences", "Flux", "ForwardDiff", "Logging", "LuxAMDGPU", "LuxCUDA", "LuxTestUtils", "MLUtils", "MPI", "NCCL", "Optimisers", "Pkg", "ReTestItems", "ReverseDiff", "SimpleChains", "StableRNGs", "Statistics", "Test", "Tracker", "Zygote"]
test = ["Aqua", "ComponentArrays", "DynamicExpressions", "Enzyme", "ExplicitImports", "FiniteDifferences", "Flux", "ForwardDiff", "Logging", "LuxAMDGPU", "LuxCUDA", "LuxTestUtils", "MLUtils", "MPI", "NCCL", "Optimisers", "Pkg", "Reactant", "ReTestItems", "ReverseDiff", "SimpleChains", "StableRNGs", "Statistics", "Test", "Tracker", "Zygote"]
2 changes: 2 additions & 0 deletions docs/src/api/Lux/contrib.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ Lux.Experimental.TrainState
Lux.Experimental.compute_gradients
Lux.Experimental.apply_gradients
Lux.Experimental.apply_gradients!
Lux.Experimental.single_train_step
Lux.Experimental.single_train_step!
```

## Parameter Freezing
Expand Down
18 changes: 18 additions & 0 deletions ext/LuxReactantExt/LuxReactantExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
module LuxReactantExt

using Adapt: adapt
using ArgCheck: @argcheck
using ConcreteStructs: @concrete
using Enzyme: Enzyme, Active, Const, Duplicated
using Functors: fmapstructure, fmap
using Random: AbstractRNG, Xoshiro
using Reactant: Reactant
using Lux: Lux, LuxEltypeAdaptor, AutoReactant
using LuxCore: LuxCore, AbstractExplicitLayer

include("utils.jl")

# compile the entire training loop
include("train.jl")

end
Empty file added ext/LuxReactantExt/train.jl
Empty file.
20 changes: 20 additions & 0 deletions ext/LuxReactantExt/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
@inline Lux.__make_reactant_array(x::Reactant.RArray) = x
@inline function Lux.__make_reactant_array(x::AbstractArray)
hasmethod(Reactant.ArrayToConcrete, Tuple{typeof(x)}) &&
return Reactant.ConcreteRArray(x)
return __make_tracer(x)
end
@inline Lux.__make_reactant_array(x) = __make_tracer(x)

@inline function __make_tracer(x)
return Reactant.make_tracer(IdDict(), x, (), Reactant.ArrayToConcrete, nothing)
end

@inline function __try_similar_structure(x::AbstractArray, y::NamedTuple{()})
length(x) == 0 && return y
throw(DimensionMismatch(lazy"Expected empty array, got $(size(x))."))
end
@inline function __try_similar_structure(x::AbstractArray, y::AbstractArray)
return parent(x) !== x ? copy(x) : x # unview arrays and such
end
@inline __try_similar_structure(x, y) = fmap(__try_similar_structure, x, y)
20 changes: 12 additions & 8 deletions src/Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ module Lux
using PrecompileTools: @recompile_invalidations

@recompile_invalidations begin
using ADTypes: AbstractADType, AutoForwardDiff, AutoReverseDiff, AutoTracker, AutoZygote
using ADTypes: AbstractADType, AutoEnzyme, AutoForwardDiff, AutoReverseDiff,
AutoTracker, AutoZygote
using Adapt: Adapt, adapt
using ArgCheck: @argcheck
using ArrayInterface: ArrayInterface
Expand All @@ -16,7 +17,7 @@ using PrecompileTools: @recompile_invalidations
using Markdown: @doc_str
using OhMyThreads: tmapreduce
using Preferences: @load_preference
using Random: Random, AbstractRNG
using Random: Random, AbstractRNG, Xoshiro
using Reexport: @reexport

using LuxCore, LuxLib, LuxDeviceUtils, WeightInitializers
Expand Down Expand Up @@ -48,6 +49,12 @@ const DISABLE_AUTOMATIC_NESTED_AD_SWITCH = @load_preference("DisableAutomaticNes
# Utilities
include("utils.jl")

# Transform to and from other frameworks
include("transform/types.jl")
include("transform/flux.jl")
include("transform/simplechains.jl")
include("transform/reactant.jl")

# Layer Implementations
include("layers/basic.jl")
include("layers/containers.jl")
Expand All @@ -72,11 +79,6 @@ include("helpers/compact.jl")
include("helpers/autodiff.jl")
include("helpers/nested_ad.jl")

# Transform to and from other frameworks
include("transform/types.jl")
include("transform/flux.jl")
include("transform/simplechains.jl")

# Distributed Training
include("distributed/backend.jl")
include("distributed/public_api.jl")
Expand All @@ -103,13 +105,15 @@ export @compact, CompactLuxLayer

export jacobian_vector_product, vector_jacobian_product
export batched_jacobian
export AutoForwardDiff, AutoReverseDiff, AutoTracker, AutoZygote
export AutoEnzyme, AutoForwardDiff, AutoReverseDiff, AutoTracker, AutoZygote
export AutoReactant

export f16, f32, f64

export transform
export FromFluxAdaptor, FluxLayer
export ToSimpleChainsAdaptor, SimpleChainsLayer
export ToReactantAdaptor, ReactantLayer
export DynamicExpressionsLayer

export MPIBackend, NCCLBackend, DistributedUtils
Expand Down
81 changes: 65 additions & 16 deletions src/contrib/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ Internal fields:
- `cache`: Cached values. Implementations are free to use this for whatever they want.
- `objective_function`: Objective function might be cached.
!!! warning
Constructing this object directly shouldn't be considered a stable API. Use the
version with the Optimisers API.
"""
@concrete struct TrainState{C, F}
cache::C
Expand All @@ -27,30 +32,33 @@ end
function Base.show(io::IO, ts::TrainState)
println(io, "TrainState")
println(io, " model: ", ts.model)
println(io, " parameters: ", Lux.parameterlength(ts.parameters))
println(io, " states: ", Lux.statelength(ts.states))
println(io, " # of parameters: ", Lux.parameterlength(ts.parameters))
println(io, " # of states: ", Lux.statelength(ts.states))
println(io, " optimizer_state: ", ts.optimizer_state)
print(io, " step: ", ts.step)
ts.cache !== nothing && print(io, "\n cache: ", nameof(typeof(ts.cache)))
ts.objective_function !== nothing &&
print(io, "\n objective_function: ", nameof(typeof(ts.objective_function)))
end

"""
apply_gradients(ts::TrainState, grads)
Update the parameters stored in `ts` using the gradients `grads`.
const APPLY_GRAD_DOCSTRING = """
## Arguments
- `ts`: [`TrainState`](@ref) object.
- `grads`: Gradients of the loss function wrt `ts.params`.
- `update_inplace`: Whether to update the parameters inplace or not.
## Returns
Updated [`TrainState`](@ref) object.
"""

"""
apply_gradients(ts::TrainState, grads)
Update the parameters stored in `ts` using the gradients `grads`.
$(APPLY_GRAD_DOCSTRING)
"""
function apply_gradients end

"""
Expand All @@ -59,14 +67,7 @@ function apply_gradients end
Update the parameters stored in `ts` using the gradients `grads`. This is an inplace version
of [`apply_gradients`](@ref).
## Arguments
- `ts`: [`TrainState`](@ref) object.
- `grads`: Gradients of the loss function wrt `ts.params`.
## Returns
Updated [`TrainState`](@ref) object.
$(APPLY_GRAD_DOCSTRING)
"""
function apply_gradients! end

Expand Down Expand Up @@ -146,3 +147,51 @@ end

return wrapped_objective_function, st_updated, stats
end

"""
single_train_step!(backend, obj_fn::F, data, ts::TrainState)
Perform a single training step. Computes the gradients using [`compute_gradients`](@ref) and
updates the parameters using [`apply_gradients!`](@ref). All backends supported via
[`compute_gradients`](@ref) are supported here.
## Additional Backends
- [`AutoReactant`](@ref): Compiles the training loop to MLIR/XLA via `Reactant.jl`.
## Return
Returned values are the same as [`compute_gradients`](@ref). Note that despite the `!`,
only the parameters in `ts` are updated inplace. Users should be using the returned `ts`
object for further training steps, else there is no caching and performance will be
suboptimal (and absolutely terrible for backends like `AutoReactant`).
"""
function single_train_step! end

"""
single_train_step(backend, obj_fn::F, data, ts::TrainState)
Perform a single training step. Computes the gradients using [`compute_gradients`](@ref) and
updates the parameters using [`apply_gradients`](@ref). All backends supported via
[`compute_gradients`](@ref) are supported here.
## Additional Backends
- [`AutoReactant`](@ref): Compiles the training loop to MLIR/XLA via `Reactant.jl`.
In most cases you should use [`single_train_step!`](@ref) instead of this function.
## Return
Returned values are the same as [`compute_gradients`](@ref).
"""
function single_train_step end

for inplace in ("!", "")
step, apply_fn = Symbol(:single_train_step, inplace), Symbol(:apply_gradients, inplace)
@eval function $(step)(backend, obj_fn::F, data, ts::TrainState) where {F}
grads, loss, stats, ts = compute_gradients(backend, obj_fn, data, ts)
ts = $apply_fn(ts, grads)
return grads, loss, stats, ts
end
end
23 changes: 23 additions & 0 deletions src/transform/reactant.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
AutoReactant()
Compile the training loop to MLIR/XLA via `Reactant.jl`.
This has been added to Lux very recently and is under-going rapid development. Currently,
only a limited subset of Lux models can be compiled via `Reactant.jl`. If you encounter any
issues, please report them on the `Lux.jl` or `Reactant.jl` GitHub repository.
"""
struct AutoReactant end

"""
__make_reactant_array(x)
Converts `x` to a `Reactant.ConcreteRArray` if it is not already one.
"""
function __make_reactant_array end

@inline function __make_reactant_array(nt::NamedTuple{names}) where {names}
return NamedTuple{names}(map(__make_reactant_array, values(nt)))
end
@inline __make_reactant_array(t::Tuple) = map(__make_reactant_array, t)
@inline __make_reactant_array(x::AbstractExplicitLayer) = x

0 comments on commit 0c35c52

Please sign in to comment.