diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index b913bcf..c166b2e 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -62,3 +62,25 @@ jobs: issue: ${{ github.event.pull_request.number }} message: ${{ env.MESSAGE }} repo-token: ${{ secrets.GITHUB_TOKEN }} + + Check-next-version: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + - name: Get version from file + run: | + VERSION_FILE="navix/_version.py" + NAVIX_VERSION="$(cat navix/_version.py | grep '__version__ = ' | cut -d'=' -f2 | sed 's,\",,g' | sed "s,\',,g" | sed 's, ,,g')" + echo "Current version is:" + echo "$NAVIX_VERSION" + echo "NAVIX_VERSION=$NAVIX_VERSION" >> $GITHUB_ENV + - name: Check that git tag does not exist + run: | + NAVIX_VERSION=${{ env.NAVIX_VERSION }} + git fetch --tags + if [ $(git tag -l "$NAVIX_VERSION") ]; then + echo "Tag $NAVIX_VERSION already exists. Please update the version in navix/_version.py file." + exit 1 + fi + echo "Tag $NAVIX_VERSION will be the deployed version." diff --git a/README.md b/README.md index 498fc6c..ea716f0 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,7 @@ [![CI](https://github.com/epignatelli/navix/actions/workflows/CI.yml/badge.svg)](https://github.com/epignatelli/navix/actions/workflows/CI.yml) [![CD](https://github.com/epignatelli/navix/actions/workflows/CD.yml/badge.svg)](https://github.com/epignatelli/navix/actions/workflows/CD.yml) ![PyPI version](https://img.shields.io/pypi/v/navix?label=PyPI&color=%230099ab) + **[Quickstart](#what-is-navix)** | **[Installation](#installation)** | **[Examples](#examples)** | **[The JAX ecosystem](#jax-ecosystem-for-rl)** | **[Contribute](#join-us)** | **[Cite](#cite)** @@ -154,6 +155,7 @@ NAVIX is not alone and part of an ecosystem of JAX-powered modules for RL. Check - [Rejax](https://github.com/keraJLi/rejax): a suite of diverse agents, among which, DDPG, DQN, PPO, SAC, TD3 - [Stoix](https://github.com/EdanToledo/Stoix): useful implementations of popular single-agent RL algorithms in JAX - [JAX-CORL](https://github.com/nissymori/JAX-CORL): lean single-file implementations of offline RL algorithms with solid performance reports + - [Dopamine](https://github.com/google/dopamine): a research framework for fast prototyping of reinforcement learning algorithms ## Join Us! diff --git a/docs/api/index.md b/docs/api/index.md new file mode 100644 index 0000000..bd4026d --- /dev/null +++ b/docs/api/index.md @@ -0,0 +1 @@ +**Coming soon** \ No newline at end of file diff --git a/docs/assets/images/navix_logo.png b/docs/assets/images/navix_logo.png new file mode 100644 index 0000000..a8f351c Binary files /dev/null and b/docs/assets/images/navix_logo.png differ diff --git a/docs/assets/macros/macros.py b/docs/assets/macros/macros.py new file mode 100644 index 0000000..ea0546b --- /dev/null +++ b/docs/assets/macros/macros.py @@ -0,0 +1,3 @@ +from plumkdocs import define_env + +__all__ = ['define_env'] \ No newline at end of file diff --git a/docs/assets/stylesheets/extra.css b/docs/assets/stylesheets/extra.css new file mode 100644 index 0000000..5fc5bdf --- /dev/null +++ b/docs/assets/stylesheets/extra.css @@ -0,0 +1,140 @@ +.md-header__button.md-logo { + margin-bottom: 0; + padding-bottom: 0; +} + +.md-header__button.md-logo img, +.md-header__button.md-logo svg { + height: 3rem !important; +} + +.center { + display: block; + margin-left: auto; + margin-right: auto; +} + +.doc-attribute { + border-top: 1px solid #ccc; +} + +.doc-method { + border-top: 1px solid #ccc; +} + +.doc-function { + border-top: 1px solid #ccc; +} + +.doc-class { + border-top: 5px solid var(--md-code-bg-color); +} + +/* Remove the `In` and `Out` block in rendered Jupyter notebooks */ +.md-container .jp-Cell-outputWrapper .jp-OutputPrompt.jp-OutputArea-prompt, +.md-container .jp-Cell-inputWrapper .jp-InputPrompt.jp-InputArea-prompt { + display: none !important; +} + +/* Fix /page#foo going to the top of the viewport and being hidden by the navbar */ +html { + scroll-padding-top: 50px; +} + +/* Emphasise sections of nav on left hand side */ +nav.md-nav { + padding-left: 5px; +} + +nav.md-nav--secondary { + border-left: revert !important; +} + +.md-nav__title { + font-size: 0.9rem; +} + +.md-nav__item--section>.md-nav__link { + font-size: 0.9rem; +} + +/* More space at the bottom of the page */ + +.md-main__inner { + margin-bottom: 1.5rem; +} + + +/* Change font sizes */ +html { + /* Decrease font size for overall webpage + Down from 137.5% which is the Material default */ + font-size: 110%; +} + +.md-typeset .admonition { + /* Increase font size in admonitions */ + font-size: 100% !important; +} + +.md-typeset details { + /* Increase font size in details */ + font-size: 100% !important; +} + +.md-typeset h1 { + font-size: 1.6rem; +} + +.md-typeset h2 { + font-size: 1.5rem; +} + +.md-typeset h3 { + font-size: 1.3rem; +} + +.md-typeset h4 { + font-size: 1.1rem; +} + +.md-typeset h5 { + font-size: 0.9rem; +} + +.md-typeset h6 { + font-size: 0.8rem; +} + + +/* Highlight functions, classes etc. type signatures. Really helps to make clear where + one item ends and another begins. */ + +[data-md-color-scheme="default"] { + --doc-heading-color: #DDD; + --doc-heading-border-color: #CCC; + --doc-heading-color-alt: #F0F0F0; +} + +[data-md-color-scheme="slate"] { + --doc-heading-color: rgb(25, 25, 33); + --doc-heading-border-color: rgb(25, 25, 33); + --doc-heading-color-alt: rgb(33, 33, 44); + --md-code-bg-color: rgb(38, 38, 50); +} + +h4.doc-heading { + /* NOT var(--md-code-bg-color) as that's not visually distinct from other code blocks.*/ + background-color: var(--doc-heading-color); + border: solid var(--doc-heading-border-color); + border-width: 1.5pt; + border-radius: 2pt; + padding: 0pt 5pt 2pt 5pt; +} + +h5.doc-heading, +h6.heading { + background-color: var(--doc-heading-color-alt); + border-radius: 2pt; + padding: 0pt 5pt 2pt 5pt; +} \ No newline at end of file diff --git a/docs/benchmarks/index.md b/docs/benchmarks/index.md new file mode 100644 index 0000000..bd4026d --- /dev/null +++ b/docs/benchmarks/index.md @@ -0,0 +1 @@ +**Coming soon** \ No newline at end of file diff --git a/docs/changelog/index.md b/docs/changelog/index.md new file mode 100644 index 0000000..bd4026d --- /dev/null +++ b/docs/changelog/index.md @@ -0,0 +1 @@ +**Coming soon** \ No newline at end of file diff --git a/docs/design_notes.md b/docs/design_notes.md deleted file mode 100644 index f4538b6..0000000 --- a/docs/design_notes.md +++ /dev/null @@ -1,22 +0,0 @@ -# An Entity Component System Model for Minigrid -... - -# Fully jittable training loop -We might just use xla to compute the agent moves and grid update, but the speedup would be beneficial only if we aim to spawn thousands of objects in the environment. -On one hand, this would be easy to code, as we can use python to manage the ECS (e.g., use a dictionary to store di entities and iterate through it). -On the other hand: -1. This is not the usual case of minigrid, where we have a few objects in the environment. -2. You cannot jit the full training loop, becuase at every `env.step` we return execution management to python to handle entities, components and systems. - - -Thereforse, we must extend then jit-compatibility to the whole training loop, including the environment step and the agent moves. -To do this we have the following components: -- An Array representing the world map (the grid). Notice that entities are not visible on the grid, but only walkable/non-walkable tiles that restrict agent's ability to move. -- A Player entity, with a position and an action -- a Goal entity, with a position - - - -- Coordinates (-1, -1) are a discard pile for entities that are not in the grid. -- Every entity but the Player is batched, such that its properties have shape (B, *prop_shape). -- Every entity has a `tag`, which we use to represent the entity in categorical form. The `tag` must be an `int` greater than `1`. \ No newline at end of file diff --git a/docs/examples/customisation.ipynb b/docs/examples/customisation.ipynb new file mode 100644 index 0000000..216c984 --- /dev/null +++ b/docs/examples/customisation.ipynb @@ -0,0 +1,507 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Env customization\n", + "\n", + "This tutorial will guide you through the basics of using NAVIX. You will learn:\n", + "- How to create a `navix.Environment`,\n", + "- A vanilla, suboptimal interaction with it\n", + "- How to `jax.jit` compile the environment for faster execution\n", + "- How to run batched simulations" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Installation\n", + "\n", + "For a full guide on how to install NAVIX and its dependencies, please refer to the [official installation guide](../install/index.html)\n", + "For a quickstart, you can install NAVIX via pip:\n", + "```bash\n", + "pip install navix\n", + "```\n", + "\n", + "This will provide a standard CPU-based JAX installation. If you want to use a GPU, please [install JAX](https://github.com/google/jax/?tab=readme-ov-file#installation) with the appropriate backend." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Creating an Environment\n", + "\n", + "NAVIX provides a variety of MiniGrid environments.\n", + "You can find an exhaustive list [here](../home/environments.html). \n", + "If the environment you are looking for is not listed, please open an new [feature request](https://github.com/epignatelli/navix/issues/new?assignees=&labels=enhancement&template=feature_request.md).\n", + "\n", + "Now, let's create a simple DoorKey environment. The syntax is similar to the usual `gym.make`." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(64, 64, 3)\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGbCAYAAAAr/4yjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAVVklEQVR4nO3dfZCd89nA8evsht1kN2y3SURpJSKkImZMUO+LmmYiZCghHYbElLaUodJpURVvNaqI6RQtbaLVMSJaJRgd0+h0hCmqakrTVBKjSkJ2hTZKk72fP54n1+Nkz6ldk305m89nxoz97Z09157cfPc+57fnlIqiKAIAIqKuvwcAYOAQBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBT5UqVSKuXPnduvYMWPGxKxZs3p8G6tWrYpSqRQLFizo8Z+t5LHHHotSqRSLFi3aIl+vlmz63h977LH+HoUaJApbgQULFkSpVIqnn356i3y9pUuXxty5c+Ott97aIl+Pj+bmm2/eYhGFTYb09wAMfO+++24MGfL/p8rSpUvj8ssvj1mzZkVLS0vZscuWLYu6Oj9r9IWbb745RowY0eXK7LDDDot33303tt122/4ZjJomCnyoxsbGbh/b0NDQi5PUpvXr18ewYcP67Pbq6up69HcGH+RHuq3UrFmzorm5OV599dU47rjjorm5OUaOHBlz5syJjRs3lh37wecU5s6dG1//+tcjImLs2LFRKpWiVCrFqlWrIqLrcwrt7e0xZ86cmDRpUjQ3N8d2220XU6dOjeeee+4jz75ixYqYMWNGtLa2xrBhw+KAAw6IBx98sOKxGzdujIsvvjhGjx4dTU1NMX369HjllVfKjlm+fHmccMIJMXr06GhsbIydd945Zs6cGevWrSs77s4774zJkyfH0KFDo7W1NWbOnNnlax1++OGx1157xTPPPBOHHXZYDBs2LC6++OI45phjYtddd60444EHHhj77rtvfjx//vw48sgjY9SoUdHQ0BB77rln3HLLLWV/ZsyYMfHnP/85fvvb3+bfweGHHx4R1Z9TuOeee3L+ESNGxKmnnhqvvvpq2TE9OS8YnFwpbMU2btwYU6ZMic985jPxve99Lx599NG4/vrrY9y4cfGVr3yl4p/5/Oc/H3/961/jrrvuihtvvDFGjBgREREjR46sePyKFSvivvvuixkzZsTYsWNj9erV8cMf/jDa2trihRdeiE984hM9mnn16tVx0EEHxfr16+O8886Lj3/843HHHXfE9OnTY9GiRXH88ceXHX/11VdHqVSKb3zjG7FmzZqYN29eHHXUUfHHP/4xhg4dGu+//35MmTIl3nvvvTj33HNj9OjR8eqrr8bixYvjrbfeiu233z6/zqWXXhonnXRSfPGLX4w33ngjvv/978dhhx0Wzz77bNnDaGvXro2pU6fGzJkz49RTT40ddtghJk+eHKeddlo89dRTsd9+++WxL7/8cjz55JNx3XXX5dott9wSEydOjOnTp8eQIUPigQceiLPPPjs6OzvjnHPOiYiIefPmxbnnnhvNzc1xySWXRETEDjvsUPV+W7BgQcyePTv222+/uOaaa2L16tVx0003xeOPP95l/o9yXjCIFAx68+fPLyKieOqpp3Lt9NNPLyKiuOKKK8qO3WeffYrJkyeXrUVEcdlll+XH1113XRERxcqVK7vc1i677FKcfvrp+fG///3vYuPGjWXHrFy5smhoaCi77ZUrVxYRUcyfP/+/fi/nn39+ERHF7373u1x75513irFjxxZjxozJ21qyZEkREcVOO+1UvP3223nswoULi4gobrrppqIoiuLZZ58tIqK45557qt7mqlWrivr6+uLqq68uW3/++eeLIUOGlK23tbUVEVHceuutZceuW7euaGhoKC688MKy9e9+97tFqVQqXn755Vxbv359lxmmTJlS7LrrrmVrEydOLNra2rocu+l7X7JkSVEURfH+++8Xo0aNKvbaa6/i3XffzeMWL15cRETx7W9/O9d6cl4wOHn4aCv35S9/uezjQw89NFasWLHFvn5DQ0M+8bxx48ZYu3ZtNDc3xx577BF/+MMfevz1Hnroodh///3jkEMOybXm5uY466yzYtWqVfHCCy+UHX/aaafF8OHD8+MTTzwxdtxxx3jooYciIvJK4JFHHon169dXvM1f/OIX0dnZGSeddFK8+eab+c/o0aNj/PjxsWTJki7f8+zZs8vWNj1stnDhwig+8L5Wd999dxxwwAHxqU99KteGDh2a/75u3bp48803o62tLVasWNHlIa3uePrpp2PNmjVx9tlnlz3XMG3atJgwYULFh956+7xg4BKFrVhjY2OXh30+9rGPRUdHxxa7jc7Ozrjxxhtj/Pjx0dDQECNGjIiRI0fGn/70p4/0P7iXX3459thjjy7rn/70p/PzHzR+/Piyj0ulUuy22275HMjYsWPja1/7Wtx+++0xYsSImDJlSvzgBz8om2358uVRFEWMHz8+Ro4cWfbPiy++GGvWrCm7jZ122qnizp+TTz45XnnllXjiiSciIuKll16KZ555Jk4++eSy4x5//PE46qijoqmpKVpaWmLkyJFx8cUXR0R85PssIirebxMmTOhyn/XFecHA5TmFrVh9fX2v38Z3vvOduPTSS+OMM86IK6+8MlpbW6Ouri7OP//86Ozs7PXb747rr78+Zs2aFb/61a/i17/+dZx33nlxzTXXxJNPPhk777xzdHZ2RqlUiocffrjifdbc3Fz28Qd/0v+gY489NoYNGxYLFy6Mgw46KBYuXBh1dXUxY8aMPOall16Kz372szFhwoS44YYb4pOf/GRsu+228dBDD8WNN97YJ/dZX5wXDFyiQI+VSqVuH7to0aI44ogj4sc//nHZ+ltvvZVPUvfELrvsEsuWLeuy/pe//CU//0HLly8v+7goivjb3/4We++9d9n6pEmTYtKkSfGtb30rli5dGgcffHDceuutcdVVV8W4ceOiKIoYO3Zs7L777j2eeZOmpqY45phj4p577okbbrgh7r777jj00EPLnmx/4IEH4r333ov777+/7CGlzR+iiuj+38Om+2TZsmVx5JFHln1u2bJlXe4ztm4ePqLHmpqaIiK69RvN9fX1ZY+hR/zv1sjNt0J219FHHx2///3v8yGYiIh//etf8aMf/SjGjBkTe+65Z9nxP/3pT+Odd97JjxctWhSvvfZaTJ06NSIi3n777diwYUPZn5k0aVLU1dXFe++9FxH/u+Oqvr4+Lr/88i7fS1EUsXbt2m7Pf/LJJ8c//vGPuP322+O5557r8tDRpp/SP3g769ati/nz53f5Wk1NTd36O9h3331j1KhRceutt+b3FBHx8MMPx4svvhjTpk3r9vwMfq4U6LHJkydHRMQll1wSM2fOjG222SaOPfbYjMUHHXPMMXHFFVfE7Nmz46CDDornn38+fv7zn1fds/9hvvnNb8Zdd90VU6dOjfPOOy9aW1vjjjvuiJUrV8a9997b5bepW1tb45BDDonZs2fH6tWrY968ebHbbrvFmWeeGRERv/nNb+KrX/1qzJgxI3bffffYsGFD/OxnP4v6+vo44YQTIiJi3LhxcdVVV8VFF10Uq1atiuOOOy6GDx8eK1eujF/+8pdx1llnxZw5c7o1/9FHHx3Dhw+POXPmlN3GJp/73Odi2223jWOPPTa+9KUvxT//+c+47bbbYtSoUfHaa6+VHTt58uS45ZZb4qqrrorddtstRo0a1eVKICJim222iWuvvTZmz54dbW1t8YUvfCG3pI4ZMyYuuOCCbt//bAX6b+MTfaXaltSmpqYux1522WXF5qdFbLYltSiK4sorryx22mmnoq6urmx7aqUtqRdeeGGx4447FkOHDi0OPvjg4oknnija2trKtlN2d0tqURTFSy+9VJx44olFS0tL0djYWOy///7F4sWLy47ZtC3zrrvuKi666KJi1KhRxdChQ4tp06aVbf9csWJFccYZZxTjxo0rGhsbi9bW1uKII44oHn300S63e++99xaHHHJI0dTUVDQ1NRUTJkwozjnnnGLZsmV5TFtbWzFx4sT/Ov8pp5xSRERx1FFHVfz8/fffX+y9995FY2NjMWbMmOLaa68tfvKTn3TZBvz6668X06ZNK4YPH15ERN6fm29J3eTuu+8u9tlnn6KhoaFobW0tTjnllOLvf/972TE9OS8YnEpFsdn1MABbLc8pAJBEAYAkCgAkUQAgiQIASRQASN3+5bVNv+wDQG267bbbPvQYVwoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEAa0t8D9LbfLFlScf31jo6K6w2lUsX1vq5nS0tLxfWOKnOXqszdH3oy+0Ca+/jjj+/vET6SJVXO8YF+rtTyOV6r50p3uFIAIIkCAEkUAEiiAEASBQDSoN999J8qOxn2bm+vuP50la+zYQvN011FUVRcb68y90BSy7PXomq7dQb6/e08GZhcKQCQRAGAJAoAJFEAIIkCAGnQ7z4aXuX1Uu6rcvyaKuu3V1hbWOXY1//7SN0ykF7npadqefZaVKv3d63OPdi5UgAgiQIASRQASKIAQBIFANKg333UWWW98quuREyqsn5ThbW5VY5dXGV9XpX1P1RZB+hrrhQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBr0L3NRrXrV3t7j+Srrff0mOwD9wZUCAEkUAEiiAEASBQCSKACQBv3uo39VWf98lfWnqqz/ZwvMAjDQuVIAIIkCAEkUAEiiAEASBQDSoN999E6V9aV9OgVAbXClAEASBQCSKACQRAGANOifaP5YS0vlTxRFxeVSqdrb7/StlipzFwN87oiezT6Q5q5VtXqu1Orcg50rBQCSKACQRAGAJAoAJFEAIA363UcdHR0V19vb2/t4kp6ptgNjoM8dUduz1yLnOFuSKwUAkigAkEQBgCQKACRRACAN+t1Htfp6KbU6d0Rtz16LavX+rtW5BztXCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgDenvAXpbS0tLxfWiKCqud3aWKq7vsEPXtQMPrHybjz9eef2NNyqv19d3Xevp3KVS5bn7Q09mH0hz16paPVdqde7BzpUCAEkUAEiiAEASBQCSKACQBv3uo46Ojorr7e3tFdcr7QSKiLjzzq5rY8dWPnb69MrrJ5xQeb3SZotqOzCqzT2Q1PLstain5/hA4TwZmFwpAJBEAYAkCgAkUQAgDfonmnv6q/Hbb195feLErmvz51c+9vTTK69vt13l9XXruq7V8q/01/LstahW7+9anXuwc6UAQBIFAJIoAJBEAYAkCgCkQb/7qKeq/Yb9ZZd1XVuwoPKxs2dXXq+0ywhgIHGlAEASBQCSKACQRAGAJAoAJLuPNlPtTXYqvY/JihWVj73vvi02DkCfcqUAQBIFAJIoAJBEAYAkCgAku482UxTdXx9S5d5raNhy8wD0JVcKACRRACCJAgBJFABInmjeTLWXudh5565ra9dWPnbjxi03D0BfcqUAQBIFAJIoAJBEAYAkCgAku482U23n0IMPdl27997Kx7a3b7l5APqSKwUAkigAkEQBgCQKACRRACAN+t1HLS0tFdeLKu+mUyqVKq6vX9/922xt7f6x1WypuftDT2YfSHPXqlo9V2p17sHOlQIASRQASKIAQBIFAJIoAJAG/e6jjo6OiuvtA/wFiqrtwBjoc0fU9uy1yDnOluRKAYAkCgAkUQAgiQIASRQASIN+91Gtvl5Krc4dUduz16Javb9rde7BzpUCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAGlIfw/Q21paWiquF0VRcb1UKvXiNN1Xq3NH9Gz2gTR3rarVc6VW5x7sXCkAkEQBgCQKACRRACCJAgBp0O8+6ujoqLje3t7ex5P0TLUdGAN97ojanr0WOcfZklwpAJBEAYAkCgAkUQAgiQIAadDvPqrV10up1bkjanv2WlSr93etzj3YuVIAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIA3p7wF6W0tLS8X1oigqrpdKpV6cpvtqde6Ins0+kOa+4ILbK64vX175+Lpe/JGqs7Pr2vjxlY9dsmRcxfWBfq6sfXNtf4/w0Z3Z3wP0HlcKACRRACCJAgBJFABIogBAGvS7jzo6Oiqut7e39/EkPVNt58hAnzuidmevtsvokUcqrw/pxf96Nmzo/rG1eo5XVfn06R8DY6NWn3KlAEASBQCSKACQRAGAJAoApEG/+2igvM5LT9Xq3BG1O3u11zKqtstom226rlXZeFV1vSeqzVer93eP9ea3OZB2PPUzVwoAJFEAIIkCAEkUAEiiAEAa9LuPoLdU2lFUbSPQdttVXq+vr7y+tobflIza5koBgCQKACRRACCJAgBJFABIdh9BHxg3rvJ6Y2Pl9Sef7Lq2JV4/CT6MKwUAkigAkEQBgCQKACRPNMNHVOklLao9Gfz005XXe/Lk8dbyXjr0L1cKACRRACCJAgBJFABIogBAsvsI/k9nZ+X1DRv6do5qt1ltvq2Gl/noE64UAEiiAEASBQCSKACQRAGAZPcR/J/x43t2fF0v/khVaadRT+cbdHrztZ/sbEquFABIogBAEgUAkigAkEQBgFQqiu6999OZZ57Z27MA0Ituu+22Dz3GlQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBKRVEU/T0EAAODKwUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUA0v8AOC1h/s09XdEAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import navix as nx\n", + "\n", + "# Create the environment\n", + "env = nx.make('Navix-DoorKey-8x8-v0', observation_fn=nx.observations.rgb)\n", + "key = jax.random.PRNGKey(0)\n", + "timestep = env.reset(key)\n", + "\n", + "def render(obs, title):\n", + " plt.imshow(obs)\n", + " plt.title(title)\n", + " plt.axis('off')\n", + " plt.show()\n", + "\n", + "print(timestep.observation.shape)\n", + "render(timestep.observation, \"Initial observation\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Take-home message:**\n", + "1. To sample an initial environment state (`timestep`), we need to pass the `key` (seed) argument to the environment constructor. This is because NAVIX uses JAX's PRNGKey to generate random numbers. You can read more here\n", + "2. `env.reset` returns a [`navix.Timestep`]() object, which contains all the useful information." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## The environment interface\n", + "\n", + "We can now simulate a sequence of actions in the environment. For this example, we'll make the agent take random actions." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGbCAYAAAAr/4yjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAU5ElEQVR4nO3dfYycZbn48Wu2pVtAyrbQVkhhoYWWFgrVEk8ILw2tEEtpgfBSNEIBSzSligIxsajtqiEBE5RTKiBHoRYTKRxyEFGgyBoUOSoRMFDRKi8xBu3LLlWgvLR7//7wcMVhZ2CXX/dltp9P0j/27tOdax+G/e4zc+9MpZRSAgAiommgBwBg8BAFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFGt6KFSuiUqnE5s2bB3qUfvfW1w47iyjswm699daoVCrx2GOP9cvtrV+/PlasWBHPP/98v9zeUPHqq6/GihUr4mc/+9lAj8IuQBToN+vXr4+2tjZR6KVXX3012traakbhi1/8Ymzbtq3/h2LIEgXYSbq6uuK1117r19scPnx4jBw5sl9vk6FNFHhHb7zxRnz5y1+OmTNnxt577x177rlnHH/88dHe3t7t2B/84Acxc+bM2GuvvWLUqFExffr0uO666yLiXw9VnX322RERceKJJ0alUolKpfKuD4k89NBDcfzxx8eee+4ZLS0tcdppp8Xvf//7msdu3rw5zjnnnBg1alTss88+cemll3b7Jr1u3bo47rjjoqWlJd73vvfFlClTYtmyZVXHvP7667F8+fI45JBDorm5OQ444ID4/Oc/H6+//nrVcZVKJZYuXRrf//734/DDD4/m5ua45557YsyYMXHhhRd2m+8f//hHjBw5Mq644ooen9vnn38+xo4dGxERbW1ted5WrFgREbWfU9i+fXt89atfjUmTJkVzc3McdNBBsWzZsm7zH3TQQXHqqafGL37xi/jQhz4UI0eOjIkTJ8b3vve9ev852BUUdlm33HJLiYjym9/8pu4xmzZtKvvtt1+57LLLyg033FCuueaaMmXKlLLbbruVxx9/PI974IEHSkSUOXPmlFWrVpVVq1aVpUuXlrPPPruUUsqf//zn8pnPfKZERFm2bFlZs2ZNWbNmTfnb3/5W97bXrVtXhg8fXiZPnlyuueaa0tbWVvbdd98yevTo8txzz+Vxy5cvLxFRpk+fXubPn1+uv/768vGPf7xERDnvvPPyuKeeeqqMGDGiHH300eW6664rN954Y7niiivKCSeckMfs2LGjnHzyyWWPPfYon/3sZ8tNN91Uli5dWoYPH15OO+20qvkiokydOrWMHTu2tLW1lVWrVpXHH3+8XHTRRaWlpaW8/vrrVcevXr266nz35Ny+/PLL5YYbbigRUc4444w8b08++WTV1/7vFi1aVCKinHXWWWXVqlXl/PPPLxFRTj/99KrjWltby5QpU8r48ePLsmXLyvXXX18++MEPlkqlUp566qm6/10Y2kRhF9aTKGzfvr3bN7fOzs4yfvz4ctFFF+XapZdeWkaNGlW2b99e93PdcccdJSJKe3t7j+abMWNGGTduXNmyZUuuPfnkk6Wpqamcf/75ufbWN8YFCxZU/fslS5aUiMhvoN/4xjdKRJRNmzbVvc01a9aUpqam8vOf/7xq/cYbbywRUR555JFci4jS1NRUnn766apj77///hIR5Z577qlaP+WUU8rEiRPz456e202bNpWIKMuXL+8279uj8MQTT5SIKIsXL6467oorrigRUR566KFca21tLRFRHn744VzbuHFjaW5uLpdffnm322LX4OEj3tGwYcNixIgREfGvx8w7Ojpi+/btcfTRR8dvf/vbPK6lpSVeeeWVWLdu3U653RdffDGeeOKJuOCCC2LMmDG5fuSRR8ZJJ50UP/7xj7v9m0suuaTq409/+tMREXlsS0tLRETcfffd0dXVVfN277jjjpg6dWocdthhsXnz5vwze/bsiIhuD5vNmjUrpk2bVrU2e/bs2HfffeP222/Ptc7Ozli3bl0sXLgw13p6bnvjra/1sssuq1q//PLLIyLi3nvvrVqfNm1aHH/88fnx2LFjY8qUKfHss8++p9un8YkC72r16tVx5JFHxsiRI2OfffaJsWPHxr333htbt27NY5YsWRKTJ0+OuXPnxoQJE+Kiiy6K++677z3f5gsvvBAREVOmTOn2d1OnTo3NmzfHK6+8UrV+6KGHVn08adKkaGpqyt1OCxcujGOPPTYWL14c48ePj3PPPTfWrl1bFYgNGzbE008/HWPHjq36M3ny5IiI2LhxY9VtHHzwwd3mGz58eJx55plx99135+P4d911V7z55ptVUYjo2bntjRdeeCGamprikEMOqVp///vfHy0tLXle33LggQd2+xyjR4+Ozs7O93T7ND5R4B3ddtttccEFF8SkSZPiO9/5Ttx3332xbt26mD17dtU303HjxsUTTzwRP/zhD2PBggXR3t4ec+fOjUWLFg3Y7G9/Anb33XePhx9+OB588ME477zz4ne/+10sXLgwTjrppNixY0dE/Osn9unTp8e6detq/lmyZEm3z1nLueeeG//85z/jJz/5SURErF27Ng477LA46qij8pientud8bXXM2zYsJrrxbv07rKGD/QADG533nlnTJw4Me66666qbzTLly/vduyIESNi/vz5MX/+/Ojq6oolS5bETTfdFF/60pfikEMO6dVv3ra2tkZExB/+8Iduf/fMM8/EvvvuG3vuuWfV+oYNG6p+cv/Tn/4UXV1dcdBBB+VaU1NTzJkzJ+bMmRPXXnttXHXVVXHllVdGe3t7fPjDH45JkybFk08+GXPmzPn/+k3hE044Ifbbb7+4/fbb47jjjouHHnoorrzyyqpjenpue3veurq6YsOGDTF16tRc//vf/x4vvfRSnleox5UC7+itnyT//SfHX/3qV/Hoo49WHbdly5aqj5uamuLII4+MiMiHUN76Jv7SSy+96+3ut99+MWPGjFi9enXV8U899VQ88MADccopp3T7N6tWrar6eOXKlRERMXfu3IiI6Ojo6PZvZsyYUTXjOeecE3/961/j5ptv7nbstm3buj1kVU9TU1OcddZZcc8998SaNWti+/bt3R466um53WOPPSKiZ+ftrfPyzW9+s2r92muvjYiIefPm9Wh+dl2uFIjvfve7NR//v/TSS+PUU0+Nu+66K84444yYN29ePPfcc3HjjTfGtGnT4uWXX85jFy9eHB0dHTF79uyYMGFCvPDCC7Fy5cqYMWNG/sQ6Y8aMGDZsWFx99dWxdevWaG5ujtmzZ8e4ceNqzvX1r3895s6dG8ccc0x84hOfiG3btsXKlStj7733zn36/+65556LBQsWxEc+8pF49NFH47bbbouPfexj+ZDNV77ylXj44Ydj3rx50draGhs3boxvfetbMWHChDjuuOMiIuK8886LtWvXxqc+9alob2+PY489Nnbs2BHPPPNMrF27Nu6///44+uije3ReFy5cGCtXrozly5fH9OnTq35yj4gen9vdd989pk2bFrfffntMnjw5xowZE0cccUQcccQR3W7zqKOOikWLFsW3v/3teOmll2LWrFnx61//OlavXh2nn356nHjiiT2anV3YAO9+YgC9tSW13p+//OUvpaurq1x11VWltbW1NDc3lw984APlRz/6UVm0aFFpbW3Nz3XnnXeWk08+uYwbN66MGDGiHHjggeWTn/xkefHFF6tu8+abby4TJ04sw4YN69H21AcffLAce+yxZffddy+jRo0q8+fPL+vXr6865q1tmevXry9nnXVW2Wuvvcro0aPL0qVLy7Zt2/K4n/70p+W0004r+++/fxkxYkTZf//9y0c/+tHyxz/+serzvfHGG+Xqq68uhx9+eGlubi6jR48uM2fOLG1tbWXr1q15XESUSy65pO7sXV1d5YADDigRUb72ta/V/PuenNtSSvnlL39ZZs6cWUaMGFG1PbXW7ym8+eabpa2trRx88MFlt912KwcccED5whe+UF577bWq41pbW8u8efO6zTVr1qwya9asul8XQ1ulFM8oAfAvnlMAIIkCAEkUAEiiAEASBQCSKACQevzLaxdffHFfzgFAH6v1m/pv50oBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIwwd6gL7W3t5ec72zs7PmeqVS6ctxeqylpaXmel/O3VVn/c066yPqrI/uxeyD5XxHRJxxxhkDPcJ74j7e/xr1vtITrhQASKIAQBIFAJIoAJBEAYA05Hcf1dvJ0NHR0c+T9E4ppeZ6X85d787woTrrj9f7RAMw+67MfZydyZUCAEkUAEiiAEASBQCSKACQhvzuo8H0eim9sbPmfn+d9XNqrC3u5ec4os76jgY9541qV7+Ps3O5UgAgiQIASRQASKIAQBIFANKQ333UqHbUWf9AnfXP1Vk/tc766F7MsrHOeu1XrgEamSsFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJC8zMUgNazO+uN11s+vs96Xb7LjLVJg6HGlAEASBQCSKACQRAGAJAoAJLuPhri/1Vn/zxprN9Q59j/qrL9SZ33kO04EDGauFABIogBAEgUAkigAkEQBgGT3EenNOuu/6OXnsfsIGpcrBQCSKACQRAGAJAoApCH/RHNLS0vN9VJKzfVKZXC8dUyjzh3Ru9kH09yNqlHvK40691DnSgGAJAoAJFEAIIkCAEkUAEhDfvdRZ2dnzfWOjo5+nqR36u3AGOxzRzT27I3IfZydyZUCAEkUAEiiAEASBQCSKACQhvzuo0Z9vZRGnTuisWdvRI16vht17qHOlQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASMMHeoC+1tLSUnO9lFJzvaurUnN9/Pjua8ccU/s2H3mk9vqmTbXXhw3rvtbbuSuV2nMPhN7MPpjmblSNel9p1LmHOlcKACRRACCJAgBJFABIogBAGvK7jzo7O2uud3R01FyvtRMoIuK227qvHXxw7WMXLKi9fuaZtddrbbaotwOj3tyDSSPP3oh6ex8fLNxPBidXCgAkUQAgiQIASRQASEP+iebe/mr83nvXXj/88O5rt9xS+9hFi2qvjxpVe33r1u5rjfwr/Y08eyNq1PPdqHMPda4UAEiiAEASBQCSKACQRAGANOR3H/VWvd+wX768+9qtt9Y+9sILa6/X2mUEMJi4UgAgiQIASRQASKIAQBIFAJLdR29T7012ar2PybPP1j72f/5np40D0K9cKQCQRAGAJAoAJFEAIIkCAMnuo7cppefrw+ucvebmnTcPQH9ypQBAEgUAkigAkEQBgOSJ5rep9zIXEyZ0X9uypfaxO3bsvHkA+pMrBQCSKACQRAGAJAoAJFEAINl99Db1dg7de2/3tf/+79rHdnTsvHkA+pMrBQCSKACQRAGAJAoAJFEAIA353UctLS0110udd9OpVCo11199tee3OWZMz4+tZ2fNPRB6M/tgmrtRNep9pVHnHupcKQCQRAGAJAoAJFEAIIkCAGnI7z7q7Oysud4xyF+gqN4OjME+d0Rjz96I3MfZmVwpAJBEAYAkCgAkUQAgiQIAacjvPmrU10tp1LkjGnv2RtSo57tR5x7qXCkAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkIYP9AB9raWlpeZ6KaXmeqVS6cNpeq5R547o3eyDae5G1aj3lUade6hzpQBAEgUAkigAkEQBgCQKAKQhv/uos7Oz5npHR0c/T9I79XZgDPa5Ixp79kbkPs7O5EoBgCQKACRRACCJAgBJFABIQ373UaO+Xkqjzh3R2LM3okY9340691DnSgGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGANHygB+hrLS0tNddLKTXXK5VKH07Tc406d0TvZh9Mc3/uc/9Vc33DhtrHN/Xhj1RdXd3XDj209rHt7ZNqrg/2+8qWzVsGeoT37uKBHqDvuFIAIIkCAEkUAEiiAEASBQDSkN991NnZWXO9o6OjnyfpnXo7Rwb73BGNO3u9XUb33197fXgf/t+zfXvPj23U+3hdte8+A2NwbNTqV64UAEiiAEASBQCSKACQRAGANOR3Hw2W13nprUadO6JxZ6/3Wkb1dhnttlv3tTobr+qu90a9+Rr1fPdaX36Zg2nH0wBzpQBAEgUAkigAkEQBgCQKAKQhv/sI+kqtHUX1NgKNGlV7fdiw2utbGvhNyWhsrhQASKIAQBIFAJIoAJBEAYBk9xH0g0mTaq+PHFl7/X//t/vaznj9JHg3rhQASKIAQBIFAJIoAJA80QzvUa2XtKj3ZPBjj9Ve782Tx7vKe+kwsFwpAJBEAYAkCgAkUQAgiQIAye4j+D9dXbXXt2/v3znq3Wa9+XYZXuajX7hSACCJAgBJFABIogBAEgUAkt1H8H8OPbR3xzf14Y9UtXYa9Xa+IacvX/vJzqbkSgGAJAoAJFEAIIkCAEkUAEiVUnr23k8XX3xxX88CQB+6+eab3/UYVwoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAECqlFLKQA8BwODgSgGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGA9P8AwhxpshzyH3sAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def unroll(key, num_steps=5):\n", + " timestep = env.reset(key)\n", + " actions = jax.random.randint(key, (num_steps,), 0, env.action_space.n)\n", + "\n", + " steps = [timestep]\n", + " for action in actions:\n", + " timestep = env.step(timestep, action)\n", + " steps.append(timestep)\n", + "\n", + " return steps\n", + "\n", + "# Unroll and print steps\n", + "steps = unroll(key, num_steps=5)\n", + "render(steps[-1].observation, \"Last observation\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Take home message:**\n", + "1. `env.step`, take two arguments: the current state of the environment (the `timestep`), and the action to take, and returns the new environment state.\n", + "2. Despite `env.step` being stochastic, it does not take a `key` argument. This is because NAVIX manages the PRNGKey internally.\n", + "3. You can still sample different environments by sampling different `keys` when creating the environment.\n", + "\n", + "This way of using NAVIX is suboptimal (and probably slower than using `gym`), as it does not take advantage of JAX's JIT compiler. We'll see how to do that in the next section." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## `Optimizing with JAX`\n", + "\n", + "One of the major perks of NAVIX is its performance optimization capabilities through JAX. We can use JAX's `jit` and `vmap` to compile and parallelize our simulation code.\n", + "We can compile the `step` function to make it faster. This is done by using the `jax.jit` decorator." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### JIT Compilation" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGbCAYAAAAr/4yjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAU+UlEQVR4nO3dfWyddd348c/pxjpARjfYJmRQ2GBjg8F0xDuEh4VNiGNsQHgYGmGAI5oxRYGYONStakjABOU3JiA/hTlMZHCTGxEFhtSgyK0SAQMTnfIQY9A9tEyB8bD1e//hzScceo603F3b071eyf7od9fO9enFoe9ePd+2lVJKCQCIiKaBHgCAwUMUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUaHgrVqyISqUSmzdvHuhR+t1b7zv0FVHYhd16661RqVTiscce65fzrV+/PlasWBHPP/98v5xvqHj11VdjxYoV8bOf/WygR2EXIAr0m/Xr10dbW5so9NKrr74abW1tNaPwxS9+MbZt29b/QzFkiQL0ka6urnjttdf69ZzDhw+PkSNH9us5GdpEgX/rjTfeiC9/+csxc+bM2HvvvWPPPfeM448/Ptrb27sd+4Mf/CBmzpwZe+21V4waNSqmT58e1113XUT860tVZ599dkREnHjiiVGpVKJSqbzrl0QeeuihOP7442PPPfeMlpaWOO200+L3v/99zWM3b94c55xzTowaNSr22WefuPTSS7t9kF63bl0cd9xx0dLSEu973/tiypQpsWzZsqpjXn/99Vi+fHkccsgh0dzcHAcccEB8/vOfj9dff73quEqlEkuXLo3vf//7cfjhh0dzc3Pcc889MWbMmLjwwgu7zfePf/wjRo4cGVdccUWPr+3zzz8fY8eOjYiItra2vG4rVqyIiNqvKWzfvj2++tWvxqRJk6K5uTkOOuigWLZsWbf5DzrooDj11FPjF7/4RXzoQx+KkSNHxsSJE+N73/tevf8c7AoKu6xbbrmlRET5zW9+U/eYTZs2lf32269cdtll5YYbbijXXHNNmTJlStltt93K448/nsc98MADJSLKnDlzyqpVq8qqVavK0qVLy9lnn11KKeXPf/5z+cxnPlMioixbtqysWbOmrFmzpvztb3+re+5169aV4cOHl8mTJ5drrrmmtLW1lX333beMHj26PPfcc3nc8uXLS0SU6dOnl/nz55frr7++fPzjHy8RUc4777w87qmnniojRowoRx99dLnuuuvKjTfeWK644opywgkn5DE7duwoJ598ctljjz3KZz/72XLTTTeVpUuXluHDh5fTTjutar6IKFOnTi1jx44tbW1tZdWqVeXxxx8vF110UWlpaSmvv/561fGrV6+uut49ubYvv/xyueGGG0pElDPOOCOv25NPPln1vr/dokWLSkSUs846q6xataqcf/75JSLK6aefXnVca2trmTJlShk/fnxZtmxZuf7668sHP/jBUqlUylNPPVX3vwtDmyjswnoShe3bt3f74NbZ2VnGjx9fLrrooly79NJLy6hRo8r27dvrPtYdd9xRIqK0t7f3aL4ZM2aUcePGlS1btuTak08+WZqamsr555+fa299YFywYEHVv1+yZEmJiPwA+o1vfKNERNm0aVPdc65Zs6Y0NTWVn//851XrN954Y4mI8sgjj+RaRJSmpqby9NNPVx17//33l4go99xzT9X6KaecUiZOnJhv9/Tabtq0qUREWb58ebd53xmFJ554okREWbx4cdVxV1xxRYmI8tBDD+Vaa2triYjy8MMP59rGjRtLc3Nzufzyy7udi12DLx/xbw0bNixGjBgREf/6mnlHR0ds3749jj766Pjtb3+bx7W0tMQrr7wS69at65Pzvvjii/HEE0/EBRdcEGPGjMn1I488Mk466aT48Y9/3O3fXHLJJVVvf/rTn46IyGNbWloiIuLuu++Orq6umue94447YurUqXHYYYfF5s2b88/s2bMjIrp92WzWrFkxbdq0qrXZs2fHvvvuG7fffnuudXZ2xrp162LhwoW51tNr2xtvva+XXXZZ1frll18eERH33ntv1fq0adPi+OOPz7fHjh0bU6ZMiWefffY9nZ/GJwq8q9WrV8eRRx4ZI0eOjH322SfGjh0b9957b2zdujWPWbJkSUyePDnmzp0bEyZMiIsuuijuu+++93zOF154ISIipkyZ0u3vpk6dGps3b45XXnmlav3QQw+tenvSpEnR1NSUu50WLlwYxx57bCxevDjGjx8f5557bqxdu7YqEBs2bIinn346xo4dW/Vn8uTJERGxcePGqnMcfPDB3eYbPnx4nHnmmXH33Xfn1/HvuuuuePPNN6uiENGza9sbL7zwQjQ1NcUhhxxStf7+978/Wlpa8rq+5cADD+z2GKNHj47Ozs73dH4anyjwb912221xwQUXxKRJk+I73/lO3HfffbFu3bqYPXt21QfTcePGxRNPPBE//OEPY8GCBdHe3h5z586NRYsWDdjs73wBdvfdd4+HH344HnzwwTjvvPPid7/7XSxcuDBOOumk2LFjR0T86zP26dOnx7p162r+WbJkSbfHrOXcc8+Nf/7zn/GTn/wkIiLWrl0bhx12WBx11FF5TE+vbV+87/UMGzas5nrxW3p3WcMHegAGtzvvvDMmTpwYd911V9UHmuXLl3c7dsSIETF//vyYP39+dHV1xZIlS+Kmm26KL33pS3HIIYf06jtvW1tbIyLiD3/4Q7e/e+aZZ2LfffeNPffcs2p9w4YNVZ+5/+lPf4qurq446KCDcq2pqSnmzJkTc+bMiWuvvTauuuqquPLKK6O9vT0+/OEPx6RJk+LJJ5+MOXPm/J++U/iEE06I/fbbL26//fY47rjj4qGHHoorr7yy6pieXtveXreurq7YsGFDTJ06Ndf//ve/x0svvZTXFepxp8C/9dZnkm//zPFXv/pVPProo1XHbdmypertpqamOPLIIyMi8ksob30Qf+mll971vPvtt1/MmDEjVq9eXXX8U089FQ888ECccsop3f7NqlWrqt5euXJlRETMnTs3IiI6Ojq6/ZsZM2ZUzXjOOefEX//617j55pu7Hbtt27ZuX7Kqp6mpKc4666y45557Ys2aNbF9+/ZuXzrq6bXdY489IqJn1+2t6/LNb36zav3aa6+NiIh58+b1aH52Xe4UiO9+97s1v/5/6aWXxqmnnhp33XVXnHHGGTFv3rx47rnn4sYbb4xp06bFyy+/nMcuXrw4Ojo6Yvbs2TFhwoR44YUXYuXKlTFjxoz8jHXGjBkxbNiwuPrqq2Pr1q3R3Nwcs2fPjnHjxtWc6+tf/3rMnTs3jjnmmPjEJz4R27Zti5UrV8bee++d+/Tf7rnnnosFCxbERz7ykXj00Ufjtttui4997GP5JZuvfOUr8fDDD8e8efOitbU1Nm7cGN/61rdiwoQJcdxxx0VExHnnnRdr166NT33qU9He3h7HHnts7NixI5555plYu3Zt3H///XH00Uf36LouXLgwVq5cGcuXL4/p06dXfeYeET2+trvvvntMmzYtbr/99pg8eXKMGTMmjjjiiDjiiCO6nfOoo46KRYsWxbe//e146aWXYtasWfHrX/86Vq9eHaeffnqceOKJPZqdXdgA735iAL21JbXen7/85S+lq6urXHXVVaW1tbU0NzeXD3zgA+VHP/pRWbRoUWltbc3HuvPOO8vJJ59cxo0bV0aMGFEOPPDA8slPfrK8+OKLVee8+eaby8SJE8uwYcN6tD31wQcfLMcee2zZfffdy6hRo8r8+fPL+vXrq455a1vm+vXry1lnnVX22muvMnr06LJ06dKybdu2PO6nP/1pOe2008r+++9fRowYUfbff//y0Y9+tPzxj3+serw33nijXH311eXwww8vzc3NZfTo0WXmzJmlra2tbN26NY+LiHLJJZfUnb2rq6sccMABJSLK1772tZp/35NrW0opv/zlL8vMmTPLiBEjqran1vo+hTfffLO0tbWVgw8+uOy2227lgAMOKF/4whfKa6+9VnVca2trmTdvXre5Zs2aVWbNmlX3/WJoq5TiFSUA/sVrCgAkUQAgiQIASRQASKIAQBIFAFKPv3nt4osv3plzALCT1fpO/XdypwBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKAKThAz3Aztbe3l5zvbOzs+Z6pVLZmeP0WEtLS831wT53RO9mrzf3jl6ec1gvj6/ljDPO6INH6X+e4/2vUZ8rPeFOAYAkCgAkUQAgiQIASRQASEN+91G9nQwdHR39PEnvlFJqrg/2uSP6ZvYP9PKcj/fy+KHEc5y+5E4BgCQKACRRACCJAgBJFABIQ3730WD6eSm90ahzR/TN7J/r5fHn/5/P2Lga9bnSqHMPde4UAEiiAEASBQCSKACQRAGANOR3HzG4vb/O+ql98Dh/6+VjAO4UAHgbUQAgiQIASRQASF5oZkCdU2d9dB88zv/r5WMA7hQAeBtRACCJAgBJFABIogBAsvuIPtdVZ73Wk21xH52z1uPcUOfYN/vonDAUuVMAIIkCAEkUAEiiAEASBQCS3Uf0uXq7ez5UY63eL9nZ2Mtz1nqc/6hz7C96+diwK3GnAEASBQCSKACQRAGAJAoAJLuP6HMj6qw/XmPtiDrHll6es1Jj7ZVePgbgTgGAtxEFAJIoAJBEAYA05F9obmlpqbleSu2XMiuVWi9Z9r9GnTui/uxRY/YdO3Hukb1cb1SN+lxp1LmHOncKACRRACCJAgBJFABIogBAGvK7jzo7O2uud3R09PMkvVNvB8ZgnzuisWdvRJ7j9CV3CgAkUQAgiQIASRQASKIAQBryu48a9eelNOrcEY09eyNq1OvdqHMPde4UAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEAaPtAD7GwtLS0110spNde7uio118eP7752zDG1z/nII7XXN22qvT5sWPe13s5dqdSeeyD0ZvbBNHejatTnSqPOPdS5UwAgiQIASRQASKIAQBIFANKQ333U2dlZc72jo6Pmeq2dQBERt93Wfe3gg2sfu2BB7fUzz6y9XmuzRb0dGPXmHkwaefZG1Nvn+GDheTI4uVMAIIkCAEkUAEiiAEAa8i809/Zb4/feu/b64Yd3X7vlltrHLlpUe33UqNrrW7d2X2vkb+lv5NkbUaNe70ade6hzpwBAEgUAkigAkEQBgCQKAKQhv/uot+p9h/3y5d3Xbr219rEXXlh7vdYuI4DBxJ0CAEkUAEiiAEASBQCSKACQ7D56h3q/ZKfW7zF59tnax/7Xf/XZOAD9yp0CAEkUAEiiAEASBQCSKACQ7D56h1J6vj68ztVrbu67eQD6kzsFAJIoAJBEAYAkCgAkLzS/Q70fczFhQve1LVtqH7tjR9/NA9Cf3CkAkEQBgCQKACRRACCJAgDJ7qN3qLdz6N57u6/953/WPrajo+/mAehP7hQASKIAQBIFAJIoAJBEAYA05HcftbS01FwvdX6bTqVSqbn+6qs9P+eYMT0/tp6+mnsg9Gb2wTR3o2rU50qjzj3UuVMAIIkCAEkUAEiiAEASBQDSkN991NnZWXO9Y5D/gKJ6OzAG+9wRjT17I/Icpy+5UwAgiQIASRQASKIAQBIFANKQ333UqD8vpVHnjmjs2RtRo17vRp17qHOnAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEAaPtAD7GwtLS0110spNdcrlcpOnKbnGnXuiN7NPpjmblSN+lxp1LmHOncKACRRACCJAgBJFABIogBAGvK7jzo7O2uud3R09PMkvVNvB8ZgnzuisWdvRJ7j9CV3CgAkUQAgiQIASRQASKIAQBryu48a9eelNOrcEY09eyNq1OvdqHMPde4UAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEjDB3qAna2lpaXmeiml5nqlUtmJ0/Rco84d0bvZB9Pcn/vc/6+5vmFD7eObduKnVF1d3dcOPbT2se3tk2quD/bnypbNWwZ6hPfu4oEeYOdxpwBAEgUAkigAkEQBgCQKAKQhv/uos7Oz5npHR0c/T9I79XaODPa5Ixp39nq7jO6/v/b68J34f8/27T0/tlGf43XVfvoMjMGxUatfuVMAIIkCAEkUAEiiAEASBQDSkN99NFh+zktvNercEY07e72fZVRvl9Fuu3Vfq7Pxqu56b9Sbr1Gvd6/tzHdzMO14GmDuFABIogBAEgUAkigAkEQBgDTkdx/BzlJrR1G9jUCjRtVeHzas9vqWBv6lZDQ2dwoAJFEAIIkCAEkUAEiiAECy+wj6waRJtddHjqy9/t//3X2tL35+ErwbdwoAJFEAIIkCAEkUAEheaIb3qNaPtKj3YvBjj9Ve782Lx7vK79JhYLlTACCJAgBJFABIogBAEgUAkt1H8L+6umqvb9/ev3PUO2e9+XYZfsxHv3CnAEASBQCSKACQRAGAJAoAJLuP4H8demjvjm/aiZ9S1dpp1Nv5hpyd+bOf7GxK7hQASKIAQBIFAJIoAJBEAYBUKaVnv/vp4osv3tmzALAT3Xzzze96jDsFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgVUopZaCHAGBwcKcAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQPofTwdpsxkEwUoAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "@jax.jit\n", + "def env_step_jit(timestep, action):\n", + " return env.step(timestep, action)\n", + "\n", + "def unroll_jit_step(key, num_steps=10):\n", + " timestep = env.reset(key)\n", + " actions = jax.random.randint(key, (num_steps,), 0, env.action_space.n)\n", + "\n", + " steps = [timestep]\n", + " for action in actions:\n", + " timestep = env_step_jit(timestep, action)\n", + " steps.append(timestep)\n", + "\n", + " return steps\n", + "\n", + "# Example usage\n", + "steps = unroll_jit_step(key, num_steps=10)\n", + "render(steps[-1].observation, \"Last observation\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's compare the two head to head." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "27.4 s ± 130 ms per loop (mean ± std. dev. of 3 runs, 1 loop each)\n" + ] + } + ], + "source": [ + "%timeit -n 1 -r 3 unroll(key, num_steps=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "328 ns ± 153 ns per loop (mean ± std. dev. of 3 runs, 1 loop each)\n" + ] + } + ], + "source": [ + "%timeit -n 1 -r 3 lambda: unroll_jit_step(key, num_steps=10)[-1].block_until_ready()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice that it's roughly in the order of $10^9$ times faster compared to its unjitted counterpart.\n", + "\n", + "But that's not the end of the story.\n", + "We can go even further and `jit` the whole simulation loop, which improves performance even more." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGbCAYAAAAr/4yjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAU+UlEQVR4nO3dfWyddd348c/pxjpARjfYJmRQ2GBjg8F0xDuEh4VNiGNsQHgYGmGAI5oxRYGYONStakjABOU3JiA/hTlMZHCTGxEFhtSgyK0SAQMTnfIQY9A9tEyB8bD1e//hzScceo603F3b071eyf7od9fO9enFoe9ePd+2lVJKCQCIiKaBHgCAwUMUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUaHgrVqyISqUSmzdvHuhR+t1b7zv0FVHYhd16661RqVTiscce65fzrV+/PlasWBHPP/98v5xvqHj11VdjxYoV8bOf/WygR2EXIAr0m/Xr10dbW5so9NKrr74abW1tNaPwxS9+MbZt29b/QzFkiQL0ka6urnjttdf69ZzDhw+PkSNH9us5GdpEgX/rjTfeiC9/+csxc+bM2HvvvWPPPfeM448/Ptrb27sd+4Mf/CBmzpwZe+21V4waNSqmT58e1113XUT860tVZ599dkREnHjiiVGpVKJSqbzrl0QeeuihOP7442PPPfeMlpaWOO200+L3v/99zWM3b94c55xzTowaNSr22WefuPTSS7t9kF63bl0cd9xx0dLSEu973/tiypQpsWzZsqpjXn/99Vi+fHkccsgh0dzcHAcccEB8/vOfj9dff73quEqlEkuXLo3vf//7cfjhh0dzc3Pcc889MWbMmLjwwgu7zfePf/wjRo4cGVdccUWPr+3zzz8fY8eOjYiItra2vG4rVqyIiNqvKWzfvj2++tWvxqRJk6K5uTkOOuigWLZsWbf5DzrooDj11FPjF7/4RXzoQx+KkSNHxsSJE+N73/tevf8c7AoKu6xbbrmlRET5zW9+U/eYTZs2lf32269cdtll5YYbbijXXHNNmTJlStltt93K448/nsc98MADJSLKnDlzyqpVq8qqVavK0qVLy9lnn11KKeXPf/5z+cxnPlMioixbtqysWbOmrFmzpvztb3+re+5169aV4cOHl8mTJ5drrrmmtLW1lX333beMHj26PPfcc3nc8uXLS0SU6dOnl/nz55frr7++fPzjHy8RUc4777w87qmnniojRowoRx99dLnuuuvKjTfeWK644opywgkn5DE7duwoJ598ctljjz3KZz/72XLTTTeVpUuXluHDh5fTTjutar6IKFOnTi1jx44tbW1tZdWqVeXxxx8vF110UWlpaSmvv/561fGrV6+uut49ubYvv/xyueGGG0pElDPOOCOv25NPPln1vr/dokWLSkSUs846q6xataqcf/75JSLK6aefXnVca2trmTJlShk/fnxZtmxZuf7668sHP/jBUqlUylNPPVX3vwtDmyjswnoShe3bt3f74NbZ2VnGjx9fLrrooly79NJLy6hRo8r27dvrPtYdd9xRIqK0t7f3aL4ZM2aUcePGlS1btuTak08+WZqamsr555+fa299YFywYEHVv1+yZEmJiPwA+o1vfKNERNm0aVPdc65Zs6Y0NTWVn//851XrN954Y4mI8sgjj+RaRJSmpqby9NNPVx17//33l4go99xzT9X6KaecUiZOnJhv9/Tabtq0qUREWb58ebd53xmFJ554okREWbx4cdVxV1xxRYmI8tBDD+Vaa2triYjy8MMP59rGjRtLc3Nzufzyy7udi12DLx/xbw0bNixGjBgREf/6mnlHR0ds3749jj766Pjtb3+bx7W0tMQrr7wS69at65Pzvvjii/HEE0/EBRdcEGPGjMn1I488Mk466aT48Y9/3O3fXHLJJVVvf/rTn46IyGNbWloiIuLuu++Orq6umue94447YurUqXHYYYfF5s2b88/s2bMjIrp92WzWrFkxbdq0qrXZs2fHvvvuG7fffnuudXZ2xrp162LhwoW51tNr2xtvva+XXXZZ1frll18eERH33ntv1fq0adPi+OOPz7fHjh0bU6ZMiWefffY9nZ/GJwq8q9WrV8eRRx4ZI0eOjH322SfGjh0b9957b2zdujWPWbJkSUyePDnmzp0bEyZMiIsuuijuu+++93zOF154ISIipkyZ0u3vpk6dGps3b45XXnmlav3QQw+tenvSpEnR1NSUu50WLlwYxx57bCxevDjGjx8f5557bqxdu7YqEBs2bIinn346xo4dW/Vn8uTJERGxcePGqnMcfPDB3eYbPnx4nHnmmXH33Xfn1/HvuuuuePPNN6uiENGza9sbL7zwQjQ1NcUhhxxStf7+978/Wlpa8rq+5cADD+z2GKNHj47Ozs73dH4anyjwb912221xwQUXxKRJk+I73/lO3HfffbFu3bqYPXt21QfTcePGxRNPPBE//OEPY8GCBdHe3h5z586NRYsWDdjs73wBdvfdd4+HH344HnzwwTjvvPPid7/7XSxcuDBOOumk2LFjR0T86zP26dOnx7p162r+WbJkSbfHrOXcc8+Nf/7zn/GTn/wkIiLWrl0bhx12WBx11FF5TE+vbV+87/UMGzas5nrxW3p3WcMHegAGtzvvvDMmTpwYd911V9UHmuXLl3c7dsSIETF//vyYP39+dHV1xZIlS+Kmm26KL33pS3HIIYf06jtvW1tbIyLiD3/4Q7e/e+aZZ2LfffeNPffcs2p9w4YNVZ+5/+lPf4qurq446KCDcq2pqSnmzJkTc+bMiWuvvTauuuqquPLKK6O9vT0+/OEPx6RJk+LJJ5+MOXPm/J++U/iEE06I/fbbL26//fY47rjj4qGHHoorr7yy6pieXtveXreurq7YsGFDTJ06Ndf//ve/x0svvZTXFepxp8C/9dZnkm//zPFXv/pVPProo1XHbdmypertpqamOPLIIyMi8ksob30Qf+mll971vPvtt1/MmDEjVq9eXXX8U089FQ888ECccsop3f7NqlWrqt5euXJlRETMnTs3IiI6Ojq6/ZsZM2ZUzXjOOefEX//617j55pu7Hbtt27ZuX7Kqp6mpKc4666y45557Ys2aNbF9+/ZuXzrq6bXdY489IqJn1+2t6/LNb36zav3aa6+NiIh58+b1aH52Xe4UiO9+97s1v/5/6aWXxqmnnhp33XVXnHHGGTFv3rx47rnn4sYbb4xp06bFyy+/nMcuXrw4Ojo6Yvbs2TFhwoR44YUXYuXKlTFjxoz8jHXGjBkxbNiwuPrqq2Pr1q3R3Nwcs2fPjnHjxtWc6+tf/3rMnTs3jjnmmPjEJz4R27Zti5UrV8bee++d+/Tf7rnnnosFCxbERz7ykXj00Ufjtttui4997GP5JZuvfOUr8fDDD8e8efOitbU1Nm7cGN/61rdiwoQJcdxxx0VExHnnnRdr166NT33qU9He3h7HHnts7NixI5555plYu3Zt3H///XH00Uf36LouXLgwVq5cGcuXL4/p06dXfeYeET2+trvvvntMmzYtbr/99pg8eXKMGTMmjjjiiDjiiCO6nfOoo46KRYsWxbe//e146aWXYtasWfHrX/86Vq9eHaeffnqceOKJPZqdXdgA735iAL21JbXen7/85S+lq6urXHXVVaW1tbU0NzeXD3zgA+VHP/pRWbRoUWltbc3HuvPOO8vJJ59cxo0bV0aMGFEOPPDA8slPfrK8+OKLVee8+eaby8SJE8uwYcN6tD31wQcfLMcee2zZfffdy6hRo8r8+fPL+vXrq455a1vm+vXry1lnnVX22muvMnr06LJ06dKybdu2PO6nP/1pOe2008r+++9fRowYUfbff//y0Y9+tPzxj3+serw33nijXH311eXwww8vzc3NZfTo0WXmzJmlra2tbN26NY+LiHLJJZfUnb2rq6sccMABJSLK1772tZp/35NrW0opv/zlL8vMmTPLiBEjqran1vo+hTfffLO0tbWVgw8+uOy2227lgAMOKF/4whfKa6+9VnVca2trmTdvXre5Zs2aVWbNmlX3/WJoq5TiFSUA/sVrCgAkUQAgiQIASRQASKIAQBIFAFKPv3nt4osv3plzALCT1fpO/XdypwBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKAKThAz3Aztbe3l5zvbOzs+Z6pVLZmeP0WEtLS831wT53RO9mrzf3jl6ec1gvj6/ljDPO6INH6X+e4/2vUZ8rPeFOAYAkCgAkUQAgiQIASRQASEN+91G9nQwdHR39PEnvlFJqrg/2uSP6ZvYP9PKcj/fy+KHEc5y+5E4BgCQKACRRACCJAgBJFABIQ3730WD6eSm90ahzR/TN7J/r5fHn/5/P2Lga9bnSqHMPde4UAEiiAEASBQCSKACQRAGANOR3HzG4vb/O+ql98Dh/6+VjAO4UAHgbUQAgiQIASRQASF5oZkCdU2d9dB88zv/r5WMA7hQAeBtRACCJAgBJFABIogBAsvuIPtdVZ73Wk21xH52z1uPcUOfYN/vonDAUuVMAIIkCAEkUAEiiAEASBQCS3Uf0uXq7ez5UY63eL9nZ2Mtz1nqc/6hz7C96+diwK3GnAEASBQCSKACQRAGAJAoAJLuP6HMj6qw/XmPtiDrHll6es1Jj7ZVePgbgTgGAtxEFAJIoAJBEAYA05F9obmlpqbleSu2XMiuVWi9Z9r9GnTui/uxRY/YdO3Hukb1cb1SN+lxp1LmHOncKACRRACCJAgBJFABIogBAGvK7jzo7O2uud3R09PMkvVNvB8ZgnzuisWdvRJ7j9CV3CgAkUQAgiQIASRQASKIAQBryu48a9eelNOrcEY09eyNq1OvdqHMPde4UAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEAaPtAD7GwtLS0110spNde7uio118eP7752zDG1z/nII7XXN22qvT5sWPe13s5dqdSeeyD0ZvbBNHejatTnSqPOPdS5UwAgiQIASRQASKIAQBIFANKQ333U2dlZc72jo6Pmeq2dQBERt93Wfe3gg2sfu2BB7fUzz6y9XmuzRb0dGPXmHkwaefZG1Nvn+GDheTI4uVMAIIkCAEkUAEiiAEAa8i809/Zb4/feu/b64Yd3X7vlltrHLlpUe33UqNrrW7d2X2vkb+lv5NkbUaNe70ade6hzpwBAEgUAkigAkEQBgCQKAKQhv/uot+p9h/3y5d3Xbr219rEXXlh7vdYuI4DBxJ0CAEkUAEiiAEASBQCSKACQ7D56h3q/ZKfW7zF59tnax/7Xf/XZOAD9yp0CAEkUAEiiAEASBQCSKACQ7D56h1J6vj68ztVrbu67eQD6kzsFAJIoAJBEAYAkCgAkLzS/Q70fczFhQve1LVtqH7tjR9/NA9Cf3CkAkEQBgCQKACRRACCJAgDJ7qN3qLdz6N57u6/953/WPrajo+/mAehP7hQASKIAQBIFAJIoAJBEAYA05HcftbS01FwvdX6bTqVSqbn+6qs9P+eYMT0/tp6+mnsg9Gb2wTR3o2rU50qjzj3UuVMAIIkCAEkUAEiiAEASBQDSkN991NnZWXO9Y5D/gKJ6OzAG+9wRjT17I/Icpy+5UwAgiQIASRQASKIAQBIFANKQ333UqD8vpVHnjmjs2RtRo17vRp17qHOnAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEAaPtAD7GwtLS0110spNdcrlcpOnKbnGnXuiN7NPpjmblSN+lxp1LmHOncKACRRACCJAgBJFABIogBAGvK7jzo7O2uud3R09PMkvVNvB8ZgnzuisWdvRJ7j9CV3CgAkUQAgiQIASRQASKIAQBryu48a9eelNOrcEY09eyNq1OvdqHMPde4UAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEjDB3qAna2lpaXmeiml5nqlUtmJ0/Rco84d0bvZB9Pcn/vc/6+5vmFD7eObduKnVF1d3dcOPbT2se3tk2quD/bnypbNWwZ6hPfu4oEeYOdxpwBAEgUAkigAkEQBgCQKAKQhv/uos7Oz5npHR0c/T9I79XaODPa5Ixp39nq7jO6/v/b68J34f8/27T0/tlGf43XVfvoMjMGxUatfuVMAIIkCAEkUAEiiAEASBQDSkN99NFh+zktvNercEY07e72fZVRvl9Fuu3Vfq7Pxqu56b9Sbr1Gvd6/tzHdzMO14GmDuFABIogBAEgUAkigAkEQBgDTkdx/BzlJrR1G9jUCjRtVeHzas9vqWBv6lZDQ2dwoAJFEAIIkCAEkUAEiiAECy+wj6waRJtddHjqy9/t//3X2tL35+ErwbdwoAJFEAIIkCAEkUAEheaIb3qNaPtKj3YvBjj9Ve782Lx7vK79JhYLlTACCJAgBJFABIogBAEgUAkt1H8L+6umqvb9/ev3PUO2e9+XYZfsxHv3CnAEASBQCSKACQRAGAJAoAJLuP4H8demjvjm/aiZ9S1dpp1Nv5hpyd+bOf7GxK7hQASKIAQBIFAJIoAJBEAYBUKaVnv/vp4osv3tmzALAT3Xzzze96jDsFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgVUopZaCHAGBwcKcAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQPofTwdpsxkEwUoAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def unroll_scan(key, num_steps=10):\n", + " timestep = env.reset(key)\n", + " actions = jax.random.randint(key, (num_steps,), 0, env.action_space.n)\n", + "\n", + " timestep, _ = jax.lax.scan(\n", + " lambda timestep, action: (env.step(timestep, action), ()),\n", + " timestep,\n", + " actions,\n", + " unroll=10,\n", + " )\n", + " return timestep\n", + "\n", + "\n", + "# Example usage\n", + "unroll_jit_loop = jax.jit(unroll_scan, static_argnums=(1,))\n", + "timestep = unroll_jit_loop(key, num_steps=10)\n", + "render(timestep.observation, \"Last observation\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "40.5 ms ± 1.49 ms per loop (mean ± std. dev. of 5 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "%timeit -n 10 -r 5 unroll_jit_step(key, num_steps=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "353 µs ± 71.3 µs per loop (mean ± std. dev. of 5 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "%timeit -n 10 -r 5 unroll_jit_loop(key, num_steps=10).t.block_until_ready()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We improved the performance by three more orders of magnitude, and we are at $10^12$.\n", + "This is because we are now compiling the whole simulation loop, not just the `step` function.\n", + "\n", + "That's still not the end of the story. We can improve the performance even more by using `jax.vmap` to parallelize multiple environment simulations." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Batched environments\n", + "\n", + "We can run multiple simulations in parallel using `vmap`." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "# Let's compile the function ahead of time\n", + "num_envs = 32\n", + "keys = jax.random.split(key, num_envs)\n", + "unroll_batched = jax.jit(jax.vmap(unroll_scan, in_axes=(0, None)), static_argnums=(1,)).lower(keys, 10).compile()" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGbCAYAAAAr/4yjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAZIklEQVR4nO3de3DU5b3H8c8mIQm3ZEHCrcQAAcIdBKYdBcwQKIIYCIerHSFIoXWAchHqTKFtiG3pIC3oAAJ6RoyhMwUtI4MUSSxx8HZaGYU5QGlRgbYOlEvC/Rr2OX84+R6X/a3s0tx5v2b4I09+2d+TZ0Pe+e0+2ficc04AAEiKqekJAABqD6IAADBEAQBgiAIAwBAFAIAhCgAAQxQAAIYoAAAMUQAAGKKAqCxdulQ+n09nzpyp6alUu4rPvTYrLy/XM888o9TUVMXExCgnJ6emp4Q6hihUk1dffVU+n0979+6tlvMdOnRIS5cu1bFjx6rlfPXFlStXtHTpUr377rs1PZW78sorr2jFihUaP368CgoKtGDBgpqeUlQ+/PBDDRo0SI0aNVLr1q01d+5cXbp0qaandU8hCvXUoUOHlJ+fTxSidOXKFeXn53tG4ac//amuXr1a/ZOKwu7du/Wtb31Lq1at0pQpU5SZmVnTU4rYvn37NHToUF25ckUrV67UjBkz9NJLL2nChAk1PbV7SlxNTwC4G4FAQDdu3FBiYmK1nTMuLk5xcbX7v8ypU6fk9/trehp3ZfHixWrWrJneffddJSUlSZLat2+vmTNnqqioSMOHD6/hGd4buFKoRW7cuKGf//zn6t+/v5KTk9W4cWMNHjxYJSUlIcf+/ve/V//+/dW0aVMlJSWpV69eeuGFFyR99VBVxU9XQ4YMkc/nk8/nu+NDIrt379bgwYPVuHFj+f1+jRkzRn/96189jz1z5owmTpyopKQk3XfffZo3b56uXbsWdExxcbEGDRokv9+vJk2aKCMjQ4sXLw465vr168rLy1OnTp2UkJCg1NRUPfPMM7p+/XrQcT6fT3PmzNHvfvc79ejRQwkJCdq+fbuaN2+uJ598MmR+Fy5cUGJiohYtWhTx2h47dkwpKSmSpPz8fFu3pUuXSvJ+TqG8vFy/+MUvlJ6eroSEBLVv316LFy8OmX/79u312GOP6f3339e3v/1tJSYmqmPHjnrttdfC3R1BLl++rIULFyo1NVUJCQnKyMjQb37zG1W8yPGxY8fk8/lUUlKigwcPRnyf79y50+7zpk2batSoUTp48GDQMdOmTVOTJk305ZdfKicnR02aNFFKSooWLVqkW7duSZJu3rwZ8X3h5cKFCyouLtYTTzxhQZCkqVOnqkmTJtqyZUtE64RK4FAtNm7c6CS5jz/+OOwxp0+fdm3atHFPP/20W7dunXvuuedcRkaGa9Cggfv000/tuKKiIifJDR061K1du9atXbvWzZkzx02YMME559znn3/u5s6d6yS5xYsXu8LCQldYWOhOnjwZ9tzFxcUuLi7OdenSxT333HMuPz/ftWjRwjVr1swdPXrUjsvLy3OSXK9evVx2drZbs2aNe+KJJ5wkN2XKFDvuwIEDLj4+3g0YMMC98MILbv369W7RokXu4YcftmNu3brlhg8f7ho1auTmz5/vNmzY4ObMmePi4uLcmDFjguYnyXXr1s2lpKS4/Px8t3btWvfpp5+66dOnO7/f765fvx50fEFBQdB6R7K2ly5dcuvWrXOS3NixY23d9u/fH/S5f11ubq6T5MaPH+/Wrl3rpk6d6iS5nJycoOPS0tJcRkaGa9WqlVu8eLFbs2aN69evn/P5fO7AgQNh7xfnnAsEAi4rK8v5fD43Y8YMt2bNGpedne0kufnz59vcCwsLXdeuXV27du0ius9fe+015/P53IgRI9zq1avd8uXLXfv27Z3f7w+6z3Nzc11iYqLr0aOHmz59ulu3bp0bN26ck+RefPFFOy7S+8LL+++/7yS5zZs3h7xv0KBBrl+/ft+4Rqg8RKGaRBKF8vLykP9QZWVlrlWrVm769Ok2Nm/ePJeUlOTKy8vD3tbrr7/uJLmSkpKI5te3b1/XsmVLd/bsWRvbv3+/i4mJcVOnTrWxim+Mo0ePDvr4WbNmOUn2DXTVqlVOkjt9+nTYcxYWFrqYmBj33nvvBY2vX7/eSXIffPCBjUlyMTEx7uDBg0HH7tq1y0ly27dvDxp/9NFHXceOHe3tSNf29OnTTpLLy8sLme/tUdi3b5+T5GbMmBF03KJFi5wkt3v3bhtLS0tzktyePXts7NSpUy4hIcEtXLgw5Fxf9+abbzpJ7pe//GXQ+Pjx453P53OfffaZjWVmZroePXp84+0559zFixed3+93M2fODBo/efKkS05ODhqvCN+zzz4bdOwDDzzg+vfvb29Hel94qfh6/fr6VJgwYYJr3br1HT8nVA4ePqpFYmNjFR8fL+mrx8xLS0tVXl6uAQMG6JNPPrHj/H6/Ll++rOLi4ko574kTJ7Rv3z5NmzZNzZs3t/HevXvru9/9rv74xz+GfMzs2bOD3v7Rj34kSXZsxePa27ZtUyAQ8Dzv66+/rm7duqlr1646c+aM/cvKypKkkIfNMjMz1b1796CxrKwstWjRQps3b7axsrIyFRcXa9KkSTYW6dpGo+Jzffrpp4PGFy5cKEnasWNH0Hj37t01ePBgezslJUUZGRn64osv7nie2NhYzZ07N+Q8zjnt3Lkz6rkXFxfr3Llzevzxx4PWPjY2Vt/5znc8H7J86qmngt4ePHhw0NwjvS+8VDyBn5CQEPK+xMTEWv8Ef31CFGqZgoIC9e7dW4mJibrvvvuUkpKiHTt26Pz583bMrFmz1KVLF40cOVLt2rXT9OnT9fbbb9/1OY8fPy5JysjICHlft27ddObMGV2+fDlovHPnzkFvp6enKyYmxnY7TZo0SQMHDtSMGTPUqlUrTZ48WVu2bAkKxJEjR3Tw4EGlpKQE/evSpYukr540/boOHTqEzC8uLk7jxo3Ttm3b7HH8rVu36ubNmyHfiCJZ22gcP35cMTEx6tSpU9B469at5ff7bV0r3H///SG30axZM5WVld3xPG3btlXTpk2Dxrt162bvj9aRI0ckffWN/Pb1LyoqCln7xMREe74l3NyjuS9u17BhQ0kKeS5Gkq5du2bvR9Wr3Vsp7jGbNm3StGnTlJOTox//+Mdq2bKlYmNj9etf/1qff/65HdeyZUvt27dPu3bt0s6dO7Vz505t3LhRU6dOVUFBQY3M/fYnYBs2bKg9e/aopKREO3bs0Ntvv63NmzcrKytLRUVFio2NVSAQUK9evbRy5UrP20xNTQ25TS+TJ0/Whg0btHPnTuXk5GjLli3q2rWr+vTpY8dEuraV8bmHExsb6znuauAv4lbEubCwUK1btw55/+27rMLN/XaR3Bde2rRpI+mrq9bbnThxQm3bto3o/PjPEYVa5I033lDHjh21devWoG80eXl5IcfGx8crOztb2dnZCgQCmjVrljZs2KCf/exn6tSpU1S/eZuWliZJ+tvf/hbyvsOHD6tFixZq3Lhx0PiRI0eCfnL/7LPPFAgE1L59exuLiYnR0KFDNXToUK1cuVLLli3TkiVLVFJSomHDhik9PV379+/X0KFD/6PfFH744YfVpk0bbd68WYMGDdLu3bu1ZMmSoGMiXdto1y0QCOjIkSP2U7sk/fvf/9a5c+dsXf9TaWlpeuedd3Tx4sWgq4XDhw/b+6OVnp4u6asfMIYNG1Yp85Qiuy+89OzZU3Fxcdq7d68mTpxo4zdu3NC+ffuCxlC1ePioFqn4aezrPzn++c9/1kcffRR03NmzZ4PejomJUe/evSX9/+V3xTfxc+fO3fG8bdq0Ud++fVVQUBB0/IEDB1RUVKRHH3005GPWrl0b9Pbq1aslSSNHjpQklZaWhnxM3759g+Y4ceJEffnll3r55ZdDjr169WrIQ1bhxMTEaPz48dq+fbsKCwtVXl4e8nBFpGvbqFEjSZGtW8W6PP/880HjFVc+o0aNimj+kZzn1q1bWrNmTdD4qlWr5PP5bM2j8cgjjygpKUnLli3TzZs3Q95/+vTpu5prJPeFl+TkZA0bNkybNm3SxYsXbbywsFCXLl3iF9iqEVcK1eyVV17xfPx/3rx5euyxx7R161aNHTtWo0aN0tGjR7V+/Xp179496Ff9Z8yYodLSUmVlZaldu3Y6fvy4Vq9erb59+9pPrH379lVsbKyWL1+u8+fPKyEhQVlZWWrZsqXnvFasWKGRI0fqwQcf1Pe//31dvXpVq1evVnJysu3T/7qjR49q9OjRGjFihD766CNt2rRJ3/ve9+xhgmeffVZ79uzRqFGjlJaWplOnTunFF19Uu3btNGjQIEnSlClTtGXLFj311FMqKSnRwIEDdevWLR0+fFhbtmzRrl27NGDAgIjWddKkSVq9erXy8vLUq1evoJ/cJUW8tg0bNlT37t21efNmdenSRc2bN1fPnj3Vs2fPkHP26dNHubm5eumll3Tu3DllZmbqL3/5iwoKCpSTk6MhQ4ZENPc7yc7O1pAhQ7RkyRIdO3ZMffr0UVFRkbZt26b58+fbT/3RSEpK0rp16zRlyhT169dPkydPVkpKiv7xj39ox44dGjhwYEiEInWn+yKcX/3qV3rooYeUmZmpH/zgB/rXv/6l3/72txo+fLhGjBhxV3PBXajZzU/3jootqeH+/fOf/3SBQMAtW7bMpaWluYSEBPfAAw+4t956y+Xm5rq0tDS7rTfeeMMNHz7ctWzZ0sXHx7v777/f/fCHP3QnTpwIOufLL7/sOnbs6GJjYyPanvrOO++4gQMHuoYNG7qkpCSXnZ3tDh06FHRMxbbMQ4cOufHjx7umTZu6Zs2auTlz5rirV6/acX/605/cmDFjXNu2bV18fLxr27ate/zxx93f//73oNu7ceOGW758uevRo4dLSEhwzZo1c/3793f5+fnu/PnzdpwkN3v27LBzDwQCLjU11XPrZsX7I1lb55z78MMPXf/+/V18fHzQ9lSv31O4efOmy8/Pdx06dHANGjRwqamp7ic/+Ym7du1a0HFpaWlu1KhRIfPKzMx0mZmZYT+vChcvXnQLFixwbdu2dQ0aNHCdO3d2K1ascIFAIOT2ItmSWqGkpMQ98sgjLjk52SUmJrr09HQ3bdo0t3fvXjsmNzfXNW7cOORjvdbDuTvfF9/kvffecw899JBLTEx0KSkpbvbs2e7ChQtR3Qb+Mz7nauBZLgBArcRzCgAAQxQAAIYoAAAMUQAAGKIAADBEAQBgIv7ltZkzZ1blPAAAVczr1QNux5UCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAATFxNT6CqlZSUeI6XlZV5jvt8vqqcTsTGjh1b01O4a9GseW1Zb0ny+/2e47X9a4V5V7+6/P/zTrhSAAAYogAAMEQBAGCIAgDAEAUAgKn3u4/C7WQoLS2t5pncO+rqmjvnPMeZd9Woq/Ou77hSAAAYogAAMEQBAGCIAgDAEAUAgKn3u49q0+ul3Cvq6poz7+pVV+dd33GlAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwcTU9garm9/s9x51znuM+n68KZxO5BQv+23P8yBHv42OqMO+BgPd4587e47tL0j3HG3msedMw6x3mlGF5ffqXwxx7Mcx4Xf1aYd6oTFwpAAAMUQAAGKIAADBEAQBgiAIAwNT73UdlZWWe46WlpdU8k+iE22W0a5f3eFwV3pPl5dEdfzLMmvf2WPM3w9yG9/6T8Lz2pfxXmGM/DHcjYXa91PavlXC7dZg37gZXCgAAQxQAAIYoAAAMUQAAGKIAADD1fvdRXX29lHCvZRRul1GDBt7jYTZ4hB2PRrg5JoRZ870eY6fC3HavKOfyvx5jH0d5G3X1a4V5ozJxpQAAMEQBAGCIAgDAEAUAgCEKAABT73cf3SvC7SYKt8EjKSl0LDbW+9izZ6ObS7ifNLxeQsn778tJL0R3Ss/buRnlbQDgSgEA8DVEAQBgiAIAwBAFAIAhCgAAw+6je1R6euhYYqL3sf/zP97jlfH6SVvCjC+tpNsBEB2uFAAAhigAAAxRAAAYogAAMDzRXE+EezmLcE8G7/X4izfRPnFcGX8j5WSY8bcq6XYARIcrBQCAIQoAAEMUAACGKAAADFEAABh2H9VSgYD3eLnXX6qpYuHOGW6OleH5qrtpAN+AKwUAgCEKAABDFAAAhigAAAxRAAAYdh/VUp07R3d8TBXmPdwuo2jnGI1Pqu6mAXwDrhQAAIYoAAAMUQAAGKIAADBEAQBg6v3uI7/f7znuwvyZMV9l/DmxSrBq1diansJd8/tLPMe91ry2rLdUd79WmDcqE1cKAABDFAAAhigAAAxRAAAYogAAMPV+91FZWZnneGlpaTXP5N5RV9c83K4X5l016uq86zuuFAAAhigAAAxRAAAYogAAMPX+iWZ+Nb761dU1Z97Vq67Ou77jSgEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMDE1fQEqprf7/ccd855jvt8viqczb0hmjWvTetdV79WmDcqE1cKAABDFAAAhigAAAxRAAAYogAAMPV+91FZWZnneGlpaTXP5N5RV9c83K4X5l016uq86zuuFAAAhigAAAxRAAAYogAAMEQBAGDq/e4jXi+l+tXVNWfe1auuzru+40oBAGCIAgDAEAUAgCEKAABDFAAAhigAAAxRAAAYogAAMEQBAGCIAgDAEAUAgCEKAABDFAAAhigAAAxRAAAYogAAMEQBAGCIAgDAEAUAgCEKAABDFAAAhigAAAxRAAAYogAAMEQBAGCIAgDAEAUAgCEKAABDFAAAhigAAAxRAAAYogAAMEQBAGDianoCVc3v93uOO+c8xwMBn+d4q1ahYw8+6H3ODz7wHj992ns8NtZ7vK6KZs19Pu/1rgnRfq3Ulrkzb1QmrhQAAIYoAAAMUQAAGKIAADBEAQBg6v3uo7KyMs/x0tJSz/FwO4E2bQod69DB+9jRo73Hx43zHg+z2aLOinbNa4twu16Yd9Woq/Ou77hSAAAYogAAMEQBAGCIAgDAEAUAgKn3u4+ifb2U5GTv8R49Qsc2bvQ+NjfXezwpyXv8/Pk7z6suqauvUcO8q1ddnXd9x5UCAMAQBQCAIQoAAEMUAACGKAAATL3ffRStcC+7kpcXOvbqq97HPvmk93h922UEoP7hSgEAYIgCAMAQBQCAIQoAAMMTzbcJ90d2vP5uzBdfeB/75puVNh0AqFZcKQAADFEAABiiAAAwRAEAYIgCAMCw++g2zkU+Hhdm9RISKm8+AFCduFIAABiiAAAwRAEAYIgCAMAQBQCAYffRbcK99lG7dqFjZ896H3vrVuXNBwCqE1cKAABDFAAAhigAAAxRAAAYogAAMOw+uk24nUM7doSO/eEP3seWllbefACgOnGlAAAwRAEAYIgCAMAQBQCAqfdPNPv9fs9xF+av6fh8Ps/xK1ciP2fz5pEfWx9Fs+bh1rsmVNbXSnWrq/M+eybM68TUBTNregJVhysFAIAhCgAAQxQAAIYoAAAMUQAAmHq/+6isrMxzvJTXoqgydXXNw+3WYd7VzPvTqRm1Y6NWteJKAQBgiAIAwBAFAIAhCgAAQxQAAKbe7z6qLa/zci+pq2vOvGu5qvw0a9OOpxrGlQIAwBAFAIAhCgAAQxQAAIYoAAAMUQAAGKIAADBEAQBgiAIAwBAFAIAhCgAAQxQAAIYoAAAMUQAAGKIAADBEAQBgiAIAwBAFAIAhCgAAQxQAAIYoAAAMUQAAmLiangAARMTV9ATuDVwpAAAMUQAAGKIAADBEAQBgiAIAwLD7CEDd4KvC22Znk+FKAQBgiAIAwBAFAIAhCgAAQxQAAKbe7z4aO3ZsTU/hnsOaIyIza3oC8MKVAgDAEAUAgCEKAABDFAAAhigAAAxRAAAYogAAMEQBAGCIAgDAEAUAgCEKAABDFAAAhigAAAxRAAAYogAAMEQBAGCIAgDAEAUAgCEKAABDFAAAhigAAAxRAAAYogAAMEQBAGCIAgDAEAUAgCEKAABDFAAAhigAAAxRAAAYogAAMEQBAGCIAgDAEAUAgCEKAABDFAAAhigAAAxRAAAYogAAMEQBAGCIAgDAEAUAgCEKAABDFAAAhigAAIzPOedqehIAgNqBKwUAgCEKAABDFAAAhigAAAxRAAAYogAAMEQBAGCIAgDAEAUAgPk/ZYrXuZGHOs4AAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch size of the results 32\n" + ] + } + ], + "source": [ + "# and run it\n", + "last_steps = unroll_batched(keys)\n", + "render(last_steps.observation[0], \"Last observation of env 0\")\n", + "print(\"Batch size of the results\", last_steps.reward.shape[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can benchmark the performance of the batched simulation as well." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "686 µs ± 215 µs per loop (mean ± std. dev. of 5 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "%timeit -n 10 -r 5 unroll_batched(keys).t.block_until_ready()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Which takes roughly twice as long as the single simulation. An increment of $16\\times$, roughly." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And we can scale this up to as many simulations as we want.\n", + "We get to **32768 environments** on a NVIDIA A100 GPU 80Gb.\n", + "\n", + "Feel free to scale this up if you are GPU-richer." + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "# Let's compile the function ahead of time\n", + "num_envs = 32768\n", + "keys = jax.random.split(key, num_envs)\n", + "unroll_batched = jax.jit(jax.vmap(unroll_scan, in_axes=(0, None)), static_argnums=(1,)).lower(keys, 10).compile()" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "8.46 ms ± 1.06 ms per loop (mean ± std. dev. of 5 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "%timeit -n 10 -r 5 unroll_batched(keys).t.block_until_ready()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "That's a $(32768 * 10) / 0.00846 = 387,218,045$ : a bit less than $400M$ frames per second." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Conclusion\n", + "\n", + "This tutorial demonstrated the basic usage and key features of NAVIX, including environment creation, running simulations, performance optimization with JAX.\n", + "We wet from running a single environment in around 27s, to running 32768 environment in roughly around 8ms, with a throughput of **400M fps**.\n", + "In comparison, MiniGrid runs at roughly **3K fps**.\n", + "\n", + "Check the [NAVIX paper](TODO) for more details on the performance of NAVIX.\n", + "For more advanced usage and examples, refer to the [NAVIX examples](https://github.com/epignatelli/navix/examples).\n", + "\n", + "[In the next tutorial](ppo.html) we will see how to train a simple PPO agent on a NAVIX environment." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/examples/getting_started.ipynb b/docs/examples/getting_started.ipynb new file mode 100644 index 0000000..e4298a1 --- /dev/null +++ b/docs/examples/getting_started.ipynb @@ -0,0 +1,507 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# NAVIX 101\n", + "\n", + "This tutorial will guide you through the basics of using NAVIX. You will learn:\n", + "- How to create a `navix.Environment`,\n", + "- A vanilla, suboptimal interaction with it\n", + "- How to `jax.jit` compile the environment for faster execution\n", + "- How to run batched simulations" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Installation\n", + "\n", + "For a full guide on how to install NAVIX and its dependencies, please refer to the [official installation guide](../install/index.html)\n", + "For a quickstart, you can install NAVIX via pip:\n", + "```bash\n", + "pip install navix\n", + "```\n", + "\n", + "This will provide a standard CPU-based JAX installation. If you want to use a GPU, please [install JAX](https://github.com/google/jax/?tab=readme-ov-file#installation) with the appropriate backend." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Creating an Environment\n", + "\n", + "NAVIX provides a variety of MiniGrid environments.\n", + "You can find an exhaustive list [here](../home/environments.html). \n", + "If the environment you are looking for is not listed, please open an new [feature request](https://github.com/epignatelli/navix/issues/new?assignees=&labels=enhancement&template=feature_request.md).\n", + "\n", + "Now, let's create a simple DoorKey environment. The syntax is similar to the usual `gym.make`." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(64, 64, 3)\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGbCAYAAAAr/4yjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAVVklEQVR4nO3dfZCd89nA8evsht1kN2y3SURpJSKkImZMUO+LmmYiZCghHYbElLaUodJpURVvNaqI6RQtbaLVMSJaJRgd0+h0hCmqakrTVBKjSkJ2hTZKk72fP54n1+Nkz6ldk305m89nxoz97Z09157cfPc+57fnlIqiKAIAIqKuvwcAYOAQBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBT5UqVSKuXPnduvYMWPGxKxZs3p8G6tWrYpSqRQLFizo8Z+t5LHHHotSqRSLFi3aIl+vlmz63h977LH+HoUaJApbgQULFkSpVIqnn356i3y9pUuXxty5c+Ott97aIl+Pj+bmm2/eYhGFTYb09wAMfO+++24MGfL/p8rSpUvj8ssvj1mzZkVLS0vZscuWLYu6Oj9r9IWbb745RowY0eXK7LDDDot33303tt122/4ZjJomCnyoxsbGbh/b0NDQi5PUpvXr18ewYcP67Pbq6up69HcGH+RHuq3UrFmzorm5OV599dU47rjjorm5OUaOHBlz5syJjRs3lh37wecU5s6dG1//+tcjImLs2LFRKpWiVCrFqlWrIqLrcwrt7e0xZ86cmDRpUjQ3N8d2220XU6dOjeeee+4jz75ixYqYMWNGtLa2xrBhw+KAAw6IBx98sOKxGzdujIsvvjhGjx4dTU1NMX369HjllVfKjlm+fHmccMIJMXr06GhsbIydd945Zs6cGevWrSs77s4774zJkyfH0KFDo7W1NWbOnNnlax1++OGx1157xTPPPBOHHXZYDBs2LC6++OI45phjYtddd60444EHHhj77rtvfjx//vw48sgjY9SoUdHQ0BB77rln3HLLLWV/ZsyYMfHnP/85fvvb3+bfweGHHx4R1Z9TuOeee3L+ESNGxKmnnhqvvvpq2TE9OS8YnFwpbMU2btwYU6ZMic985jPxve99Lx599NG4/vrrY9y4cfGVr3yl4p/5/Oc/H3/961/jrrvuihtvvDFGjBgREREjR46sePyKFSvivvvuixkzZsTYsWNj9erV8cMf/jDa2trihRdeiE984hM9mnn16tVx0EEHxfr16+O8886Lj3/843HHHXfE9OnTY9GiRXH88ceXHX/11VdHqVSKb3zjG7FmzZqYN29eHHXUUfHHP/4xhg4dGu+//35MmTIl3nvvvTj33HNj9OjR8eqrr8bixYvjrbfeiu233z6/zqWXXhonnXRSfPGLX4w33ngjvv/978dhhx0Wzz77bNnDaGvXro2pU6fGzJkz49RTT40ddtghJk+eHKeddlo89dRTsd9+++WxL7/8cjz55JNx3XXX5dott9wSEydOjOnTp8eQIUPigQceiLPPPjs6OzvjnHPOiYiIefPmxbnnnhvNzc1xySWXRETEDjvsUPV+W7BgQcyePTv222+/uOaaa2L16tVx0003xeOPP95l/o9yXjCIFAx68+fPLyKieOqpp3Lt9NNPLyKiuOKKK8qO3WeffYrJkyeXrUVEcdlll+XH1113XRERxcqVK7vc1i677FKcfvrp+fG///3vYuPGjWXHrFy5smhoaCi77ZUrVxYRUcyfP/+/fi/nn39+ERHF7373u1x75513irFjxxZjxozJ21qyZEkREcVOO+1UvP3223nswoULi4gobrrppqIoiuLZZ58tIqK45557qt7mqlWrivr6+uLqq68uW3/++eeLIUOGlK23tbUVEVHceuutZceuW7euaGhoKC688MKy9e9+97tFqVQqXn755Vxbv359lxmmTJlS7LrrrmVrEydOLNra2rocu+l7X7JkSVEURfH+++8Xo0aNKvbaa6/i3XffzeMWL15cRETx7W9/O9d6cl4wOHn4aCv35S9/uezjQw89NFasWLHFvn5DQ0M+8bxx48ZYu3ZtNDc3xx577BF/+MMfevz1Hnroodh///3jkEMOybXm5uY466yzYtWqVfHCCy+UHX/aaafF8OHD8+MTTzwxdtxxx3jooYciIvJK4JFHHon169dXvM1f/OIX0dnZGSeddFK8+eab+c/o0aNj/PjxsWTJki7f8+zZs8vWNj1stnDhwig+8L5Wd999dxxwwAHxqU99KteGDh2a/75u3bp48803o62tLVasWNHlIa3uePrpp2PNmjVx9tlnlz3XMG3atJgwYULFh956+7xg4BKFrVhjY2OXh30+9rGPRUdHxxa7jc7Ozrjxxhtj/Pjx0dDQECNGjIiRI0fGn/70p4/0P7iXX3459thjjy7rn/70p/PzHzR+/Piyj0ulUuy22275HMjYsWPja1/7Wtx+++0xYsSImDJlSvzgBz8om2358uVRFEWMHz8+Ro4cWfbPiy++GGvWrCm7jZ122qnizp+TTz45XnnllXjiiSciIuKll16KZ555Jk4++eSy4x5//PE46qijoqmpKVpaWmLkyJFx8cUXR0R85PssIirebxMmTOhyn/XFecHA5TmFrVh9fX2v38Z3vvOduPTSS+OMM86IK6+8MlpbW6Ouri7OP//86Ozs7PXb747rr78+Zs2aFb/61a/i17/+dZx33nlxzTXXxJNPPhk777xzdHZ2RqlUiocffrjifdbc3Fz28Qd/0v+gY489NoYNGxYLFy6Mgw46KBYuXBh1dXUxY8aMPOall16Kz372szFhwoS44YYb4pOf/GRsu+228dBDD8WNN97YJ/dZX5wXDFyiQI+VSqVuH7to0aI44ogj4sc//nHZ+ltvvZVPUvfELrvsEsuWLeuy/pe//CU//0HLly8v+7goivjb3/4We++9d9n6pEmTYtKkSfGtb30rli5dGgcffHDceuutcdVVV8W4ceOiKIoYO3Zs7L777j2eeZOmpqY45phj4p577okbbrgh7r777jj00EPLnmx/4IEH4r333ov777+/7CGlzR+iiuj+38Om+2TZsmVx5JFHln1u2bJlXe4ztm4ePqLHmpqaIiK69RvN9fX1ZY+hR/zv1sjNt0J219FHHx2///3v8yGYiIh//etf8aMf/SjGjBkTe+65Z9nxP/3pT+Odd97JjxctWhSvvfZaTJ06NSIi3n777diwYUPZn5k0aVLU1dXFe++9FxH/u+Oqvr4+Lr/88i7fS1EUsXbt2m7Pf/LJJ8c//vGPuP322+O5557r8tDRpp/SP3g769ati/nz53f5Wk1NTd36O9h3331j1KhRceutt+b3FBHx8MMPx4svvhjTpk3r9vwMfq4U6LHJkydHRMQll1wSM2fOjG222SaOPfbYjMUHHXPMMXHFFVfE7Nmz46CDDornn38+fv7zn1fds/9hvvnNb8Zdd90VU6dOjfPOOy9aW1vjjjvuiJUrV8a9997b5bepW1tb45BDDonZs2fH6tWrY968ebHbbrvFmWeeGRERv/nNb+KrX/1qzJgxI3bffffYsGFD/OxnP4v6+vo44YQTIiJi3LhxcdVVV8VFF10Uq1atiuOOOy6GDx8eK1eujF/+8pdx1llnxZw5c7o1/9FHHx3Dhw+POXPmlN3GJp/73Odi2223jWOPPTa+9KUvxT//+c+47bbbYtSoUfHaa6+VHTt58uS45ZZb4qqrrorddtstRo0a1eVKICJim222iWuvvTZmz54dbW1t8YUvfCG3pI4ZMyYuuOCCbt//bAX6b+MTfaXaltSmpqYux1522WXF5qdFbLYltSiK4sorryx22mmnoq6urmx7aqUtqRdeeGGx4447FkOHDi0OPvjg4oknnija2trKtlN2d0tqURTFSy+9VJx44olFS0tL0djYWOy///7F4sWLy47ZtC3zrrvuKi666KJi1KhRxdChQ4tp06aVbf9csWJFccYZZxTjxo0rGhsbi9bW1uKII44oHn300S63e++99xaHHHJI0dTUVDQ1NRUTJkwozjnnnGLZsmV5TFtbWzFx4sT/Ov8pp5xSRERx1FFHVfz8/fffX+y9995FY2NjMWbMmOLaa68tfvKTn3TZBvz6668X06ZNK4YPH15ERN6fm29J3eTuu+8u9tlnn6KhoaFobW0tTjnllOLvf/972TE9OS8YnEpFsdn1MABbLc8pAJBEAYAkCgAkUQAgiQIASRQASN3+5bVNv+wDQG267bbbPvQYVwoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEAa0t8D9LbfLFlScf31jo6K6w2lUsX1vq5nS0tLxfWOKnOXqszdH3oy+0Ca+/jjj+/vET6SJVXO8YF+rtTyOV6r50p3uFIAIIkCAEkUAEiiAEASBQDSoN999J8qOxn2bm+vuP50la+zYQvN011FUVRcb68y90BSy7PXomq7dQb6/e08GZhcKQCQRAGAJAoAJFEAIIkCAGnQ7z4aXuX1Uu6rcvyaKuu3V1hbWOXY1//7SN0ykF7npadqefZaVKv3d63OPdi5UgAgiQIASRQASKIAQBIFANKg333UWWW98quuREyqsn5ThbW5VY5dXGV9XpX1P1RZB+hrrhQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBr0L3NRrXrV3t7j+Srrff0mOwD9wZUCAEkUAEiiAEASBQCSKACQBv3uo39VWf98lfWnqqz/ZwvMAjDQuVIAIIkCAEkUAEiiAEASBQDSoN999E6V9aV9OgVAbXClAEASBQCSKACQRAGANOifaP5YS0vlTxRFxeVSqdrb7/StlipzFwN87oiezT6Q5q5VtXqu1Orcg50rBQCSKACQRAGAJAoAJFEAIA363UcdHR0V19vb2/t4kp6ptgNjoM8dUduz1yLnOFuSKwUAkigAkEQBgCQKACRRACAN+t1Htfp6KbU6d0Rtz16LavX+rtW5BztXCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgDenvAXpbS0tLxfWiKCqud3aWKq7vsEPXtQMPrHybjz9eef2NNyqv19d3Xevp3KVS5bn7Q09mH0hz16paPVdqde7BzpUCAEkUAEiiAEASBQCSKACQBv3uo46Ojorr7e3tFdcr7QSKiLjzzq5rY8dWPnb69MrrJ5xQeb3SZotqOzCqzT2Q1PLstain5/hA4TwZmFwpAJBEAYAkCgAkUQAgDfonmnv6q/Hbb195feLErmvz51c+9vTTK69vt13l9XXruq7V8q/01/LstahW7+9anXuwc6UAQBIFAJIoAJBEAYAkCgCkQb/7qKeq/Yb9ZZd1XVuwoPKxs2dXXq+0ywhgIHGlAEASBQCSKACQRAGAJAoAJLuPNlPtTXYqvY/JihWVj73vvi02DkCfcqUAQBIFAJIoAJBEAYAkCgAku482UxTdXx9S5d5raNhy8wD0JVcKACRRACCJAgBJFABInmjeTLWXudh5565ra9dWPnbjxi03D0BfcqUAQBIFAJIoAJBEAYAkCgAku482U23n0IMPdl27997Kx7a3b7l5APqSKwUAkigAkEQBgCQKACRRACAN+t1HLS0tFdeLKu+mUyqVKq6vX9/922xt7f6x1WypuftDT2YfSHPXqlo9V2p17sHOlQIASRQASKIAQBIFAJIoAJAG/e6jjo6OiuvtA/wFiqrtwBjoc0fU9uy1yDnOluRKAYAkCgAkUQAgiQIASRQASIN+91Gtvl5Krc4dUduz16Javb9rde7BzpUCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAGlIfw/Q21paWiquF0VRcb1UKvXiNN1Xq3NH9Gz2gTR3rarVc6VW5x7sXCkAkEQBgCQKACRRACCJAgBp0O8+6ujoqLje3t7ex5P0TLUdGAN97ojanr0WOcfZklwpAJBEAYAkCgAkUQAgiQIAadDvPqrV10up1bkjanv2WlSr93etzj3YuVIAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIA3p7wF6W0tLS8X1oigqrpdKpV6cpvtqde6Ins0+kOa+4ILbK64vX175+Lpe/JGqs7Pr2vjxlY9dsmRcxfWBfq6sfXNtf4/w0Z3Z3wP0HlcKACRRACCJAgBJFABIogBAGvS7jzo6Oiqut7e39/EkPVNt58hAnzuidmevtsvokUcqrw/pxf96Nmzo/rG1eo5XVfn06R8DY6NWn3KlAEASBQCSKACQRAGAJAoApEG/+2igvM5LT9Xq3BG1O3u11zKqtstom226rlXZeFV1vSeqzVer93eP9ea3OZB2PPUzVwoAJFEAIIkCAEkUAEiiAEAa9LuPoLdU2lFUbSPQdttVXq+vr7y+tobflIza5koBgCQKACRRACCJAgBJFABIdh9BHxg3rvJ6Y2Pl9Sef7Lq2JV4/CT6MKwUAkigAkEQBgCQKACRPNMNHVOklLao9Gfz005XXe/Lk8dbyXjr0L1cKACRRACCJAgBJFABIogBAsvsI/k9nZ+X1DRv6do5qt1ltvq2Gl/noE64UAEiiAEASBQCSKACQRAGAZPcR/J/x43t2fF0v/khVaadRT+cbdHrztZ/sbEquFABIogBAEgUAkigAkEQBgFQqiu6999OZZ57Z27MA0Ituu+22Dz3GlQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBKRVEU/T0EAAODKwUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUA0v8AOC1h/s09XdEAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import navix as nx\n", + "\n", + "# Create the environment\n", + "env = nx.make('Navix-DoorKey-8x8-v0', observation_fn=nx.observations.rgb)\n", + "key = jax.random.PRNGKey(0)\n", + "timestep = env.reset(key)\n", + "\n", + "def render(obs, title):\n", + " plt.imshow(obs)\n", + " plt.title(title)\n", + " plt.axis('off')\n", + " plt.show()\n", + "\n", + "print(timestep.observation.shape)\n", + "render(timestep.observation, \"Initial observation\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Take-home message:**\n", + "1. To sample an initial environment state (`timestep`), we need to pass the `key` (seed) argument to the environment constructor. This is because NAVIX uses JAX's PRNGKey to generate random numbers. You can read more here\n", + "2. `env.reset` returns a [`navix.Timestep`]() object, which contains all the useful information." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## The environment interface\n", + "\n", + "We can now simulate a sequence of actions in the environment. For this example, we'll make the agent take random actions." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGbCAYAAAAr/4yjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAU5ElEQVR4nO3dfYycZbn48Wu2pVtAyrbQVkhhoYWWFgrVEk8ILw2tEEtpgfBSNEIBSzSligIxsajtqiEBE5RTKiBHoRYTKRxyEFGgyBoUOSoRMFDRKi8xBu3LLlWgvLR7//7wcMVhZ2CXX/dltp9P0j/27tOdax+G/e4zc+9MpZRSAgAiommgBwBg8BAFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFGt6KFSuiUqnE5s2bB3qUfvfW1w47iyjswm699daoVCrx2GOP9cvtrV+/PlasWBHPP/98v9zeUPHqq6/GihUr4mc/+9lAj8IuQBToN+vXr4+2tjZR6KVXX3012traakbhi1/8Ymzbtq3/h2LIEgXYSbq6uuK1117r19scPnx4jBw5sl9vk6FNFHhHb7zxRnz5y1+OmTNnxt577x177rlnHH/88dHe3t7t2B/84Acxc+bM2GuvvWLUqFExffr0uO666yLiXw9VnX322RERceKJJ0alUolKpfKuD4k89NBDcfzxx8eee+4ZLS0tcdppp8Xvf//7msdu3rw5zjnnnBg1alTss88+cemll3b7Jr1u3bo47rjjoqWlJd73vvfFlClTYtmyZVXHvP7667F8+fI45JBDorm5OQ444ID4/Oc/H6+//nrVcZVKJZYuXRrf//734/DDD4/m5ua45557YsyYMXHhhRd2m+8f//hHjBw5Mq644ooen9vnn38+xo4dGxERbW1ted5WrFgREbWfU9i+fXt89atfjUmTJkVzc3McdNBBsWzZsm7zH3TQQXHqqafGL37xi/jQhz4UI0eOjIkTJ8b3vve9ev852BUUdlm33HJLiYjym9/8pu4xmzZtKvvtt1+57LLLyg033FCuueaaMmXKlLLbbruVxx9/PI974IEHSkSUOXPmlFWrVpVVq1aVpUuXlrPPPruUUsqf//zn8pnPfKZERFm2bFlZs2ZNWbNmTfnb3/5W97bXrVtXhg8fXiZPnlyuueaa0tbWVvbdd98yevTo8txzz+Vxy5cvLxFRpk+fXubPn1+uv/768vGPf7xERDnvvPPyuKeeeqqMGDGiHH300eW6664rN954Y7niiivKCSeckMfs2LGjnHzyyWWPPfYon/3sZ8tNN91Uli5dWoYPH15OO+20qvkiokydOrWMHTu2tLW1lVWrVpXHH3+8XHTRRaWlpaW8/vrrVcevXr266nz35Ny+/PLL5YYbbigRUc4444w8b08++WTV1/7vFi1aVCKinHXWWWXVqlXl/PPPLxFRTj/99KrjWltby5QpU8r48ePLsmXLyvXXX18++MEPlkqlUp566qm6/10Y2kRhF9aTKGzfvr3bN7fOzs4yfvz4ctFFF+XapZdeWkaNGlW2b99e93PdcccdJSJKe3t7j+abMWNGGTduXNmyZUuuPfnkk6Wpqamcf/75ufbWN8YFCxZU/fslS5aUiMhvoN/4xjdKRJRNmzbVvc01a9aUpqam8vOf/7xq/cYbbywRUR555JFci4jS1NRUnn766apj77///hIR5Z577qlaP+WUU8rEiRPz456e202bNpWIKMuXL+8279uj8MQTT5SIKIsXL6467oorrigRUR566KFca21tLRFRHn744VzbuHFjaW5uLpdffnm322LX4OEj3tGwYcNixIgREfGvx8w7Ojpi+/btcfTRR8dvf/vbPK6lpSVeeeWVWLdu3U653RdffDGeeOKJuOCCC2LMmDG5fuSRR8ZJJ50UP/7xj7v9m0suuaTq409/+tMREXlsS0tLRETcfffd0dXVVfN277jjjpg6dWocdthhsXnz5vwze/bsiIhuD5vNmjUrpk2bVrU2e/bs2HfffeP222/Ptc7Ozli3bl0sXLgw13p6bnvjra/1sssuq1q//PLLIyLi3nvvrVqfNm1aHH/88fnx2LFjY8qUKfHss8++p9un8YkC72r16tVx5JFHxsiRI2OfffaJsWPHxr333htbt27NY5YsWRKTJ0+OuXPnxoQJE+Kiiy6K++677z3f5gsvvBAREVOmTOn2d1OnTo3NmzfHK6+8UrV+6KGHVn08adKkaGpqyt1OCxcujGOPPTYWL14c48ePj3PPPTfWrl1bFYgNGzbE008/HWPHjq36M3ny5IiI2LhxY9VtHHzwwd3mGz58eJx55plx99135+P4d911V7z55ptVUYjo2bntjRdeeCGamprikEMOqVp///vfHy0tLXle33LggQd2+xyjR4+Ozs7O93T7ND5R4B3ddtttccEFF8SkSZPiO9/5Ttx3332xbt26mD17dtU303HjxsUTTzwRP/zhD2PBggXR3t4ec+fOjUWLFg3Y7G9/Anb33XePhx9+OB588ME477zz4ne/+10sXLgwTjrppNixY0dE/Osn9unTp8e6detq/lmyZEm3z1nLueeeG//85z/jJz/5SURErF27Ng477LA46qij8pientud8bXXM2zYsJrrxbv07rKGD/QADG533nlnTJw4Me66666qbzTLly/vduyIESNi/vz5MX/+/Ojq6oolS5bETTfdFF/60pfikEMO6dVv3ra2tkZExB/+8Iduf/fMM8/EvvvuG3vuuWfV+oYNG6p+cv/Tn/4UXV1dcdBBB+VaU1NTzJkzJ+bMmRPXXnttXHXVVXHllVdGe3t7fPjDH45JkybFk08+GXPmzPn/+k3hE044Ifbbb7+4/fbb47jjjouHHnoorrzyyqpjenpue3veurq6YsOGDTF16tRc//vf/x4vvfRSnleox5UC7+itnyT//SfHX/3qV/Hoo49WHbdly5aqj5uamuLII4+MiMiHUN76Jv7SSy+96+3ut99+MWPGjFi9enXV8U899VQ88MADccopp3T7N6tWrar6eOXKlRERMXfu3IiI6Ojo6PZvZsyYUTXjOeecE3/961/j5ptv7nbstm3buj1kVU9TU1OcddZZcc8998SaNWti+/bt3R466um53WOPPSKiZ+ftrfPyzW9+s2r92muvjYiIefPm9Wh+dl2uFIjvfve7NR//v/TSS+PUU0+Nu+66K84444yYN29ePPfcc3HjjTfGtGnT4uWXX85jFy9eHB0dHTF79uyYMGFCvPDCC7Fy5cqYMWNG/sQ6Y8aMGDZsWFx99dWxdevWaG5ujtmzZ8e4ceNqzvX1r3895s6dG8ccc0x84hOfiG3btsXKlStj7733zn36/+65556LBQsWxEc+8pF49NFH47bbbouPfexj+ZDNV77ylXj44Ydj3rx50draGhs3boxvfetbMWHChDjuuOMiIuK8886LtWvXxqc+9alob2+PY489Nnbs2BHPPPNMrF27Nu6///44+uije3ReFy5cGCtXrozly5fH9OnTq35yj4gen9vdd989pk2bFrfffntMnjw5xowZE0cccUQcccQR3W7zqKOOikWLFsW3v/3teOmll2LWrFnx61//OlavXh2nn356nHjiiT2anV3YAO9+YgC9tSW13p+//OUvpaurq1x11VWltbW1NDc3lw984APlRz/6UVm0aFFpbW3Nz3XnnXeWk08+uYwbN66MGDGiHHjggeWTn/xkefHFF6tu8+abby4TJ04sw4YN69H21AcffLAce+yxZffddy+jRo0q8+fPL+vXr6865q1tmevXry9nnXVW2Wuvvcro0aPL0qVLy7Zt2/K4n/70p+W0004r+++/fxkxYkTZf//9y0c/+tHyxz/+serzvfHGG+Xqq68uhx9+eGlubi6jR48uM2fOLG1tbWXr1q15XESUSy65pO7sXV1d5YADDigRUb72ta/V/PuenNtSSvnlL39ZZs6cWUaMGFG1PbXW7ym8+eabpa2trRx88MFlt912KwcccED5whe+UF577bWq41pbW8u8efO6zTVr1qwya9asul8XQ1ulFM8oAfAvnlMAIIkCAEkUAEiiAEASBQCSKACQevzLaxdffHFfzgFAH6v1m/pv50oBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIwwd6gL7W3t5ec72zs7PmeqVS6ctxeqylpaXmel/O3VVn/c066yPqrI/uxeyD5XxHRJxxxhkDPcJ74j7e/xr1vtITrhQASKIAQBIFAJIoAJBEAYA05Hcf1dvJ0NHR0c+T9E4ppeZ6X85d787woTrrj9f7RAMw+67MfZydyZUCAEkUAEiiAEASBQCSKACQhvzuo8H0eim9sbPmfn+d9XNqrC3u5ec4os76jgY9541qV7+Ps3O5UgAgiQIASRQASKIAQBIFANKQ333UqHbUWf9AnfXP1Vk/tc766F7MsrHOeu1XrgEamSsFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJC8zMUgNazO+uN11s+vs96Xb7LjLVJg6HGlAEASBQCSKACQRAGAJAoAJLuPhri/1Vn/zxprN9Q59j/qrL9SZ33kO04EDGauFABIogBAEgUAkigAkEQBgGT3EenNOuu/6OXnsfsIGpcrBQCSKACQRAGAJAoApCH/RHNLS0vN9VJKzfVKZXC8dUyjzh3Ru9kH09yNqlHvK40691DnSgGAJAoAJFEAIIkCAEkUAEhDfvdRZ2dnzfWOjo5+nqR36u3AGOxzRzT27I3IfZydyZUCAEkUAEiiAEASBQCSKACQhvzuo0Z9vZRGnTuisWdvRI16vht17qHOlQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASMMHeoC+1tLSUnO9lFJzvaurUnN9/Pjua8ccU/s2H3mk9vqmTbXXhw3rvtbbuSuV2nMPhN7MPpjmblSNel9p1LmHOlcKACRRACCJAgBJFABIogBAGvK7jzo7O2uud3R01FyvtRMoIuK227qvHXxw7WMXLKi9fuaZtddrbbaotwOj3tyDSSPP3oh6ex8fLNxPBidXCgAkUQAgiQIASRQASEP+iebe/mr83nvXXj/88O5rt9xS+9hFi2qvjxpVe33r1u5rjfwr/Y08eyNq1PPdqHMPda4UAEiiAEASBQCSKACQRAGANOR3H/VWvd+wX768+9qtt9Y+9sILa6/X2mUEMJi4UgAgiQIASRQASKIAQBIFAJLdR29T7012ar2PybPP1j72f/5np40D0K9cKQCQRAGAJAoAJFEAIIkCAMnuo7cppefrw+ucvebmnTcPQH9ypQBAEgUAkigAkEQBgOSJ5rep9zIXEyZ0X9uypfaxO3bsvHkA+pMrBQCSKACQRAGAJAoAJFEAINl99Db1dg7de2/3tf/+79rHdnTsvHkA+pMrBQCSKACQRAGAJAoAJFEAIA353UctLS0110udd9OpVCo11199tee3OWZMz4+tZ2fNPRB6M/tgmrtRNep9pVHnHupcKQCQRAGAJAoAJFEAIIkCAGnI7z7q7Oysud4xyF+gqN4OjME+d0Rjz96I3MfZmVwpAJBEAYAkCgAkUQAgiQIAacjvPmrU10tp1LkjGnv2RtSo57tR5x7qXCkAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkIYP9AB9raWlpeZ6KaXmeqVS6cNpeq5R547o3eyDae5G1aj3lUade6hzpQBAEgUAkigAkEQBgCQKAKQhv/uos7Oz5npHR0c/T9I79XZgDPa5Ixp79kbkPs7O5EoBgCQKACRRACCJAgBJFABIQ373UaO+Xkqjzh3R2LM3okY9340691DnSgGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGANHygB+hrLS0tNddLKTXXK5VKH07Tc406d0TvZh9Mc3/uc/9Vc33DhtrHN/Xhj1RdXd3XDj209rHt7ZNqrg/2+8qWzVsGeoT37uKBHqDvuFIAIIkCAEkUAEiiAEASBQDSkN991NnZWXO9o6OjnyfpnXo7Rwb73BGNO3u9XUb33197fXgf/t+zfXvPj23U+3hdte8+A2NwbNTqV64UAEiiAEASBQCSKACQRAGANOR3Hw2W13nprUadO6JxZ6/3Wkb1dhnttlv3tTobr+qu90a9+Rr1fPdaX36Zg2nH0wBzpQBAEgUAkigAkEQBgCQKAKQhv/sI+kqtHUX1NgKNGlV7fdiw2utbGvhNyWhsrhQASKIAQBIFAJIoAJBEAYBk9xH0g0mTaq+PHFl7/X//t/vaznj9JHg3rhQASKIAQBIFAJIoAJA80QzvUa2XtKj3ZPBjj9Ve782Tx7vKe+kwsFwpAJBEAYAkCgAkUQAgiQIAye4j+D9dXbXXt2/v3znq3Wa9+XYZXuajX7hSACCJAgBJFABIogBAEgUAkt1H8H8OPbR3xzf14Y9UtXYa9Xa+IacvX/vJzqbkSgGAJAoAJFEAIIkCAEkUAEiVUnr23k8XX3xxX88CQB+6+eab3/UYVwoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAECqlFLKQA8BwODgSgGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGA9P8AwhxpshzyH3sAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def unroll(key, num_steps=5):\n", + " timestep = env.reset(key)\n", + " actions = jax.random.randint(key, (num_steps,), 0, env.action_space.n)\n", + "\n", + " steps = [timestep]\n", + " for action in actions:\n", + " timestep = env.step(timestep, action)\n", + " steps.append(timestep)\n", + "\n", + " return steps\n", + "\n", + "# Unroll and print steps\n", + "steps = unroll(key, num_steps=5)\n", + "render(steps[-1].observation, \"Last observation\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Take home message:**\n", + "1. `env.step`, take two arguments: the current state of the environment (the `timestep`), and the action to take, and returns the new environment state.\n", + "2. Despite `env.step` being stochastic, it does not take a `key` argument. This is because NAVIX manages the PRNGKey internally.\n", + "3. You can still sample different environments by sampling different `keys` when creating the environment.\n", + "\n", + "This way of using NAVIX is suboptimal (and probably slower than using `gym`), as it does not take advantage of JAX's JIT compiler. We'll see how to do that in the next section." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## `Optimizing with JAX`\n", + "\n", + "One of the major perks of NAVIX is its performance optimization capabilities through JAX. We can use JAX's `jit` and `vmap` to compile and parallelize our simulation code.\n", + "We can compile the `step` function to make it faster. This is done by using the `jax.jit` decorator." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### JIT Compilation" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGbCAYAAAAr/4yjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAU+UlEQVR4nO3dfWyddd348c/pxjpARjfYJmRQ2GBjg8F0xDuEh4VNiGNsQHgYGmGAI5oxRYGYONStakjABOU3JiA/hTlMZHCTGxEFhtSgyK0SAQMTnfIQY9A9tEyB8bD1e//hzScceo603F3b071eyf7od9fO9enFoe9ePd+2lVJKCQCIiKaBHgCAwUMUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUaHgrVqyISqUSmzdvHuhR+t1b7zv0FVHYhd16661RqVTiscce65fzrV+/PlasWBHPP/98v5xvqHj11VdjxYoV8bOf/WygR2EXIAr0m/Xr10dbW5so9NKrr74abW1tNaPwxS9+MbZt29b/QzFkiQL0ka6urnjttdf69ZzDhw+PkSNH9us5GdpEgX/rjTfeiC9/+csxc+bM2HvvvWPPPfeM448/Ptrb27sd+4Mf/CBmzpwZe+21V4waNSqmT58e1113XUT860tVZ599dkREnHjiiVGpVKJSqbzrl0QeeuihOP7442PPPfeMlpaWOO200+L3v/99zWM3b94c55xzTowaNSr22WefuPTSS7t9kF63bl0cd9xx0dLSEu973/tiypQpsWzZsqpjXn/99Vi+fHkccsgh0dzcHAcccEB8/vOfj9dff73quEqlEkuXLo3vf//7cfjhh0dzc3Pcc889MWbMmLjwwgu7zfePf/wjRo4cGVdccUWPr+3zzz8fY8eOjYiItra2vG4rVqyIiNqvKWzfvj2++tWvxqRJk6K5uTkOOuigWLZsWbf5DzrooDj11FPjF7/4RXzoQx+KkSNHxsSJE+N73/tevf8c7AoKu6xbbrmlRET5zW9+U/eYTZs2lf32269cdtll5YYbbijXXHNNmTJlStltt93K448/nsc98MADJSLKnDlzyqpVq8qqVavK0qVLy9lnn11KKeXPf/5z+cxnPlMioixbtqysWbOmrFmzpvztb3+re+5169aV4cOHl8mTJ5drrrmmtLW1lX333beMHj26PPfcc3nc8uXLS0SU6dOnl/nz55frr7++fPzjHy8RUc4777w87qmnniojRowoRx99dLnuuuvKjTfeWK644opywgkn5DE7duwoJ598ctljjz3KZz/72XLTTTeVpUuXluHDh5fTTjutar6IKFOnTi1jx44tbW1tZdWqVeXxxx8vF110UWlpaSmvv/561fGrV6+uut49ubYvv/xyueGGG0pElDPOOCOv25NPPln1vr/dokWLSkSUs846q6xataqcf/75JSLK6aefXnVca2trmTJlShk/fnxZtmxZuf7668sHP/jBUqlUylNPPVX3vwtDmyjswnoShe3bt3f74NbZ2VnGjx9fLrrooly79NJLy6hRo8r27dvrPtYdd9xRIqK0t7f3aL4ZM2aUcePGlS1btuTak08+WZqamsr555+fa299YFywYEHVv1+yZEmJiPwA+o1vfKNERNm0aVPdc65Zs6Y0NTWVn//851XrN954Y4mI8sgjj+RaRJSmpqby9NNPVx17//33l4go99xzT9X6KaecUiZOnJhv9/Tabtq0qUREWb58ebd53xmFJ554okREWbx4cdVxV1xxRYmI8tBDD+Vaa2triYjy8MMP59rGjRtLc3Nzufzyy7udi12DLx/xbw0bNixGjBgREf/6mnlHR0ds3749jj766Pjtb3+bx7W0tMQrr7wS69at65Pzvvjii/HEE0/EBRdcEGPGjMn1I488Mk466aT48Y9/3O3fXHLJJVVvf/rTn46IyGNbWloiIuLuu++Orq6umue94447YurUqXHYYYfF5s2b88/s2bMjIrp92WzWrFkxbdq0qrXZs2fHvvvuG7fffnuudXZ2xrp162LhwoW51tNr2xtvva+XXXZZ1frll18eERH33ntv1fq0adPi+OOPz7fHjh0bU6ZMiWefffY9nZ/GJwq8q9WrV8eRRx4ZI0eOjH322SfGjh0b9957b2zdujWPWbJkSUyePDnmzp0bEyZMiIsuuijuu+++93zOF154ISIipkyZ0u3vpk6dGps3b45XXnmlav3QQw+tenvSpEnR1NSUu50WLlwYxx57bCxevDjGjx8f5557bqxdu7YqEBs2bIinn346xo4dW/Vn8uTJERGxcePGqnMcfPDB3eYbPnx4nHnmmXH33Xfn1/HvuuuuePPNN6uiENGza9sbL7zwQjQ1NcUhhxxStf7+978/Wlpa8rq+5cADD+z2GKNHj47Ozs73dH4anyjwb912221xwQUXxKRJk+I73/lO3HfffbFu3bqYPXt21QfTcePGxRNPPBE//OEPY8GCBdHe3h5z586NRYsWDdjs73wBdvfdd4+HH344HnzwwTjvvPPid7/7XSxcuDBOOumk2LFjR0T86zP26dOnx7p162r+WbJkSbfHrOXcc8+Nf/7zn/GTn/wkIiLWrl0bhx12WBx11FF5TE+vbV+87/UMGzas5nrxW3p3WcMHegAGtzvvvDMmTpwYd911V9UHmuXLl3c7dsSIETF//vyYP39+dHV1xZIlS+Kmm26KL33pS3HIIYf06jtvW1tbIyLiD3/4Q7e/e+aZZ2LfffeNPffcs2p9w4YNVZ+5/+lPf4qurq446KCDcq2pqSnmzJkTc+bMiWuvvTauuuqquPLKK6O9vT0+/OEPx6RJk+LJJ5+MOXPm/J++U/iEE06I/fbbL26//fY47rjj4qGHHoorr7yy6pieXtveXreurq7YsGFDTJ06Ndf//ve/x0svvZTXFepxp8C/9dZnkm//zPFXv/pVPProo1XHbdmypertpqamOPLIIyMi8ksob30Qf+mll971vPvtt1/MmDEjVq9eXXX8U089FQ888ECccsop3f7NqlWrqt5euXJlRETMnTs3IiI6Ojq6/ZsZM2ZUzXjOOefEX//617j55pu7Hbtt27ZuX7Kqp6mpKc4666y45557Ys2aNbF9+/ZuXzrq6bXdY489IqJn1+2t6/LNb36zav3aa6+NiIh58+b1aH52Xe4UiO9+97s1v/5/6aWXxqmnnhp33XVXnHHGGTFv3rx47rnn4sYbb4xp06bFyy+/nMcuXrw4Ojo6Yvbs2TFhwoR44YUXYuXKlTFjxoz8jHXGjBkxbNiwuPrqq2Pr1q3R3Nwcs2fPjnHjxtWc6+tf/3rMnTs3jjnmmPjEJz4R27Zti5UrV8bee++d+/Tf7rnnnosFCxbERz7ykXj00Ufjtttui4997GP5JZuvfOUr8fDDD8e8efOitbU1Nm7cGN/61rdiwoQJcdxxx0VExHnnnRdr166NT33qU9He3h7HHnts7NixI5555plYu3Zt3H///XH00Uf36LouXLgwVq5cGcuXL4/p06dXfeYeET2+trvvvntMmzYtbr/99pg8eXKMGTMmjjjiiDjiiCO6nfOoo46KRYsWxbe//e146aWXYtasWfHrX/86Vq9eHaeffnqceOKJPZqdXdgA735iAL21JbXen7/85S+lq6urXHXVVaW1tbU0NzeXD3zgA+VHP/pRWbRoUWltbc3HuvPOO8vJJ59cxo0bV0aMGFEOPPDA8slPfrK8+OKLVee8+eaby8SJE8uwYcN6tD31wQcfLMcee2zZfffdy6hRo8r8+fPL+vXrq455a1vm+vXry1lnnVX22muvMnr06LJ06dKybdu2PO6nP/1pOe2008r+++9fRowYUfbff//y0Y9+tPzxj3+serw33nijXH311eXwww8vzc3NZfTo0WXmzJmlra2tbN26NY+LiHLJJZfUnb2rq6sccMABJSLK1772tZp/35NrW0opv/zlL8vMmTPLiBEjqran1vo+hTfffLO0tbWVgw8+uOy2227lgAMOKF/4whfKa6+9VnVca2trmTdvXre5Zs2aVWbNmlX3/WJoq5TiFSUA/sVrCgAkUQAgiQIASRQASKIAQBIFAFKPv3nt4osv3plzALCT1fpO/XdypwBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKAKThAz3Aztbe3l5zvbOzs+Z6pVLZmeP0WEtLS831wT53RO9mrzf3jl6ec1gvj6/ljDPO6INH6X+e4/2vUZ8rPeFOAYAkCgAkUQAgiQIASRQASEN+91G9nQwdHR39PEnvlFJqrg/2uSP6ZvYP9PKcj/fy+KHEc5y+5E4BgCQKACRRACCJAgBJFABIQ3730WD6eSm90ahzR/TN7J/r5fHn/5/P2Lga9bnSqHMPde4UAEiiAEASBQCSKACQRAGANOR3HzG4vb/O+ql98Dh/6+VjAO4UAHgbUQAgiQIASRQASF5oZkCdU2d9dB88zv/r5WMA7hQAeBtRACCJAgBJFABIogBAsvuIPtdVZ73Wk21xH52z1uPcUOfYN/vonDAUuVMAIIkCAEkUAEiiAEASBQCS3Uf0uXq7ez5UY63eL9nZ2Mtz1nqc/6hz7C96+diwK3GnAEASBQCSKACQRAGAJAoAJLuP6HMj6qw/XmPtiDrHll6es1Jj7ZVePgbgTgGAtxEFAJIoAJBEAYA05F9obmlpqbleSu2XMiuVWi9Z9r9GnTui/uxRY/YdO3Hukb1cb1SN+lxp1LmHOncKACRRACCJAgBJFABIogBAGvK7jzo7O2uud3R09PMkvVNvB8ZgnzuisWdvRJ7j9CV3CgAkUQAgiQIASRQASKIAQBryu48a9eelNOrcEY09eyNq1OvdqHMPde4UAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEAaPtAD7GwtLS0110spNde7uio118eP7752zDG1z/nII7XXN22qvT5sWPe13s5dqdSeeyD0ZvbBNHejatTnSqPOPdS5UwAgiQIASRQASKIAQBIFANKQ333U2dlZc72jo6Pmeq2dQBERt93Wfe3gg2sfu2BB7fUzz6y9XmuzRb0dGPXmHkwaefZG1Nvn+GDheTI4uVMAIIkCAEkUAEiiAEAa8i809/Zb4/feu/b64Yd3X7vlltrHLlpUe33UqNrrW7d2X2vkb+lv5NkbUaNe70ade6hzpwBAEgUAkigAkEQBgCQKAKQhv/uot+p9h/3y5d3Xbr219rEXXlh7vdYuI4DBxJ0CAEkUAEiiAEASBQCSKACQ7D56h3q/ZKfW7zF59tnax/7Xf/XZOAD9yp0CAEkUAEiiAEASBQCSKACQ7D56h1J6vj68ztVrbu67eQD6kzsFAJIoAJBEAYAkCgAkLzS/Q70fczFhQve1LVtqH7tjR9/NA9Cf3CkAkEQBgCQKACRRACCJAgDJ7qN3qLdz6N57u6/953/WPrajo+/mAehP7hQASKIAQBIFAJIoAJBEAYA05HcftbS01FwvdX6bTqVSqbn+6qs9P+eYMT0/tp6+mnsg9Gb2wTR3o2rU50qjzj3UuVMAIIkCAEkUAEiiAEASBQDSkN991NnZWXO9Y5D/gKJ6OzAG+9wRjT17I/Icpy+5UwAgiQIASRQASKIAQBIFANKQ333UqD8vpVHnjmjs2RtRo17vRp17qHOnAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEAaPtAD7GwtLS0110spNdcrlcpOnKbnGnXuiN7NPpjmblSN+lxp1LmHOncKACRRACCJAgBJFABIogBAGvK7jzo7O2uud3R09PMkvVNvB8ZgnzuisWdvRJ7j9CV3CgAkUQAgiQIASRQASKIAQBryu48a9eelNOrcEY09eyNq1OvdqHMPde4UAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEjDB3qAna2lpaXmeiml5nqlUtmJ0/Rco84d0bvZB9Pcn/vc/6+5vmFD7eObduKnVF1d3dcOPbT2se3tk2quD/bnypbNWwZ6hPfu4oEeYOdxpwBAEgUAkigAkEQBgCQKAKQhv/uos7Oz5npHR0c/T9I79XaODPa5Ixp39nq7jO6/v/b68J34f8/27T0/tlGf43XVfvoMjMGxUatfuVMAIIkCAEkUAEiiAEASBQDSkN99NFh+zktvNercEY07e72fZVRvl9Fuu3Vfq7Pxqu56b9Sbr1Gvd6/tzHdzMO14GmDuFABIogBAEgUAkigAkEQBgDTkdx/BzlJrR1G9jUCjRtVeHzas9vqWBv6lZDQ2dwoAJFEAIIkCAEkUAEiiAECy+wj6waRJtddHjqy9/t//3X2tL35+ErwbdwoAJFEAIIkCAEkUAEheaIb3qNaPtKj3YvBjj9Ve782Lx7vK79JhYLlTACCJAgBJFABIogBAEgUAkt1H8L+6umqvb9/ev3PUO2e9+XYZfsxHv3CnAEASBQCSKACQRAGAJAoAJLuP4H8demjvjm/aiZ9S1dpp1Nv5hpyd+bOf7GxK7hQASKIAQBIFAJIoAJBEAYBUKaVnv/vp4osv3tmzALAT3Xzzze96jDsFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgVUopZaCHAGBwcKcAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQPofTwdpsxkEwUoAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "@jax.jit\n", + "def env_step_jit(timestep, action):\n", + " return env.step(timestep, action)\n", + "\n", + "def unroll_jit_step(key, num_steps=10):\n", + " timestep = env.reset(key)\n", + " actions = jax.random.randint(key, (num_steps,), 0, env.action_space.n)\n", + "\n", + " steps = [timestep]\n", + " for action in actions:\n", + " timestep = env_step_jit(timestep, action)\n", + " steps.append(timestep)\n", + "\n", + " return steps\n", + "\n", + "# Example usage\n", + "steps = unroll_jit_step(key, num_steps=10)\n", + "render(steps[-1].observation, \"Last observation\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's compare the two head to head." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "27.4 s ± 130 ms per loop (mean ± std. dev. of 3 runs, 1 loop each)\n" + ] + } + ], + "source": [ + "%timeit -n 1 -r 3 unroll(key, num_steps=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "328 ns ± 153 ns per loop (mean ± std. dev. of 3 runs, 1 loop each)\n" + ] + } + ], + "source": [ + "%timeit -n 1 -r 3 lambda: unroll_jit_step(key, num_steps=10)[-1].block_until_ready()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice that it's roughly in the order of $10^9$ times faster compared to its unjitted counterpart.\n", + "\n", + "But that's not the end of the story.\n", + "We can go even further and `jit` the whole simulation loop, which improves performance even more." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGbCAYAAAAr/4yjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAU+UlEQVR4nO3dfWyddd348c/pxjpARjfYJmRQ2GBjg8F0xDuEh4VNiGNsQHgYGmGAI5oxRYGYONStakjABOU3JiA/hTlMZHCTGxEFhtSgyK0SAQMTnfIQY9A9tEyB8bD1e//hzScceo603F3b071eyf7od9fO9enFoe9ePd+2lVJKCQCIiKaBHgCAwUMUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUaHgrVqyISqUSmzdvHuhR+t1b7zv0FVHYhd16661RqVTiscce65fzrV+/PlasWBHPP/98v5xvqHj11VdjxYoV8bOf/WygR2EXIAr0m/Xr10dbW5so9NKrr74abW1tNaPwxS9+MbZt29b/QzFkiQL0ka6urnjttdf69ZzDhw+PkSNH9us5GdpEgX/rjTfeiC9/+csxc+bM2HvvvWPPPfeM448/Ptrb27sd+4Mf/CBmzpwZe+21V4waNSqmT58e1113XUT860tVZ599dkREnHjiiVGpVKJSqbzrl0QeeuihOP7442PPPfeMlpaWOO200+L3v/99zWM3b94c55xzTowaNSr22WefuPTSS7t9kF63bl0cd9xx0dLSEu973/tiypQpsWzZsqpjXn/99Vi+fHkccsgh0dzcHAcccEB8/vOfj9dff73quEqlEkuXLo3vf//7cfjhh0dzc3Pcc889MWbMmLjwwgu7zfePf/wjRo4cGVdccUWPr+3zzz8fY8eOjYiItra2vG4rVqyIiNqvKWzfvj2++tWvxqRJk6K5uTkOOuigWLZsWbf5DzrooDj11FPjF7/4RXzoQx+KkSNHxsSJE+N73/tevf8c7AoKu6xbbrmlRET5zW9+U/eYTZs2lf32269cdtll5YYbbijXXHNNmTJlStltt93K448/nsc98MADJSLKnDlzyqpVq8qqVavK0qVLy9lnn11KKeXPf/5z+cxnPlMioixbtqysWbOmrFmzpvztb3+re+5169aV4cOHl8mTJ5drrrmmtLW1lX333beMHj26PPfcc3nc8uXLS0SU6dOnl/nz55frr7++fPzjHy8RUc4777w87qmnniojRowoRx99dLnuuuvKjTfeWK644opywgkn5DE7duwoJ598ctljjz3KZz/72XLTTTeVpUuXluHDh5fTTjutar6IKFOnTi1jx44tbW1tZdWqVeXxxx8vF110UWlpaSmvv/561fGrV6+uut49ubYvv/xyueGGG0pElDPOOCOv25NPPln1vr/dokWLSkSUs846q6xataqcf/75JSLK6aefXnVca2trmTJlShk/fnxZtmxZuf7668sHP/jBUqlUylNPPVX3vwtDmyjswnoShe3bt3f74NbZ2VnGjx9fLrrooly79NJLy6hRo8r27dvrPtYdd9xRIqK0t7f3aL4ZM2aUcePGlS1btuTak08+WZqamsr555+fa299YFywYEHVv1+yZEmJiPwA+o1vfKNERNm0aVPdc65Zs6Y0NTWVn//851XrN954Y4mI8sgjj+RaRJSmpqby9NNPVx17//33l4go99xzT9X6KaecUiZOnJhv9/Tabtq0qUREWb58ebd53xmFJ554okREWbx4cdVxV1xxRYmI8tBDD+Vaa2triYjy8MMP59rGjRtLc3Nzufzyy7udi12DLx/xbw0bNixGjBgREf/6mnlHR0ds3749jj766Pjtb3+bx7W0tMQrr7wS69at65Pzvvjii/HEE0/EBRdcEGPGjMn1I488Mk466aT48Y9/3O3fXHLJJVVvf/rTn46IyGNbWloiIuLuu++Orq6umue94447YurUqXHYYYfF5s2b88/s2bMjIrp92WzWrFkxbdq0qrXZs2fHvvvuG7fffnuudXZ2xrp162LhwoW51tNr2xtvva+XXXZZ1frll18eERH33ntv1fq0adPi+OOPz7fHjh0bU6ZMiWefffY9nZ/GJwq8q9WrV8eRRx4ZI0eOjH322SfGjh0b9957b2zdujWPWbJkSUyePDnmzp0bEyZMiIsuuijuu+++93zOF154ISIipkyZ0u3vpk6dGps3b45XXnmlav3QQw+tenvSpEnR1NSUu50WLlwYxx57bCxevDjGjx8f5557bqxdu7YqEBs2bIinn346xo4dW/Vn8uTJERGxcePGqnMcfPDB3eYbPnx4nHnmmXH33Xfn1/HvuuuuePPNN6uiENGza9sbL7zwQjQ1NcUhhxxStf7+978/Wlpa8rq+5cADD+z2GKNHj47Ozs73dH4anyjwb912221xwQUXxKRJk+I73/lO3HfffbFu3bqYPXt21QfTcePGxRNPPBE//OEPY8GCBdHe3h5z586NRYsWDdjs73wBdvfdd4+HH344HnzwwTjvvPPid7/7XSxcuDBOOumk2LFjR0T86zP26dOnx7p162r+WbJkSbfHrOXcc8+Nf/7zn/GTn/wkIiLWrl0bhx12WBx11FF5TE+vbV+87/UMGzas5nrxW3p3WcMHegAGtzvvvDMmTpwYd911V9UHmuXLl3c7dsSIETF//vyYP39+dHV1xZIlS+Kmm26KL33pS3HIIYf06jtvW1tbIyLiD3/4Q7e/e+aZZ2LfffeNPffcs2p9w4YNVZ+5/+lPf4qurq446KCDcq2pqSnmzJkTc+bMiWuvvTauuuqquPLKK6O9vT0+/OEPx6RJk+LJJ5+MOXPm/J++U/iEE06I/fbbL26//fY47rjj4qGHHoorr7yy6pieXtveXreurq7YsGFDTJ06Ndf//ve/x0svvZTXFepxp8C/9dZnkm//zPFXv/pVPProo1XHbdmypertpqamOPLIIyMi8ksob30Qf+mll971vPvtt1/MmDEjVq9eXXX8U089FQ888ECccsop3f7NqlWrqt5euXJlRETMnTs3IiI6Ojq6/ZsZM2ZUzXjOOefEX//617j55pu7Hbtt27ZuX7Kqp6mpKc4666y45557Ys2aNbF9+/ZuXzrq6bXdY489IqJn1+2t6/LNb36zav3aa6+NiIh58+b1aH52Xe4UiO9+97s1v/5/6aWXxqmnnhp33XVXnHHGGTFv3rx47rnn4sYbb4xp06bFyy+/nMcuXrw4Ojo6Yvbs2TFhwoR44YUXYuXKlTFjxoz8jHXGjBkxbNiwuPrqq2Pr1q3R3Nwcs2fPjnHjxtWc6+tf/3rMnTs3jjnmmPjEJz4R27Zti5UrV8bee++d+/Tf7rnnnosFCxbERz7ykXj00Ufjtttui4997GP5JZuvfOUr8fDDD8e8efOitbU1Nm7cGN/61rdiwoQJcdxxx0VExHnnnRdr166NT33qU9He3h7HHnts7NixI5555plYu3Zt3H///XH00Uf36LouXLgwVq5cGcuXL4/p06dXfeYeET2+trvvvntMmzYtbr/99pg8eXKMGTMmjjjiiDjiiCO6nfOoo46KRYsWxbe//e146aWXYtasWfHrX/86Vq9eHaeffnqceOKJPZqdXdgA735iAL21JbXen7/85S+lq6urXHXVVaW1tbU0NzeXD3zgA+VHP/pRWbRoUWltbc3HuvPOO8vJJ59cxo0bV0aMGFEOPPDA8slPfrK8+OKLVee8+eaby8SJE8uwYcN6tD31wQcfLMcee2zZfffdy6hRo8r8+fPL+vXrq455a1vm+vXry1lnnVX22muvMnr06LJ06dKybdu2PO6nP/1pOe2008r+++9fRowYUfbff//y0Y9+tPzxj3+serw33nijXH311eXwww8vzc3NZfTo0WXmzJmlra2tbN26NY+LiHLJJZfUnb2rq6sccMABJSLK1772tZp/35NrW0opv/zlL8vMmTPLiBEjqran1vo+hTfffLO0tbWVgw8+uOy2227lgAMOKF/4whfKa6+9VnVca2trmTdvXre5Zs2aVWbNmlX3/WJoq5TiFSUA/sVrCgAkUQAgiQIASRQASKIAQBIFAFKPv3nt4osv3plzALCT1fpO/XdypwBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKAKThAz3Aztbe3l5zvbOzs+Z6pVLZmeP0WEtLS831wT53RO9mrzf3jl6ec1gvj6/ljDPO6INH6X+e4/2vUZ8rPeFOAYAkCgAkUQAgiQIASRQASEN+91G9nQwdHR39PEnvlFJqrg/2uSP6ZvYP9PKcj/fy+KHEc5y+5E4BgCQKACRRACCJAgBJFABIQ3730WD6eSm90ahzR/TN7J/r5fHn/5/P2Lga9bnSqHMPde4UAEiiAEASBQCSKACQRAGANOR3HzG4vb/O+ql98Dh/6+VjAO4UAHgbUQAgiQIASRQASF5oZkCdU2d9dB88zv/r5WMA7hQAeBtRACCJAgBJFABIogBAsvuIPtdVZ73Wk21xH52z1uPcUOfYN/vonDAUuVMAIIkCAEkUAEiiAEASBQCS3Uf0uXq7ez5UY63eL9nZ2Mtz1nqc/6hz7C96+diwK3GnAEASBQCSKACQRAGAJAoAJLuP6HMj6qw/XmPtiDrHll6es1Jj7ZVePgbgTgGAtxEFAJIoAJBEAYA05F9obmlpqbleSu2XMiuVWi9Z9r9GnTui/uxRY/YdO3Hukb1cb1SN+lxp1LmHOncKACRRACCJAgBJFABIogBAGvK7jzo7O2uud3R09PMkvVNvB8ZgnzuisWdvRJ7j9CV3CgAkUQAgiQIASRQASKIAQBryu48a9eelNOrcEY09eyNq1OvdqHMPde4UAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEAaPtAD7GwtLS0110spNde7uio118eP7752zDG1z/nII7XXN22qvT5sWPe13s5dqdSeeyD0ZvbBNHejatTnSqPOPdS5UwAgiQIASRQASKIAQBIFANKQ333U2dlZc72jo6Pmeq2dQBERt93Wfe3gg2sfu2BB7fUzz6y9XmuzRb0dGPXmHkwaefZG1Nvn+GDheTI4uVMAIIkCAEkUAEiiAEAa8i809/Zb4/feu/b64Yd3X7vlltrHLlpUe33UqNrrW7d2X2vkb+lv5NkbUaNe70ade6hzpwBAEgUAkigAkEQBgCQKAKQhv/uot+p9h/3y5d3Xbr219rEXXlh7vdYuI4DBxJ0CAEkUAEiiAEASBQCSKACQ7D56h3q/ZKfW7zF59tnax/7Xf/XZOAD9yp0CAEkUAEiiAEASBQCSKACQ7D56h1J6vj68ztVrbu67eQD6kzsFAJIoAJBEAYAkCgAkLzS/Q70fczFhQve1LVtqH7tjR9/NA9Cf3CkAkEQBgCQKACRRACCJAgDJ7qN3qLdz6N57u6/953/WPrajo+/mAehP7hQASKIAQBIFAJIoAJBEAYA05HcftbS01FwvdX6bTqVSqbn+6qs9P+eYMT0/tp6+mnsg9Gb2wTR3o2rU50qjzj3UuVMAIIkCAEkUAEiiAEASBQDSkN991NnZWXO9Y5D/gKJ6OzAG+9wRjT17I/Icpy+5UwAgiQIASRQASKIAQBIFANKQ333UqD8vpVHnjmjs2RtRo17vRp17qHOnAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEAaPtAD7GwtLS0110spNdcrlcpOnKbnGnXuiN7NPpjmblSN+lxp1LmHOncKACRRACCJAgBJFABIogBAGvK7jzo7O2uud3R09PMkvVNvB8ZgnzuisWdvRJ7j9CV3CgAkUQAgiQIASRQASKIAQBryu48a9eelNOrcEY09eyNq1OvdqHMPde4UAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEjDB3qAna2lpaXmeiml5nqlUtmJ0/Rco84d0bvZB9Pcn/vc/6+5vmFD7eObduKnVF1d3dcOPbT2se3tk2quD/bnypbNWwZ6hPfu4oEeYOdxpwBAEgUAkigAkEQBgCQKAKQhv/uos7Oz5npHR0c/T9I79XaODPa5Ixp39nq7jO6/v/b68J34f8/27T0/tlGf43XVfvoMjMGxUatfuVMAIIkCAEkUAEiiAEASBQDSkN99NFh+zktvNercEY07e72fZVRvl9Fuu3Vfq7Pxqu56b9Sbr1Gvd6/tzHdzMO14GmDuFABIogBAEgUAkigAkEQBgDTkdx/BzlJrR1G9jUCjRtVeHzas9vqWBv6lZDQ2dwoAJFEAIIkCAEkUAEiiAECy+wj6waRJtddHjqy9/t//3X2tL35+ErwbdwoAJFEAIIkCAEkUAEheaIb3qNaPtKj3YvBjj9Ve782Lx7vK79JhYLlTACCJAgBJFABIogBAEgUAkt1H8L+6umqvb9/ev3PUO2e9+XYZfsxHv3CnAEASBQCSKACQRAGAJAoAJLuP4H8demjvjm/aiZ9S1dpp1Nv5hpyd+bOf7GxK7hQASKIAQBIFAJIoAJBEAYBUKaVnv/vp4osv3tmzALAT3Xzzze96jDsFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgVUopZaCHAGBwcKcAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQPofTwdpsxkEwUoAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def unroll_scan(key, num_steps=10):\n", + " timestep = env.reset(key)\n", + " actions = jax.random.randint(key, (num_steps,), 0, env.action_space.n)\n", + "\n", + " timestep, _ = jax.lax.scan(\n", + " lambda timestep, action: (env.step(timestep, action), ()),\n", + " timestep,\n", + " actions,\n", + " unroll=10,\n", + " )\n", + " return timestep\n", + "\n", + "\n", + "# Example usage\n", + "unroll_jit_loop = jax.jit(unroll_scan, static_argnums=(1,))\n", + "timestep = unroll_jit_loop(key, num_steps=10)\n", + "render(timestep.observation, \"Last observation\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "40.5 ms ± 1.49 ms per loop (mean ± std. dev. of 5 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "%timeit -n 10 -r 5 unroll_jit_step(key, num_steps=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "353 µs ± 71.3 µs per loop (mean ± std. dev. of 5 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "%timeit -n 10 -r 5 unroll_jit_loop(key, num_steps=10).t.block_until_ready()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We improved the performance by three more orders of magnitude, and we are at $10^12$.\n", + "This is because we are now compiling the whole simulation loop, not just the `step` function.\n", + "\n", + "That's still not the end of the story. We can improve the performance even more by using `jax.vmap` to parallelize multiple environment simulations." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Batched environments\n", + "\n", + "We can run multiple simulations in parallel using `vmap`." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "# Let's compile the function ahead of time\n", + "num_envs = 32\n", + "keys = jax.random.split(key, num_envs)\n", + "unroll_batched = jax.jit(jax.vmap(unroll_scan, in_axes=(0, None)), static_argnums=(1,)).lower(keys, 10).compile()" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGbCAYAAAAr/4yjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAZIklEQVR4nO3de3DU5b3H8c8mIQm3ZEHCrcQAAcIdBKYdBcwQKIIYCIerHSFIoXWAchHqTKFtiG3pIC3oAAJ6RoyhMwUtI4MUSSxx8HZaGYU5QGlRgbYOlEvC/Rr2OX84+R6X/a3s0tx5v2b4I09+2d+TZ0Pe+e0+2ficc04AAEiKqekJAABqD6IAADBEAQBgiAIAwBAFAIAhCgAAQxQAAIYoAAAMUQAAGKKAqCxdulQ+n09nzpyp6alUu4rPvTYrLy/XM888o9TUVMXExCgnJ6emp4Q6hihUk1dffVU+n0979+6tlvMdOnRIS5cu1bFjx6rlfPXFlStXtHTpUr377rs1PZW78sorr2jFihUaP368CgoKtGDBgpqeUlQ+/PBDDRo0SI0aNVLr1q01d+5cXbp0qaandU8hCvXUoUOHlJ+fTxSidOXKFeXn53tG4ac//amuXr1a/ZOKwu7du/Wtb31Lq1at0pQpU5SZmVnTU4rYvn37NHToUF25ckUrV67UjBkz9NJLL2nChAk1PbV7SlxNTwC4G4FAQDdu3FBiYmK1nTMuLk5xcbX7v8ypU6fk9/trehp3ZfHixWrWrJneffddJSUlSZLat2+vmTNnqqioSMOHD6/hGd4buFKoRW7cuKGf//zn6t+/v5KTk9W4cWMNHjxYJSUlIcf+/ve/V//+/dW0aVMlJSWpV69eeuGFFyR99VBVxU9XQ4YMkc/nk8/nu+NDIrt379bgwYPVuHFj+f1+jRkzRn/96189jz1z5owmTpyopKQk3XfffZo3b56uXbsWdExxcbEGDRokv9+vJk2aKCMjQ4sXLw465vr168rLy1OnTp2UkJCg1NRUPfPMM7p+/XrQcT6fT3PmzNHvfvc79ejRQwkJCdq+fbuaN2+uJ598MmR+Fy5cUGJiohYtWhTx2h47dkwpKSmSpPz8fFu3pUuXSvJ+TqG8vFy/+MUvlJ6eroSEBLVv316LFy8OmX/79u312GOP6f3339e3v/1tJSYmqmPHjnrttdfC3R1BLl++rIULFyo1NVUJCQnKyMjQb37zG1W8yPGxY8fk8/lUUlKigwcPRnyf79y50+7zpk2batSoUTp48GDQMdOmTVOTJk305ZdfKicnR02aNFFKSooWLVqkW7duSZJu3rwZ8X3h5cKFCyouLtYTTzxhQZCkqVOnqkmTJtqyZUtE64RK4FAtNm7c6CS5jz/+OOwxp0+fdm3atHFPP/20W7dunXvuuedcRkaGa9Cggfv000/tuKKiIifJDR061K1du9atXbvWzZkzx02YMME559znn3/u5s6d6yS5xYsXu8LCQldYWOhOnjwZ9tzFxcUuLi7OdenSxT333HMuPz/ftWjRwjVr1swdPXrUjsvLy3OSXK9evVx2drZbs2aNe+KJJ5wkN2XKFDvuwIEDLj4+3g0YMMC98MILbv369W7RokXu4YcftmNu3brlhg8f7ho1auTmz5/vNmzY4ObMmePi4uLcmDFjguYnyXXr1s2lpKS4/Px8t3btWvfpp5+66dOnO7/f765fvx50fEFBQdB6R7K2ly5dcuvWrXOS3NixY23d9u/fH/S5f11ubq6T5MaPH+/Wrl3rpk6d6iS5nJycoOPS0tJcRkaGa9WqlVu8eLFbs2aN69evn/P5fO7AgQNh7xfnnAsEAi4rK8v5fD43Y8YMt2bNGpedne0kufnz59vcCwsLXdeuXV27du0ius9fe+015/P53IgRI9zq1avd8uXLXfv27Z3f7w+6z3Nzc11iYqLr0aOHmz59ulu3bp0bN26ck+RefPFFOy7S+8LL+++/7yS5zZs3h7xv0KBBrl+/ft+4Rqg8RKGaRBKF8vLykP9QZWVlrlWrVm769Ok2Nm/ePJeUlOTKy8vD3tbrr7/uJLmSkpKI5te3b1/XsmVLd/bsWRvbv3+/i4mJcVOnTrWxim+Mo0ePDvr4WbNmOUn2DXTVqlVOkjt9+nTYcxYWFrqYmBj33nvvBY2vX7/eSXIffPCBjUlyMTEx7uDBg0HH7tq1y0ly27dvDxp/9NFHXceOHe3tSNf29OnTTpLLy8sLme/tUdi3b5+T5GbMmBF03KJFi5wkt3v3bhtLS0tzktyePXts7NSpUy4hIcEtXLgw5Fxf9+abbzpJ7pe//GXQ+Pjx453P53OfffaZjWVmZroePXp84+0559zFixed3+93M2fODBo/efKkS05ODhqvCN+zzz4bdOwDDzzg+vfvb29Hel94qfh6/fr6VJgwYYJr3br1HT8nVA4ePqpFYmNjFR8fL+mrx8xLS0tVXl6uAQMG6JNPPrHj/H6/Ll++rOLi4ko574kTJ7Rv3z5NmzZNzZs3t/HevXvru9/9rv74xz+GfMzs2bOD3v7Rj34kSXZsxePa27ZtUyAQ8Dzv66+/rm7duqlr1646c+aM/cvKypKkkIfNMjMz1b1796CxrKwstWjRQps3b7axsrIyFRcXa9KkSTYW6dpGo+Jzffrpp4PGFy5cKEnasWNH0Hj37t01ePBgezslJUUZGRn64osv7nie2NhYzZ07N+Q8zjnt3Lkz6rkXFxfr3Llzevzxx4PWPjY2Vt/5znc8H7J86qmngt4ePHhw0NwjvS+8VDyBn5CQEPK+xMTEWv8Ef31CFGqZgoIC9e7dW4mJibrvvvuUkpKiHTt26Pz583bMrFmz1KVLF40cOVLt2rXT9OnT9fbbb9/1OY8fPy5JysjICHlft27ddObMGV2+fDlovHPnzkFvp6enKyYmxnY7TZo0SQMHDtSMGTPUqlUrTZ48WVu2bAkKxJEjR3Tw4EGlpKQE/evSpYukr540/boOHTqEzC8uLk7jxo3Ttm3b7HH8rVu36ubNmyHfiCJZ22gcP35cMTEx6tSpU9B469at5ff7bV0r3H///SG30axZM5WVld3xPG3btlXTpk2Dxrt162bvj9aRI0ckffWN/Pb1LyoqCln7xMREe74l3NyjuS9u17BhQ0kKeS5Gkq5du2bvR9Wr3Vsp7jGbNm3StGnTlJOTox//+Mdq2bKlYmNj9etf/1qff/65HdeyZUvt27dPu3bt0s6dO7Vz505t3LhRU6dOVUFBQY3M/fYnYBs2bKg9e/aopKREO3bs0Ntvv63NmzcrKytLRUVFio2NVSAQUK9evbRy5UrP20xNTQ25TS+TJ0/Whg0btHPnTuXk5GjLli3q2rWr+vTpY8dEuraV8bmHExsb6znuauAv4lbEubCwUK1btw55/+27rMLN/XaR3Bde2rRpI+mrq9bbnThxQm3bto3o/PjPEYVa5I033lDHjh21devWoG80eXl5IcfGx8crOztb2dnZCgQCmjVrljZs2KCf/exn6tSpU1S/eZuWliZJ+tvf/hbyvsOHD6tFixZq3Lhx0PiRI0eCfnL/7LPPFAgE1L59exuLiYnR0KFDNXToUK1cuVLLli3TkiVLVFJSomHDhik9PV379+/X0KFD/6PfFH744YfVpk0bbd68WYMGDdLu3bu1ZMmSoGMiXdto1y0QCOjIkSP2U7sk/fvf/9a5c+dsXf9TaWlpeuedd3Tx4sWgq4XDhw/b+6OVnp4u6asfMIYNG1Yp85Qiuy+89OzZU3Fxcdq7d68mTpxo4zdu3NC+ffuCxlC1ePioFqn4aezrPzn++c9/1kcffRR03NmzZ4PejomJUe/evSX9/+V3xTfxc+fO3fG8bdq0Ud++fVVQUBB0/IEDB1RUVKRHH3005GPWrl0b9Pbq1aslSSNHjpQklZaWhnxM3759g+Y4ceJEffnll3r55ZdDjr169WrIQ1bhxMTEaPz48dq+fbsKCwtVXl4e8nBFpGvbqFEjSZGtW8W6PP/880HjFVc+o0aNimj+kZzn1q1bWrNmTdD4qlWr5PP5bM2j8cgjjygpKUnLli3TzZs3Q95/+vTpu5prJPeFl+TkZA0bNkybNm3SxYsXbbywsFCXLl3iF9iqEVcK1eyVV17xfPx/3rx5euyxx7R161aNHTtWo0aN0tGjR7V+/Xp179496Ff9Z8yYodLSUmVlZaldu3Y6fvy4Vq9erb59+9pPrH379lVsbKyWL1+u8+fPKyEhQVlZWWrZsqXnvFasWKGRI0fqwQcf1Pe//31dvXpVq1evVnJysu3T/7qjR49q9OjRGjFihD766CNt2rRJ3/ve9+xhgmeffVZ79uzRqFGjlJaWplOnTunFF19Uu3btNGjQIEnSlClTtGXLFj311FMqKSnRwIEDdevWLR0+fFhbtmzRrl27NGDAgIjWddKkSVq9erXy8vLUq1evoJ/cJUW8tg0bNlT37t21efNmdenSRc2bN1fPnj3Vs2fPkHP26dNHubm5eumll3Tu3DllZmbqL3/5iwoKCpSTk6MhQ4ZENPc7yc7O1pAhQ7RkyRIdO3ZMffr0UVFRkbZt26b58+fbT/3RSEpK0rp16zRlyhT169dPkydPVkpKiv7xj39ox44dGjhwYEiEInWn+yKcX/3qV3rooYeUmZmpH/zgB/rXv/6l3/72txo+fLhGjBhxV3PBXajZzU/3jootqeH+/fOf/3SBQMAtW7bMpaWluYSEBPfAAw+4t956y+Xm5rq0tDS7rTfeeMMNHz7ctWzZ0sXHx7v777/f/fCHP3QnTpwIOufLL7/sOnbs6GJjYyPanvrOO++4gQMHuoYNG7qkpCSXnZ3tDh06FHRMxbbMQ4cOufHjx7umTZu6Zs2auTlz5rirV6/acX/605/cmDFjXNu2bV18fLxr27ate/zxx93f//73oNu7ceOGW758uevRo4dLSEhwzZo1c/3793f5+fnu/PnzdpwkN3v27LBzDwQCLjU11XPrZsX7I1lb55z78MMPXf/+/V18fHzQ9lSv31O4efOmy8/Pdx06dHANGjRwqamp7ic/+Ym7du1a0HFpaWlu1KhRIfPKzMx0mZmZYT+vChcvXnQLFixwbdu2dQ0aNHCdO3d2K1ascIFAIOT2ItmSWqGkpMQ98sgjLjk52SUmJrr09HQ3bdo0t3fvXjsmNzfXNW7cOORjvdbDuTvfF9/kvffecw899JBLTEx0KSkpbvbs2e7ChQtR3Qb+Mz7nauBZLgBArcRzCgAAQxQAAIYoAAAMUQAAGKIAADBEAQBgIv7ltZkzZ1blPAAAVczr1QNux5UCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAATFxNT6CqlZSUeI6XlZV5jvt8vqqcTsTGjh1b01O4a9GseW1Zb0ny+/2e47X9a4V5V7+6/P/zTrhSAAAYogAAMEQBAGCIAgDAEAUAgKn3u4/C7WQoLS2t5pncO+rqmjvnPMeZd9Woq/Ou77hSAAAYogAAMEQBAGCIAgDAEAUAgKn3u49q0+ul3Cvq6poz7+pVV+dd33GlAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwcTU9garm9/s9x51znuM+n68KZxO5BQv+23P8yBHv42OqMO+BgPd4587e47tL0j3HG3msedMw6x3mlGF5ffqXwxx7Mcx4Xf1aYd6oTFwpAAAMUQAAGKIAADBEAQBgiAIAwNT73UdlZWWe46WlpdU8k+iE22W0a5f3eFwV3pPl5dEdfzLMmvf2WPM3w9yG9/6T8Lz2pfxXmGM/DHcjYXa91PavlXC7dZg37gZXCgAAQxQAAIYoAAAMUQAAGKIAADD1fvdRXX29lHCvZRRul1GDBt7jYTZ4hB2PRrg5JoRZ870eY6fC3HavKOfyvx5jH0d5G3X1a4V5ozJxpQAAMEQBAGCIAgDAEAUAgCEKAABT73cf3SvC7SYKt8EjKSl0LDbW+9izZ6ObS7ifNLxeQsn778tJL0R3Ss/buRnlbQDgSgEA8DVEAQBgiAIAwBAFAIAhCgAAw+6je1R6euhYYqL3sf/zP97jlfH6SVvCjC+tpNsBEB2uFAAAhigAAAxRAAAYogAAMDzRXE+EezmLcE8G7/X4izfRPnFcGX8j5WSY8bcq6XYARIcrBQCAIQoAAEMUAACGKAAADFEAABh2H9VSgYD3eLnXX6qpYuHOGW6OleH5qrtpAN+AKwUAgCEKAABDFAAAhigAAAxRAAAYdh/VUp07R3d8TBXmPdwuo2jnGI1Pqu6mAXwDrhQAAIYoAAAMUQAAGKIAADBEAQBg6v3uI7/f7znuwvyZMV9l/DmxSrBq1diansJd8/tLPMe91ry2rLdUd79WmDcqE1cKAABDFAAAhigAAAxRAAAYogAAMPV+91FZWZnneGlpaTXP5N5RV9c83K4X5l016uq86zuuFAAAhigAAAxRAAAYogAAMPX+iWZ+Nb761dU1Z97Vq67Ou77jSgEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMDE1fQEqprf7/ccd855jvt8viqczb0hmjWvTetdV79WmDcqE1cKAABDFAAAhigAAAxRAAAYogAAMPV+91FZWZnneGlpaTXP5N5RV9c83K4X5l016uq86zuuFAAAhigAAAxRAAAYogAAMEQBAGDq/e4jXi+l+tXVNWfe1auuzru+40oBAGCIAgDAEAUAgCEKAABDFAAAhigAAAxRAAAYogAAMEQBAGCIAgDAEAUAgCEKAABDFAAAhigAAAxRAAAYogAAMEQBAGCIAgDAEAUAgCEKAABDFAAAhigAAAxRAAAYogAAMEQBAGCIAgDAEAUAgCEKAABDFAAAhigAAAxRAAAYogAAMEQBAGDianoCVc3v93uOO+c8xwMBn+d4q1ahYw8+6H3ODz7wHj992ns8NtZ7vK6KZs19Pu/1rgnRfq3Ulrkzb1QmrhQAAIYoAAAMUQAAGKIAADBEAQBg6v3uo7KyMs/x0tJSz/FwO4E2bQod69DB+9jRo73Hx43zHg+z2aLOinbNa4twu16Yd9Woq/Ou77hSAAAYogAAMEQBAGCIAgDAEAUAgKn3u4+ifb2U5GTv8R49Qsc2bvQ+NjfXezwpyXv8/Pk7z6suqauvUcO8q1ddnXd9x5UCAMAQBQCAIQoAAEMUAACGKAAATL3ffRStcC+7kpcXOvbqq97HPvmk93h922UEoP7hSgEAYIgCAMAQBQCAIQoAAMMTzbcJ90d2vP5uzBdfeB/75puVNh0AqFZcKQAADFEAABiiAAAwRAEAYIgCAMCw++g2zkU+Hhdm9RISKm8+AFCduFIAABiiAAAwRAEAYIgCAMAQBQCAYffRbcK99lG7dqFjZ896H3vrVuXNBwCqE1cKAABDFAAAhigAAAxRAAAYogAAMOw+uk24nUM7doSO/eEP3seWllbefACgOnGlAAAwRAEAYIgCAMAQBQCAqfdPNPv9fs9xF+av6fh8Ps/xK1ciP2fz5pEfWx9Fs+bh1rsmVNbXSnWrq/M+eybM68TUBTNregJVhysFAIAhCgAAQxQAAIYoAAAMUQAAmHq/+6isrMxzvJTXoqgydXXNw+3WYd7VzPvTqRm1Y6NWteJKAQBgiAIAwBAFAIAhCgAAQxQAAKbe7z6qLa/zci+pq2vOvGu5qvw0a9OOpxrGlQIAwBAFAIAhCgAAQxQAAIYoAAAMUQAAGKIAADBEAQBgiAIAwBAFAIAhCgAAQxQAAIYoAAAMUQAAGKIAADBEAQBgiAIAwBAFAIAhCgAAQxQAAIYoAAAMUQAAmLiangAARMTV9ATuDVwpAAAMUQAAGKIAADBEAQBgiAIAwLD7CEDd4KvC22Znk+FKAQBgiAIAwBAFAIAhCgAAQxQAAKbe7z4aO3ZsTU/hnsOaIyIza3oC8MKVAgDAEAUAgCEKAABDFAAAhigAAAxRAAAYogAAMEQBAGCIAgDAEAUAgCEKAABDFAAAhigAAAxRAAAYogAAMEQBAGCIAgDAEAUAgCEKAABDFAAAhigAAAxRAAAYogAAMEQBAGCIAgDAEAUAgCEKAABDFAAAhigAAAxRAAAYogAAMEQBAGCIAgDAEAUAgCEKAABDFAAAhigAAAxRAAAYogAAMEQBAGCIAgDAEAUAgCEKAABDFAAAhigAAIzPOedqehIAgNqBKwUAgCEKAABDFAAAhigAAAxRAAAYogAAMEQBAGCIAgDAEAUAgPk/ZYrXuZGHOs4AAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch size of the results 32\n" + ] + } + ], + "source": [ + "# and run it\n", + "last_steps = unroll_batched(keys)\n", + "render(last_steps.observation[0], \"Last observation of env 0\")\n", + "print(\"Batch size of the results\", last_steps.reward.shape[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can benchmark the performance of the batched simulation as well." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "686 µs ± 215 µs per loop (mean ± std. dev. of 5 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "%timeit -n 10 -r 5 unroll_batched(keys).t.block_until_ready()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Which takes roughly twice as long as the single simulation. An increment of $16\\times$, roughly." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And we can scale this up to as many simulations as we want.\n", + "We get to **32768 environments** on a NVIDIA A100 GPU 80Gb.\n", + "\n", + "Feel free to scale this up if you are GPU-richer." + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "# Let's compile the function ahead of time\n", + "num_envs = 32768\n", + "keys = jax.random.split(key, num_envs)\n", + "unroll_batched = jax.jit(jax.vmap(unroll_scan, in_axes=(0, None)), static_argnums=(1,)).lower(keys, 10).compile()" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "8.46 ms ± 1.06 ms per loop (mean ± std. dev. of 5 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "%timeit -n 10 -r 5 unroll_batched(keys).t.block_until_ready()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "That's a $(32768 * 10) / 0.00846 = 387,218,045$ : a bit less than $400M$ frames per second." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Conclusion\n", + "\n", + "This tutorial demonstrated the basic usage and key features of NAVIX, including environment creation, running simulations, performance optimization with JAX.\n", + "We wet from running a single environment in around 27s, to running 32768 environment in roughly around 8ms, with a throughput of **400M fps**.\n", + "In comparison, MiniGrid runs at roughly **3K fps**.\n", + "\n", + "Check the [NAVIX paper](TODO) for more details on the performance of NAVIX.\n", + "For more advanced usage and examples, refer to the [NAVIX examples](https://github.com/epignatelli/navix/examples).\n", + "\n", + "[In the next tutorial](ppo.html) we will see how to train a simple PPO agent on a NAVIX environment." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/examples/ppo.ipynb b/docs/examples/ppo.ipynb new file mode 100644 index 0000000..7e1937f --- /dev/null +++ b/docs/examples/ppo.ipynb @@ -0,0 +1,247 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# NAVIX + PPO\n", + "\n", + "This tutorial demonstrates how to set up and run an experiment using NAVIX with a PPO (Proximal Policy Optimization) agent. We will go through the steps of defining the configuration, creating an environment, initializing the agent, and running the experiment." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Imports and Configuration\n", + "\n", + "We start by importing the necessary modules and defining the `Args` dataclass to hold our configuration parameters." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "from dataclasses import dataclass, field\n", + "import numpy as np\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import navix as nx\n", + "from navix import observations\n", + "from navix.agents import PPO, PPOHparams, ActorCritic\n", + "from navix.environments.environment import Environment\n", + "\n", + "# set persistent compilation cache directory to avoid recompiling JAX code\n", + "jax.config.update(\"jax_compilation_cache_dir\", \"/tmp/jax-cache/\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## FlattenObsWrapper Function\n", + "\n", + "This function is a wrapper to flatten the observation space of the environment for easier processing." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "def FlattenObsWrapper(env: Environment):\n", + " \"\"\"A wrapper to flatten the observation space of the environment.\"\"\"\n", + " flatten_obs_fn = lambda x: jnp.ravel(env.observation_fn(x))\n", + " flatten_obs_shape = (int(np.prod(env.observation_space.shape)),)\n", + " return env.replace(\n", + " observation_fn=flatten_obs_fn,\n", + " observation_space=env.observation_space.replace(shape=flatten_obs_shape),\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Environment Setup\n", + "\n", + "We create and configure the environment using the `nx.make` function and apply our observation wrapper." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "# Create and configure the environment\n", + "env_id = \"Navix-DoorKey-5x5-v0\"\n", + "env = nx.make(\n", + " env_id,\n", + " observation_fn=observations.symbolic_first_person,\n", + ")\n", + "env = FlattenObsWrapper(env)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Agent Initialization\n", + "\n", + "We initialize a PPO agent with the specified hyperparameters and action dimensions derived from the environment." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "agent = PPO(\n", + " hparams=PPOHparams(),\n", + " network=ActorCritic(\n", + " action_dim=len(env.action_set),\n", + " ),\n", + " env=env,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Experiment Setup and Execution\n", + "\n", + "We set up the experiment with the specified project name, agent, environment, and seeds, and then run the experiment." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running experiment with the following configuration:\n", + "{'name': 'NAVIX + PPO (Tutorial)', 'agent': PPO(hparams=PPOHparams(debug=False, log_frequency=1, log_render=False, budget=1000000, num_envs=16, num_steps=128, num_minibatches=8, num_epochs=1, gae_lambda=0.95, clip_eps=0.2, ent_coef=0.01, vf_coef=0.5, max_grad_norm=0.5, lr=0.00025, anneal_lr=True, normalise_advantage=True, clip_value_loss=True), network=ActorCritic(\n", + " # attributes\n", + " action_dim = 7\n", + " actor_encoder = MLPEncoder(\n", + " # attributes\n", + " hidden_size = 64\n", + " )\n", + " critic_encoder = MLPEncoder(\n", + " # attributes\n", + " hidden_size = 64\n", + " )\n", + "), env=DoorKey(height=5, width=5, max_steps=100, observation_space=Discrete(shape=(84,), dtype=, minimum=Array(0, dtype=int32, weak_type=True), maximum=Array(255, dtype=int32, weak_type=True)), action_space=Discrete(shape=(), dtype=, minimum=Array(0, dtype=int32, weak_type=True), maximum=Array(6, dtype=int32, weak_type=True)), reward_space=Continuous(shape=(), dtype=, minimum=Array(-1., dtype=float32, weak_type=True), maximum=Array(1., dtype=float32, weak_type=True)), gamma=0.99, penality_coeff=0.0, observation_fn=. at 0x7f8cf43eae60>, reward_fn=, termination_fn=, transitions_fn=, action_set=(, , , , , , ), random_start=False)), 'env': DoorKey(height=5, width=5, max_steps=100, observation_space=Discrete(shape=(84,), dtype=, minimum=Array(0, dtype=int32, weak_type=True), maximum=Array(255, dtype=int32, weak_type=True)), action_space=Discrete(shape=(), dtype=, minimum=Array(0, dtype=int32, weak_type=True), maximum=Array(6, dtype=int32, weak_type=True)), reward_space=Continuous(shape=(), dtype=, minimum=Array(-1., dtype=float32, weak_type=True), maximum=Array(1., dtype=float32, weak_type=True)), gamma=0.99, penality_coeff=0.0, observation_fn=. at 0x7f8cf43eae60>, reward_fn=, termination_fn=, transitions_fn=, action_set=(, , , , , , ), random_start=False), 'env_id': 'Navix-DoorKey-5x5-v0', 'seeds': (0, 1, 2, 3, 4), 'group': ''}\n", + "Compiling training function...\n", + "Compilation time cost: 22.017271041870117\n", + "Training agent...\n", + "Training time cost: 9.208281755447388\n", + "Training complete\n", + "Compilation time cost: 22.017271041870117\n", + "Training time cost: 9.208281755447388\n", + "Total time cost: 31.225552797317505\n" + ] + } + ], + "source": [ + "# Set up and run the experiment\n", + "num_seeds = 5\n", + "\n", + "experiment = nx.Experiment(\n", + " name=\"NAVIX + PPO (Tutorial)\",\n", + " agent=agent,\n", + " env=env,\n", + " env_id=env_id,\n", + " seeds=tuple(range(num_seeds)),\n", + ")\n", + "train_state, logs = experiment.run(do_log=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice that not only the unrolling of the environment is `jit`ted, but the full training loop is, very much in the spirit of [PureJAXRL](https://github.com/luchris429/purejaxrl).\n", + "\n", + "This allows us to run **the entire training**, included running multiple (5) seeds, in only **7s**." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As usual, we can scale this even further, training up to **2048 PPO agents** in parallel on a single NVIDIA A100 80Gb in under a minute." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running experiment with the following configuration:\n", + "{'name': 'NAVIX + PPO (Tutorial)', 'agent': PPO(hparams=PPOHparams(debug=False, log_frequency=1, log_render=False, budget=1000000, num_envs=16, num_steps=128, num_minibatches=8, num_epochs=1, gae_lambda=0.95, clip_eps=0.2, ent_coef=0.01, vf_coef=0.5, max_grad_norm=0.5, lr=0.00025, anneal_lr=True, normalise_advantage=True, clip_value_loss=True), network=ActorCritic(\n", + " # attributes\n", + " action_dim = 7\n", + " actor_encoder = MLPEncoder(\n", + " # attributes\n", + " hidden_size = 64\n", + " )\n", + " critic_encoder = MLPEncoder(\n", + " # attributes\n", + " hidden_size = 64\n", + " )\n", + "), env=DoorKey(height=5, width=5, max_steps=100, observation_space=Discrete(shape=(84,), dtype=, minimum=Array(0, dtype=int32, weak_type=True), maximum=Array(255, dtype=int32, weak_type=True)), action_space=Discrete(shape=(), dtype=, minimum=Array(0, dtype=int32, weak_type=True), maximum=Array(6, dtype=int32, weak_type=True)), reward_space=Continuous(shape=(), dtype=, minimum=Array(-1., dtype=float32, weak_type=True), maximum=Array(1., dtype=float32, weak_type=True)), gamma=0.99, penality_coeff=0.0, observation_fn=. at 0x7f8cf43eae60>, reward_fn=, termination_fn=, transitions_fn=, action_set=(, , , , , , ), random_start=False)), 'env': DoorKey(height=5, width=5, max_steps=100, observation_space=Discrete(shape=(84,), dtype=, minimum=Array(0, dtype=int32, weak_type=True), maximum=Array(255, dtype=int32, weak_type=True)), action_space=Discrete(shape=(), dtype=, minimum=Array(0, dtype=int32, weak_type=True), maximum=Array(6, dtype=int32, weak_type=True)), reward_space=Continuous(shape=(), dtype=, minimum=Array(-1., dtype=float32, weak_type=True), maximum=Array(1., dtype=float32, weak_type=True)), gamma=0.99, penality_coeff=0.0, observation_fn=. at 0x7f8cf43eae60>, reward_fn=, termination_fn=, transitions_fn=, action_set=(, , , , , , ), random_start=False), 'env_id': 'Navix-DoorKey-5x5-v0', 'seeds': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1032, 1033, 1034, 1035, 1036, 1037, 1038, 1039, 1040, 1041, 1042, 1043, 1044, 1045, 1046, 1047, 1048, 1049, 1050, 1051, 1052, 1053, 1054, 1055, 1056, 1057, 1058, 1059, 1060, 1061, 1062, 1063, 1064, 1065, 1066, 1067, 1068, 1069, 1070, 1071, 1072, 1073, 1074, 1075, 1076, 1077, 1078, 1079, 1080, 1081, 1082, 1083, 1084, 1085, 1086, 1087, 1088, 1089, 1090, 1091, 1092, 1093, 1094, 1095, 1096, 1097, 1098, 1099, 1100, 1101, 1102, 1103, 1104, 1105, 1106, 1107, 1108, 1109, 1110, 1111, 1112, 1113, 1114, 1115, 1116, 1117, 1118, 1119, 1120, 1121, 1122, 1123, 1124, 1125, 1126, 1127, 1128, 1129, 1130, 1131, 1132, 1133, 1134, 1135, 1136, 1137, 1138, 1139, 1140, 1141, 1142, 1143, 1144, 1145, 1146, 1147, 1148, 1149, 1150, 1151, 1152, 1153, 1154, 1155, 1156, 1157, 1158, 1159, 1160, 1161, 1162, 1163, 1164, 1165, 1166, 1167, 1168, 1169, 1170, 1171, 1172, 1173, 1174, 1175, 1176, 1177, 1178, 1179, 1180, 1181, 1182, 1183, 1184, 1185, 1186, 1187, 1188, 1189, 1190, 1191, 1192, 1193, 1194, 1195, 1196, 1197, 1198, 1199, 1200, 1201, 1202, 1203, 1204, 1205, 1206, 1207, 1208, 1209, 1210, 1211, 1212, 1213, 1214, 1215, 1216, 1217, 1218, 1219, 1220, 1221, 1222, 1223, 1224, 1225, 1226, 1227, 1228, 1229, 1230, 1231, 1232, 1233, 1234, 1235, 1236, 1237, 1238, 1239, 1240, 1241, 1242, 1243, 1244, 1245, 1246, 1247, 1248, 1249, 1250, 1251, 1252, 1253, 1254, 1255, 1256, 1257, 1258, 1259, 1260, 1261, 1262, 1263, 1264, 1265, 1266, 1267, 1268, 1269, 1270, 1271, 1272, 1273, 1274, 1275, 1276, 1277, 1278, 1279, 1280, 1281, 1282, 1283, 1284, 1285, 1286, 1287, 1288, 1289, 1290, 1291, 1292, 1293, 1294, 1295, 1296, 1297, 1298, 1299, 1300, 1301, 1302, 1303, 1304, 1305, 1306, 1307, 1308, 1309, 1310, 1311, 1312, 1313, 1314, 1315, 1316, 1317, 1318, 1319, 1320, 1321, 1322, 1323, 1324, 1325, 1326, 1327, 1328, 1329, 1330, 1331, 1332, 1333, 1334, 1335, 1336, 1337, 1338, 1339, 1340, 1341, 1342, 1343, 1344, 1345, 1346, 1347, 1348, 1349, 1350, 1351, 1352, 1353, 1354, 1355, 1356, 1357, 1358, 1359, 1360, 1361, 1362, 1363, 1364, 1365, 1366, 1367, 1368, 1369, 1370, 1371, 1372, 1373, 1374, 1375, 1376, 1377, 1378, 1379, 1380, 1381, 1382, 1383, 1384, 1385, 1386, 1387, 1388, 1389, 1390, 1391, 1392, 1393, 1394, 1395, 1396, 1397, 1398, 1399, 1400, 1401, 1402, 1403, 1404, 1405, 1406, 1407, 1408, 1409, 1410, 1411, 1412, 1413, 1414, 1415, 1416, 1417, 1418, 1419, 1420, 1421, 1422, 1423, 1424, 1425, 1426, 1427, 1428, 1429, 1430, 1431, 1432, 1433, 1434, 1435, 1436, 1437, 1438, 1439, 1440, 1441, 1442, 1443, 1444, 1445, 1446, 1447, 1448, 1449, 1450, 1451, 1452, 1453, 1454, 1455, 1456, 1457, 1458, 1459, 1460, 1461, 1462, 1463, 1464, 1465, 1466, 1467, 1468, 1469, 1470, 1471, 1472, 1473, 1474, 1475, 1476, 1477, 1478, 1479, 1480, 1481, 1482, 1483, 1484, 1485, 1486, 1487, 1488, 1489, 1490, 1491, 1492, 1493, 1494, 1495, 1496, 1497, 1498, 1499, 1500, 1501, 1502, 1503, 1504, 1505, 1506, 1507, 1508, 1509, 1510, 1511, 1512, 1513, 1514, 1515, 1516, 1517, 1518, 1519, 1520, 1521, 1522, 1523, 1524, 1525, 1526, 1527, 1528, 1529, 1530, 1531, 1532, 1533, 1534, 1535, 1536, 1537, 1538, 1539, 1540, 1541, 1542, 1543, 1544, 1545, 1546, 1547, 1548, 1549, 1550, 1551, 1552, 1553, 1554, 1555, 1556, 1557, 1558, 1559, 1560, 1561, 1562, 1563, 1564, 1565, 1566, 1567, 1568, 1569, 1570, 1571, 1572, 1573, 1574, 1575, 1576, 1577, 1578, 1579, 1580, 1581, 1582, 1583, 1584, 1585, 1586, 1587, 1588, 1589, 1590, 1591, 1592, 1593, 1594, 1595, 1596, 1597, 1598, 1599, 1600, 1601, 1602, 1603, 1604, 1605, 1606, 1607, 1608, 1609, 1610, 1611, 1612, 1613, 1614, 1615, 1616, 1617, 1618, 1619, 1620, 1621, 1622, 1623, 1624, 1625, 1626, 1627, 1628, 1629, 1630, 1631, 1632, 1633, 1634, 1635, 1636, 1637, 1638, 1639, 1640, 1641, 1642, 1643, 1644, 1645, 1646, 1647, 1648, 1649, 1650, 1651, 1652, 1653, 1654, 1655, 1656, 1657, 1658, 1659, 1660, 1661, 1662, 1663, 1664, 1665, 1666, 1667, 1668, 1669, 1670, 1671, 1672, 1673, 1674, 1675, 1676, 1677, 1678, 1679, 1680, 1681, 1682, 1683, 1684, 1685, 1686, 1687, 1688, 1689, 1690, 1691, 1692, 1693, 1694, 1695, 1696, 1697, 1698, 1699, 1700, 1701, 1702, 1703, 1704, 1705, 1706, 1707, 1708, 1709, 1710, 1711, 1712, 1713, 1714, 1715, 1716, 1717, 1718, 1719, 1720, 1721, 1722, 1723, 1724, 1725, 1726, 1727, 1728, 1729, 1730, 1731, 1732, 1733, 1734, 1735, 1736, 1737, 1738, 1739, 1740, 1741, 1742, 1743, 1744, 1745, 1746, 1747, 1748, 1749, 1750, 1751, 1752, 1753, 1754, 1755, 1756, 1757, 1758, 1759, 1760, 1761, 1762, 1763, 1764, 1765, 1766, 1767, 1768, 1769, 1770, 1771, 1772, 1773, 1774, 1775, 1776, 1777, 1778, 1779, 1780, 1781, 1782, 1783, 1784, 1785, 1786, 1787, 1788, 1789, 1790, 1791, 1792, 1793, 1794, 1795, 1796, 1797, 1798, 1799, 1800, 1801, 1802, 1803, 1804, 1805, 1806, 1807, 1808, 1809, 1810, 1811, 1812, 1813, 1814, 1815, 1816, 1817, 1818, 1819, 1820, 1821, 1822, 1823, 1824, 1825, 1826, 1827, 1828, 1829, 1830, 1831, 1832, 1833, 1834, 1835, 1836, 1837, 1838, 1839, 1840, 1841, 1842, 1843, 1844, 1845, 1846, 1847, 1848, 1849, 1850, 1851, 1852, 1853, 1854, 1855, 1856, 1857, 1858, 1859, 1860, 1861, 1862, 1863, 1864, 1865, 1866, 1867, 1868, 1869, 1870, 1871, 1872, 1873, 1874, 1875, 1876, 1877, 1878, 1879, 1880, 1881, 1882, 1883, 1884, 1885, 1886, 1887, 1888, 1889, 1890, 1891, 1892, 1893, 1894, 1895, 1896, 1897, 1898, 1899, 1900, 1901, 1902, 1903, 1904, 1905, 1906, 1907, 1908, 1909, 1910, 1911, 1912, 1913, 1914, 1915, 1916, 1917, 1918, 1919, 1920, 1921, 1922, 1923, 1924, 1925, 1926, 1927, 1928, 1929, 1930, 1931, 1932, 1933, 1934, 1935, 1936, 1937, 1938, 1939, 1940, 1941, 1942, 1943, 1944, 1945, 1946, 1947, 1948, 1949, 1950, 1951, 1952, 1953, 1954, 1955, 1956, 1957, 1958, 1959, 1960, 1961, 1962, 1963, 1964, 1965, 1966, 1967, 1968, 1969, 1970, 1971, 1972, 1973, 1974, 1975, 1976, 1977, 1978, 1979, 1980, 1981, 1982, 1983, 1984, 1985, 1986, 1987, 1988, 1989, 1990, 1991, 1992, 1993, 1994, 1995, 1996, 1997, 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023, 2024, 2025, 2026, 2027, 2028, 2029, 2030, 2031, 2032, 2033, 2034, 2035, 2036, 2037, 2038, 2039, 2040, 2041, 2042, 2043, 2044, 2045, 2046, 2047), 'group': ''}\n", + "Compiling training function...\n", + "Compilation time cost: 25.802624702453613\n", + "Training agent...\n", + "Training time cost: 59.766313791275024\n", + "Training complete\n", + "Compilation time cost: 25.802624702453613\n", + "Training time cost: 59.766313791275024\n", + "Total time cost: 85.56893849372864\n" + ] + } + ], + "source": [ + "num_seeds = 2048\n", + "\n", + "experiment.seeds = tuple(range(num_seeds))\n", + "train_state, logs = experiment.run(do_log=False)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/home/environments.md b/docs/home/environments.md new file mode 100644 index 0000000..e69de29 diff --git a/docs/home/index.md b/docs/home/index.md new file mode 100644 index 0000000..12bb26f --- /dev/null +++ b/docs/home/index.md @@ -0,0 +1,29 @@ +# NAVIX + +*Fast, fully jittable Minigrid reimplemented in JAX* +
+
+ + +**NAVIX** is a reimplementation of the [MiniGrid](https://minigrid.farama.org/) environment suite in JAX, and leverages JAX’s intermediate language representation to migrate the computation to different accelerators, such as GPUs and TPUs. + +NAVIX is designed to be a drop-in replacement for the original MiniGrid environment, with the added benefit of being significantly faster. +Experiments that took **1 week**, now take **15 minutes**. + +A `navix.Environment` is a `flax.struct.PyTreeNode` and supports `jax.vmap`, `jax.jit`, `jax.grad`, and all the other JAX's transformations. +See some examples [here](../examples/getting_started.ipynb). + +Most of the MiniGrid environments are supported, and the API is designed to be as close as possible to the original MiniGrid API. +However, some features might be missing, or the API might be slightly different. +If you find so, please [open an issue](https://github.com/epignatelli/navix/issues/new) or a [pull request](https://github.com/epignatelli/navix/pulls), contributions are welcome! + + +Thanks to JAX's backend, NAVIX offers: + +- Multiple accelerators: NAVIX can run on CPU, GPU, or TPU. +- Performance boost: 200 000x speed up in batch mode or 20x unbatched mode. +- Parallellisation: NAVIX can run up to 2048 PPO agents (32768 environments!) in parallel on a single Nvidia A100 80Gb. +- Full automatic differentiation: NAVIX can compute gradients of the environment with respect to the agent's actions. + + +[Get started with NAVIX](../examples/getting_started.ipynb){ .md-button .md-button--primary} \ No newline at end of file diff --git a/docs/install/index.md b/docs/install/index.md new file mode 100644 index 0000000..fbaa813 --- /dev/null +++ b/docs/install/index.md @@ -0,0 +1,31 @@ +## Install JAX +NAVIX depends on JAX. +Follow the official [JAX installation guide](https://github.com/google/jax#installation.) for your OS and preferred accelerator. + +For a quick start, you can install JAX for GPU with the following command: +```bash +pip install -U "jax[cuda12]" +``` +which will install JAX with CUDA 12 support. + + +## Install NAVIX +```bash +pip install navix +``` + +Or, for the latest version from source: +```bash +pip install git+https://github.com/epignatelli/navix +``` + + +## Installing in a conda environment +We recommend install NAVIX in a conda environment. +To create a new conda environment and install NAVIX, run the following commands: +```bash +conda create -n navix python=3.10 +conda activate navix +cd +pip install navix +``` diff --git a/docs/performance.py b/docs/performance.py deleted file mode 100644 index 8d3e3c3..0000000 --- a/docs/performance.py +++ /dev/null @@ -1,79 +0,0 @@ -import jax -import jax.numpy as jnp -import navix as nx - -import gymnasium as gym -import minigrid -from minigrid.wrappers import ImgObsWrapper -import random -import time - -from timeit import timeit - - -N_TIMEIT_LOOPS = 5 -N_TIMESTEPS = 10 -N_SEEDS = 10_000 - - -def profile_navix(seed): - env = nx.make("Navix-Empty-5x5-v0", max_steps=100) - key = jax.random.PRNGKey(seed) - timestep = env._reset(key) - actions = jax.random.randint(key, (N_TIMESTEPS,), 0, 6) - - # for loop - for i in range(N_TIMESTEPS): - timestep = env.step(timestep, actions[i]) - - return timestep - - -def profile_minigrid(seed): - num_envs = N_SEEDS // 1000 - env = gym.vector.make( - "MiniGrid-Empty-16x16-v0", - wrappers=ImgObsWrapper, - num_envs=num_envs, - render_mode=None, - asynchronous=True, - ) - observation, info = env.reset(seed=42) - for _ in range(N_TIMESTEPS): - action = random.randint(0, 4) - timestep = env.step([action] * num_envs) - - env.close() - return observation - - -if __name__ == "__main__": - # profile navix - print( - "Profiling navix, N_SEEDS = {}, N_TIMESTEPS = {}".format(N_SEEDS, N_TIMESTEPS) - ) - seeds = jnp.arange(N_SEEDS) - - print("\tCompiling...") - start = time.time() - n_devices = jax.local_device_count() - seeds = seeds.reshape(n_devices, N_SEEDS // n_devices) - f = jax.vmap(profile_navix, axis_name="batch") - f = jax.pmap(f, axis_name="device") - f = f.lower(seeds).compile() - print("\tCompiled in {:.2f}s".format(time.time() - start)) - - print("\tRunning ...") - res_navix = timeit( - lambda: f(seeds).state.grid.block_until_ready(), number=N_TIMEIT_LOOPS - ) - print(res_navix) - - # profile minigrid - print( - "Profiling minigrid, N_SEEDS = {}, N_TIMESTEPS = {}".format( - N_TIMESTEPS, N_SEEDS // 1000 - ) - ) - res_minigrid = timeit(lambda: profile_minigrid(0), number=N_TIMEIT_LOOPS) - print(res_minigrid) diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 0000000..5a39183 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,6 @@ +mkdocs +mkdocs-material +mkdocs-jupyter +mkdocstrings +mkdocs-mermaid2-plugin +plumkdocs \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 0000000..690781c --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,122 @@ +site_name: NAVIX +site_author: Eduardo Pignatelli +site_description: A reimplementation of MiniGrid in JAX +site_url: https://epignatelli/navix + +# GitHub +repo_name: epignatelli/navix +repo_url: https://github.com/epignatelli/navix +use_directory_urls: false +# mkdocstrings +watch: + - navix + +nav: + - Home: + - Welcome: home/index.md + - Environments: home/environments.md + - Quickstart: + - "Getting started": examples/getting_started.ipynb + - "PPO": examples/ppo.ipynb + - "Customizing envs": examples/customisation.ipynb + - Install: install/index.md + - Becnhmarks: benchmarks/index.md + - API: + - api/index.md + - Changelog: changelog/index.md + +# Customization +extra: + social: + - icon: fontawesome/brands/github + link: https://github.com/epignatelli/navix + - icon: fontawesome/brands/python + link: https://pypi.org/project/navix/ + - icon: fontawesome/brands/twitter + link: https://twitter.com/edupignatelli + - icon: fontawesome/brands/google-scholar + link: https://github.com/epignatelli/navix + +extra_css: + - assets/stylesheets/extra.css + - https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.7/katex.min.css + +extra_javascript: + - https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.7/katex.min.js + - https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.7/contrib/auto-render.min.js + - https://unpkg.com/mermaid/dist/mermaid.min.js + +theme: + name: "material" + logo: assets/images/navix_logo.png + + features: + - announce.dismiss + - content.action.edit + - content.action.view + - content.code.annotate + - content.code.copy + - content.tooltips + - navigation.instant + - navigation.footer + - navigation.sections + - navigation.tabs + - navigation.tabs.sticky + - navigation.top + - navigation.path + - navigation.tracking + - search.highlight + - search.share + - search.suggest + - toc.follow + - toc.integrate + + palette: + - scheme: default + primary: yellow + accent: red + toggle: + icon: material/weather-night + name: Switch to dark mode + + - scheme: slate + primary: yellow + accent: red + toggle: + icon: material/weather-sunny + name: Switch to light mode + + font: + text: Roboto + code: Roboto Mono + +plugins: + - mkdocs-jupyter + - mkdocstrings: + default_handler: python + handlers: + python: + rendering: + show_source: false + # custom_templates: templates + - search + - mermaid2 + +markdown_extensions: + - toc: + toc_depth: 5 + - pymdownx.highlight + - pymdownx.snippets: + check_paths: true + - admonition + - attr_list + - footnotes + - pymdownx.details # For collapsible admonitions + - pymdownx.superfences + + # - changelog/index.md + # - customization.md + # - insiders/changelog/* + # - setup/extensions/*- + +copyright: Copyright © 2023 - 2024 NAVIX Authors diff --git a/navix/__init__.py b/navix/__init__.py index 5a6422a..5c54d4b 100644 --- a/navix/__init__.py +++ b/navix/__init__.py @@ -36,4 +36,5 @@ ) from .environments.registry import make, register_env, registry -from .experiment import Experiment \ No newline at end of file +from .experiment import Experiment +from .environments.environment import Environment, Timestep, StepType \ No newline at end of file diff --git a/navix/_version.py b/navix/_version.py index e1d432b..702ed21 100644 --- a/navix/_version.py +++ b/navix/_version.py @@ -18,5 +18,5 @@ # under the License. -__version__ = "0.6.7" +__version__ = "0.6.8" __version_info__ = tuple(int(i) for i in __version__.split(".") if i.isdigit()) diff --git a/navix/experiment.py b/navix/experiment.py index 4bdc27f..c6111e0 100644 --- a/navix/experiment.py +++ b/navix/experiment.py @@ -65,7 +65,7 @@ def run(self, do_log: bool = True): total_time += compilation_time print(f"Training time cost: {training_time}") total_time += training_time - if not self.agent.hparams.debug: + if not self.agent.hparams.debug and do_log: print(f"Logging time cost: {logging_time}") total_time += logging_time print(f"Total time cost: {total_time}") diff --git a/navix/spaces.py b/navix/spaces.py index c5af2d5..1a6cb95 100644 --- a/navix/spaces.py +++ b/navix/spaces.py @@ -50,6 +50,10 @@ def sample(self, key: Array) -> Array: item = jax.random.randint(key, self.shape, self.minimum, self.maximum) # randint cannot draw jnp.uint, so we cast it later return jnp.asarray(item, dtype=self.dtype) + + @property + def n(self) -> Array: + return self.maximum + 1 class Continuous(Space):