Skip to content

Commit

Permalink
feat: allow no grad option for reactant (#1190)
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal authored Jan 8, 2025
1 parent 10ea255 commit c81629b
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 24 deletions.
3 changes: 2 additions & 1 deletion examples/ConditionalVAE/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,8 @@ function main(; batchsize=128, image_size=(64, 64), num_latent_dims=8, max_num_f
start_time = time()
for (i, X) in enumerate(train_dataloader)
(_, loss, _, train_state) = Training.single_train_step!(
AutoEnzyme(), loss_function, X, train_state)
AutoEnzyme(), loss_function, X, train_state; return_gradients=Val(false)
)

loss_total += loss
total_samples += size(X, ndims(X))
Expand Down
2 changes: 1 addition & 1 deletion ext/LuxReactantExt/LuxReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using Enzyme: Enzyme, Const, Duplicated, Active
using Optimisers: Optimisers
using Reactant: Reactant, @compile, @code_hlo, AnyTracedRArray, TracedRArray, TracedRNumber
using Setfield: @set!
using Static: False
using Static: True, False

using Lux: Lux, LuxOps, Training, Utils
using Lux.Training: TrainingBackendCache, ReactantBackend
Expand Down
27 changes: 18 additions & 9 deletions ext/LuxReactantExt/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ function Lux.Training.compute_gradients_impl(
end

function Lux.Training.compute_gradients_impl(::ReactantBackend, obj_fn::F, data,
ts::Training.TrainState{<:TrainingBackendCache{ReactantBackend}, F}) where {F}
ts::Training.TrainState{<:TrainingBackendCache{<:ReactantBackend}, F}) where {F}
grads, loss, stats, st = ts.cache.extras.compiled_gradient_function(
obj_fn, ts.model, data, ts.parameters, ts.states)
@set! ts.states = st
Expand All @@ -70,7 +70,7 @@ for inplace in ("!", "")

# Ideally users never hit this dispatch but it is still good to have as a fallback
@eval function Lux.Training.$(apply_gradients_fn)(
ts::Training.TrainState{<:TrainingBackendCache{ReactantBackend}}, grads
ts::Training.TrainState{<:TrainingBackendCache{<:ReactantBackend}}, grads
)
if hasfield(typeof(ts.cache.extras), :update_function)
update_function = ts.cache.extras.update_function
Expand All @@ -94,15 +94,15 @@ for inplace in ("!", "")
@eval function Lux.Training.$(fname)(backend::ReactantBackend, objective_function::F,
data, ts::Training.TrainState) where {F}
maybe_dump_to_mlir_file!($(internal_fn), objective_function, ts.model, data,
ts.parameters, ts.states, ts.optimizer_state)
ts.parameters, ts.states, ts.optimizer_state, backend.return_gradients)

compiled_grad_and_step_function = @compile $(internal_fn)(
objective_function, ts.model, data, ts.parameters, ts.states,
ts.optimizer_state)
ts.optimizer_state, backend.return_gradients)

grads, ps, loss, stats, st, opt_state = compiled_grad_and_step_function(
objective_function, ts.model, data, ts.parameters, ts.states,
ts.optimizer_state)
ts.optimizer_state, backend.return_gradients)

cache = TrainingBackendCache(
backend, False(), nothing, (; compiled_grad_and_step_function))
Expand All @@ -116,10 +116,11 @@ for inplace in ("!", "")
return grads, loss, stats, ts
end

@eval function Lux.Training.$(fname)(::ReactantBackend, obj_fn::F, data,
ts::Training.TrainState{<:TrainingBackendCache{ReactantBackend}, F}) where {F}
@eval function Lux.Training.$(fname)(backend::ReactantBackend, obj_fn::F, data,
ts::Training.TrainState{<:TrainingBackendCache{<:ReactantBackend}, F}) where {F}
grads, ps, loss, stats, st, opt_state = ts.cache.extras.compiled_grad_and_step_function(
obj_fn, ts.model, data, ts.parameters, ts.states, ts.optimizer_state)
obj_fn, ts.model, data, ts.parameters, ts.states,
ts.optimizer_state, backend.return_gradients)

@set! ts.states = st
@set! ts.parameters = ps
Expand All @@ -131,7 +132,15 @@ for inplace in ("!", "")

# XXX: Inplace version not actually inplace
@eval function $(internal_fn)(
objective_function::F, model, data, ps, st, opt_state) where {F}
objective_function::F, model, data, ps, st, opt_state, ::False) where {F}
dps, loss, stats, stₙ = compute_gradients_internal(
objective_function, model, data, ps, st)
opt_state, ps = Optimisers.$(update_fn)(opt_state, ps, dps)
return nothing, ps, loss, stats, stₙ, opt_state
end

