From 125f6cc1a8f785df15e94759af48db0509b32fd8 Mon Sep 17 00:00:00 2001 From: Fritz449 Date: Thu, 3 Mar 2022 17:41:19 +0300 Subject: [PATCH] Fixed ppo.ipynb --- week09_policy_II/ppo.ipynb | 577 +++++++++++++++++++++++++++++-------- 1 file changed, 458 insertions(+), 119 deletions(-) diff --git a/week09_policy_II/ppo.ipynb b/week09_policy_II/ppo.ipynb index 4d1fd67d7..ee96c9b77 100644 --- a/week09_policy_II/ppo.ipynb +++ b/week09_policy_II/ppo.ipynb @@ -50,7 +50,8 @@ "outputs": [], "source": [ "!git clone https://github.com/benelot/pybullet-gym lib/pybullet-gym\n", - "!pip install -e lib/pybullet-gym" + "!pip install -e lib/pybullet-gym\n", + "!pip install wandb" ] }, { @@ -70,6 +71,8 @@ "metadata": {}, "outputs": [], "source": [ + "CWD = './'\n", + "\n", "import gym \n", "import pybulletgym\n", "\n", @@ -86,15 +89,90 @@ "metadata": {}, "outputs": [], "source": [ - "class Summaries(gym.Wrapper):\n", - " \"\"\" Wrapper to write summaries. \"\"\"\n", - " def step(self, action):\n", - " # TODO: implement writing summaries\n", - " return self.env.step(action)\n", - " \n", - " def reset(self, **kwargs):\n", - " # TODO: implement writing summaries\n", - " return self.env.reset(**kwargs)" + "import wandb\n", + "from collections import defaultdict, deque\n", + "import numpy as np\n", + "\n", + "\n", + "class SummariesBase(gym.Wrapper):\n", + " \"\"\" Env summaries writer base.\"\"\"\n", + "\n", + " def __init__(self, env, prefix=None, running_mean_size=100):\n", + " super().__init__(env)\n", + " self.episode_counter = 0\n", + " self.prefix = prefix\n", + "\n", + " nenvs = getattr(self.env.unwrapped, \"nenvs\", 1)\n", + " self.rewards = np.zeros(nenvs)\n", + " self.had_ended_episodes = np.zeros(nenvs, dtype=np.bool)\n", + " self.episode_lengths = np.zeros(nenvs)\n", + " self.reward_queues = [deque([], maxlen=running_mean_size)\n", + " for _ in range(nenvs)]\n", + "\n", + " def should_write_summaries(self):\n", + " \"\"\" Returns true if it's time to write summaries. \"\"\"\n", + " return np.all(self.had_ended_episodes)\n", + "\n", + " def add_summaries(self):\n", + " \"\"\" Writes summaries. \"\"\"\n", + " self.add_summary_scalar(\n", + " f\"{self.prefix}/total_reward\",\n", + " np.mean([q[-1] for q in self.reward_queues]))\n", + " self.add_summary_scalar(\n", + " f\"{self.prefix}/reward_mean_{self.reward_queues[0].maxlen}\",\n", + " np.mean([np.mean(q) for q in self.reward_queues]))\n", + " self.add_summary_scalar(\n", + " f\"{self.prefix}/episode_length\",\n", + " np.mean(self.episode_lengths))\n", + " if self.had_ended_episodes.size > 1:\n", + " self.add_summary_scalar(\n", + " f\"{self.prefix}/min_reward\",\n", + " min(q[-1] for q in self.reward_queues))\n", + " self.add_summary_scalar(\n", + " f\"{self.prefix}/max_reward\",\n", + " max(q[-1] for q in self.reward_queues))\n", + " self.episode_lengths.fill(0)\n", + " self.had_ended_episodes.fill(False)\n", + "\n", + " def step(self, action):\n", + " obs, rew, done, info = self.env.step(action)\n", + " self.rewards += rew\n", + " self.episode_lengths[~self.had_ended_episodes] += 1\n", + "\n", + " info_collection = [info] if isinstance(info, dict) else info\n", + " done_collection = [done] if isinstance(done, bool) else done\n", + " done_indices = [i for i, info in enumerate(info_collection)\n", + " if info.get(\"real_done\", done_collection[i])]\n", + " for i in done_indices:\n", + " if not self.had_ended_episodes[i]:\n", + " self.had_ended_episodes[i] = True\n", + " self.reward_queues[i].append(self.rewards[i])\n", + " self.rewards[i] = 0\n", + "\n", + " if self.should_write_summaries():\n", + " self.add_summaries()\n", + " return obs, rew, done, info\n", + "\n", + " def reset(self, **kwargs):\n", + " self.rewards.fill(0)\n", + " self.episode_lengths.fill(0)\n", + " self.had_ended_episodes.fill(False)\n", + " return self.env.reset(**kwargs)\n", + "\n", + " \n", + "class TorchSummaries(SummariesBase):\n", + " \"\"\" Wrapper to write summaries. \"\"\"\n", + " def __init__(self, env, prefix=None, running_mean_size=100):\n", + " super().__init__(env, prefix, running_mean_size)\n", + " self.__step = 0\n", + " \n", + " def set_step(self, step):\n", + " self.__step = step\n", + "\n", + " def add_summary_scalar(self, name, value, step=None):\n", + " if step is None:\n", + " step = self.__step\n", + " wandb.log({name: value}, step=step)" ] }, { @@ -113,8 +191,150 @@ "source": [ "from mujoco_wrappers import Normalize\n", "\n", - "env = Normalize(Summaries(gym.make(\"HalfCheetahMuJoCoEnv-v0\")));\n", - "env.unwrapped.seed(0);" + "ENV_NAME = \"HalfCheetahMuJoCoEnv-v0\"\n", + "\n", + "def create_mujoco_env():\n", + " env = gym.make(ENV_NAME)\n", + " return env\n", + "\n", + "NENVS = 1\n", + "env = Normalize(TorchSummaries(create_mujoco_env(), prefix=ENV_NAME))\n", + "\n", + "STATE_DIM = env.observation_space.shape[0]\n", + "N_ACTIONS = env.action_space.shape[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class AsArray:\n", + " \"\"\" \n", + " Converts lists of interactions to ndarray.\n", + " \"\"\"\n", + " def __call__(self, trajectory):\n", + " # Modify trajectory inplace. \n", + " for k, v in filter(lambda kv: kv[0] != \"state\",\n", + " trajectory.items()):\n", + " trajectory[k] = np.asarray(v)\n", + " \n", + "class AsTensor:\n", + " \"\"\" \n", + " Converts lists of interactions to DEVICE torch.Tensor.\n", + " \"\"\"\n", + " def __call__(self, trajectory):\n", + " # Modify trajectory inplace. \n", + " for k, v in filter(lambda kv: kv[0] != \"state\",\n", + " trajectory.items()):\n", + " trajectory[k] = torch.Tensor(v).to(DEVICE)\n", + "\n", + "\"\"\" RL env runner \"\"\"\n", + "from collections import defaultdict\n", + "\n", + "import numpy as np\n", + "\n", + "\n", + "class EnvRunner:\n", + " \"\"\" Reinforcement learning runner in an environment with given policy \"\"\"\n", + "\n", + " def __init__(self, env, policy, nsteps, transforms=None, step_var=None):\n", + " self.env = env\n", + " self.policy = policy\n", + " self.nsteps = nsteps\n", + " self.transforms = transforms or []\n", + " self.step_var = step_var if step_var is not None else 0\n", + " self.state = {\"latest_observation\": self.env.reset()}\n", + "\n", + " @property\n", + " def nenvs(self):\n", + " \"\"\" Returns number of batched envs or `None` if env is not batched \"\"\"\n", + " return getattr(self.env.unwrapped, \"nenvs\", None)\n", + "\n", + " def reset(self):\n", + " \"\"\" Resets env and runner states. \"\"\"\n", + " self.state[\"latest_observation\"] = self.env.reset()\n", + " self.policy.reset()\n", + "\n", + " def get_next(self):\n", + " \"\"\" Runs the agent in the environment. \"\"\"\n", + " trajectory = defaultdict(list, {\"actions\": []})\n", + " observations = []\n", + " rewards = []\n", + " resets = []\n", + " self.state[\"env_steps\"] = self.nsteps\n", + "\n", + " for i in range(self.nsteps):\n", + " observations.append(self.state[\"latest_observation\"])\n", + " act = self.policy.act(self.state[\"latest_observation\"])\n", + " if \"actions\" not in act:\n", + " raise ValueError(\"result of policy.act must contain 'actions' \"\n", + " f\"but has keys {list(act.keys())}\")\n", + " for key, val in act.items():\n", + " trajectory[key].append(val)\n", + "\n", + " obs, rew, done, _ = self.env.step(trajectory[\"actions\"][-1])\n", + " self.state[\"latest_observation\"] = obs\n", + " rewards.append(rew)\n", + " resets.append(done)\n", + " self.step_var += self.nenvs or 1\n", + "\n", + " # Only reset if the env is not batched. Batched envs should\n", + " # auto-reset.\n", + " if not self.nenvs and np.all(done):\n", + " self.state[\"env_steps\"] = i + 1\n", + " self.state[\"latest_observation\"] = self.env.reset()\n", + "\n", + " trajectory.update(\n", + " observations=observations,\n", + " rewards=rewards,\n", + " resets=resets)\n", + " trajectory[\"state\"] = self.state\n", + "\n", + " for transform in self.transforms:\n", + " transform(trajectory)\n", + " return trajectory" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will use `EnvRunner` to perform interactions with an environment with a policy for a fixed number of timesteps. Calling `.get_next()` on a runner will return a trajectory — dictionary \n", + "containing keys\n", + "\n", + "* `\"observations\"`\n", + "* `\"rewards\"` \n", + "* `\"resets\"`\n", + "* `\"actions\"`\n", + "* all other keys that you defined in `Policy`,\n", + "\n", + "under each of these keys there is a `np.ndarray` of specified length $T$ — the size of partial trajectory. \n", + "\n", + "Additionally, before returning a trajectory this runner can apply a list of transformations. \n", + "Each transformation is simply a callable that should modify passed trajectory in-place." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "class DummyPolicy:\n", + " def act(self, inputs, training=False):\n", + " assert not training\n", + " N = inputs.shape[0]\n", + " return {\"actions\": np.random.randn(6,), \"values\": np.nan * np.ones(N)}\n", + "\n", + "runner = EnvRunner(env, DummyPolicy(), 3,\n", + " transforms=[AsArray()])\n", + "trajectory = runner.get_next()\n", + "\n", + "{k: v.shape for k, v in trajectory.items() if k != \"state\"}" ] }, { @@ -142,10 +362,41 @@ "metadata": {}, "outputs": [], "source": [ - "# import tensorflow as tf\n", - "# import torch\n", + "import torch\n", + "from torch import nn\n", + "\n", + "DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'\n", + "\n", "\n", - "" + "class PPOAgent(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.policy_nn = nn.Sequential(nn.Linear(STATE_DIM, 64),\n", + " nn.ReLU(),\n", + " nn.Linear(64, 64),\n", + " nn.ReLU(),\n", + " nn.Linear(64, N_ACTIONS),\n", + " nn.Tanh())\n", + " self.value_nn = nn.Sequential(nn.Linear(STATE_DIM, 64),\n", + " nn.ReLU(),\n", + " nn.Linear(64, 64),\n", + " nn.ReLU(),\n", + " nn.Linear(64, 1))\n", + " self.covariance_vec = nn.Parameter(torch.zeros(N_ACTIONS, requires_grad=True))\n", + " self.__initialize_net_weights(self.policy_nn)\n", + " self.__initialize_net_weights(self.value_nn)\n", + "\n", + " def __initialize_net_weights(self, net):\n", + " for p in net.parameters():\n", + " if p.ndim < 2:\n", + " nn.init.zeros_(p)\n", + " else:\n", + " nn.init.orthogonal_(p, 2 ** 0.5)\n", + " \n", + " def forward(self, observations):\n", + " policy_mean = self.policy_nn(observations)\n", + " value = self.value_nn(observations)\n", + " return policy_mean, torch.exp(self.covariance_vec), value" ] }, { @@ -182,69 +433,29 @@ "outputs": [], "source": [ "class Policy:\n", - " def __init__(self, model):\n", - " self.model = model\n", + " def __init__(self, model):\n", + " self.model = model\n", + " \n", + " def estimate_v(self, inputs):\n", + " with torch.no_grad():\n", + " return self.model.value_nn(inputs)\n", " \n", - " def act(self, inputs, training=False):\n", - " \n", - " # Should return a dict." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We will use `EnvRunner` to perform interactions with an environment with a policy for a fixed number of timesteps. Calling `.get_next()` on a runner will return a trajectory — dictionary \n", - "containing keys\n", - "\n", - "* `\"observations\"`\n", - "* `\"rewards\"` \n", - "* `\"resets\"`\n", - "* `\"actions\"`\n", - "* all other keys that you defined in `Policy`,\n", - "\n", - "under each of these keys there is a `np.ndarray` of specified length $T$ — the size of partial trajectory. \n", - "\n", - "Additionally, before returning a trajectory this runner can apply a list of transformations. \n", - "Each transformation is simply a callable that should modify passed trajectory in-place." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class AsArray:\n", - " \"\"\" \n", - " Converts lists of interactions to ndarray.\n", - " \"\"\"\n", - " def __call__(self, trajectory):\n", - " # Modify trajectory inplace. \n", - " for k, v in filter(lambda kv: kv[0] != \"state\",\n", - " trajectory.items()):\n", - " trajectory[k] = np.asarray(v)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "from runners import EnvRunner\n", - "\n", - "class DummyPolicy:\n", - " def act(self, inputs, training=False):\n", - " assert not training\n", - " return {\"actions\": np.random.randn(6), \"values\": np.nan}\n", - " \n", - "runner = EnvRunner(env, DummyPolicy(), 3,\n", - " transforms=[AsArray()])\n", - "trajectory = runner.get_next()\n", - "\n", - "{k: v.shape for k, v in trajectory.items() if k != \"state\"}" + " def act(self, inputs, training=False):\n", + " # Should return a dict.\n", + " if isinstance(inputs, np.ndarray):\n", + " inputs = torch.Tensor(inputs).to(DEVICE)\n", + " if training:\n", + " mean, cov, value = self.model(inputs)\n", + " return {'distribution': torch.distributions.MultivariateNormal(mean, torch.diag(cov)),\n", + " 'values': value}\n", + " else:\n", + " with torch.no_grad():\n", + " mean, cov, value = self.model(inputs)\n", + " dist = torch.distributions.MultivariateNormal(mean, torch.diag(cov))\n", + " actions = dist.sample()\n", + " return {'actions': actions.cpu().numpy(),\n", + " 'log_probs': dist.log_prob(actions).cpu().numpy(),\n", + " 'values': value.cpu().numpy()}" ] }, { @@ -285,13 +496,13 @@ "outputs": [], "source": [ "class GAE:\n", - " \"\"\" Generalized Advantage Estimator. \"\"\"\n", - " def __init__(self, policy, gamma=0.99, lambda_=0.95):\n", + " \"\"\" Generalized Advantage Estimator. \"\"\"\n", + " def __init__(self, policy, gamma=0.99, lambda_=0.95):\n", " self.policy = policy\n", " self.gamma = gamma\n", " self.lambda_ = lambda_\n", - " \n", - " def __call__(self, trajectory):\n", + "\n", + " def __call__(self, trajectory):\n", " " ] }, @@ -311,27 +522,58 @@ "metadata": {}, "outputs": [], "source": [ + "from math import ceil\n", + "from tqdm.notebook import tqdm\n", + "\n", "class TrajectorySampler:\n", - " \"\"\" Samples minibatches from trajectory for a number of epochs. \"\"\"\n", - " def __init__(self, runner, num_epochs, num_minibatches, transforms=None):\n", - " self.runner = runner\n", - " self.num_epochs = num_epochs\n", - " self.num_minibatches = num_minibatches\n", - " self.transforms = transforms or []\n", - " self.minibatch_count = 0\n", - " self.epoch_count = 0\n", - " self.trajectory = None\n", - " \n", - " def shuffle_trajectory(self):\n", - " \"\"\" Shuffles all elements in trajectory.\n", + " \"\"\" Samples minibatches from trajectory for a number of epochs. \"\"\"\n", + " def __init__(self, runner, num_epochs, num_minibatches, transforms=None):\n", + " self.runner = runner\n", + " self.num_epochs = num_epochs\n", + " self.num_minibatches = num_minibatches\n", + " self.transforms = transforms or []\n", + " self.minibatch_count = 0\n", + " self.epoch_count = 0\n", + " self.trajectory = None\n", + " self.trajectory_length = 0\n", + " self.permutation = None\n", + " self.sample_trajectory()\n", + " \n", + " def sample_trajectory(self):\n", + " self.trajectory = self.runner.get_next()\n", + " self.trajectory_length = self.trajectory['actions'].shape[0]\n", + " self.shuffle_trajectory()\n", + " \n", + " def choose_minibatch(self, idx):\n", + " mb_size = ceil(self.trajectory_length / self.num_minibatches)\n", + " permutation_slice = self.permutation[idx*mb_size:(idx+1)*mb_size]\n", + " minibatch = {}\n", + " for k, v in self.trajectory.items():\n", + " if k != 'state':\n", + " minibatch[k] = v[permutation_slice]\n", + " return minibatch\n", " \n", - " Should be called at the beginning of each epoch.\n", - " \"\"\"\n", - " \n", + " def shuffle_trajectory(self):\n", + " \"\"\" Shuffles all elements in trajectory.\n", + "\n", + " Should be called at the beginning of each epoch.\n", + " \"\"\"\n", + " self.permutation = torch.randperm(self.trajectory_length)\n", " \n", - " def get_next(self):\n", - " \"\"\" Returns next minibatch. \"\"\"\n", - " " + " def get_next(self):\n", + " \"\"\" Returns next minibatch. \"\"\"\n", + " if self.minibatch_count == self.num_minibatches:\n", + " self.minibatch_count = 0\n", + " self.epoch_count += 1\n", + " self.shuffle_trajectory()\n", + " if self.epoch_count == self.num_epochs:\n", + " self.epoch_count = 0\n", + " self.sample_trajectory()\n", + " minibatch = self.choose_minibatch(self.minibatch_count)\n", + " self.minibatch_count += 1\n", + " for tf in self.transforms:\n", + " tf(minibatch)\n", + " return minibatch" ] }, { @@ -348,8 +590,8 @@ "outputs": [], "source": [ "class NormalizeAdvantages:\n", - " \"\"\" Normalizes advantages to have zero mean and variance 1. \"\"\"\n", - " def __call__(self, trajectory):\n", + " \"\"\" Normalizes advantages to have zero mean and variance 1. \"\"\"\n", + " def __call__(self, trajectory):\n", " adv = trajectory[\"advantages\"]\n", " adv = (adv - adv.mean()) / (adv.std() + 1e-8)\n", " trajectory[\"advantages\"] = adv" @@ -371,17 +613,17 @@ "def make_ppo_runner(env, policy, num_runner_steps=2048,\n", " gamma=0.99, lambda_=0.95, \n", " num_epochs=10, num_minibatches=32):\n", - " \"\"\" Creates runner for PPO algorithm. \"\"\"\n", - " runner_transforms = [AsArray(),\n", + " \"\"\" Creates runner for PPO algorithm. \"\"\"\n", + " runner_transforms = [AsArray(),\n", " GAE(policy, gamma=gamma, lambda_=lambda_)]\n", - " runner = EnvRunner(env, policy, num_runner_steps, \n", + " runner = EnvRunner(env, policy, num_runner_steps, \n", " transforms=runner_transforms)\n", - " \n", - " sampler_transforms = [NormalizeAdvantages()]\n", - " sampler = TrajectorySampler(runner, num_epochs=num_epochs, \n", + "\n", + " sampler_transforms = [NormalizeAdvantages()]\n", + " sampler = TrajectorySampler(runner, num_epochs=num_epochs, \n", " num_minibatches=num_minibatches,\n", " transforms=sampler_transforms)\n", - " return sampler" + " return sampler" ] }, { @@ -433,7 +675,7 @@ "outputs": [], "source": [ "class PPO:\n", - " def __init__(self, policy, optimizer,\n", + " def __init__(self, policy, optimizer,\n", " cliprange=0.2,\n", " value_loss_coef=0.25,\n", " max_grad_norm=0.5):\n", @@ -443,22 +685,22 @@ " self.value_loss_coef = value_loss_coef\n", " # Note that we don't need entropy regularization for this env.\n", " self.max_grad_norm = max_grad_norm\n", - " \n", - " def policy_loss(self, trajectory, act):\n", + "\n", + " def policy_loss(self, trajectory, act):\n", " \"\"\" Computes and returns policy loss on a given trajectory. \"\"\"\n", " \n", - " \n", - " def value_loss(self, trajectory, act):\n", + "\n", + " def value_loss(self, trajectory, act):\n", " \"\"\" Computes and returns value loss on a given trajectory. \"\"\"\n", " \n", - " \n", - " def loss(self, trajectory):\n", + "\n", + " def loss(self, trajectory):\n", " act = self.policy.act(trajectory[\"observations\"], training=True)\n", " policy_loss = self.policy_loss(trajectory, act)\n", " value_loss = self.value_loss(trajectory, act)\n", " return policy_loss + self.value_loss_coef * value_loss\n", - " \n", - " def step(self, trajectory):\n", + "\n", + " def step(self, trajectory):\n", " \"\"\" Computes the loss function and performs a single gradient step. \"\"\"\n", " " ] @@ -491,13 +733,110 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "from math import ceil\n", + "\n", + "NUM_RUNNER_STEPS = 2048\n", + "NUM_RUNNER_EPOCHS = 10\n", + "NUM_RUNNER_MINIBATCHES = 32\n", + "MINIBATCH_SIZE = NUM_RUNNER_STEPS * NENVS // NUM_RUNNER_MINIBATCHES\n", + "\n", + "TOTAL_MINIBATCHES = int(3e6) // NUM_RUNNER_STEPS // NENVS * NUM_RUNNER_EPOCHS * NUM_RUNNER_MINIBATCHES\n", + "MINIBATCHES_IN_EPOCH = NUM_RUNNER_EPOCHS * NUM_RUNNER_MINIBATCHES\n", + "TOTAL_EPOCHS = ceil(TOTAL_MINIBATCHES // MINIBATCHES_IN_EPOCH)\n", + "\n", + "START = 0\n", + "CKPT_FREQ = 3\n", + "AGENT_LR = 3e-4\n", + "\n", + "EXPERIMENT_NAME = 'PPO'\n", + "config = wandb.config\n", + "config.init_lr = AGENT_LR\n", + "config.num_runner_steps = NUM_RUNNER_STEPS\n", + "config.num_runner_epochs = NUM_RUNNER_EPOCHS\n", + "config.num_runner_minibatches = NUM_RUNNER_MINIBATCHES\n", + "config.minibatch_size = MINIBATCH_SIZE\n", + "config.nenvs = NENVS\n", + "\n", + "MY_NAME = \n", + "\n", + "wandb.init(project='ppo', entity=MY_NAME, name=EXPERIMENT_NAME, config=config)\n", + "\n", + "env.reset()\n", + "\n", + "policy = Policy(model)\n", + "runner = make_ppo_runner(env, policy, num_runner_steps=NUM_RUNNER_STEPS,\n", + " num_epochs=NUM_RUNNER_EPOCHS, num_minibatches=NUM_RUNNER_MINIBATCHES)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = torch.optim.Adam(model.parameters(), lr=AGENT_LR, eps=1e-5)\n", + "sched = torch.optim.lr_scheduler.LambdaLR(optimizer,\n", + " lambda epoch: (TOTAL_EPOCHS - epoch) / (TOTAL_EPOCHS - epoch + 1))\n", + "ppo = PPO(policy, optimizer)\n", + "ppo.step(runner.get_next())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "MODEL_BASE_PATH = os.path.join(CWD, 'models')\n", + "os.makedirs(MODEL_BASE_PATH, exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from tqdm.notebook import tqdm\n", + "\n", + "for ep in tqdm(range(START, TOTAL_EPOCHS), desc='Epoch'):\n", + " for mb_idx in range(MINIBATCHES_IN_EPOCH):\n", + " env.env.set_step((ep * MINIBATCHES_IN_EPOCH + mb_idx) * MINIBATCH_SIZE)\n", + " trajectory = runner.get_next()\n", + " losses = ppo.step(trajectory)\n", + " for k, v in losses.items():\n", + " env.env.add_summary_scalar(k, v)\n", + " env.env.add_summary_scalar('lr', sched.get_last_lr())\n", + " if (ep + 1) % CKPT_FREQ == 0:\n", + " sched.step()\n", + " state_dict = {'model': model.state_dict(),\n", + " 'optimizer': optimizer.state_dict(),\n", + " 'scheduler': sched.state_dict(),\n", + " 'ep': ep}\n", + " torch.save(state_dict, os.path.join(MODEL_BASE_PATH, 'ppo.pth'))" + ] } ], "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", "name": "python", - "pygments_lexer": "ipython3" + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" } }, "nbformat": 4,