Skip to content

Commit

Permalink
single waypoint reward is now working ... turned out to be a local mi…
Browse files Browse the repository at this point in the history
…nima thing?
  • Loading branch information
mginoya committed Aug 26, 2024
1 parent 31389cc commit 1f5f123
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 34 deletions.
44 changes: 30 additions & 14 deletions alfredo/agents/aant/aant.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def reset(self, rng: jax.Array) -> State:
jcmd = self._sample_command(rng3)
#wcmd = self._sample_waypoint(rng3)

wcmd = jp.array([10.0, 10.0, 0.5])
wcmd = jp.array([0.0, 10.0])

q = self.sys.init_q + jax.random.uniform(
rng1, (self.sys.q_size(),), minval=low, maxval=hi
Expand Down Expand Up @@ -118,6 +118,9 @@ def reset(self, rng: jax.Array) -> State:
'pos_x_world_abs': zero,
'pos_y_world_abs': zero,
'pos_z_world_abs': zero,
'dist_goal_x': zero,
'dist_goal_y': zero,
#'dist_goal_z': zero,
}

return State(pipeline_state, obs, reward, done, metrics, state_info)
Expand All @@ -128,14 +131,16 @@ def step(self, state: State, action: jax.Array) -> State:
pipeline_state0 = state.pipeline_state
pipeline_state = self.pipeline_step(pipeline_state0, action)

#print(f"wcmd: {state.info['wcmd']}")
#print(f"x.pos[0]: {pipeline_state.x.pos[0]}")
waypoint_cost = rTracking_Waypoint(self.sys,
state.pipeline_state,
pipeline_state,
state.info['wcmd'],
weight=1.0,
weight=100.0,
focus_idx_range=0)

lin_vel_reward = rTracking_lin_vel(self.sys,
state.pipeline_state,
pipeline_state,
jp.array([0, 0, 0]), #dummy values for previous CoM
jp.array([0, 0, 0]), #dummy values for current CoM
self.dt,
Expand All @@ -144,30 +149,30 @@ def step(self, state: State, action: jax.Array) -> State:
focus_idx_range=(0,0))

yaw_vel_reward = rTracking_yaw_vel(self.sys,
state.pipeline_state,
pipeline_state,
state.info['jcmd'],
weight=10.8,
focus_idx_range=(0,0))

ctrl_cost = rControl_act_ss(self.sys,
state.pipeline_state,
pipeline_state,
action,
weight=0.0)

torque_cost = rTorques(self.sys,
state.pipeline_state,
pipeline_state,
action,
weight=0.0)

upright_reward = rUpright(self.sys,
state.pipeline_state,
pipeline_state,
weight=0.0)

healthy_reward = rHealthy_simple_z(self.sys,
state.pipeline_state,
pipeline_state,
self._healthy_z_range,
early_terminate=self._terminate_when_unhealthy,
weight=1.0,
weight=0.0,
focus_idx_range=(0, 2))
reward = 0.0
reward = healthy_reward[0]
Expand All @@ -181,15 +186,23 @@ def step(self, state: State, action: jax.Array) -> State:
pos_world = pipeline_state.x.pos[0]
abs_pos_world = jp.abs(pos_world)

print(f'true position in world: {pos_world}')
print(f'absolute position in world: {abs_pos_world}\n')
#print(f"wcmd: {state.info['wcmd']}")
#print(f"x.pos[0]: {pipeline_state.x.pos[0]}")
wcmd = state.info['wcmd']
dist_goal = pos_world[0:2] - wcmd
#print(dist_goal)

#print(f'true position in world: {pos_world}')
#print(f'absolute position in world: {abs_pos_world}')
#print(f"dist_goal: {dist_goal}\n")

obs = self._get_obs(pipeline_state, state.info)
# print(f"\n")
# print(f"healthy_reward? {healthy_reward}")
# print(f"\n")
done = 1.0 - healthy_reward[1] if self._terminate_when_unhealthy else 0.0

#done = 1.0 - healthy_reward[1] if self._terminate_when_unhealthy else 0.0
done = 0.0

state.metrics.update(
reward_ctrl = ctrl_cost,
reward_alive = healthy_reward[0],
Expand All @@ -201,6 +214,9 @@ def step(self, state: State, action: jax.Array) -> State:
pos_x_world_abs = abs_pos_world[0],
pos_y_world_abs = abs_pos_world[1],
pos_z_world_abs = abs_pos_world[2],
dist_goal_x = dist_goal[0],
dist_goal_y = dist_goal[1],
#dist_goal_z = dist_goal[2],
)

