You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# Add warning that bptt_truncated_learning is forced to be true# due to incomplete implementation of CUDA kernel for bptt_learning## @TODO : remove this warning once the CUDA kernel, with state gradient, is implementedifself.bptt_truncated_learning==False:
print("====================================================================")
print("[WARNING]: bptt_truncated_learning is set as true (was configured as false), due to incomplete implementation of CUDA kernel for bptt_learning")
print("====================================================================")
self.bptt_truncated_learning=True
Thanks for sharing this great project!
https://github.com/RWKV/RWKV-infctx-trainer/blob/70d02c4997578a027d110e3acb03a523d3986448/RWKV-v6/src/model.py#L291C1-L300C1
Just to confirm, when doing tbptt, this is essentially similar to the gradient estimator used in TransformerXL right?
The text was updated successfully, but these errors were encountered: