Skip to content

Commit

Permalink
Adding RL tutorial (deepchem#3968)
Browse files Browse the repository at this point in the history
* Adding RL tutorial

* Porting existing RL tutorial to PyTorch

* Removing unnecessary files
  • Loading branch information
NimishaDey authored May 29, 2024
1 parent d4cc476 commit b348694
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 36 deletions.
7 changes: 4 additions & 3 deletions deepchem/rl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class GymEnvironment(Environment):
def __init__(self, name):
"""Create an Environment wrapping the OpenAI Gym environment with a specified name."""
import gym
self.env = gym.make(name)
self.env = gym.make(name, render_mode='human')
self.name = name
space = self.env.action_space
if 'n' in dir(space):
Expand All @@ -163,11 +163,12 @@ def __init__(self, name):
action_shape=space.shape)

def reset(self):
self._state = self.env.reset()
self._state = self.env.reset()[0]
self._terminated = False

def step(self, action):
self._state, reward, self._terminated, info = self.env.step(action)
self._state, reward, self._terminated, bool_val, info = self.env.step(
action)
return reward

def __deepcopy__(self, memo):
Expand Down
193 changes: 160 additions & 33 deletions examples/tutorials/Using_Reinforcement_Learning_to_Play_Pong.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
Expand All @@ -33,7 +33,69 @@
"id": "-1kpETs2GnbI",
"outputId": "dc8d5ae6-a0d7-4236-8168-8b615806ce41"
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: deepchem in c:\\users\\hp\\deepchem_2 (2.8.1.dev20240501183346)\n",
"Requirement already satisfied: joblib in c:\\users\\hp\\anaconda3\\envs\\deep\\lib\\site-packages (from deepchem) (1.3.2)\n",
"Requirement already satisfied: numpy>=1.21 in c:\\users\\hp\\anaconda3\\envs\\deep\\lib\\site-packages (from deepchem) (1.26.4)\n",
"Requirement already satisfied: pandas in c:\\users\\hp\\anaconda3\\envs\\deep\\lib\\site-packages\\pandas-2.2.1-py3.10-win-amd64.egg (from deepchem) (2.2.1)\n",
"Requirement already satisfied: scikit-learn in c:\\users\\hp\\anaconda3\\envs\\deep\\lib\\site-packages (from deepchem) (1.4.1.post1)\n",
"Requirement already satisfied: sympy in c:\\users\\hp\\anaconda3\\envs\\deep\\lib\\site-packages (from deepchem) (1.12)\n",
"Requirement already satisfied: scipy>=1.10.1 in c:\\users\\hp\\anaconda3\\envs\\deep\\lib\\site-packages (from deepchem) (1.12.0)\n",
"Requirement already satisfied: rdkit in c:\\users\\hp\\anaconda3\\envs\\deep\\lib\\site-packages\\rdkit-2023.9.5-py3.10-win-amd64.egg (from deepchem) (2023.9.5)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in c:\\users\\hp\\anaconda3\\envs\\deep\\lib\\site-packages (from pandas->deepchem) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in c:\\users\\hp\\anaconda3\\envs\\deep\\lib\\site-packages\\pytz-2024.1-py3.10.egg (from pandas->deepchem) (2024.1)\n",
"Requirement already satisfied: tzdata>=2022.7 in c:\\users\\hp\\anaconda3\\envs\\deep\\lib\\site-packages\\tzdata-2024.1-py3.10.egg (from pandas->deepchem) (2024.1)\n",
"Requirement already satisfied: Pillow in c:\\users\\hp\\anaconda3\\envs\\deep\\lib\\site-packages (from rdkit->deepchem) (10.2.0)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in c:\\users\\hp\\anaconda3\\envs\\deep\\lib\\site-packages (from scikit-learn->deepchem) (3.3.0)\n",
"Requirement already satisfied: mpmath>=0.19 in c:\\users\\hp\\anaconda3\\envs\\deep\\lib\\site-packages (from sympy->deepchem) (1.3.0)\n",
"Requirement already satisfied: six>=1.5 in c:\\users\\hp\\anaconda3\\envs\\deep\\lib\\site-packages (from python-dateutil>=2.8.2->pandas->deepchem) (1.16.0)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"No normalization for SPS. Feature removed!\n",
"No normalization for AvgIpc. Feature removed!\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From c:\\Users\\HP\\anaconda3\\envs\\deep\\lib\\site-packages\\keras\\src\\losses.py:2976: The name tf.losses.sparse_softmax_cross_entropy is deprecated. Please use tf.compat.v1.losses.sparse_softmax_cross_entropy instead.\n",
"\n",
"WARNING:tensorflow:From c:\\Users\\HP\\anaconda3\\envs\\deep\\lib\\site-packages\\tensorflow\\python\\util\\deprecation.py:588: calling function (from tensorflow.python.eager.polymorphic_function.polymorphic_function) with experimental_relax_shapes is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"experimental_relax_shapes is deprecated, use reduce_retracing instead\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Skipped loading modules with pytorch-geometric dependency, missing a dependency. No module named 'dgl'\n",
"Skipped loading modules with transformers dependency. No module named 'transformers'\n",
"cannot import name 'HuggingFaceModel' from 'deepchem.models.torch_models' (c:\\users\\hp\\deepchem_2\\deepchem\\models\\torch_models\\__init__.py)\n",
"Skipped loading modules with pytorch-lightning dependency, missing a dependency. No module named 'lightning'\n",
"Skipped loading some Jax models, missing a dependency. No module named 'jax'\n"
]
},
{
"data": {
"text/plain": [
"'2.8.1.dev'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"!pip install --pre deepchem\n",
"import deepchem\n",
Expand All @@ -42,7 +104,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
Expand All @@ -52,9 +114,33 @@
"id": "9sv6kX_VsoZ1",
"outputId": "ce4206d5-7917-4cad-c716-238a41f78e2a"
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: gym[accept-rom-license,atari] in c:\\users\\hp\\anaconda3\\envs\\deep\\lib\\site-packages (0.26.2)\n",
"Requirement already satisfied: numpy>=1.18.0 in c:\\users\\hp\\anaconda3\\envs\\deep\\lib\\site-packages (from gym[accept-rom-license,atari]) (1.26.4)\n",
"Requirement already satisfied: cloudpickle>=1.2.0 in c:\\users\\hp\\anaconda3\\envs\\deep\\lib\\site-packages (from gym[accept-rom-license,atari]) (3.0.0)\n",
"Requirement already satisfied: gym-notices>=0.0.4 in c:\\users\\hp\\anaconda3\\envs\\deep\\lib\\site-packages (from gym[accept-rom-license,atari]) (0.0.8)\n",
"Requirement already satisfied: ale-py~=0.8.0 in c:\\users\\hp\\anaconda3\\envs\\deep\\lib\\site-packages (from gym[accept-rom-license,atari]) (0.8.1)\n",
"Requirement already satisfied: autorom~=0.4.2 in c:\\users\\hp\\anaconda3\\envs\\deep\\lib\\site-packages (from autorom[accept-rom-license]~=0.4.2; extra == \"accept-rom-license\"->gym[accept-rom-license,atari]) (0.4.2)\n",
"Requirement already satisfied: importlib-resources in c:\\users\\hp\\anaconda3\\envs\\deep\\lib\\site-packages (from ale-py~=0.8.0->gym[accept-rom-license,atari]) (6.4.0)\n",
"Requirement already satisfied: typing-extensions in c:\\users\\hp\\anaconda3\\envs\\deep\\lib\\site-packages (from ale-py~=0.8.0->gym[accept-rom-license,atari]) (4.9.0)\n",
"Requirement already satisfied: click in c:\\users\\hp\\anaconda3\\envs\\deep\\lib\\site-packages (from autorom~=0.4.2->autorom[accept-rom-license]~=0.4.2; extra == \"accept-rom-license\"->gym[accept-rom-license,atari]) (8.1.7)\n",
"Requirement already satisfied: requests in c:\\users\\hp\\anaconda3\\envs\\deep\\lib\\site-packages (from autorom~=0.4.2->autorom[accept-rom-license]~=0.4.2; extra == \"accept-rom-license\"->gym[accept-rom-license,atari]) (2.31.0)\n",
"Requirement already satisfied: tqdm in c:\\users\\hp\\anaconda3\\envs\\deep\\lib\\site-packages (from autorom~=0.4.2->autorom[accept-rom-license]~=0.4.2; extra == \"accept-rom-license\"->gym[accept-rom-license,atari]) (4.66.2)\n",
"Requirement already satisfied: AutoROM.accept-rom-license in c:\\users\\hp\\anaconda3\\envs\\deep\\lib\\site-packages (from autorom[accept-rom-license]~=0.4.2; extra == \"accept-rom-license\"->gym[accept-rom-license,atari]) (0.6.1)\n",
"Requirement already satisfied: colorama in c:\\users\\hp\\anaconda3\\envs\\deep\\lib\\site-packages (from click->autorom~=0.4.2->autorom[accept-rom-license]~=0.4.2; extra == \"accept-rom-license\"->gym[accept-rom-license,atari]) (0.4.6)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in c:\\users\\hp\\anaconda3\\envs\\deep\\lib\\site-packages (from requests->autorom~=0.4.2->autorom[accept-rom-license]~=0.4.2; extra == \"accept-rom-license\"->gym[accept-rom-license,atari]) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in c:\\users\\hp\\anaconda3\\envs\\deep\\lib\\site-packages (from requests->autorom~=0.4.2->autorom[accept-rom-license]~=0.4.2; extra == \"accept-rom-license\"->gym[accept-rom-license,atari]) (3.6)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in c:\\users\\hp\\anaconda3\\envs\\deep\\lib\\site-packages (from requests->autorom~=0.4.2->autorom[accept-rom-license]~=0.4.2; extra == \"accept-rom-license\"->gym[accept-rom-license,atari]) (2.2.1)\n",
"Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\hp\\appdata\\roaming\\python\\python310\\site-packages (from requests->autorom~=0.4.2->autorom[accept-rom-license]~=0.4.2; extra == \"accept-rom-license\"->gym[accept-rom-license,atari]) (2022.5.18.1)\n"
]
}
],
"source": [
"!pip install 'gym[atari]'"
"!pip install \"gym[atari,accept-rom-license]\""
]
},
{
Expand All @@ -74,7 +160,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {
"colab": {},
"colab_type": "code",
Expand All @@ -85,19 +171,21 @@
"import deepchem as dc\n",
"import numpy as np\n",
"\n",
"\n",
"class PongEnv(dc.rl.GymEnvironment):\n",
" def __init__(self):\n",
" super(PongEnv, self).__init__('Pong-v0')\n",
" super(PongEnv, self).__init__('Pong-v4')\n",
" self._state_shape = (80, 80)\n",
" \n",
"\n",
" @property\n",
" def state(self):\n",
" # Crop everything outside the play area, reduce the image size,\n",
" # and convert it to black and white.\n",
" cropped = np.array(self._state)[34:194, :, :]\n",
" state_array = self._state\n",
" cropped = state_array[34:194, :, :]\n",
" reduced = cropped[0:-1:2, 0:-1:2]\n",
" grayscale = np.sum(reduced, axis=2)\n",
" bw = np.zeros(grayscale.shape)\n",
" bw = np.zeros(grayscale.shape, dtype=np.float32)\n",
" bw[grayscale != 233] = 1\n",
" return bw\n",
"\n",
Expand Down Expand Up @@ -127,34 +215,52 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "BLdt8WAQsoaH"
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"from tensorflow.keras.layers import Input, Concatenate, Conv2D, Dense, Flatten, GRU, Reshape\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"class PongPolicy(dc.rl.Policy):\n",
" def __init__(self):\n",
" super(PongPolicy, self).__init__(['action_prob', 'value', 'rnn_state'], [np.zeros(16)])\n",
" super(PongPolicy, self).__init__(['action_prob', 'value', 'rnn_state'], [np.zeros(16, dtype=np.float32)])\n",
"\n",
" def create_model(self, **kwargs):\n",
" state = Input(shape=(80, 80))\n",
" rnn_state = Input(shape=(16,))\n",
" conv1 = Conv2D(16, kernel_size=8, strides=4, activation=tf.nn.relu)(Reshape((80, 80, 1))(state))\n",
" conv2 = Conv2D(32, kernel_size=4, strides=2, activation=tf.nn.relu)(conv1)\n",
" dense = Dense(256, activation=tf.nn.relu)(Flatten()(conv2))\n",
" gru, rnn_final_state = GRU(16, return_state=True, return_sequences=True, time_major=True)(\n",
" Reshape((-1, 256))(dense), initial_state=rnn_state)\n",
" concat = Concatenate()([dense, Reshape((16,))(gru)])\n",
" action_prob = Dense(env.n_actions, activation=tf.nn.softmax)(concat)\n",
" value = Dense(1)(concat)\n",
" return tf.keras.Model(inputs=[state, rnn_state], outputs=[action_prob, value, rnn_final_state])\n",
"\n",
" class TestModel(nn.Module):\n",
" def __init__(self):\n",
" super(TestModel, self).__init__()\n",
" # Convolutional layers\n",
" self.conv1 = nn.Conv2d(1, 16, kernel_size=8, stride=4)\n",
" self.conv2 = nn.Conv2d(16, 32, kernel_size=4, stride=2)\n",
" self.fc1 = nn.Linear(2048, 256)\n",
" self.gru = nn.GRU(256, 16, batch_first = True)\n",
" self.action_prob = nn.Linear(272, env.n_actions)\n",
" self.value = nn.Linear(272, 1)\n",
" def forward(self, inputs):\n",
" state = (torch.from_numpy((inputs[0])))\n",
" rnn_state = (torch.from_numpy(inputs[1]))\n",
" reshaped = state.view(-1, 1, 80, 80)\n",
" conv1 = F.relu(self.conv1(reshaped))\n",
" conv2 = F.relu(self.conv2(conv1))\n",
" conv2 = conv2.view(conv2.size(0), -1)\n",
" x = F.relu(self.fc1(conv2))\n",
" reshaped_x = x.view(1, -1, 256)\n",
" #x = torch.flatten(x, 1)\n",
" gru_out, rnn_final_state = self.gru(reshaped_x, rnn_state.unsqueeze(0))\n",
" rnn_final_state = rnn_final_state.view(-1,16)\n",
" gru_out = gru_out.view(-1, 16)\n",
" concat = torch.cat((x, gru_out), dim=1)\n",
" #concat = concat.view(-1, 272)\n",
" action_prob = F.softmax(self.action_prob(concat), dim=-1)\n",
" value = self.value(concat)\n",
" return action_prob, value, rnn_final_state\n",
" return TestModel()\n",
"policy = PongPolicy()"
]
},
Expand All @@ -170,7 +276,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {
"colab": {},
"colab_type": "code",
Expand All @@ -179,8 +285,11 @@
},
"outputs": [],
"source": [
"import torch.nn.functional as F\n",
"from deepchem.rl.torch_rl.torch_a2c import A2C\n",
"\n",
"from deepchem.models.optimizers import Adam\n",
"a2c = dc.rl.A2C(env, policy, model_dir='model', optimizer=Adam(learning_rate=0.0002))"
"a2c = A2C(env, policy, model_dir='model', optimizer=Adam(learning_rate=0.0002))"
]
},
{
Expand All @@ -195,13 +304,22 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Wa18EQlmsoaV"
},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\HP\\anaconda3\\envs\\deep\\lib\\site-packages\\gym\\utils\\passive_env_checker.py:233: DeprecationWarning: `np.bool8` is a deprecated alias for `np.bool_`. (Deprecated NumPy 1.24)\n",
" if not isinstance(terminated, (bool, np.bool8)):\n"
]
}
],
"source": [
"# Change this to train as many steps as you have patience for.\n",
"a2c.fit(1000)"
Expand All @@ -219,13 +337,22 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Ud6DB_ndsoab"
},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\HP\\anaconda3\\envs\\deep\\lib\\site-packages\\gym\\utils\\passive_env_checker.py:289: UserWarning: \u001b[33mWARN: No render fps was declared in the environment (env.metadata['render_fps'] is None or not defined), rendering may occur at inconsistent fps.\u001b[0m\n",
" logger.warn(\n"
]
}
],
"source": [
"# This code doesn't work well on Colab\n",
"env.reset()\n",
Expand Down Expand Up @@ -273,7 +400,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
"version": "3.10.13"
}
},
"nbformat": 4,
Expand Down

0 comments on commit b348694

Please sign in to comment.