return state.replace(
Expand Down
15 changes: 11 additions & 4 deletions alfredo/rewards/rControl.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,15 @@ def rTracking_Waypoint(sys: base.System,
# x_i = pipeline_state.x.vmap().do(
# base.Transform.create(pos=sys.link.inertia.transform.pos)
# )

#print(f"wcmd: {waypoint}")
#print(f"x.pos[0]: {pipeline_state.x.pos[0]}")
torso_pos = pipeline_state.x.pos[focus_idx_range]
pos_goal_diff = torso_pos[0:2] - waypoint[0:2]
#print(f"pos_goal_diff: {pos_goal_diff}")
pos_sum_abs_diff = -jp.sum(jp.abs(pos_goal_diff))
#inv_euclid_dist = -math.safe_norm(pos_goal_diff)
#print(f"pos_sum_abs_diff: {pos_sum_abs_diff}")

pos_goal_diff = pipeline_state.x.pos[focus_idx_range] - waypoint
inv_euclid_dist = -math.safe_norm(pos_goal_diff)

return weight*inv_euclid_dist
#return weight*inv_euclid_dist
return weight*pos_sum_abs_diff
6 changes: 4 additions & 2 deletions experiments/AAnt-locomotion/one_physics_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
env_xml_path=env_xml_paths[0],
agent_xml_path=agent_xml_path)

rng = jax.random.PRNGKey(seed=0)
rng = jax.random.PRNGKey(seed=3)
state = env.reset(rng=rng) #state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0))

normalize = lambda x, y: x
Expand All @@ -60,7 +60,9 @@
policy_params = (params[0], params[1])
inference_fn = make_policy(policy_params)

wcmd = jp.array([10.0, 10.0, 0.5])
wcmd = jp.array([0.0, 1000.0])
key_envs, _ = jax.random.split(rng)
state = env.reset(rng=key_envs)

print(f"q: {state.pipeline_state.q}")
print(f"\n")
Expand Down
30 changes: 18 additions & 12 deletions experiments/AAnt-locomotion/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,28 +28,34 @@
"backend": "positional",
"seed": 0,
"len_training": 1_500_000,
"batch_size": 1024,
"batch_size": 2048,
"episode_len": 1000,
},
)

normalize_fn = running_statistics.normalize

def progress(num_steps, metrics):
print(num_steps)
print(metrics)
epi_len = wandb.config.episode_len
wandb.log(
{
"step": num_steps,
"Total Reward": metrics["eval/episode_reward"],
"Waypoint Reward": metrics["eval/episode_reward_waypoint"],
"Total Reward": metrics["eval/episode_reward"]/epi_len,
"Waypoint Reward": metrics["eval/episode_reward_waypoint"]/epi_len,
#"Lin Vel Reward": metrics["eval/episode_reward_lin_vel"],
#"Yaw Vel Reward": metrics["eval/episode_reward_yaw_vel"],
"Alive Reward": metrics["eval/episode_reward_alive"],
"Ctrl Reward": metrics["eval/episode_reward_ctrl"],
"Upright Reward": metrics["eval/episode_reward_upright"],
"Torque Reward": metrics["eval/episode_reward_torque"],
"Abs Pos X World": metrics["eval/episode_pos_x_world_abs"],
"Abs Pos Y World": metrics["eval/episode_pos_y_world_abs"],
"Abs Pos Z World": metrics["eval/episode_pos_z_world_abs"],
"Alive Reward": metrics["eval/episode_reward_alive"]/epi_len,
"Ctrl Reward": metrics["eval/episode_reward_ctrl"]/epi_len,
"Upright Reward": metrics["eval/episode_reward_upright"]/epi_len,
"Torque Reward": metrics["eval/episode_reward_torque"]/epi_len,
"Abs Pos X World": metrics["eval/episode_pos_x_world_abs"]/epi_len,
"Abs Pos Y World": metrics["eval/episode_pos_y_world_abs"]/epi_len,
"Abs Pos Z World": metrics["eval/episode_pos_z_world_abs"]/epi_len,
"Dist Goal X": metrics["eval/episode_dist_goal_x"]/epi_len,
"Dist Goal Y": metrics["eval/episode_dist_goal_y"]/epi_len,
#"Dist Goal Z": metrics["eval/episode_dist_goal_z"]/epi_len,
}
)

Expand Down Expand Up @@ -120,9 +126,9 @@ def progress(num_steps, metrics):
train_fn = functools.partial(
ppo.train,
num_timesteps=wandb.config.len_training,
num_evals=100,
num_evals=400,
reward_scaling=0.1,
episode_length=1000,
episode_length=wandb.config.episode_len,
normalize_observations=True,
action_repeat=1,
unroll_length=10,
Expand Down
4 changes: 2 additions & 2 deletions experiments/AAnt-locomotion/vis_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
jit_env_step = jax.jit(env.step)

rollout = []
rng = jax.random.PRNGKey(seed=0)
rng = jax.random.PRNGKey(seed=13294)
state = jit_env_reset(rng=rng)

normalize = lambda x, y: x
Expand All @@ -79,7 +79,7 @@
#yaw_vel = 0.0 # rad/s
#jcmd = jp.array([x_vel, y_vel, yaw_vel])

wcmd = jp.array([10.0, 10.0, 0.5])
wcmd = jp.array([0.0, 10.0])

# generate policy rollout
for _ in range(episode_length):
Expand Down

0 comments on commit 1f5f123

Please sign in to comment.