-
Notifications
You must be signed in to change notification settings - Fork 373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TypeError when running R2D2 using CUDA #561
Comments
I have fixed this bug in above commit, you can test this demo again. |
Thanks for investigating this issue so quickly! I'm running now into another (probably related) exception:
|
I fixed this problem when processing |
Thanks for this fix! The Cartpole training runs now for some time, but eventually runs into a new exception:
|
Can you always reproduce this IndexError? |
This bug is not always reproducible. |
This means that there is no group in the buffer. Another question, why use group sample in cartpole training, can you provide your main file? |
Here is the Jupyter Notebook that I'm running on Colab. # !pip install git+https://github.com/opendilab/DI-engine.git@main#egg=DI-engine
import ding
import gym
from ditk import logging
from ding.model import DRQN
from ding.policy import R2D2Policy
from ding.envs import DingEnvWrapper, BaseEnvManagerV2
from ding.data import DequeBuffer
from ding.config import compile_config
from ding.framework import task
from ding.framework.context import OnlineRLContext
from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \
eps_greedy_handler, CkptSaver, nstep_reward_enhancer
from ding.utils import set_pkg_seed
from dizoo.classic_control.cartpole.config.cartpole_r2d2_config import main_config, create_config
def main():
logging.getLogger().setLevel(logging.INFO)
cfg = compile_config(main_config, create_cfg=create_config, auto=True)
cfg["policy"]["cuda"] = True
with task.start(async_mode=False, ctx=OnlineRLContext()):
collector_env = BaseEnvManagerV2(
env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)],
cfg=cfg.env.manager
)
evaluator_env = BaseEnvManagerV2(
env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)],
cfg=cfg.env.manager
)
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
model = DRQN(**cfg.policy.model)
buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
policy = R2D2Policy(cfg.policy, model=model)
task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(eps_greedy_handler(cfg))
task.use(StepCollector(cfg, policy.collect_mode, collector_env))
task.use(nstep_reward_enhancer(cfg))
task.use(data_pusher(cfg, buffer_, group_by_env=True))
task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
task.use(CkptSaver(cfg, policy, train_freq=100))
task.run()
if __name__ == "__main__":
main() |
unroll_len is set to 42 in the configuration of r2d2, which may be too large for some environments. So not enough samples are collected and stored in the buffer. When sampling, it will actively filter out the group with insufficient unroll_len, which may happen the exception above. The solution is reduce the value of unroll_len. |
@MarcoMeter Are you still dealing with this issue? |
I ran the CartPole training three times using an unroll_len of 20. The exception did not occur again. How can I determine an unroll_len that does not cause this exception? |
Log
Steps to reproduce:
Run CartPole R2D2 example while setting
cfg["policy"]["cuda"]=True
.The text was updated successfully, but these errors were encountered: