-
Notifications
You must be signed in to change notification settings - Fork 78
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #93 from salesforce/ddpg
Ddpg
- Loading branch information
Showing
30 changed files
with
3,226 additions
and
251 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
110 changes: 110 additions & 0 deletions
110
example_envs/single_agent/classic_control/pendulum/pendulum.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import numpy as np | ||
from warp_drive.utils.constants import Constants | ||
from warp_drive.utils.data_feed import DataFeed | ||
from warp_drive.utils.gpu_environment_context import CUDAEnvironmentContext | ||
|
||
from example_envs.single_agent.base import SingleAgentEnv, map_to_single_agent, get_action_for_single_agent | ||
from gym.envs.classic_control.pendulum import PendulumEnv | ||
|
||
_OBSERVATIONS = Constants.OBSERVATIONS | ||
_ACTIONS = Constants.ACTIONS | ||
_REWARDS = Constants.REWARDS | ||
|
||
|
||
class ClassicControlPendulumEnv(SingleAgentEnv): | ||
|
||
name = "ClassicControlPendulumEnv" | ||
|
||
def __init__(self, episode_length, env_backend="cpu", reset_pool_size=0, seed=None): | ||
super().__init__(episode_length, env_backend, reset_pool_size, seed=seed) | ||
|
||
self.gym_env = PendulumEnv(g=9.81) | ||
|
||
self.action_space = map_to_single_agent(self.gym_env.action_space) | ||
self.observation_space = map_to_single_agent(self.gym_env.observation_space) | ||
|
||
def step(self, action=None): | ||
self.timestep += 1 | ||
action = get_action_for_single_agent(action) | ||
observation, reward, terminated, _, _ = self.gym_env.step(action) | ||
|
||
obs = map_to_single_agent(observation) | ||
rew = map_to_single_agent(reward) | ||
done = {"__all__": self.timestep >= self.episode_length or terminated} | ||
info = {} | ||
|
||
return obs, rew, done, info | ||
|
||
def reset(self): | ||
self.timestep = 0 | ||
if self.reset_pool_size < 2: | ||
# we use a fixed initial state all the time | ||
initial_obs, _ = self.gym_env.reset(seed=self.seed) | ||
else: | ||
initial_obs, _ = self.gym_env.reset(seed=None) | ||
obs = map_to_single_agent(initial_obs) | ||
|
||
return obs | ||
|
||
|
||
class CUDAClassicControlPendulumEnv(ClassicControlPendulumEnv, CUDAEnvironmentContext): | ||
|
||
def get_data_dictionary(self): | ||
data_dict = DataFeed() | ||
# the reset function returns the initial observation which is a processed tuple from state | ||
# so we will call env.state to access the initial state | ||
self.gym_env.reset(seed=self.seed) | ||
initial_state = self.gym_env.state | ||
|
||
if self.reset_pool_size < 2: | ||
data_dict.add_data( | ||
name="state", | ||
data=np.atleast_2d(initial_state), | ||
save_copy_and_apply_at_reset=True, | ||
) | ||
else: | ||
data_dict.add_data( | ||
name="state", | ||
data=np.atleast_2d(initial_state), | ||
save_copy_and_apply_at_reset=False, | ||
) | ||
return data_dict | ||
|
||
def get_tensor_dictionary(self): | ||
tensor_dict = DataFeed() | ||
return tensor_dict | ||
|
||
def get_reset_pool_dictionary(self): | ||
reset_pool_dict = DataFeed() | ||
if self.reset_pool_size >= 2: | ||
state_reset_pool = [] | ||
for _ in range(self.reset_pool_size): | ||
self.gym_env.reset(seed=None) | ||
initial_state = self.gym_env.state | ||
state_reset_pool.append(np.atleast_2d(initial_state)) | ||
state_reset_pool = np.stack(state_reset_pool, axis=0) | ||
assert len(state_reset_pool.shape) == 3 and state_reset_pool.shape[2] == 2 | ||
|
||
reset_pool_dict.add_pool_for_reset(name="state_reset_pool", | ||
data=state_reset_pool, | ||
reset_target="state") | ||
return reset_pool_dict | ||
|
||
def step(self, actions=None): | ||
self.timestep += 1 | ||
args = [ | ||
"state", | ||
_ACTIONS, | ||
"_done_", | ||
_REWARDS, | ||
_OBSERVATIONS, | ||
"_timestep_", | ||
("episode_length", "meta"), | ||
] | ||
if self.env_backend == "numba": | ||
self.cuda_step[ | ||
self.cuda_function_manager.grid, self.cuda_function_manager.block | ||
](*self.cuda_step_function_feed(args)) | ||
else: | ||
raise Exception("CUDAClassicControlPendulumEnv expects env_backend = 'numba' ") | ||
|
72 changes: 72 additions & 0 deletions
72
example_envs/single_agent/classic_control/pendulum/pendulum_step_numba.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import numba | ||
import numba.cuda as numba_driver | ||
import numpy as np | ||
import math | ||
|
||
DEFAULT_X = np.pi | ||
DEFAULT_Y = 1.0 | ||
|
||
max_speed = 8 | ||
max_torque = 2.0 | ||
dt = 0.05 | ||
g = 9.81 | ||
m = 1.0 | ||
l = 1.0 | ||
|
||
@numba_driver.jit | ||
def _clip(v, min, max): | ||
if v < min: | ||
return min | ||
if v > max: | ||
return max | ||
return v | ||
|
||
|
||
@numba_driver.jit | ||
def angle_normalize(x): | ||
return ((x + np.pi) % (2 * np.pi)) - np.pi | ||
|
||
|
||
@numba_driver.jit | ||
def NumbaClassicControlPendulumEnvStep( | ||
state_arr, | ||
action_arr, | ||
done_arr, | ||
reward_arr, | ||
observation_arr, | ||
env_timestep_arr, | ||
episode_length): | ||
|
||
kEnvId = numba_driver.blockIdx.x | ||
kThisAgentId = numba_driver.threadIdx.x | ||
|
||
assert kThisAgentId == 0, "We only have one agent per environment" | ||
|
||
env_timestep_arr[kEnvId] += 1 | ||
|
||
assert 0 < env_timestep_arr[kEnvId] <= episode_length | ||
|
||
action = action_arr[kEnvId, kThisAgentId, 0] | ||
|
||
u = _clip(action, -max_torque, max_torque) | ||
|
||
th = state_arr[kEnvId, kThisAgentId, 0] | ||
thdot = state_arr[kEnvId, kThisAgentId, 1] | ||
|
||
costs = angle_normalize(th) ** 2 + 0.1 * thdot ** 2 + 0.001 * (u ** 2) | ||
|
||
newthdot = thdot + (3 * g / (2 * l) * math.sin(th) + 3.0 / (m * l ** 2) * u) * dt | ||
newthdot = _clip(newthdot, -max_speed, max_speed) | ||
newth = th + newthdot * dt | ||
|
||
state_arr[kEnvId, kThisAgentId, 0] = newth | ||
state_arr[kEnvId, kThisAgentId, 1] = newthdot | ||
|
||
observation_arr[kEnvId, kThisAgentId, 0] = math.cos(newth) | ||
observation_arr[kEnvId, kThisAgentId, 1] = math.sin(newth) | ||
observation_arr[kEnvId, kThisAgentId, 2] = newthdot | ||
|
||
reward_arr[kEnvId, kThisAgentId] = -costs | ||
|
||
if env_timestep_arr[kEnvId] == episode_length: | ||
done_arr[kEnvId] = 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,7 @@ | |
|
||
setup( | ||
name="rl-warp-drive", | ||
version="2.6.2", | ||
version="2.7", | ||
author="Tian Lan, Sunil Srinivasa, Brenton Chu, Stephan Zheng", | ||
author_email="[email protected]", | ||
description="Framework for fast end-to-end " | ||
|
86 changes: 86 additions & 0 deletions
86
tests/example_envs/numba_tests/single_agent/classic_control/test_pendulum.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import unittest | ||
import numpy as np | ||
import torch | ||
|
||
from warp_drive.env_cpu_gpu_consistency_checker import EnvironmentCPUvsGPU | ||
from example_envs.single_agent.classic_control.pendulum.pendulum import \ | ||
ClassicControlPendulumEnv, CUDAClassicControlPendulumEnv | ||
from warp_drive.env_wrapper import EnvWrapper | ||
|
||
|
||
env_configs = { | ||
"test1": { | ||
"episode_length": 200, | ||
"reset_pool_size": 0, | ||
"seed": 32145, | ||
}, | ||
} | ||
|
||
|
||
class MyTestCase(unittest.TestCase): | ||
""" | ||
CPU v GPU consistency unit tests | ||
""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.testing_class = EnvironmentCPUvsGPU( | ||
cpu_env_class=ClassicControlPendulumEnv, | ||
cuda_env_class=CUDAClassicControlPendulumEnv, | ||
env_configs=env_configs, | ||
gpu_env_backend="numba", | ||
num_envs=5, | ||
num_episodes=2, | ||
) | ||
|
||
def test_env_consistency(self): | ||
try: | ||
self.testing_class.test_env_reset_and_step() | ||
except AssertionError: | ||
self.fail("ClassicControlPendulumEnv environment consistency tests failed") | ||
|
||
def test_reset_pool(self): | ||
env_wrapper = EnvWrapper( | ||
env_obj=CUDAClassicControlPendulumEnv(episode_length=100, reset_pool_size=8), | ||
num_envs=3, | ||
env_backend="numba", | ||
) | ||
env_wrapper.reset_all_envs() | ||
env_wrapper.env_resetter.init_reset_pool(env_wrapper.cuda_data_manager, seed=12345) | ||
self.assertTrue(env_wrapper.cuda_data_manager.reset_target_to_pool["state"] == "state_reset_pool") | ||
|
||
# squeeze() the agent dimension which is 1 always | ||
state_after_initial_reset = env_wrapper.cuda_data_manager.pull_data_from_device("state").squeeze() | ||
|
||
reset_pool = env_wrapper.cuda_data_manager.pull_data_from_device( | ||
env_wrapper.cuda_data_manager.get_reset_pool("state")) | ||
reset_pool_mean = reset_pool.mean(axis=0).squeeze() | ||
|
||
self.assertTrue(reset_pool.std(axis=0).mean() > 1e-4) | ||
|
||
env_wrapper.cuda_data_manager.data_on_device_via_torch("_done_")[:] = torch.from_numpy( | ||
np.array([1, 1, 0]) | ||
).cuda() | ||
|
||
state_values = {0: [], 1: [], 2: []} | ||
for _ in range(10000): | ||
env_wrapper.env_resetter.reset_when_done(env_wrapper.cuda_data_manager, mode="if_done", undo_done_after_reset=False) | ||
res = env_wrapper.cuda_data_manager.pull_data_from_device("state") | ||
state_values[0].append(res[0]) | ||
state_values[1].append(res[1]) | ||
state_values[2].append(res[2]) | ||
|
||
state_values_env0_mean = np.stack(state_values[0]).mean(axis=0).squeeze() | ||
state_values_env1_mean = np.stack(state_values[1]).mean(axis=0).squeeze() | ||
state_values_env2_mean = np.stack(state_values[2]).mean(axis=0).squeeze() | ||
|
||
for i in range(len(reset_pool_mean)): | ||
self.assertTrue(np.absolute(state_values_env0_mean[i] - reset_pool_mean[i]) < 0.1 * abs(reset_pool_mean[i])) | ||
self.assertTrue(np.absolute(state_values_env1_mean[i] - reset_pool_mean[i]) < 0.1 * abs(reset_pool_mean[i])) | ||
self.assertTrue( | ||
np.absolute( | ||
state_values_env2_mean[i] - state_after_initial_reset[0][i] | ||
) < 0.001 * abs(state_after_initial_reset[0][i]) | ||
) | ||
|
||
|
Oops, something went wrong.