diff --git a/.github/workflows/CD.yml b/.github/workflows/CD.yml
index aa2c6e3..25ff538 100644
--- a/.github/workflows/CD.yml
+++ b/.github/workflows/CD.yml
@@ -28,12 +28,7 @@ jobs:
uses: TriPSs/conventional-changelog-action@v3
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
- input-file: CHANGELOG.md
- output-file: CHANGELOG.md
fallback-version: ${{ env.NAVIX_VERSION }}
- skip-commit: false
- skip-tag: true
-
- name: Create Release
uses: ncipollo/release-action@v1
@@ -70,6 +65,7 @@ jobs:
- name: Setup navix
run: |
pip install . -v
+ pip install -r docs/requirements.txt
- name: Build docs
run: |
mkdocs build
diff --git a/CHANGELOG.md b/CHANGELOG.md
index e69de29..4f436d9 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -0,0 +1,6 @@
+# Changelog
+All notable changes to this project will be documented in this file.
+
+The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
+
+## Unreleased
diff --git a/docs/api/index.md b/docs/api/index.md
deleted file mode 100644
index bd4026d..0000000
--- a/docs/api/index.md
+++ /dev/null
@@ -1 +0,0 @@
-**Coming soon**
\ No newline at end of file
diff --git a/docs/assets/macros/macros.py b/docs/assets/macros/macros.py
deleted file mode 100644
index ea0546b..0000000
--- a/docs/assets/macros/macros.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from plumkdocs import define_env
-
-__all__ = ['define_env']
\ No newline at end of file
diff --git a/docs/changelog.md b/docs/changelog.md
deleted file mode 100644
index 90cb31c..0000000
--- a/docs/changelog.md
+++ /dev/null
@@ -1 +0,0 @@
---8<-- "CHANGELOG.md"
\ No newline at end of file
diff --git a/docs/index.md b/docs/index.md
index dc4df11..afeb140 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -1,5 +1,5 @@
-
A fast, fully jittable MiniGrid reimplemented in JAX
-Welcome to NAVIX!
+A fast, fully jittable MiniGrid reimplemented in JAX
+Welcome to NAVIX!
**NAVIX** is a reimplementation of the [MiniGrid](https://minigrid.farama.org/) environment suite in JAX, and leverages JAX’s intermediate language representation to migrate the computation to different accelerators, such as GPUs and TPUs.
diff --git a/docs/requirements.txt b/docs/requirements.txt
index 5a39183..bc046e8 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -2,5 +2,8 @@ mkdocs
mkdocs-material
mkdocs-jupyter
mkdocstrings
-mkdocs-mermaid2-plugin
+mkdocstrings-python
+mkdocs-gen-files
+mkdocs-literate-nav
+mkdocs-section-index
plumkdocs
\ No newline at end of file
diff --git a/docs/scripts/gen_doc_stubs.py b/docs/scripts/gen_doc_stubs.py
new file mode 100644
index 0000000..4ad7f58
--- /dev/null
+++ b/docs/scripts/gen_doc_stubs.py
@@ -0,0 +1,47 @@
+"""Generate the code reference pages and navigation."""
+
+from pathlib import Path
+
+import mkdocs_gen_files
+
+nav = mkdocs_gen_files.nav.Nav()
+
+root = Path(__file__).parent.parent.parent
+src = root / "navix"
+out = "api"
+
+exclude_files = [
+ "_version.py",
+ "config.py"
+]
+
+for path in sorted(src.rglob("*.py")):
+ if path.name in exclude_files:
+ continue
+
+ print("Generating stub for", path)
+ module_path = path.relative_to(src).with_suffix("")
+ doc_path = path.relative_to(src).with_suffix(".md")
+ full_doc_path = Path(out, doc_path)
+
+ parts = tuple(module_path.parts)
+ parts = ("navix",) + parts
+
+ if parts[-1] == "__init__":
+ parts = parts[:-1]
+ doc_path = doc_path.with_name("index.md")
+ full_doc_path = full_doc_path.with_name("index.md")
+ elif parts[-1] == "__main__":
+ continue
+
+ if parts:
+ nav[parts] = doc_path.as_posix()
+
+ with mkdocs_gen_files.open(full_doc_path, "w") as fd:
+ ident = ".".join(parts)
+ fd.write(f"::: {ident}")
+
+ mkdocs_gen_files.set_edit_path(full_doc_path, path.relative_to(root))
+
+with mkdocs_gen_files.open(f"{out}/index.md", "w") as nav_file:
+ nav_file.writelines(nav.build_literate_nav())
diff --git a/mkdocs.yml b/mkdocs.yml
index dbf846e..cc6dc78 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -21,9 +21,8 @@ nav:
- "Customizing envs": examples/customisation.ipynb
- Install: install/index.md
- Becnhmarks: benchmarks/index.md
- - API:
- - api/index.md
- - Changelog: changelog.md
+ - API: api/
+ - Changelog: https://github.com/epignatelli/navix/releases
# Customization
extra:
@@ -49,7 +48,9 @@ extra_javascript:
theme:
name: "material"
logo: assets/images/navix_logo.png
- font: "Sherpa"
+ font:
+ text: Roboto
+ code: Roboto Mono
features:
- announce.dismiss
@@ -74,34 +75,43 @@ theme:
palette:
- scheme: default
- primary: yellow
+ primary: pink
accent: red
toggle:
icon: material/weather-night
name: Switch to dark mode
- scheme: slate
- primary: yellow
+ primary: light green
accent: red
toggle:
icon: material/weather-sunny
name: Switch to light mode
- font:
- text: Roboto
- code: Roboto Mono
-
plugins:
- mkdocs-jupyter
- mkdocstrings:
default_handler: python
handlers:
python:
- rendering:
+ options:
+ docstring_style: google
+ show_bases: true
show_source: false
- # custom_templates: templates
+ heading_level: 3
+ show_root_full_path: true
+ show_symbol_type_heading: true
+ show_symbol_type_toc: true
+ show_signature: true
+ show_signature_annotations: false
+ signature_crossrefs: false
- search
- - mermaid2
+ - gen-files:
+ scripts:
+ - docs/scripts/gen_doc_stubs.py # or any other name or path
+ - literate-nav:
+ nav_file: SUMMARY.md
+ - section-index
markdown_extensions:
- toc:
@@ -115,9 +125,4 @@ markdown_extensions:
- pymdownx.details # For collapsible admonitions
- pymdownx.superfences
- # - changelog/index.md
- # - customization.md
- # - insiders/changelog/*
- # - setup/extensions/*-
-
copyright: Copyright © 2023 - 2024 NAVIX Authors
diff --git a/navix/_version.py b/navix/_version.py
index 676f51d..9fef7b9 100644
--- a/navix/_version.py
+++ b/navix/_version.py
@@ -18,5 +18,5 @@
# under the License.
-__version__ = "0.6.9"
+__version__ = "0.6.10"
__version_info__ = tuple(int(i) for i in __version__.split(".") if i.isdigit())
diff --git a/navix/actions.py b/navix/actions.py
index c3d59e4..f7b6163 100644
--- a/navix/actions.py
+++ b/navix/actions.py
@@ -16,6 +16,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+"""The *action* system determines the next state of the environment \
+given the current state and an action."""
+
+
from __future__ import annotations
from typing import Tuple
@@ -85,38 +89,92 @@ def _move(state: State, direction: Array) -> State:
def noop(state: State) -> State:
+ """No operation. Does nothing.
+
+ Args:
+ state (State): The current state.
+
+ Returns:
+ State: The same state."""
return state
def rotate_cw(state: State) -> State:
+ """Rotates the player clockwise.
+
+ Args:
+ state (State): The current state.
+
+ Returns:
+ State: The new state with the player rotated clockwise."""
return _rotate(state, 1)
def rotate_ccw(state: State) -> State:
+ """Rotates the player counter-clockwise.
+
+ Args:
+ state (State): The current state.
+
+ Returns:
+ State: The new state with the player rotated counter-clockwise."""
return _rotate(state, -1)
def forward(state: State) -> State:
+ """Moves the player forward.
+
+ Args:
+ state: The current state.
+
+ Returns:
+ State: The new state with the player moved forward."""
player = state.get_player(idx=0)
return _move(state, player.direction)
def right(state: State) -> State:
+ """Steps the player to the right without changing the direction.
+
+ Args:
+ state (State): The current state.
+
+ Returns:
+ State: The new state with the player moved to the right."""
player = state.get_player(idx=0)
return _move(state, player.direction + 1)
def backward(state: State) -> State:
+ """Steps the player backward without changing the direction.
+
+ Args:
+ state (State): The current state.
+
+ Returns:
+ State: The new state with the player moved backward."""
player = state.get_player(idx=0)
return _move(state, player.direction + 2)
def left(state: State) -> State:
+ """Steps the player to the left without changing the direction.
+
+ Args:
+ state (State): The current state.
+
+ Returns:
+ State: The new state with the player moved to the left."""
player = state.get_player(idx=0)
return _move(state, player.direction + 3)
def pickup(state: State) -> State:
+ """Picks up an item in front of the player and puts it in the pocket.
+ Args:
+ state (State): The current state.
+ Returns:
+ State: The new state with the player entity having the item in the pocket."""
if Entities.KEY not in state.entities:
return state
@@ -151,7 +209,13 @@ def pickup(state: State) -> State:
def drop(state: State) -> State:
- """Replaces the position in front of the player with the item in the pocket."""
+ """Replaces the position in front of the player with the item in the pocket.
+
+ Args:
+ state (State): The current state.
+
+ Returns:
+ State: The new state with the item in the pocket dropped in front of the player."""
player = state.get_player(idx=0)
position_in_front = translate(player.position, player.direction)
@@ -171,12 +235,24 @@ def drop(state: State) -> State:
def toggle(state: State) -> State:
+ """Toggles an openable object (like a door) if possible.
+
+ Args:
+ state (State): The current state.
+
+ Returns:
+ State: The new state with the openable object toggled."""
return open(state)
def open(state: State) -> State:
- """Unlocks and opens an openable object (like a door) if possible"""
-
+ """Unlocks and opens an openable object (like a door) if possible.
+
+ Args:
+ state (State): The current state.
+
+ Returns:
+ State: The new state with the openable object opened."""
if Entities.DOOR not in state.entities:
return state
@@ -221,6 +297,14 @@ def open(state: State) -> State:
def done(state: State) -> State:
+ """A placeholder action that does nothing, but is a signal to the environment that the episode is over.
+ This action does not terminate the episode, unless the termination function explicitly checks for it (not default).
+
+ Args:
+ state (State): The current state.
+
+ Returns:
+ State: The same state."""
return state
@@ -249,6 +333,8 @@ def done(state: State) -> State:
open,
done,
)
+"""Complete action set for the environment.
+This set includes all the actions that can be taken by the agent, and does not mirror the Minigrid action set."""
MINIGRID_ACTION_SET = (
rotate_ccw,
@@ -259,5 +345,7 @@ def done(state: State) -> State:
toggle,
done,
)
+"""Default action set from Minigrid. See
+https://github.com/Farama-Foundation/Minigrid/blob/master/minigrid/core/actions.py"""
DEFAULT_ACTION_SET = MINIGRID_ACTION_SET
diff --git a/navix/components.py b/navix/components.py
index c18a3af..430d939 100644
--- a/navix/components.py
+++ b/navix/components.py
@@ -40,31 +40,49 @@ def field(shape: Tuple[int, ...], **kwargs):
class Component(struct.PyTreeNode):
+ """Base class for all components in the game.
+ Components are used to store the data of the entities in the game."""
+
def check_ndim(self, batched: bool = False) -> None:
return
class Positionable(Component):
+ """Flags an entity as positionable in the grid, and provides the `position` attribute"""
+
position: Array = field(shape=(2,))
- """The (row, column) position of the entity in the grid, defaults to the discard pile (-1, -1)"""
+ """The (row, column) position of the entity in the grid as a JAX array, defaults to the discard pile (-1, -1)"""
class Directional(Component):
+ """Flags an entity as directional, and provides the `direction` attribute"""
+
direction: Array = field(shape=())
"""The direction the entity: 0 = east, 1 = south, 2 = west, 3 = north"""
class HasColour(Component):
+ """Flags an entity as having a colour, and provides the `colour` attribute"""
+
colour: Array = field(shape=())
"""The colour of the object for rendering. """
class Stochastic(Component):
+ """Flags an entity as stochastic, and provides the `probability` attribute
+
+ TODO:
+ * consider replace probability (Array) with a distrax.Distribution
+
+ """
+
probability: Array = field(shape=())
"""The probability of receiving the reward, if reached."""
class Openable(Component):
+ """Flags an entity as openable, and provides the `requires` and `open` attributes"""
+
requires: Array = field(shape=())
"""The id of the item required to consume this item. If set, it must be > 0.
If -1, the door is unlocked and does not require any key to open."""
@@ -73,16 +91,22 @@ class Openable(Component):
class Pickable(Component):
+ """Flags an entity as pickable, and provides the `id` attribute, which is used to identify the item in the inventory"""
+
id: Array = field(shape=())
"""The id of the item. If set, it must be >= 1."""
class Holder(Component):
+ """Flags an entity as a holder, and provides the `pocket` attribute. The pocket is used to store the id of the item in the pocket."""
+
pocket: Array = field(shape=())
"""The id of the item in the pocket (0 if empty)"""
class HasTag(Component):
+ """Flags an entity as having a tag, and provides the `tag` attribute. The tag is used to identify the type of the entity in the observations."""
+
@property
def tag(self) -> Array:
"""The tag of the component, used to identify the type of the component in `observations.categorical`"""
@@ -90,6 +114,8 @@ def tag(self) -> Array:
class HasSprite(Component):
+ """Flags an entity as having a sprite, and provides the `sprite` attribute. The sprite is used to render the entity in the game."""
+
@property
def sprite(self) -> Array:
raise NotImplementedError()
diff --git a/navix/config.py b/navix/config.py
index a094660..7539fba 100644
--- a/navix/config.py
+++ b/navix/config.py
@@ -2,6 +2,8 @@
class Config:
+ """Config class to store global variables."""
+
def __init__(self):
self.ARRAY_CHECKS_ENABLED = False
diff --git a/navix/entities.py b/navix/entities.py
index ac7bcb5..dab5e62 100644
--- a/navix/entities.py
+++ b/navix/entities.py
@@ -24,6 +24,8 @@
class Entities(struct.PyTreeNode):
+ """Entities enum class to store the names of the entities in the game."""
+
WALL: str = struct.field(pytree_node=False, default="wall")
FLOOR: str = struct.field(pytree_node=False, default="floor")
PLAYER: str = struct.field(pytree_node=False, default="player")
@@ -36,6 +38,8 @@ class Entities(struct.PyTreeNode):
class EntityIds:
+ """EntityIds enum class to store the ids of the entities in the game."""
+
UNKNOWN: Array = jnp.asarray(0, dtype=jnp.uint8)
FLOOR: Array = jnp.asarray(1, dtype=jnp.uint8)
WALL: Array = jnp.asarray(2, dtype=jnp.uint8)
@@ -49,6 +53,8 @@ class EntityIds:
class Directions:
+ """Directions enum class to store the directions in the game."""
+
EAST = jnp.asarray(0)
SOUTH = jnp.asarray(1)
WEST = jnp.asarray(2)
@@ -56,30 +62,40 @@ class Directions:
class Entity(Positionable, HasTag, HasSprite):
- """Entities are components that can be placed in the environment"""
+ """Entities are components that can be placed in the environment, and have a position and a tag.
+ To create an entity, use the `create` method."""
def __getitem__(self: T, idx) -> T:
return jax.tree_util.tree_map(lambda x: x[idx], self)
@property
def name(self) -> str:
+ """The name of the entity
+
+ Returns:
+ str: the name of the entity"""
return self.__class__.__name__
@property
def shape(self) -> Tuple[int, ...]:
- """The batch shape of the entity"""
+ """The batch shape of the entity. The batch shape is the shape of the entity excluding the dimensions of the component.
+ For example, if the entity has a position of shape (batch_size, 2), the shape of the entity is (batch_size,).
+ """
return self.position.shape[:-1]
@property
def ndim(self) -> int:
+ """The number of dimensions of the entity. The number of dimensions is the number of dimensions of the position minus 1."""
return self.position.ndim - 1
@property
def walkable(self) -> Array:
+ """The walkable attribute of the entity. The walkable attribute is a boolean array that indicates if the entity can be walked on."""
raise NotImplementedError()
@property
def transparent(self) -> Array:
+ """The transparent attribute of the entity. The transparent attribute is a boolean array that indicates if the entity is transparent to rendering."""
raise NotImplementedError()
diff --git a/navix/events.py b/navix/events.py
index 02927c7..9356b36 100644
--- a/navix/events.py
+++ b/navix/events.py
@@ -27,18 +27,47 @@
def on_goal_reached(state: State) -> Array:
+ """Checks whether the goal has been reached using the `goal_reached` event.
+
+ Args:
+ state (State): The current state of the game.
+
+ Returns:
+ Array: A boolean array indicating whether the goal has been reached."""
return state.events.goal_reached.happened
def on_lava_fall(state: State) -> Array:
+ """Checks whether the lava has fallen using the `lava_fall` event.
+
+ Args:
+ state (State): The current state of the game.
+
+ Returns:
+ Array: A boolean array indicating whether the lava has fallen."""
return state.events.lava_fall.happened
def on_ball_hit(state: State) -> Array:
+ """Checks whether the ball has hit something using the `ball_hit` event.
+
+ Args:
+ state (State): The current state of the game.
+
+ Returns:
+ Array: A boolean array indicating whether the ball has hit something."""
return state.events.ball_hit.happened
def on_door_done(state: State) -> Array:
+ """Checks whether the action `done` has been called in front of a `Door` object with the correct colour.
+
+ Args:
+ state (State): The current state of the game.
+
+ Returns:
+ Array: A boolean array indicating whether the action `done` has been called in front of a `Door` object with the correct colour.
+ """
assert (
state.mission is not None
), "Termination on door done requires the state to specify a mission."
@@ -57,4 +86,11 @@ def on_door_done(state: State) -> Array:
def on_wall_hit(state: State) -> Array:
+ """Checks whether the wall has been hit using the `wall_hit` event.
+
+ Args:
+ state (State): The current state of the game.
+
+ Returns:
+ Array: A boolean array indicating whether the wall has been hit."""
return state.events.wall_hit.happened
diff --git a/navix/experiment.py b/navix/experiment.py
index c6111e0..0e49fb6 100644
--- a/navix/experiment.py
+++ b/navix/experiment.py
@@ -12,6 +12,26 @@
class Experiment:
+ """A class to run an experiment with a given agent and environment.
+
+ Args:
+ name (str): The name of the experiment.
+ agent (Agent): The agent to use in the experiment.
+ env (Environment): The environment to use in the experiment.
+ env_id (str): The ID of the environment.
+ seeds (Tuple[int, ...]): The seeds to use in the experiment.
+ group (str): The group to use in the experiment.
+
+ Attributes:
+ name (str): The name of the experiment.
+ agent (Agent): The agent to use in the experiment.
+ env (Environment): The environment to use in the experiment.
+ env_id (str): The ID of the environment.
+ seeds (Tuple[int, ...]): The seeds to use in the experiment.
+ group (str): The group to use in the experiment.
+
+ """
+
def __init__(
self,
name: str,
@@ -29,6 +49,17 @@ def __init__(
self.group = group
def run(self, do_log: bool = True):
+ """Default function to run the experiment. This function compiles the training function, trains the agent, and logs the results.
+
+ Args:
+ do_log (bool): Whether to log the results to wandb.
+ !!! Warning
+ Logging to `wandb` is usually much slower than training the agent itself.
+ The time is linear in the number of seeds.
+
+ Returns:
+ Tuple: A tuple containing the final training state and the logs.
+ """
print("Running experiment with the following configuration:")
print(vars(self))
rng = jnp.asarray([jax.random.PRNGKey(seed) for seed in self.seeds])
@@ -74,6 +105,21 @@ def run(self, do_log: bool = True):
def run_hparam_search(
self, hparams_distr: Dict[str, distrax.Distribution], pop_size: int
):
+ """Function to run a hyperparameter search for the experiment. This function \
+ samples hyperparameters from the given distributions, trains the agent, and \
+ logs the results.
+
+ Args:
+ hparams_distr (Dict[str, distrax.Distribution]): A dictionary of \
+ hyperparameter distributions. The keys are the hyperparameter names, which \
+ must exist in `self.agent.hparams`, and the values are the corresponding \
+ distributions.
+ pop_size (int): The number of hyperparameter sets to sample.
+
+ Returns:
+ Tuple: A tuple containing the final training states and the logs, batched \
+ over the hyperparameter sets.
+ """
hparams_fields = fields(self.agent.hparams)
for k in hparams_distr:
member = list(filter(lambda x: x.name == k, hparams_fields))
diff --git a/navix/grid.py b/navix/grid.py
index 963d15f..2b4a48b 100644
--- a/navix/grid.py
+++ b/navix/grid.py
@@ -34,11 +34,30 @@
def coordinates(grid: Array) -> Coordinates:
+ """Returns a tuple of 2D coordinates [(col, row), ...] for each cell in the grid.
+ A grid array of shape `i32[height, width]` will return a tuple of length (height * width),
+ containing two arrays, each of shape `i32[2]`.
+
+ Args:
+ grid (Array): A 2D grid of shape (height, width).
+
+ Returns:
+ Tuple[Array, Array]: A tuple of two arrays containing the 2D coordinates of \
+ each cell in the grid.
+ """
return tuple(jnp.mgrid[0 : grid.shape[0], 0 : grid.shape[1]]) # type: ignore
-def idx_from_coordinates(grid: Array, coordinates: Array):
- """Converts a batch of 2D coordinates [(col, row), ...] into a flat index"""
+def idx_from_coordinates(grid: Array, coordinates: Array) -> Array:
+ """Converts a batch of 2D coordinates [(col, row), ...] into a flat index
+
+ Args:
+ grid (Array): A 2D grid of shape (height, width).
+ coordinates (Array): A batch of 2D coordinates of shape (batch_size, 2).
+
+ Returns:
+ Array: A flat index of shape `i32[batch_size]` for each coordinate in the batch.
+ """
coordinates = coordinates.T
assert coordinates.shape[0] == 2, coordinates.shape
@@ -46,8 +65,16 @@ def idx_from_coordinates(grid: Array, coordinates: Array):
return jnp.asarray(idx, dtype=jnp.int32)
-def coordinates_from_idx(grid: Array, idx: Array):
- """Converts a flat index into a 2D coordinate (col, row)"""
+def coordinates_from_idx(grid: Array, idx: Array) -> Array:
+ """Converts a flat index of shape `i32[]` into a 2D coordinate `i32[2]` containing \
+ (col, row) data. The index is calculated as `idx = row * width + col`.
+
+ Args:
+ grid (Array): A 2D grid of shape (height, width).
+ idx (Array): A flat index of shape `i32[]`.
+
+ Returns:
+ Array: A 2D coordinate of shape `i32[2]` containing the (col, row) data."""
coords = jnp.divmod(idx, grid.shape[1])
return jnp.asarray(coords, dtype=jnp.int32).T
@@ -62,6 +89,16 @@ def mask_by_coordinates(
Returns a mask of the same shape as `grid` where the value is 1 if the
corresponding element in `grid` satisfies the `comparison_fn` with the
corresponding element in `address` (col, row) and 0 otherwise.
+
+ Args:
+ grid (Array): A 2D grid of shape (height, width).
+ address (Coordinates): A tuple of 2D coordinates (col, row).
+ comparison_fn (Callable[[Array, Array], Array], optional): A comparison function. \
+ Defaults to `jnp.greater_equal`.
+
+ Returns:
+ Array: A boolean mask of the same shape as `grid`.
+
"""
mesh = jnp.mgrid[0 : grid.shape[0], 0 : grid.shape[1]]
cond_1 = comparison_fn(mesh[0], address[0])
@@ -73,6 +110,17 @@ def mask_by_coordinates(
def translate(
position: Array, direction: Array, modulus: Array = jnp.asarray(1)
) -> Array:
+ """Translates a point in a grid by a given direction and modulus.
+
+ Args:
+ position (Array): A 2D coordinate of shape `i32[2]` containing the (col, row) data.
+ direction (Array): A direction in the range [0, 1, 2, 3] representing the \
+ cardinal directions [east, south, west, north].
+ modulus (Array, optional): The modulus of the translation. Defaults to jnp.asarray(1).
+
+ Returns:
+ Array: A 2D coordinate of shape `i32[2]` containing the (col, row) data.
+ """
moves = (
lambda position: position + jnp.asarray((0, modulus)), # east
lambda position: position + jnp.asarray((modulus, 0)), # south
@@ -83,22 +131,73 @@ def translate(
def translate_forward(position: Array, forward_direction: Array, modulus: Array):
+ """Translates a point in a grid by a given forward direction and modulus.
+
+ Args:
+ position (Array): A 2D coordinate of shape `i32[2]` containing the (col, row) data.
+ forward_direction (Array): A direction in the range [0, 1, 2, 3] representing the \
+ cardinal directions [east, south, west, north].
+ modulus (Array): The modulus of the translation.
+
+ Returns:
+ Array: A 2D coordinate of shape `i32[2]` containing the (col, row) data."""
return translate(position, forward_direction, modulus)
def translate_left(position: Array, forward_direction: Array, modulus: Array):
+ """Translates a point in a grid by a given left direction and modulus.
+
+ Args:
+ position (Array): A 2D coordinate of shape `i32[2]` containing the (col, row) data.
+ forward_direction (Array): A direction in the range [0, 1, 2, 3] representing the \
+ cardinal directions [east, south, west, north].
+ modulus (Array): The modulus of the translation.
+
+ Returns:
+ Array: A 2D coordinate of shape `i32[2]` containing the (col, row) data."""
return translate(position, (forward_direction + 3) % 4, modulus)
def translate_right(position: Array, forward_direction: Array, modulus: Array):
+ """Translates a point in a grid by a given right direction and modulus.
+
+ Args:
+ position (Array): A 2D coordinate of shape `i32[2]` containing the (col, row) data.
+ forward_direction (Array): A direction in the range [0, 1, 2, 3] representing the \
+ cardinal directions [east, south, west, north].
+ modulus (Array): The modulus of the translation.
+
+ Returns:
+ Array: A 2D coordinate of shape `i32[2]` containing the (col, row) data."""
return translate(position, (forward_direction + 1) % 4, modulus)
def rotate(direction: Array, spin: int) -> Array:
+ """Changes a direction vectory by a given number of spins.
+
+ Args:
+ direction (Array): A direction vector of shape `i32[]` in the range [0, 3] \
+ representing the cardinal directions [east, south, west, north].
+ spin (int): The number of spins to apply.
+
+ Returns:
+ Array: A direction vector of shape `i32[]` in the range [0, 3] representing \
+ the cardinal directions [east, south, west, north]."""
return (direction + spin) % 4
def align(patch: Array, current_direction: Array, desired_direction: Array) -> Array:
+ """Aligns a patch of the grid from the current direction to the desired direction.
+
+ Args:
+ patch (Array): A patch of the grid.
+ current_direction (Array): The current direction in the range [0, 1, 2, 3] \
+ representing the cardinal directions [east, south, west, north].
+ desired_direction (Array): The desired direction in the range [0, 1, 2, 3] \
+ representing the cardinal directions [east, south, west, north].
+
+ Returns:
+ Array: A patch of the grid aligned to the desired direction."""
return jax.lax.switch(
desired_direction - current_direction,
(
@@ -114,6 +213,16 @@ def align(patch: Array, current_direction: Array, desired_direction: Array) -> A
def random_positions(
key: Array, grid: Array, n: int = 1, exclude: Array = jnp.asarray((-1, -1))
) -> Array:
+ """Generates `n` random positions in the grid, excluding the `exclude` position.
+
+ Args:
+ key (Array): A random key.
+ grid (Array): A 2D grid of shape (height, width).
+ n (int, optional): The number of random positions to generate. Defaults to 1.
+ exclude (Array, optional): The position to exclude. Defaults to jnp.asarray((-1, -1)).
+
+ Returns:
+ Array: A batch of random positions of shape `i32[n, 2]`."""
probs = grid.reshape(-1)
indices = idx_from_coordinates(grid, exclude)
probs = probs.at[indices].set(-1) + 1.0
@@ -123,14 +232,40 @@ def random_positions(
def random_directions(key: Array, n=1) -> Array:
+ """Generates `n` random directions in the range [0, 1, 2, 3] representing the \
+ cardinal directions [east, south, west, north].
+
+ Args:
+ key (Array): A random key.
+ n (int, optional): The number of random directions to generate. Defaults to 1.
+
+ Returns:
+ Array: A batch of random directions of shape `i32[n]`."""
return jax.random.randint(key, (n,), 0, 4).squeeze()
def random_colour(key: Array, n=1) -> Array:
+ """Generates `n` random colours in the range [0, 1, 2, 3, 4, 5].
+
+ Args:
+ key (Array): A random key.
+ n (int, optional): The number of random colours to generate. Defaults to 1.
+
+ Returns:
+ Array: A batch of random colours of shape `u8[n]`."""
return jax.random.randint(key, (n,), 0, 6).squeeze()
def positions_equal(a: Array, b: Array) -> Array:
+ """Checks if two points are equal.
+
+ Args:
+ a (Array): A 2D coordinate of shape `i32[2]` containing the (col, row) data.
+ b (Array): A 2D coordinate of shape `i32[2]` containing the (col, row) data.
+
+ Returns:
+
+ """
if b.ndim == 1:
b = b[None]
if a.ndim == 1:
@@ -141,14 +276,34 @@ def positions_equal(a: Array, b: Array) -> Array:
return is_equal
-def room(height: int, width: int):
- """A grid of ids of size `width` x `height`, including the sorrounding walls"""
+def room(height: int, width: int) -> Array:
+ """Creates an array representing a room of size `height` x `width`, including
+ a set of walls around the room. The room is represented as a 2D grid of shape
+ `(height, width)`, including walls, with walls set to -1 and empty tiles set to 0.
+
+ Args:
+ height (int): The height of the room.
+ width (int): The width of the room.
+
+ Returns:
+ Array: A 2D grid of shape `(height, width)` representing a room."""
grid = jnp.zeros((height - 2, width - 2), dtype=jnp.int32)
return jnp.pad(grid, 1, mode="constant", constant_values=-1)
def two_rooms(height: int, width: int, key: Array) -> Tuple[Array, Array]:
- """Two rooms separated by a vertical wall at `width // 2`"""
+ """Creates a 2D grid representing two rooms of size `height` x `width`, separated
+ by a wall. The rooms are represented as a 2D grid of shape `(height, width)`, \
+ including walls, with walls set to -1 and empty tiles set to 0.
+
+ Args:
+ height (int): The height of the rooms.
+ width (int): The width of the rooms.
+ key (Array): A random key, determining the position of the wall separating the rooms.
+
+ Returns:
+ Tuple[Array, Array]: A tuple containing the 2D grid representing the rooms \
+ and the column index of the wall separating the rooms."""
# create room
grid = jnp.zeros((height - 2, width - 2), dtype=jnp.int32)
grid = jnp.pad(grid, 1, mode="constant", constant_values=-1)
@@ -162,6 +317,17 @@ def two_rooms(height: int, width: int, key: Array) -> Tuple[Array, Array]:
def vertical_wall(
grid: Array, row_idx: int, opening_col_idx: Array | None = None
) -> Array:
+ """Creates a vertical wall in the grid at the given row index, with an opening at the \
+ given column index.
+
+ Args:
+ grid (Array): A 2D grid of shape `(height, width)`.
+ row_idx (int): The row index where the wall is placed.
+ opening_col_idx (Array, optional): The column index where the opening is placed. \
+ Defaults to None.
+
+ Returns:
+ Array: A 2D grid of shape `(height, width)` with a vertical wall."""
rows = jnp.arange(1, grid.shape[0] - 1)
cols = jnp.asarray([row_idx] * (grid.shape[0] - 2))
positions = jnp.stack((rows, cols), axis=1)
@@ -175,6 +341,17 @@ def vertical_wall(
def horizontal_wall(
grid: Array, col_idx: int, opening_row_idx: Array | None = None
) -> Array:
+ """Creates a horizontal wall in the grid at the given column index, with an opening at the \
+ given row index.
+
+ Args:
+ grid (Array): A 2D grid of shape `(height, width)`.
+ col_idx (int): The column index where the wall is placed.
+ opening_row_idx (Array, optional): The row index where the opening is placed. \
+ Defaults to None.
+
+ Returns:
+ Array: A 2D grid of shape `(height, width)` with a horizontal wall."""
rows = jnp.asarray([col_idx] * (grid.shape[1] - 2))
cols = jnp.arange(1, grid.shape[1] - 1)
positions = jnp.stack((rows, cols), axis=1)
@@ -188,6 +365,17 @@ def horizontal_wall(
def crop(
grid: Array, origin: Array, direction: Array, radius: int, padding_value: int = 0
) -> Array:
+ """Crops a grid around a given origin, facing a given direction, with a given radius.
+
+ Args:
+ grid (Array): A 2D grid of shape `(height, width)`.
+ origin (Array): The origin of the crop.
+ direction (Array): The direction the crop is facing.
+ radius (int): The radius of the crop.
+ padding_value (int, optional): The padding value. Defaults to 0.
+
+ Returns:
+ Array: A cropped grid."""
input_shape = grid.shape
# assert radius % 2, "Radius must be an odd number"
# mid = jnp.asarray([g // 2 for g in grid.shape[:2]])
@@ -228,6 +416,17 @@ def crop(
def view_cone(transparency_map: Array, origin: Array, radius: int) -> Array:
+ """Computes the view cone of a given origin in a grid with a given radius.
+ The view cone is a boolean map of transparent (1) and opaque (0) tiles, indicating
+ whether a tile is visible from the origin or not.
+
+ Args:
+ transparency_map (Array): A boolean map of transparent (1) and opaque (0) tiles.
+ origin (Array): The origin of the view cone.
+ radius (int): The radius of the view cone.
+
+ Returns:
+ Array: The view cone of the given origin in the grid with the given radius."""
# transparency_map is a boolean map of transparent (1) and opaque (0) tiles
def fin_diff(array, _):
@@ -251,6 +450,19 @@ def fin_diff(array, _):
def from_ascii_map(ascii_map: str, mapping: Dict[str, int] = {}) -> Array:
+ """Converts an ASCII map into a 2D grid. The ASCII map is a string where each character
+ represents a tile in the grid. The mapping dictionary can be used to map ASCII characters
+ to integer values. By default, the mapping is as follows:
+ - `#` is mapped to -1
+ - `.` is mapped to 0
+
+ Args:
+ ascii_map (str): The ASCII map.
+ mapping (Dict[str, int], optional): A dictionary mapping ASCII characters to integer \
+ values. Defaults to {}.
+
+ Returns:
+ Array: A 2D grid representing the ASCII map."""
mapping = {**{"#": -1, ".": 0}, **mapping}
ascii_map = ascii_map.strip()
@@ -266,6 +478,12 @@ def from_ascii_map(ascii_map: str, mapping: Dict[str, int] = {}) -> Array:
class RoomsGrid(struct.PyTreeNode):
+ """A grid of rooms. Each room is represented as a 2D grid of shape `(room_height, room_width)`,
+ with walls set to -1 and empty tiles set to 0. The grid of rooms is represented as a 2D grid of
+ shape `(rows * (room_height + 1), cols * (room_width + 1))`, with walls set to -1 and empty tiles
+ set to 0. The grid of rooms is represented as a 2D grid of shape `(rows * (room_height + 1), cols * (room_width + 1))`,
+ with walls set to -1 and empty tiles set to 0."""
+
room_starts: Array # shape (rows, cols)
room_size: Tuple[int, int]
@@ -273,6 +491,15 @@ class RoomsGrid(struct.PyTreeNode):
def create(
cls, num_rows: int, num_cols: int, room_size: Tuple[int, int]
) -> RoomsGrid:
+ """Creates a grid of rooms with the given number of rows and columns, and the given room size.
+
+ Args:
+ num_rows (int): The number of rows.
+ num_cols (int): The number of columns.
+ room_size (Tuple[int, int]): The size of each room `(height, width)`.
+
+ Returns:
+ RoomsGrid: A grid of rooms."""
# generate rooms grid
height = num_rows * (room_size[0] + 1)
width = num_cols * (room_size[1] + 1)
@@ -286,6 +513,15 @@ def create(
return cls(starts, room_size)
def get_grid(self, occupied_positions: Array | None = None) -> Array:
+ """Computes the array representation of the grid of rooms, with walls set to \
+ -1 and empty tiles set to 0.
+
+ Args:
+ occupied_positions (Array, optional): A batch of extra occupied positions \
+ of shape `(n, 2)`. Defaults to None.
+
+ Returns:
+ Array: A 2D grid of shape `(rows * (room_height + 1), cols * (room_width + 1))`."""
room_size = self.room_size
num_rows, num_cols = self.room_starts.shape[:2]
grid = jnp.zeros(
@@ -299,6 +535,15 @@ def get_grid(self, occupied_positions: Array | None = None) -> Array:
return grid
def position_in_room(self, row: Array, col: Array, *, key: Array) -> Array:
+ """Generates a random position in a given room.
+
+ Args:
+ row (Array): The row index of the room.
+ col (Array): The column index of the room.
+ key (Array): A random key.
+
+ Returns:
+ Array: A random position in the given room."""
k1, k2 = jax.random.split(key)
local_row = jax.random.randint(k1, (), minval=1, maxval=self.room_size[0])
local_col = jax.random.randint(k2, (), minval=1, maxval=self.room_size[1])
@@ -308,7 +553,17 @@ def position_in_room(self, row: Array, col: Array, *, key: Array) -> Array:
def position_on_border(
self, row: Array, col: Array, side: int, *, key: Array
) -> Array:
- """Side is 0: west, 1: east, 2: north, 3: south (like padding)"""
+ """Generates a random position on the border of a given room.
+ Side is 0: west, 1: east, 2: north, 3: south (like padding)
+
+ Args:
+ row (Array): The row index of the room.
+ col (Array): The column index of the room.
+ side (int): The side of the room.
+ key (Array): A random key.
+
+ Returns:
+ Array: A random position on the border of the given room."""
starts = self.room_starts[row, col]
room_size = self.room_size
if side == 0:
diff --git a/navix/observations.py b/navix/observations.py
index a83d202..a2577a1 100644
--- a/navix/observations.py
+++ b/navix/observations.py
@@ -35,10 +35,28 @@
def none(state: State) -> Array:
+ """An empty observation represented as an array of shape f32[0].
+ Useful for testing purposes.
+
+ Args:
+ state (State): The current state of the game.
+
+ Returns:
+ Array: A 0-shaped array `f32[0]`."""
return jnp.asarray(())
def categorical(state: State) -> Array:
+ """Fully observable grid with a categorical state representation.
+ Each entity is represented by its unique integer tag.
+
+ Args:
+ state (State): The current state of the game.
+
+ Returns:
+ Array: A grid of integers, where each integer represents an entity, \
+ represented as an array of shape `i32[H, W]`, where `H` and `W` are the height \
+ and width of the grid."""
# get idx of entity on the set of patches
indices = idx_from_coordinates(state.grid, state.get_positions())
# get tags corresponding to the entities
@@ -51,6 +69,15 @@ def categorical(state: State) -> Array:
def categorical_first_person(state: State) -> Array:
+ """Categorical state representation, but cropped to the agent's view, and aligned \
+ with the agent's direction, such that the agent always points upwards.
+
+ Args:
+ state (State): The current state of the game.
+
+ Returns:
+ Array: A grid of integers, where each integer represents an entity, \
+ represented as an array of shape `i32[2 * RADIUS + 1, 2 * RADIUS + 1]`."""
# get transparency map
transparency_map = jnp.where(state.grid == 0, 1, 0)
positions = state.get_positions()
@@ -72,9 +99,19 @@ def categorical_first_person(state: State) -> Array:
def symbolic(state: State) -> Array:
- """Fully observable grid with a symbolic state representation.
- The symbol is a triple of (OBJECT_TAG, COLOUR_IDX, OPEN/CLOSED/LOCKED), \
- where X and Y are the coordinates on the grid, and IDX is the id of the object."""
+ """Fully observable grid with a symbolic state representation as originally \
+ proposed in the MiniGrid environment.
+ The symbol is a triple of (OBJECT_TAG, COLOUR_IDX, OPEN/CLOSED/LOCKED). The
+ last layer might also contain the direction of the entity, for example, the
+ direction of the agent.
+
+ Args:
+ state (State): The current state of the game.
+
+ Returns:
+ Array: A grid of integers, where each integer represents an entity, \
+ represented as an array of shape `u8[H, W, 3]`, where `H` and `W` are the height \
+ and width of the grid."""
# initialise as all floors
H, W = state.grid.shape
obs = jnp.zeros((H, W, 3), dtype=jnp.uint8)
@@ -104,9 +141,16 @@ def symbolic(state: State) -> Array:
def symbolic_first_person(state: State) -> Array:
- """First person view with a symbolic state representation.
- The symbol is a triple of (OBJECT_TAG, COLOUR_IDX, OPEN/CLOSED/LOCKED), \
- where X and Y are the coordinates on the grid, and IDX is the id of the object."""
+ """First person view with a symbolic state representation, but cropped to the \
+ agent's view, and aligned with the agent's direction, such that the agent always \
+ points upwards. See `symbolic` for more details.
+
+ Args:
+ state (State): The current state of the game.
+
+ Returns:
+ Array: A grid of integers, where each integer represents an entity, \
+ represented as an array of shape `u8[2 * RADIUS + 1, 2 * RADIUS + 1, 3]`."""
# get transparency map
obs = symbolic(state)
@@ -131,6 +175,18 @@ def symbolic_first_person(state: State) -> Array:
def rgb(state: State) -> Array:
+ """Fully observable grid with an RGB state representation.
+ Each entity is represented by its unique RGB sprite. The RGB sprites are \
+ stored in a cache, and the entities are placed on the grid according to their \
+ positions.
+
+ Args:
+ state (State): The current state of the game.
+
+ Returns:
+ Array: An RGB image of the grid, represented as an array of shape \
+ `u8[H * S, W * S, 3]`, where `H` and `W` are the height and width of the grid,
+ and `S` is the size of the tile."""
# get idx of entity on the flat set of patches
indices = idx_from_coordinates(state.grid, state.get_positions())
# get tiles corresponding to the entities
@@ -149,6 +205,18 @@ def rgb(state: State) -> Array:
def rgb_first_person(state: State) -> Array:
+ """First person view with an RGB state representation.
+ The image is cropped to the agent's view, and aligned with the agent's direction, \
+ such that the agent always points upwards. See `rgb` for more details.
+ See `rgb` for more details.
+
+ Args:
+ state (State): The current state of the game.
+
+ Returns:
+ Array: An RGB image of the agent's view, represented as an array of shape \
+ `u8[(2 * RADIUS + 1) * S, (2 * RADIUS + 1) * S, 3]`, where
+ `S` is the size of the tile."""
# calculate final image size
# get agent's view
# image_size = (
diff --git a/navix/rewards.py b/navix/rewards.py
index 92cc47e..4bdbd64 100644
--- a/navix/rewards.py
+++ b/navix/rewards.py
@@ -31,6 +31,19 @@ def compose(
*reward_functions: Callable[[State, Array, State], Array],
operator: Callable = jnp.sum,
) -> Callable:
+ """Compose multiple reward functions into a single reward function.
+ The functions are called in order and the results are reduced using the `operator` \
+ function.
+
+ Args:
+ *reward_functions (Callable[[State, Array, State], Array]): A list of reward functions.
+ operator (Callable): The operator to reduce the results of the reward functions.
+ It must be a function that takes a list of arrays, or an array and returns an \
+ array of size `f32[]`.
+
+ Returns:
+ Callable: A composed reward function that applies the `operator` to the results of the \
+ reward functions."""
return lambda prev_state, action, state: operator(
jnp.asarray(
[f(prev_state, action, state) for f in reward_functions], dtype=jnp.float32
@@ -39,23 +52,62 @@ def compose(
def free(state: State) -> Array:
+ """A reward function that always returns 0, to simulate reward-free learning.
+
+ Args:
+ state (State): The current state of the game.
+
+ Returns:
+ Array: A scalar array `f32[]` with value 0."""
return jnp.asarray(0.0, dtype=jnp.float32)
def on_goal_reached(prev_state: State, action: Array, state: State) -> Array:
+ """A reward function that returns 1 when the goal is reached, and 0 otherwise.
+
+ Args:
+ state (State): The current state of the game.
+
+ Returns:
+ Array: A scalar array `f32[]` with value 1 if the goal is reached, and 0 otherwise.
+ """
return jnp.asarray(events.on_goal_reached(state), dtype=jnp.float32)
def action_cost(
prev_state: State, action: Array, new_state: State, cost: float = 0.01
) -> Array:
+ """A reward function that returns a negative value when an action is taken.
+ All actions have a cost of `cost`, except for noops.
+
+ Args:
+ prev_state (State): The previous state of the game.
+ action (Array): The action taken.
+ new_state (State): The new state of the game.
+ cost (float): The cost of taking an action.
+
+ Returns:
+ Array: A scalar array `f32[]` with value -`cost` if the action is not a noop, \
+ and 0 otherwise."""
# noops are free
- return -jnp.asarray(action > 0, dtype=jnp.float32) * cost
+ return -jnp.asarray(action != 6, dtype=jnp.float32) * cost
def time_cost(
prev_state: State, action: Array, new_state: State, cost: float = 0.01
) -> Array:
+ """A reward function that returns a negative value as time passes, paying a cost \
+ of `cost` at each time step.
+
+ Args:
+ prev_state (State): The previous state of the game.
+ action (Array): The action taken.
+ new_state (State): The new state of the game.
+ cost (float): The cost of time passing.
+
+ Returns:
+ Array: A scalar array `f32[]` with value -`cost`.
+ """
# time always has a cost
return -jnp.asarray(cost, dtype=jnp.float32)
@@ -63,13 +115,32 @@ def time_cost(
def wall_hit_cost(
prev_state: State, action: Array, state: State, cost: float = 0.01
) -> Array:
+ """A reward function that returns a negative value when the agent hits a wall, \
+ paying a cost of `cost` for each wall hit.
+
+ Args:
+ state (State): The current state of the game.
+ cost (float): The cost of hitting a wall.
+
+ Returns:
+ Array: A scalar array `f32[]` with value -`cost` if the agent hits a wall, \
+ and 0 otherwise."""
return jnp.asarray(events.on_wall_hit(state), dtype=jnp.float32) * cost
-def on_door_done(
- prev_state: State, action: Array, state: State, cost: float = 0.01
-) -> Array:
+def on_door_done(prev_state: State, action: Array, state: State) -> Array:
+ """A reward function that returns a positive value when the agent uses the action \
+ `done` in front of a door.
+
+ Args:
+ state (State): The current state of the game.
+
+ Returns:
+ Array: A scalar array `f32[]` with value 1 if the agent uses the action `done` in \
+ front of a door, and 0 otherwise."""
+
return jnp.asarray(events.on_door_done(state), dtype=jnp.float32)
DEFAULT_TASK = compose(on_goal_reached, action_cost)
+"""The default task for the game, composed of the `on_goal_reached` and `action_cost` reward functions."""
diff --git a/navix/spaces.py b/navix/spaces.py
index 1a6cb95..0158a5d 100644
--- a/navix/spaces.py
+++ b/navix/spaces.py
@@ -25,12 +25,30 @@
class Space(struct.PyTreeNode):
+ """Base class for all spaces in the game. Spaces define the shape and type of the \
+ observations, actions and rewards in the game.
+ The `sample` method is used to generate random samples from the space.
+
+ !!! note
+ To initialize a space, use the `create` method of the specific space class.
+
+ TODO:
+ * maximum and minimum should be static objects, not arrays.
+ But how do we handle the case when they are not scalars? Maybe numpy arrays?"""
+
shape: Shape = struct.field(pytree_node=False)
dtype: jnp.dtype = struct.field(pytree_node=False)
minimum: Array
maximum: Array
def sample(self, key: Array) -> Array:
+ """Generate a random sample from the space.
+
+ Args:
+ key (Array): A random key to generate the sample.
+
+ Returns:
+ Array: A random sample from the space."""
raise NotImplementedError()
@@ -39,6 +57,15 @@ class Discrete(Space):
def create(
cls, n_elements: int | jax.Array, shape: Shape = (), dtype=jnp.int32
) -> Discrete:
+ """Create a discrete space with a given number of elements.
+
+ Args:
+ n_elements (int | jax.Array): The number of elements in the space.
+ shape (Shape): The shape of the space.
+ dtype (jnp.dtype): The data type of the space.
+
+ Returns:
+ Discrete: A discrete space with the given number of elements."""
return Discrete(
shape=shape,
dtype=dtype,
@@ -47,12 +74,23 @@ def create(
)
def sample(self, key: Array) -> Array:
+ """Generate a random sample from the space.
+
+ Args:
+ key (Array): A random key to generate the sample.
+
+ Returns:
+ Array: A random sample from the space."""
item = jax.random.randint(key, self.shape, self.minimum, self.maximum)
# randint cannot draw jnp.uint, so we cast it later
return jnp.asarray(item, dtype=self.dtype)
-
+
@property
def n(self) -> Array:
+ """The number of elements in the space.
+
+ Returns:
+ Array: The number of elements in the space."""
return self.maximum + 1
@@ -61,9 +99,27 @@ class Continuous(Space):
def create(
cls, shape: Shape, minimum: Array, maximum: Array, dtype=jnp.float32
) -> Continuous:
+ """Create a continuous space with a given shape, minimum and maximum values.
+
+ Args:
+ shape (Shape): The shape of the space.
+ minimum (Array): The minimum value of the space.
+ maximum (Array): The maximum value of the space.
+ dtype (jnp.dtype): The data type of the space.
+
+ Returns:
+ Continuous: A continuous space with the given shape, minimum and maximum values.
+ """
return Continuous(shape=shape, dtype=dtype, minimum=minimum, maximum=maximum)
def sample(self, key: Array) -> Array:
+ """Generate a random sample from the space.
+
+ Args:
+ key (Array): A random key to generate the sample.
+
+ Returns:
+ Array: A random sample from the space."""
assert jnp.issubdtype(self.dtype, jnp.floating)
# see: https://github.com/google/jax/issues/14003
lower = jnp.nan_to_num(self.minimum)
diff --git a/navix/states.py b/navix/states.py
index c815626..aaddd99 100644
--- a/navix/states.py
+++ b/navix/states.py
@@ -34,6 +34,8 @@
class EventType:
+ """Enumeration of the different types of events that can happen in the environment."""
+
NONE: Array = jnp.asarray(-1, dtype=jnp.int32)
REACH: Array = jnp.asarray(0, dtype=jnp.int32)
HIT: Array = jnp.asarray(1, dtype=jnp.int32)
@@ -44,6 +46,23 @@ class EventType:
class Event(Positionable, HasColour):
+ """A struct representing an event that happened in the environment. It contains the
+ position of the event, the colour of the entity involved in the event, whether the event
+ happened, and the type of event that happened.
+
+ !!! note
+ Notice that we need the `happened` property, which flags if an event has
+ happened or not, because JAX does not support variable size arrays.
+ This means that we cannot add an event to the list in the middle of training.
+ Instead, we initialise all events, and mask them out as not happened.
+
+
+ Attributes:
+ position (Array): The (row, column) position of the event in the grid.
+ colour (Array): The colour of the entity involved in the event.
+ happened (Array): A boolean flag indicating whether the event happened.
+ event_type (Array): The type of event that happened."""
+
position: Array = jnp.asarray([-1, -1], dtype=jnp.int32)
colour: Array = PALETTE.UNSET
happened: Array = jnp.asarray(False, dtype=jnp.bool_)
@@ -60,6 +79,20 @@ def __ne__(self, other: Event) -> Array:
class EventsManager(struct.PyTreeNode):
+ """A struct that manages the events. It contains the different events that can happen
+ in the environment, such as the goal being reached, the player being hit by a ball, etc.
+
+ Attributes:
+ goal_reached (Event): An event indicating that the goal has been reached.
+ ball_hit (Event): An event indicating that the player has been hit by a ball.
+ wall_hit (Event): An event indicating that the player has hit a wall.
+ lava_fall (Event): An event indicating that the lava has fallen.
+ key_pickup (Event): An event indicating that the player has picked up a key.
+ door_opening (Event): An event indicating that the player has opened a door.
+ door_unlock (Event): An event indicating that the player has unlocked a door.
+ ball_pickup (Event): An event indicating that the player has picked up a ball.
+ """
+
goal_reached: Event = Event()
ball_hit: Event = Event()
wall_hit: Event = Event()
@@ -70,6 +103,15 @@ class EventsManager(struct.PyTreeNode):
ball_pickup: Event = Event()
def record_walk_into(self, entity: Entity, position: Array) -> EventsManager:
+ """Flags an event when the player walks into an entity as happened and returns the
+ updated events manager.
+
+ Args:
+ entity (Entity): The entity the player walked into.
+ position (Array): The position of the entity in the grid.
+
+ Returns:
+ EventsManager: The updated events manager."""
if isinstance(entity, Goal):
return self.record_goal_reached(entity, position)
elif isinstance(entity, Wall):
@@ -79,6 +121,15 @@ def record_walk_into(self, entity: Entity, position: Array) -> EventsManager:
return self
def record_pickup(self, entity: Entity, position: Array) -> EventsManager:
+ """Flags an event when the player picks up an entity as happened and returns the
+ updated events manager.
+
+ Args:
+ entity (Entity): The entity the player picked up.
+ position (Array): The position of the entity in the grid.
+
+ Returns:
+ EventsManager: The updated events manager."""
if isinstance(entity, Key):
return self.record_key_pickup(entity, position)
elif isinstance(entity, Ball):
@@ -86,6 +137,15 @@ def record_pickup(self, entity: Entity, position: Array) -> EventsManager:
return self
def record_goal_reached(self, goal: Goal, position: Array) -> EventsManager:
+ """Flags an event when the player reaches the goal as happened and returns the
+ updated events manager.
+
+ Args:
+ goal (Goal): The goal the player reached.
+ position (Array): The position of the goal in the grid.
+
+ Returns:
+ EventsManager: The updated events manager."""
idx = jnp.where(goal.position == position, size=1)[0][0]
goal = goal[idx]
return self.replace(
@@ -98,6 +158,14 @@ def record_goal_reached(self, goal: Goal, position: Array) -> EventsManager:
)
def record_ball_hit(self, ball: Ball) -> EventsManager:
+ """Flags an event when the player is hit by a ball as happened and returns the
+ updated events manager.
+
+ Args:
+ ball (Ball): The ball that hit the player.
+
+ Returns:
+ EventsManager: The updated events manager."""
return self.replace(
ball_hit=Event(
position=ball.position,
@@ -108,6 +176,15 @@ def record_ball_hit(self, ball: Ball) -> EventsManager:
)
def record_wall_hit(self, wall: Wall, position: Array) -> EventsManager:
+ """Flags an event when the player hits a wall as happened and returns the
+ updated events manager.
+
+ Args:
+ wall (Wall): The wall the player hit.
+ position (Array): The position of the wall in the grid.
+
+ Returns:
+ EventsManager: The updated events manager."""
idx = jnp.where(wall.position == position, size=1)[0][0]
wall = wall[idx]
return self.replace(
@@ -120,6 +197,14 @@ def record_wall_hit(self, wall: Wall, position: Array) -> EventsManager:
)
def record_grid_hit(self, position: Array) -> EventsManager:
+ """Flags an event when the player hits a wall as happened and returns the
+ updated events manager.
+
+ Args:
+ position (Array): The position of the wall in the grid.
+
+ Returns:
+ EventsManager: The updated events manager."""
return self.replace(
wall_hit=Event(
position=position,
@@ -130,6 +215,15 @@ def record_grid_hit(self, position: Array) -> EventsManager:
)
def record_lava_fall(self, lava: Lava, position: Array) -> EventsManager:
+ """Flags an event when the lava falls as happened and returns the
+ updated events manager.
+
+ Args:
+ lava (Lava): The lava that fell.
+ position (Array): The position of the lava in the grid.
+
+ Returns:
+ EventsManager: The updated events manager."""
idx = jnp.where(lava.position == position, size=1)[0][0]
lava = lava[idx]
return self.replace(
@@ -142,6 +236,15 @@ def record_lava_fall(self, lava: Lava, position: Array) -> EventsManager:
)
def record_key_pickup(self, key: Key, position: Array) -> EventsManager:
+ """Flags an event when the player picks up a key as happened and returns the
+ updated events manager.
+
+ Args:
+ key (Key): The key the player picked up.
+ position (Array): The position of the key in the grid.
+
+ Returns:
+ EventsManager: The updated events manager."""
idx = jnp.where(key.position == position, size=1)[0][0]
key = key[idx]
return self.replace(
@@ -154,6 +257,15 @@ def record_key_pickup(self, key: Key, position: Array) -> EventsManager:
)
def record_door_opening(self, door: Door, position: Array) -> EventsManager:
+ """Flags an event when the player opens a door as happened and returns the
+ updated events manager.
+
+ Args:
+ door (Door): The door the player opened.
+ position (Array): The position of the door in the grid.
+
+ Returns:
+ EventsManager: The updated events manager."""
idx = jnp.where(door.position == position, size=1)[0][0]
door = door[idx]
return self.replace(
@@ -166,6 +278,15 @@ def record_door_opening(self, door: Door, position: Array) -> EventsManager:
)
def record_door_unlock(self, door: Door, position: Array) -> EventsManager:
+ """Flags an event when the player unlocks a door as happened and returns the
+ updated events manager.
+
+ Args:
+ door (Door): The door the player unlocked.
+ position (Array): The position of the door in the grid.
+
+ Returns:
+ EventsManager: The updated events manager."""
idx = jnp.where(door.position == position, size=1)[0][0]
door = door[idx]
return self.replace(
@@ -178,6 +299,15 @@ def record_door_unlock(self, door: Door, position: Array) -> EventsManager:
)
def record_ball_pickup(self, ball: Ball, position: Array) -> EventsManager:
+ """Flags an event when the player picks up a ball as happened and returns the
+ updated events manager.
+
+ Args:
+ ball (Ball): The ball the player picked up.
+ position (Array): The position of the ball in the grid.
+
+ Returns:
+ EventsManager: The updated events manager."""
idx = jnp.where(ball.position == position, size=1)[0][0]
ball = ball[idx]
return self.replace(
@@ -208,76 +338,112 @@ class State(struct.PyTreeNode):
mission: Event | None = None
def get_entity(self, entity_enum: str) -> Entity:
+ """Get an entity from the state by its enum.
+
+ Args:
+ entity_enum (str): The enum of the entity to get.
+
+ Returns:
+ Entity: The entity from the state."""
return self.entities[entity_enum]
def set_entity(self, entity_enum: str, entity: Entity) -> State:
+ """Set an entity in the state by its enum.
+
+ Args:
+ entity_enum (str): The enum of the entity to set.
+ entity (Entity): The entity to set.
+
+ Returns:
+ State: The updated state."""
self.entities[entity_enum] = entity
return self
def get_walls(self) -> Wall:
+ """Gets all the `WALL` entities from the state."""
return self.entities.get(Entities.WALL, Wall()) # type: ignore
def set_walls(self, walls: Wall) -> State:
+ """Sets the `WALL` entities in the state."""
self.entities[Entities.WALL] = walls
return self
def get_player(self, idx: int = 0) -> Player:
+ """Gets the player entity from the state."""
return self.entities[Entities.PLAYER][idx] # type: ignore
def set_player(self, player: Player, idx: int = 0) -> State:
+ """Sets the player entity in the state. Notice that we only support one player in the
+ environment for now, but this can easily be extended to multiple players."""
# TODO(epignatelli): this is a hack and won't work in multi-agent settings
self.entities[Entities.PLAYER] = player[None]
return self
def get_goals(self) -> Goal:
+ """Gets the goal entity from the state."""
return self.entities[Entities.GOAL] # type: ignore
def set_goals(self, goals: Goal) -> State:
+ """Sets the goal entity in the state."""
self.entities[Entities.GOAL] = goals
return self
def get_keys(self) -> Key:
+ """Gets the key entity from the state."""
return self.entities[Entities.KEY] # type: ignore
def set_keys(self, keys: Key) -> State:
+ """Sets the key entity in the state."""
self.entities[Entities.KEY] = keys
return self
def get_doors(self) -> Door:
+ """Gets the door entity from the state."""
return self.entities[Entities.DOOR] # type: ignore
def set_doors(self, doors: Door) -> State:
+ """Sets the door entity in the state."""
self.entities[Entities.DOOR] = doors
return self
def get_lavas(self) -> Lava:
+ """Gets the lava entity from the state."""
return self.entities[Entities.LAVA] # type: ignore
def get_balls(self) -> Ball:
+ """Gets the ball entity from the state."""
return self.entities[Entities.BALL] # type: ignore
def get_boxes(self) -> Ball:
+ """Gets the box entity from the state."""
return self.entities[Entities.BOX] # type: ignore
def set_balls(self, balls: Ball) -> State:
+ """Sets the ball entity in the state."""
self.entities[Entities.BALL] = balls
return self
def set_boxes(self, boxes: Box) -> State:
+ """Sets the box entity in the state."""
self.entities[Entities.BOX] = boxes
return self
def set_events(self, events: EventsManager) -> State:
+ """Sets the events in the state."""
return self.replace(events=events)
def get_positions(self) -> Array:
+ """Get the positions of all the entities in the state."""
return jnp.concatenate([self.entities[k].position for k in self.entities])
def get_tags(self) -> Array:
+ """Get the tags of all the entities in the state."""
return jnp.concatenate([self.entities[k].tag for k in self.entities])
def get_sprites(self) -> Array:
+ """Get the sprites of all the entities in the state."""
return jnp.concatenate([self.entities[k].sprite for k in self.entities])
def get_transparency(self) -> Array:
+ """Get the transparency of all the entities in the state."""
return jnp.concatenate([self.entities[k].transparent for k in self.entities])
diff --git a/navix/terminations.py b/navix/terminations.py
index 85d6796..6345a0a 100644
--- a/navix/terminations.py
+++ b/navix/terminations.py
@@ -24,37 +24,91 @@
from . import events
from .states import State
-from .grid import translate
-from .entities import Entities, Player
def compose(
*term_functions: Callable[[State, Array, State], Array],
operator: Callable = jnp.any,
) -> Callable:
+ """Compose termination functions into a single termination function.
+
+ Args:
+ *term_functions (Callable): List of termination functions.
+ operator (Callable): Operator to combine the termination functions.
+
+ Returns:
+ Callable: A single termination function."""
return lambda prev_state, action, state: operator(
jnp.asarray([term_f(prev_state, action, state) for term_f in term_functions])
)
def check_truncation(terminated: Array, truncated: Array) -> Array:
+ """Check if the episode is truncated or terminated, and returns a value
+ that conforms to the `StepType` enum.
+
+ Args:
+ terminated (Array): A boolean array indicating whether the episode is terminated.
+ truncated (Array): A boolean array indicating whether the episode is truncated.
+
+ Returns:
+ Array: An integer array that represents the step type."""
result = jnp.asarray(truncated + 2 * terminated, dtype=jnp.int32)
return jnp.clip(result, 0, 2)
def on_goal_reached(prev_state: State, action: Array, state: State) -> Array:
+ """Check if the goal has been reached using the `goal_reached` event.
+
+ Args:
+ prev_state (State): The previous state of the game.
+ action (Array): The action taken by the player.
+ state (State): The current state of the game.
+
+ Returns:
+ Array: A boolean array indicating whether the goal has been reached."""
return jnp.asarray(events.on_goal_reached(state), dtype=jnp.bool_)
def on_lava_fall(prev_state: State, action: Array, state: State) -> Array:
+ """Check if the lava has fallen using the `lava_fall` event.
+
+ Args:
+ prev_state (State): The previous state of the game.
+ action (Array): The action taken by the player.
+ state (State): The current state of the game.
+
+ Returns:
+ Array: A boolean array indicating whether the lava has fallen."""
return jnp.asarray(events.on_lava_fall(state), dtype=jnp.bool_)
def on_ball_hit(prev_state: State, action: Array, state: State) -> Array:
+ """Check if the ball has hit something using the `ball_hit` event.
+
+ Args:
+ prev_state (State): The previous state of the game.
+ action (Array): The action taken by the player.
+ state (State): The current state of the game.
+
+ Returns:
+ Array: A boolean array indicating whether the ball has hit something."""
return jnp.asarray(events.on_ball_hit(state), dtype=jnp.bool_)
def on_door_done(prev_state: State, action: Array, state: State) -> Array:
+ """Check if the action `done` has been called in front of a `Door` object with the \
+ correct colour.
+
+ Args:
+ prev_state (State): The previous state of the game.
+ action (Array): The action taken by the player.
+ state (State): The current state of the game.
+
+ Returns:
+ Array: A boolean array indicating whether the action `done` has been called in \
+ front of a `Door` object with the correct colour.
+ """
return jnp.asarray(events.on_door_done(state), dtype=jnp.bool_)
diff --git a/navix/transitions.py b/navix/transitions.py
index 1a99032..b17a2ba 100644
--- a/navix/transitions.py
+++ b/navix/transitions.py
@@ -24,7 +24,6 @@
from jax import Array
import jax
import jax.numpy as jnp
-import jax.tree_util as jtu
from .entities import Entities, Ball
from .states import EventsManager, State
from .grid import positions_equal, translate
@@ -33,12 +32,33 @@
def deterministic_transition(
state: State, action: Array, actions_set: Tuple[Callable[[State], State], ...]
) -> State:
+ """Deterministic transition function. It selects the action from the set of actions
+ and applies it to the state.
+
+ Args:
+ state (State): The current state of the game.
+ action (Array): The action to be taken.
+ actions_set (Tuple[Callable[[State], State]): A set of actions that can be taken.
+
+ Returns:
+ State: The new state of the game."""
return jax.lax.switch(action, actions_set, state)
def stochastic_transition(
state: State, action: Array, actions_set: Tuple[Callable[[State], State], ...]
) -> State:
+ """Stochastic transition function. It selects the action from the set of actions
+ and applies it to the state, and updates entities that have stochastic transitions,
+ such as balls.
+
+ Args:
+ state (State): The current state of the game.
+ action (Array): The action to be taken.
+ actions_set (Tuple[Callable[[State], State]): A set of actions that can be taken.
+
+ Returns:
+ State: The new state of the game."""
# actions
state = jax.lax.switch(action, actions_set, state)
@@ -47,6 +67,14 @@ def stochastic_transition(
def update_balls(state: State) -> State:
+ """Update the position of the balls in the game.
+ Balls move in a random direction if they can, otherwise they stay in place.
+
+ Args:
+ state (State): The current state of the game.
+
+ Returns:
+ State: The new state of the game."""
def update_one(ball: Ball, key: Array) -> Tuple[Array, EventsManager]:
direction = jax.random.randint(key, (), minval=0, maxval=4)
new_position = translate(ball.position, direction)