@eval function $(internal_fn)(
objective_function::F, model, data, ps, st, opt_state, ::True) where {F}
dps, loss, stats, stₙ = compute_gradients_internal(
objective_function, model, data, ps, st)
opt_state, ps = Optimisers.$(update_fn)(opt_state, ps, dps)
Expand Down
45 changes: 32 additions & 13 deletions src/helpers/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using FastClosures: @closure
using Functors: Functors, fmap
using Optimisers: Optimisers
using Setfield: @set!
using Static: StaticBool, Static, False, True
using Static: StaticBool, Static, False, True, static

using ..Lux: Lux, Utils, ReactantCompatibleOptimisers
using LuxCore: LuxCore, AbstractLuxLayer
Expand Down Expand Up @@ -104,7 +104,9 @@ function Base.show(io::IO, ::MIME"text/plain", ts::TrainState)
print(io, "\n objective_function: ", nameof(typeof(ts.objective_function)))
end

struct ReactantBackend end
@concrete struct ReactantBackend
return_gradients <: StaticBool
end

const APPLY_GRAD_DOCSTRING = """
## Arguments
Expand Down Expand Up @@ -198,10 +200,13 @@ function compute_gradients(ad, obj_fn::F, data, ts::TrainState) where {F}
return compute_gradients_impl(maybe_wrap_adtype(ad, dev_type), obj_fn, data, ts)
end

maybe_wrap_adtype(backend::ReactantBackend, _) = backend
maybe_wrap_adtype(ad::AbstractADType, _) = ad
function maybe_wrap_adtype(ad::AbstractADType, ::Type{ReactantDevice})
ad isa AutoEnzyme && return ReactantBackend()
maybe_wrap_adtype(backend::ReactantBackend, ::Any; kwargs...) = backend
maybe_wrap_adtype(ad::AbstractADType, ::Any; kwargs...) = ad
function maybe_wrap_adtype(
ad::AbstractADType, ::Type{ReactantDevice};
return_gradients::Utils.BoolType=True()
)
ad isa AutoEnzyme && return ReactantBackend(static(return_gradients))
throw(ArgumentError("Computing gradients for models on XLA is supported only with \
Enzyme.jl (`AutoEnzyme`)."))
end
Expand Down Expand Up @@ -258,39 +263,53 @@ function wrap_objective_function(
end

"""
single_train_step!(backend, obj_fn::F, data, ts::TrainState)
single_train_step!(backend, obj_fn::F, data, ts::TrainState; return_gradients=True())
Perform a single training step. Computes the gradients using [`compute_gradients`](@ref) and
updates the parameters using [`apply_gradients!`](@ref). All backends supported via
[`compute_gradients`](@ref) are supported here.
## Keyword Arguments
- `return_gradients`: If `True()`, the gradients are returned. If `False()`, the returned
gradients are `nothing`. Defaults to `True()`. This is only used for Reactant Backend.
## Return
Returned values are the same as [`compute_gradients`](@ref). Note that despite the `!`,
only the parameters in `ts` are updated inplace. Users should be using the returned `ts`
object for further training steps, else there is no caching and performance will be
suboptimal (and absolutely terrible for backends like `AutoReactant`).
"""
function single_train_step!(backend, obj_fn::F, data, ts::TrainState) where {F}
backend = maybe_wrap_adtype(backend, get_device_type((ts.parameters, ts.states)))
function single_train_step!(backend, obj_fn::F, data, ts::TrainState;
return_gradients::Utils.BoolType=True()) where {F}
backend = maybe_wrap_adtype(
backend, get_device_type((ts.parameters, ts.states)); return_gradients)
return single_train_step_impl!(backend, obj_fn, data, ts)
end

"""
single_train_step(backend, obj_fn::F, data, ts::TrainState)
single_train_step(backend, obj_fn::F, data, ts::TrainState; return_gradients=True())
Perform a single training step. Computes the gradients using [`compute_gradients`](@ref) and
updates the parameters using [`apply_gradients`](@ref). All backends supported via
[`compute_gradients`](@ref) are supported here.
In most cases you should use [`single_train_step!`](@ref) instead of this function.
## Keyword Arguments
- `return_gradients`: If `True()`, the gradients are returned. If `False()`, the returned
gradients are `nothing`. Defaults to `True()`. This is only used for Reactant Backend.
## Return
Returned values are the same as [`compute_gradients`](@ref).
Returned values are the same as [`single_train_step!`](@ref).
"""
function single_train_step(backend, obj_fn::F, data, ts::TrainState) where {F}
backend = maybe_wrap_adtype(backend, get_device_type((ts.parameters, ts.states)))
function single_train_step(backend, obj_fn::F, data, ts::TrainState;
return_gradients::Utils.BoolType=True()) where {F}
backend = maybe_wrap_adtype(
backend, get_device_type((ts.parameters, ts.states)); return_gradients)
return single_train_step_impl(backend, obj_fn, data, ts)
end

Expand Down

2 comments on commit c81629b

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/122632

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.5.0 -m "<description of version>" c81629b7a62522809d91ab7aec0a9d6249dbd5aa
git push origin v1.5.0

Please sign in to comment.