You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In some settings you might want to run multiple steps before doing an update (aka gradient accumulation), e.g. to reduce the variance of the stochastic gradient without increasing memory requirements.
Perhaps this could implemented with a unified API but we'd need to think carefully about it.
One option might be to change the API from update(state: TensorTree, batch: TensorTree)
to update(state: TensorTree, batch: TensorTree | AccumulateBatch[TensorTree])
where AccumulateBatch is just a NamedTuple so we can differentiate it from a and then e.g.
ifnotisinstance(batch, AccumulateBatch):
batch=accumulate(batch) # convert batch to an AccumulateBatch of length 1vals= []
aux= []
forbinbatch:
v, a=log_posterior(state.params, b)
vals.append(v)
aux.append(aux)
In some settings you might want to run multiple steps before doing an
update
(aka gradient accumulation), e.g. to reduce the variance of the stochastic gradient without increasing memory requirements.Perhaps this could implemented with a unified API but we'd need to think carefully about it.
One option might be to change the API from
update(state: TensorTree, batch: TensorTree)
to
update(state: TensorTree, batch: TensorTree | AccumulateBatch[TensorTree])
where
AccumulateBatch
is just a NamedTuple so we can differentiate it from a and then e.g.Some discussion on gradient accumulation here:
https://discuss.pytorch.org/t/why-do-we-need-to-set-the-gradients-manually-to-zero-in-pytorch/4903/19?u=alband
https://wandb.ai/wandb_fc/tips/reports/How-To-Implement-Gradient-Accumulation-in-PyTorch--VmlldzoyMjMwOTk5
The text was updated successfully, but these errors were encountered: