diff --git a/reinforcement_learning/notebooks/Reinforcement_Q_Learning_from_Scratch_in_Python_with_OpenAI_Gym_Taxi.ipynb b/reinforcement_learning/notebooks/Reinforcement_Q_Learning_from_Scratch_in_Python_with_OpenAI_Gym_Taxi.ipynb new file mode 100644 index 0000000..e564b11 --- /dev/null +++ b/reinforcement_learning/notebooks/Reinforcement_Q_Learning_from_Scratch_in_Python_with_OpenAI_Gym_Taxi.ipynb @@ -0,0 +1,1256 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 689 + }, + "id": "T2YZUIm4tgmJ", + "outputId": "a4b83d45-5a86-4283-abfe-c2d226996aa0" + }, + "source": [ + "!pip install cmake 'gym[atari]==0.22.0' scipy" + ], + "execution_count": 2, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.9/dist-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", + " and should_run_async(code)\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", + "Requirement already satisfied: cmake in /usr/local/lib/python3.9/dist-packages (3.25.2)\n", + "Collecting gym[atari]==0.22.0\n", + " Downloading gym-0.22.0.tar.gz (631 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m631.1/631.1 kB\u001b[0m \u001b[31m10.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "Requirement already satisfied: scipy in /usr/local/lib/python3.9/dist-packages (1.10.1)\n", + "Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.9/dist-packages (from gym[atari]==0.22.0) (1.22.4)\n", + "Requirement already satisfied: importlib-metadata>=4.10.0 in /usr/local/lib/python3.9/dist-packages (from gym[atari]==0.22.0) (6.3.0)\n", + "Requirement already satisfied: gym-notices>=0.0.4 in /usr/local/lib/python3.9/dist-packages (from gym[atari]==0.22.0) (0.0.8)\n", + "Requirement already satisfied: cloudpickle>=1.2.0 in /usr/local/lib/python3.9/dist-packages (from gym[atari]==0.22.0) (2.2.1)\n", + "Collecting ale-py~=0.7.4\n", + " Downloading ale_py-0.7.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.6 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m45.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: importlib-resources in /usr/local/lib/python3.9/dist-packages (from ale-py~=0.7.4->gym[atari]==0.22.0) (5.12.0)\n", + "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.9/dist-packages (from importlib-metadata>=4.10.0->gym[atari]==0.22.0) (3.15.0)\n", + "Building wheels for collected packages: gym\n", + " Building wheel for gym (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for gym: filename=gym-0.22.0-py3-none-any.whl size=708393 sha256=d6829a811711e7f91023bc5545600c609f35fea94f934d60d8ed621ffddc1e37\n", + " Stored in directory: /root/.cache/pip/wheels/c4/15/15/94c62e06887fb88768c5fa41482b80905ea71f3ede81040ffa\n", + "Successfully built gym\n", + "Installing collected packages: gym, ale-py\n", + " Attempting uninstall: gym\n", + " Found existing installation: gym 0.25.2\n", + " Uninstalling gym-0.25.2:\n", + " Successfully uninstalled gym-0.25.2\n", + "Successfully installed ale-py-0.7.5 gym-0.22.0\n" + ] + }, + { + "output_type": "display_data", + "data": { + "application/vnd.colab-display-data+json": { + "pip_warning": { + "packages": [ + "gym" + ] + } + } + }, + "metadata": {} + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "https://www.gymlibrary.dev/environments/toy_text/taxi/" + ], + "metadata": { + "id": "1zPmMwbagW-J" + } + }, + { + "cell_type": "code", + "source": [ + "import time" + ], + "metadata": { + "id": "5AwJw8_qwmu4" + }, + "execution_count": 1, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "R-fJA5uPthQx", + "outputId": "2935628d-99eb-46c9-9d44-16507cb55659" + }, + "source": [ + "import gym\n", + "\n", + "env = gym.make(\"Taxi-v3\").env\n", + "\n", + "env.reset() # reset environment to a new, random state\n", + "\n", + "env.render()" + ], + "execution_count": 6, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "+---------+\n", + "|R: |\u001b[43m \u001b[0m: :\u001b[35mG\u001b[0m|\n", + "| : | : : |\n", + "| : : : : |\n", + "| | : | : |\n", + "|\u001b[34;1mY\u001b[0m| : |B: |\n", + "+---------+\n", + "\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "5 * 5 * 5 * 4\n", + "# row, col, pick (R,G,B,Y,Car), drop(R,G,B,Y)" + ], + "metadata": { + "id": "IAeH0kRedcXQ" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "The filled square represents the taxi, which is **yellow** without a passenger and **green** with a passenger.\n", + "\n", + "The pipe (\"|\") represents a wall which the taxi cannot cross.\n", + "\n", + "- **R, G, Y, B** are the possible pickup and destination locations. \n", + "- The **blue** letter represents the current passenger pick-up location, and the **purple** letter is the current destination." + ], + "metadata": { + "id": "aOmdtV5hp5AG" + } + }, + { + "cell_type": "code", + "source": [ + "env.reset() # reset environment to a new, random state\n", + "env.render()\n", + "\n", + "print(\"Action Space {}\".format(env.action_space))\n", + "print(\"State Space {}\".format(env.observation_space))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "OPO2pwrQpja0", + "outputId": "3fa309f5-6fa7-4831-b9cb-f3b0d55e7bd1" + }, + "execution_count": 7, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "+---------+\n", + "|\u001b[34;1mR\u001b[0m:\u001b[43m \u001b[0m| : :G|\n", + "| : | : : |\n", + "| : : : : |\n", + "| | : | : |\n", + "|Y| : |\u001b[35mB\u001b[0m: |\n", + "+---------+\n", + "\n", + "Action Space Discrete(6)\n", + "State Space Discrete(500)\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "-cp0nKFQtmeW", + "outputId": "2c99ac8b-1745-49c5-9a81-f53d28084d1d" + }, + "source": [ + "print(env.step(5)) # 0=Back, 1=Fwd , 2=Right , 3=Left, 4=pickup, 5=dropoff\n", + "#observation, reward, done, info\n", + "env.render()" + ], + "execution_count": 21, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "(475, 20, True, {'prob': 1.0})\n", + "+---------+\n", + "|R: | : :G|\n", + "| : | : : |\n", + "| : : : : |\n", + "| | : | : |\n", + "|Y| : |\u001b[35m\u001b[34;1m\u001b[43mB\u001b[0m\u001b[0m\u001b[0m: |\n", + "+---------+\n", + " (Dropoff)\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "JvR2a-FAuNxl", + "outputId": "12a64971-cec0-4ac6-faaa-a595c47a6a3a" + }, + "source": [ + "env.reset() # reset environment to a new, random state\n", + "env.render()\n", + "\n", + "print(\"Action Space {}\".format(env.action_space))\n", + "print(\"State Space {}\".format(env.observation_space))" + ], + "execution_count": 22, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "+---------+\n", + "|\u001b[34;1mR\u001b[0m: | : :\u001b[35mG\u001b[0m|\n", + "| : |\u001b[43m \u001b[0m: : |\n", + "| : : : : |\n", + "| | : | : |\n", + "|Y| : |B: |\n", + "+---------+\n", + "\n", + "Action Space Discrete(6)\n", + "State Space Discrete(500)\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# blue - pickup, purple - dropoff\n", + "steps = [0,3,3,1,1,4,0,0,2,2,2,2,1,1,5]\n", + "states = []\n", + "\n", + "for step in steps:\n", + " resp = env.step(step)\n", + " print(resp)\n", + " states.append(resp)\n", + " env.render()\n", + " time.sleep(1)" + ], + "metadata": { + "id": "k_B5TwhlsmwO" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "env.reset()\n", + "env.render()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "KN_GZrupvP8h", + "outputId": "cb02fcbc-1893-4472-c030-a2e8779ab7b9" + }, + "execution_count": 26, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "+---------+\n", + "|R: | : :\u001b[35mG\u001b[0m|\n", + "| : | : : |\n", + "| : : : : |\n", + "| | : | : |\n", + "|Y| :\u001b[43m \u001b[0m|\u001b[34;1mB\u001b[0m: |\n", + "+---------+\n", + "\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "vmN4Rjxqv1Vg", + "outputId": "9d905a7b-199a-4c1a-d2b9-b7340dfe4e38" + }, + "source": [ + "state = env.encode(3, 1, 2, 0) # (taxi row, taxi column, passenger index, destination index)\n", + "print(\"State:\", state)\n", + "\n", + "env.s = state\n", + "env.render()" + ], + "execution_count": 27, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "State: 328\n", + "+---------+\n", + "|R: | : :\u001b[35mG\u001b[0m|\n", + "| : | : : |\n", + "| : : : : |\n", + "| | : | : |\n", + "|Y| :\u001b[43m \u001b[0m|\u001b[34;1mB\u001b[0m: |\n", + "+---------+\n", + "\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "env.env.s = 328\n", + "env.render()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "kpMfr0XJheef", + "outputId": "38d62669-15a8-437e-e1b4-78809ead749f" + }, + "execution_count": 28, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "+---------+\n", + "|\u001b[35mR\u001b[0m: | : :G|\n", + "| : | : : |\n", + "| : : : : |\n", + "| |\u001b[43m \u001b[0m: | : |\n", + "|\u001b[34;1mY\u001b[0m| : |B: |\n", + "+---------+\n", + "\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## All possible states" + ], + "metadata": { + "id": "njQX5jhv8Kxc" + } + }, + { + "cell_type": "code", + "source": [ + "for i in range(50):\n", + " print(i)\n", + " env.env.s = i\n", + " env.render()\n", + " # time.sleep(1)" + ], + "metadata": { + "id": "KdTsHXnCwYAh" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "env.env.s = 328\n", + "env.render()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "0MVGRfgwzI4-", + "outputId": "ea317787-e1f0-4597-b4f7-2b66234e9571" + }, + "execution_count": 31, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "+---------+\n", + "|\u001b[35mR\u001b[0m: | : :G|\n", + "| : | : : |\n", + "| : : : : |\n", + "| |\u001b[43m \u001b[0m: | : |\n", + "|\u001b[34;1mY\u001b[0m| : |B: |\n", + "+---------+\n", + "\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Reward Table\n", + "\n", + "The reward table, also known as the reward function, defines the rewards or penalties associated with each action in each state. It is a fixed table that specifies the immediate reward that the agent receives for taking a particular action in a particular state. The reward table is typically defined by the problem domain, and the Q-learning algorithm uses this table to learn the optimal policy." + ], + "metadata": { + "id": "ekg3LiN08ycz" + } + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "2LL-qRKP1vwh", + "outputId": "e2f2fc6e-4b55-4b21-df78-02496f8ab21e" + }, + "source": [ + "env.P[328]" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "{0: [(1.0, 428, -1, False)],\n", + " 1: [(1.0, 228, -1, False)],\n", + " 2: [(1.0, 348, -1, False)],\n", + " 3: [(1.0, 328, -1, False)],\n", + " 4: [(1.0, 328, -10, False)],\n", + " 5: [(1.0, 328, -10, False)]}" + ] + }, + "metadata": {}, + "execution_count": 21 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "This dictionary has the structure \n", + "\n", + "```{action: [(probability, nextstate, reward, done)]}.```" + ], + "metadata": { + "id": "qjBAsSMI9Jol" + } + }, + { + "cell_type": "code", + "source": [ + "for step in range(6):\n", + " env.env.s = 328\n", + " print(env.step(step))\n", + " env.render()" + ], + "metadata": { + "id": "zuM9bjmUzpQW" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "for i in range(10):\n", + " print(env.action_space.sample())" + ], + "metadata": { + "id": "8MaHb7EV2AmM", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "ffe368a4-6dc6-4e47-9b07-707ef59f5868" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "4\n", + "3\n", + "5\n", + "4\n", + "4\n", + "5\n", + "2\n", + "4\n", + "0\n", + "3\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Solving the environment without Reinforcement Learning\n", + "\n" + ], + "metadata": { + "id": "7DAMMYw69mWH" + } + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ZxHflKla8Tn2", + "outputId": "591bb1ec-ba69-419a-a07b-66e3855446d9" + }, + "source": [ + "env.s = 328 # set environment to illustration's state\n", + "\n", + "epochs = 0\n", + "penalties, reward = 0, 0\n", + "\n", + "frames = [] # for animation\n", + "\n", + "done = False\n", + "\n", + "while not done:\n", + " action = env.action_space.sample()\n", + " state, reward, done, info = env.step(action)\n", + "\n", + " if reward == -10:\n", + " penalties += 1\n", + " \n", + " # Put each rendered frame into dict for animation\n", + " frames.append({\n", + " 'frame': env.render(mode='ansi'),\n", + " 'state': state,\n", + " 'action': action,\n", + " 'reward': reward\n", + " }\n", + " )\n", + "\n", + " epochs += 1\n", + " \n", + "print(\"Timesteps taken: {}\".format(epochs))\n", + "print(\"Penalties incurred: {}\".format(penalties))" + ], + "execution_count": 54, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Timesteps taken: 99\n", + "Penalties incurred: 29\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "len(frames)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "cl_9msXV1vZu", + "outputId": "4f12e6cf-34f2-42c2-80fa-39573172dd8f" + }, + "execution_count": 47, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "568" + ] + }, + "metadata": {}, + "execution_count": 47 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "lXeu2sz6_IKj", + "outputId": "e79413d2-a91a-489f-f8fc-1dc7a6d8beaa" + }, + "source": [ + "from IPython.display import clear_output\n", + "from time import sleep\n", + "\n", + "def print_frames(frames):\n", + " for i, frame in enumerate(frames):\n", + " clear_output(wait=True)\n", + " print(frame['frame'])\n", + " print(f\"Timestep: {i + 1}\")\n", + " print(f\"State: {frame['state']}\")\n", + " print(f\"Action: {frame['action']}\")\n", + " print(f\"Reward: {frame['reward']}\")\n", + " sleep(0.5)\n", + " \n", + "print_frames(frames)" + ], + "execution_count": 55, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "+---------+\n", + "|\u001b[35m\u001b[34;1m\u001b[43mR\u001b[0m\u001b[0m\u001b[0m: | : :G|\n", + "| : | : : |\n", + "| : : : : |\n", + "| | : | : |\n", + "|Y| : |B: |\n", + "+---------+\n", + " (Dropoff)\n", + "\n", + "Timestep: 99\n", + "State: 0\n", + "Action: 5\n", + "Reward: 20\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# DP\n", + "```\n", + "[1,2,3,4,5]\n", + "1 = 1\n", + "3 = 1+2\n", + "6 = 1+2+3\n", + "10 = 1+2+3+4\n", + "15 = 1+2+3+4+5\n", + "\n", + "1 = 1\n", + "3 = 1 + 2\n", + "6 = 3 + 3\n", + "10 = 6 + 4\n", + "15 = 10 + 5\n", + "\n", + "```\n" + ], + "metadata": { + "id": "4Pq1CCK1ltAh" + } + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "E1Z7s6TClF1u" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Enter Reinforcement Learning\n", + "\n", + "We are going to use a simple RL algorithm called Q-learning which will give our agent some memory.\n", + "\n", + "## Intro to Q-learning\n", + "\n", + "Essentially, Q-learning lets the agent use the environment's rewards to learn, over time, the best action to take in a given state.\n", + "\n", + "\n", + "![](https://miro.medium.com/v2/resize:fit:1400/1*EQ-tDj-iMdsHlGKUR81Xgw.png)" + ], + "metadata": { + "id": "BQOGUzI0AZJk" + } + }, + { + "cell_type": "markdown", + "source": [ + "![image.png]()" + ], + "metadata": { + "id": "Vsll17yDBkub" + } + }, + { + "cell_type": "markdown", + "source": [ + "# Implementing Q-learning in python\n" + ], + "metadata": { + "id": "WZjbBG1nCy3e" + } + }, + { + "cell_type": "code", + "source": [ + "import numpy as np\n", + "q_table = np.zeros([env.observation_space.n, env.action_space.n]) " + ], + "metadata": { + "id": "i_qIgEQl_e_Z" + }, + "execution_count": 56, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "q_table.shape" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "U-o-RjkMEYw0", + "outputId": "86d40cc3-513b-4753-dd25-4ff601ce14af" + }, + "execution_count": 57, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(500, 6)" + ] + }, + "metadata": {}, + "execution_count": 57 + } + ] + }, + { + "cell_type": "code", + "source": [ + "q_table" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "VBDP4B9mEdQo", + "outputId": "0ba4f69c-5494-4f49-c2d7-53685cb9a470" + }, + "execution_count": 58, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "array([[0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0.],\n", + " ...,\n", + " [0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0.]])" + ] + }, + "metadata": {}, + "execution_count": 58 + } + ] + }, + { + "cell_type": "code", + "source": [ + "np.argmax([1,2,3])" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "_3qzI1fpqBdd", + "outputId": "69627f73-1033-41af-aaef-824a151c8440" + }, + "execution_count": 62, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "2" + ] + }, + "metadata": {}, + "execution_count": 62 + } + ] + }, + { + "cell_type": "code", + "source": [ + "import random\n", + "from IPython.display import clear_output" + ], + "metadata": { + "id": "yOXysHFaDMNE" + }, + "execution_count": 59, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "%%time\n", + "\"\"\"Training the agent\"\"\"\n", + "\n", + "import random\n", + "from IPython.display import clear_output\n", + "\n", + "# Hyperparameters\n", + "alpha = 0.1 # learning rate\n", + "gamma = 0.6 # discount factor\n", + "epsilon = 0.1 # exploit and explore\n", + "\n", + "# For plotting metrics\n", + "all_epochs = []\n", + "all_penalties = []\n", + "\n", + "for i in range(1, 100001):\n", + " state = env.reset()\n", + "\n", + " epochs, penalties, reward, = 0, 0, 0\n", + " done = False\n", + " \n", + " while not done:\n", + " if random.uniform(0, 1) < epsilon:\n", + " action = env.action_space.sample() # Explore action space\n", + " else:\n", + " action = np.argmax(q_table[state]) # Exploit learned values\n", + "\n", + " next_state, reward, done, info = env.step(action) \n", + " \n", + " old_value = q_table[state, action]\n", + " next_max = np.max(q_table[next_state])\n", + " \n", + " new_value = (1 - alpha) * old_value + alpha * (reward + gamma * next_max)\n", + " q_table[state, action] = new_value\n", + "\n", + " if reward == -10:\n", + " penalties += 1\n", + "\n", + " state = next_state\n", + " epochs += 1\n", + " \n", + " if i % 100 == 0:\n", + " clear_output(wait=True)\n", + " print(f\"Episode: {i}\")\n", + "\n", + "print(\"Training finished.\\n\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "uzLQGY9_DqAK", + "outputId": "75ca1c68-5d1c-4257-f6ed-9b1d04cfd4e6" + }, + "execution_count": 63, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Episode: 100000\n", + "Training finished.\n", + "\n", + "CPU times: user 1min 8s, sys: 8.1 s, total: 1min 17s\n", + "Wall time: 1min 14s\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "for i in range(10):\n", + " print(q_table[i])" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "B-NUS2JLrQ0g", + "outputId": "5923e22b-6e1f-4d95-e9ef-28e5c28f2cc7" + }, + "execution_count": 65, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "[0. 0. 0. 0. 0. 0.]\n", + "[ -2.41837065 -2.3639511 -2.41837066 -2.36395109 -2.27325184\n", + " -11.3639493 ]\n", + "[ -1.870144 -1.45024 -1.87014399 -1.45024007 -0.7504\n", + " -10.45023379]\n", + "[ -2.36395101 -2.27325184 -2.36395029 -2.27325181 -2.1220864\n", + " -11.27325008]\n", + "[-2.4961915 -2.49656291 -2.4961915 -2.49680945 -9.45879238 -8.48029525]\n", + "[0. 0. 0. 0. 0. 0.]\n", + "[ -2.4961915 -2.49715321 -2.4961915 -2.49689163 -10.46663985\n", + " -9.50877724]\n", + "[-2.48236806 -2.48455841 -2.48236806 -2.484406 -8.47407677 -9.66789661]\n", + "[-2.27325184 -2.32928432 -2.34522429 -2.34113999 -8.69224525 -9.39155281]\n", + "[ -2.47061344 -2.47818772 -2.47855343 -2.47607242 -9.68543571\n", + " -10.17891183]\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "q_table[328]" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "_bfFkPsWDqRh", + "outputId": "471c0837-7a52-4b0a-9cc1-bc2fa965581e" + }, + "execution_count": 66, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "array([ -2.40800594, -2.27325184, -2.3922264 , -2.35603561,\n", + " -10.60994385, -10.80888382])" + ] + }, + "metadata": {}, + "execution_count": 66 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Evaluating the agent\n" + ], + "metadata": { + "id": "JSYUfQQ-Fa0E" + } + }, + { + "cell_type": "code", + "source": [ + "\"\"\"Evaluate agent's performance after Q-learning\"\"\"\n", + "\n", + "total_epochs, total_penalties = 0, 0\n", + "episodes = 100\n", + "\n", + "for _ in range(episodes):\n", + " state = env.reset()\n", + " epochs, penalties, reward = 0, 0, 0\n", + " \n", + " done = False\n", + " \n", + " while not done:\n", + " action = np.argmax(q_table[state])\n", + " state, reward, done, info = env.step(action)\n", + "\n", + " if reward == -10:\n", + " penalties += 1\n", + "\n", + " epochs += 1\n", + "\n", + " total_penalties += penalties\n", + " total_epochs += epochs\n", + "\n", + "print(f\"Results after {episodes} episodes:\")\n", + "print(f\"Average timesteps per episode: {total_epochs / episodes}\")\n", + "print(f\"Average penalties per episode: {total_penalties / episodes}\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "hyvGhkEtE4Qc", + "outputId": "24681d5a-81f4-4e75-cd63-5c1fb03bf35e" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Results after 100 episodes:\n", + "Average timesteps per episode: 13.14\n", + "Average penalties per episode: 0.0\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Realtime testing" + ], + "metadata": { + "id": "UTW0luFEFv_C" + } + }, + { + "cell_type": "code", + "source": [ + "env.env.s = 328\n", + "env.render()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "HFLV9boOGJok", + "outputId": "38d50c38-019a-41c9-b7f7-6631236e2a5b" + }, + "execution_count": 67, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "+---------+\n", + "|\u001b[35mR\u001b[0m: | : :G|\n", + "| : | : : |\n", + "| : : : : |\n", + "| |\u001b[43m \u001b[0m: | : |\n", + "|\u001b[34;1mY\u001b[0m| : |B: |\n", + "+---------+\n", + " (Dropoff)\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "env.reset()\n", + "env.s = 328 # set environment to illustration's state\n", + "env.render()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "7Uvgjw24rqid", + "outputId": "85529957-ec46-4462-d64b-13ff642b9d9a" + }, + "execution_count": 77, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "+---------+\n", + "|R: | : :\u001b[34;1mG\u001b[0m|\n", + "| : | : : |\n", + "|\u001b[43m \u001b[0m: : : : |\n", + "| | : | : |\n", + "|\u001b[35mY\u001b[0m| : |B: |\n", + "+---------+\n", + "\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "\n", + "epochs = 0\n", + "penalties, reward = 0, 0\n", + "\n", + "frames = [] # for animation\n", + "\n", + "done = False\n", + "\n", + "while not done:\n", + " action = np.argmax(q_table[state])\n", + " state, reward, done, info = env.step(action)\n", + "\n", + " if reward == -10:\n", + " penalties += 1\n", + " \n", + " # Put each rendered frame into dict for animation\n", + " frames.append({\n", + " 'frame': env.render(mode='ansi'),\n", + " 'state': state,\n", + " 'action': action,\n", + " 'reward': reward\n", + " }\n", + " )\n", + "\n", + " epochs += 1\n", + " \n", + "print(\"Timesteps taken: {}\".format(epochs))\n", + "print(\"Penalties incurred: {}\".format(penalties))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "4ln37j7oFTlx", + "outputId": "d1b821bf-af7f-4ea9-87e6-89245be11b22" + }, + "execution_count": 78, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Timesteps taken: 18\n", + "Penalties incurred: 0\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "env.render()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "xmlwNJ82FTn5", + "outputId": "cefa9289-0507-4d0f-c922-f7afa39181f2" + }, + "execution_count": 79, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "+---------+\n", + "|R: | : :G|\n", + "| : | : : |\n", + "| : : : : |\n", + "| | : | : |\n", + "|\u001b[35m\u001b[34;1m\u001b[43mY\u001b[0m\u001b[0m\u001b[0m| : |B: |\n", + "+---------+\n", + " (Dropoff)\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "WCGeM6lSFTqU" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "4roMfUOvFTxb" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "https://www.learndatasci.com/tutorials/reinforcement-q-learning-scratch-python-openai-gym/\n", + "\n", + "https://developer.nvidia.com/blog/deep-learning-nutshell-reinforcement-learning/\n", + "\n", + "https://medium.com/@MoneyAndData/ai-anyone-can-understand-part-1-reinforcement-learning-6c3b3d623a2d\n", + "\n", + "https://arshren.medium.com/deep-q-learning-a-deep-reinforcement-learning-algorithm-f1366cf1b53d\n", + "\n", + "https://www.coursera.org/specializations/reinforcement-learning" + ], + "metadata": { + "id": "nFefBluGFT6Z" + } + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "my6tYnl6FV04" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file