diff --git a/.github/workflows/CD.yml b/.github/workflows/CD.yml index aa2c6e3..25ff538 100644 --- a/.github/workflows/CD.yml +++ b/.github/workflows/CD.yml @@ -28,12 +28,7 @@ jobs: uses: TriPSs/conventional-changelog-action@v3 with: github-token: ${{ secrets.GITHUB_TOKEN }} - input-file: CHANGELOG.md - output-file: CHANGELOG.md fallback-version: ${{ env.NAVIX_VERSION }} - skip-commit: false - skip-tag: true - - name: Create Release uses: ncipollo/release-action@v1 @@ -70,6 +65,7 @@ jobs: - name: Setup navix run: | pip install . -v + pip install -r docs/requirements.txt - name: Build docs run: | mkdocs build diff --git a/CHANGELOG.md b/CHANGELOG.md index e69de29..4f436d9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -0,0 +1,6 @@ +# Changelog +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## Unreleased diff --git a/docs/api/index.md b/docs/api/index.md deleted file mode 100644 index bd4026d..0000000 --- a/docs/api/index.md +++ /dev/null @@ -1 +0,0 @@ -**Coming soon** \ No newline at end of file diff --git a/docs/assets/macros/macros.py b/docs/assets/macros/macros.py deleted file mode 100644 index ea0546b..0000000 --- a/docs/assets/macros/macros.py +++ /dev/null @@ -1,3 +0,0 @@ -from plumkdocs import define_env - -__all__ = ['define_env'] \ No newline at end of file diff --git a/docs/changelog.md b/docs/changelog.md deleted file mode 100644 index 90cb31c..0000000 --- a/docs/changelog.md +++ /dev/null @@ -1 +0,0 @@ ---8<-- "CHANGELOG.md" \ No newline at end of file diff --git a/docs/index.md b/docs/index.md index dc4df11..afeb140 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,5 +1,5 @@ -

A fast, fully jittable MiniGrid reimplemented in JAX

-

Welcome to NAVIX!

+

A fast, fully jittable MiniGrid reimplemented in JAX

+

Welcome to NAVIX!

**NAVIX** is a reimplementation of the [MiniGrid](https://minigrid.farama.org/) environment suite in JAX, and leverages JAX’s intermediate language representation to migrate the computation to different accelerators, such as GPUs and TPUs. diff --git a/docs/requirements.txt b/docs/requirements.txt index 5a39183..bc046e8 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -2,5 +2,8 @@ mkdocs mkdocs-material mkdocs-jupyter mkdocstrings -mkdocs-mermaid2-plugin +mkdocstrings-python +mkdocs-gen-files +mkdocs-literate-nav +mkdocs-section-index plumkdocs \ No newline at end of file diff --git a/docs/scripts/gen_doc_stubs.py b/docs/scripts/gen_doc_stubs.py new file mode 100644 index 0000000..4ad7f58 --- /dev/null +++ b/docs/scripts/gen_doc_stubs.py @@ -0,0 +1,47 @@ +"""Generate the code reference pages and navigation.""" + +from pathlib import Path + +import mkdocs_gen_files + +nav = mkdocs_gen_files.nav.Nav() + +root = Path(__file__).parent.parent.parent +src = root / "navix" +out = "api" + +exclude_files = [ + "_version.py", + "config.py" +] + +for path in sorted(src.rglob("*.py")): + if path.name in exclude_files: + continue + + print("Generating stub for", path) + module_path = path.relative_to(src).with_suffix("") + doc_path = path.relative_to(src).with_suffix(".md") + full_doc_path = Path(out, doc_path) + + parts = tuple(module_path.parts) + parts = ("navix",) + parts + + if parts[-1] == "__init__": + parts = parts[:-1] + doc_path = doc_path.with_name("index.md") + full_doc_path = full_doc_path.with_name("index.md") + elif parts[-1] == "__main__": + continue + + if parts: + nav[parts] = doc_path.as_posix() + + with mkdocs_gen_files.open(full_doc_path, "w") as fd: + ident = ".".join(parts) + fd.write(f"::: {ident}") + + mkdocs_gen_files.set_edit_path(full_doc_path, path.relative_to(root)) + +with mkdocs_gen_files.open(f"{out}/index.md", "w") as nav_file: + nav_file.writelines(nav.build_literate_nav()) diff --git a/mkdocs.yml b/mkdocs.yml index dbf846e..cc6dc78 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -21,9 +21,8 @@ nav: - "Customizing envs": examples/customisation.ipynb - Install: install/index.md - Becnhmarks: benchmarks/index.md - - API: - - api/index.md - - Changelog: changelog.md + - API: api/ + - Changelog: https://github.com/epignatelli/navix/releases # Customization extra: @@ -49,7 +48,9 @@ extra_javascript: theme: name: "material" logo: assets/images/navix_logo.png - font: "Sherpa" + font: + text: Roboto + code: Roboto Mono features: - announce.dismiss @@ -74,34 +75,43 @@ theme: palette: - scheme: default - primary: yellow + primary: pink accent: red toggle: icon: material/weather-night name: Switch to dark mode - scheme: slate - primary: yellow + primary: light green accent: red toggle: icon: material/weather-sunny name: Switch to light mode - font: - text: Roboto - code: Roboto Mono - plugins: - mkdocs-jupyter - mkdocstrings: default_handler: python handlers: python: - rendering: + options: + docstring_style: google + show_bases: true show_source: false - # custom_templates: templates + heading_level: 3 + show_root_full_path: true + show_symbol_type_heading: true + show_symbol_type_toc: true + show_signature: true + show_signature_annotations: false + signature_crossrefs: false - search - - mermaid2 + - gen-files: + scripts: + - docs/scripts/gen_doc_stubs.py # or any other name or path + - literate-nav: + nav_file: SUMMARY.md + - section-index markdown_extensions: - toc: @@ -115,9 +125,4 @@ markdown_extensions: - pymdownx.details # For collapsible admonitions - pymdownx.superfences - # - changelog/index.md - # - customization.md - # - insiders/changelog/* - # - setup/extensions/*- - copyright: Copyright © 2023 - 2024 NAVIX Authors diff --git a/navix/_version.py b/navix/_version.py index 676f51d..9fef7b9 100644 --- a/navix/_version.py +++ b/navix/_version.py @@ -18,5 +18,5 @@ # under the License. -__version__ = "0.6.9" +__version__ = "0.6.10" __version_info__ = tuple(int(i) for i in __version__.split(".") if i.isdigit()) diff --git a/navix/actions.py b/navix/actions.py index c3d59e4..f7b6163 100644 --- a/navix/actions.py +++ b/navix/actions.py @@ -16,6 +16,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""The *action* system determines the next state of the environment \ +given the current state and an action.""" + + from __future__ import annotations from typing import Tuple @@ -85,38 +89,92 @@ def _move(state: State, direction: Array) -> State: def noop(state: State) -> State: + """No operation. Does nothing. + + Args: + state (State): The current state. + + Returns: + State: The same state.""" return state def rotate_cw(state: State) -> State: + """Rotates the player clockwise. + + Args: + state (State): The current state. + + Returns: + State: The new state with the player rotated clockwise.""" return _rotate(state, 1) def rotate_ccw(state: State) -> State: + """Rotates the player counter-clockwise. + + Args: + state (State): The current state. + + Returns: + State: The new state with the player rotated counter-clockwise.""" return _rotate(state, -1) def forward(state: State) -> State: + """Moves the player forward. + + Args: + state: The current state. + + Returns: + State: The new state with the player moved forward.""" player = state.get_player(idx=0) return _move(state, player.direction) def right(state: State) -> State: + """Steps the player to the right without changing the direction. + + Args: + state (State): The current state. + + Returns: + State: The new state with the player moved to the right.""" player = state.get_player(idx=0) return _move(state, player.direction + 1) def backward(state: State) -> State: + """Steps the player backward without changing the direction. + + Args: + state (State): The current state. + + Returns: + State: The new state with the player moved backward.""" player = state.get_player(idx=0) return _move(state, player.direction + 2) def left(state: State) -> State: + """Steps the player to the left without changing the direction. + + Args: + state (State): The current state. + + Returns: + State: The new state with the player moved to the left.""" player = state.get_player(idx=0) return _move(state, player.direction + 3) def pickup(state: State) -> State: + """Picks up an item in front of the player and puts it in the pocket. + Args: + state (State): The current state. + Returns: + State: The new state with the player entity having the item in the pocket.""" if Entities.KEY not in state.entities: return state @@ -151,7 +209,13 @@ def pickup(state: State) -> State: def drop(state: State) -> State: - """Replaces the position in front of the player with the item in the pocket.""" + """Replaces the position in front of the player with the item in the pocket. + + Args: + state (State): The current state. + + Returns: + State: The new state with the item in the pocket dropped in front of the player.""" player = state.get_player(idx=0) position_in_front = translate(player.position, player.direction) @@ -171,12 +235,24 @@ def drop(state: State) -> State: def toggle(state: State) -> State: + """Toggles an openable object (like a door) if possible. + + Args: + state (State): The current state. + + Returns: + State: The new state with the openable object toggled.""" return open(state) def open(state: State) -> State: - """Unlocks and opens an openable object (like a door) if possible""" - + """Unlocks and opens an openable object (like a door) if possible. + + Args: + state (State): The current state. + + Returns: + State: The new state with the openable object opened.""" if Entities.DOOR not in state.entities: return state @@ -221,6 +297,14 @@ def open(state: State) -> State: def done(state: State) -> State: + """A placeholder action that does nothing, but is a signal to the environment that the episode is over. + This action does not terminate the episode, unless the termination function explicitly checks for it (not default). + + Args: + state (State): The current state. + + Returns: + State: The same state.""" return state @@ -249,6 +333,8 @@ def done(state: State) -> State: open, done, ) +"""Complete action set for the environment. +This set includes all the actions that can be taken by the agent, and does not mirror the Minigrid action set.""" MINIGRID_ACTION_SET = ( rotate_ccw, @@ -259,5 +345,7 @@ def done(state: State) -> State: toggle, done, ) +"""Default action set from Minigrid. See +https://github.com/Farama-Foundation/Minigrid/blob/master/minigrid/core/actions.py""" DEFAULT_ACTION_SET = MINIGRID_ACTION_SET diff --git a/navix/components.py b/navix/components.py index c18a3af..430d939 100644 --- a/navix/components.py +++ b/navix/components.py @@ -40,31 +40,49 @@ def field(shape: Tuple[int, ...], **kwargs): class Component(struct.PyTreeNode): + """Base class for all components in the game. + Components are used to store the data of the entities in the game.""" + def check_ndim(self, batched: bool = False) -> None: return class Positionable(Component): + """Flags an entity as positionable in the grid, and provides the `position` attribute""" + position: Array = field(shape=(2,)) - """The (row, column) position of the entity in the grid, defaults to the discard pile (-1, -1)""" + """The (row, column) position of the entity in the grid as a JAX array, defaults to the discard pile (-1, -1)""" class Directional(Component): + """Flags an entity as directional, and provides the `direction` attribute""" + direction: Array = field(shape=()) """The direction the entity: 0 = east, 1 = south, 2 = west, 3 = north""" class HasColour(Component): + """Flags an entity as having a colour, and provides the `colour` attribute""" + colour: Array = field(shape=()) """The colour of the object for rendering. """ class Stochastic(Component): + """Flags an entity as stochastic, and provides the `probability` attribute + + TODO: + * consider replace probability (Array) with a distrax.Distribution + + """ + probability: Array = field(shape=()) """The probability of receiving the reward, if reached.""" class Openable(Component): + """Flags an entity as openable, and provides the `requires` and `open` attributes""" + requires: Array = field(shape=()) """The id of the item required to consume this item. If set, it must be > 0. If -1, the door is unlocked and does not require any key to open.""" @@ -73,16 +91,22 @@ class Openable(Component): class Pickable(Component): + """Flags an entity as pickable, and provides the `id` attribute, which is used to identify the item in the inventory""" + id: Array = field(shape=()) """The id of the item. If set, it must be >= 1.""" class Holder(Component): + """Flags an entity as a holder, and provides the `pocket` attribute. The pocket is used to store the id of the item in the pocket.""" + pocket: Array = field(shape=()) """The id of the item in the pocket (0 if empty)""" class HasTag(Component): + """Flags an entity as having a tag, and provides the `tag` attribute. The tag is used to identify the type of the entity in the observations.""" + @property def tag(self) -> Array: """The tag of the component, used to identify the type of the component in `observations.categorical`""" @@ -90,6 +114,8 @@ def tag(self) -> Array: class HasSprite(Component): + """Flags an entity as having a sprite, and provides the `sprite` attribute. The sprite is used to render the entity in the game.""" + @property def sprite(self) -> Array: raise NotImplementedError() diff --git a/navix/config.py b/navix/config.py index a094660..7539fba 100644 --- a/navix/config.py +++ b/navix/config.py @@ -2,6 +2,8 @@ class Config: + """Config class to store global variables.""" + def __init__(self): self.ARRAY_CHECKS_ENABLED = False diff --git a/navix/entities.py b/navix/entities.py index ac7bcb5..dab5e62 100644 --- a/navix/entities.py +++ b/navix/entities.py @@ -24,6 +24,8 @@ class Entities(struct.PyTreeNode): + """Entities enum class to store the names of the entities in the game.""" + WALL: str = struct.field(pytree_node=False, default="wall") FLOOR: str = struct.field(pytree_node=False, default="floor") PLAYER: str = struct.field(pytree_node=False, default="player") @@ -36,6 +38,8 @@ class Entities(struct.PyTreeNode): class EntityIds: + """EntityIds enum class to store the ids of the entities in the game.""" + UNKNOWN: Array = jnp.asarray(0, dtype=jnp.uint8) FLOOR: Array = jnp.asarray(1, dtype=jnp.uint8) WALL: Array = jnp.asarray(2, dtype=jnp.uint8) @@ -49,6 +53,8 @@ class EntityIds: class Directions: + """Directions enum class to store the directions in the game.""" + EAST = jnp.asarray(0) SOUTH = jnp.asarray(1) WEST = jnp.asarray(2) @@ -56,30 +62,40 @@ class Directions: class Entity(Positionable, HasTag, HasSprite): - """Entities are components that can be placed in the environment""" + """Entities are components that can be placed in the environment, and have a position and a tag. + To create an entity, use the `create` method.""" def __getitem__(self: T, idx) -> T: return jax.tree_util.tree_map(lambda x: x[idx], self) @property def name(self) -> str: + """The name of the entity + + Returns: + str: the name of the entity""" return self.__class__.__name__ @property def shape(self) -> Tuple[int, ...]: - """The batch shape of the entity""" + """The batch shape of the entity. The batch shape is the shape of the entity excluding the dimensions of the component. + For example, if the entity has a position of shape (batch_size, 2), the shape of the entity is (batch_size,). + """ return self.position.shape[:-1] @property def ndim(self) -> int: + """The number of dimensions of the entity. The number of dimensions is the number of dimensions of the position minus 1.""" return self.position.ndim - 1 @property def walkable(self) -> Array: + """The walkable attribute of the entity. The walkable attribute is a boolean array that indicates if the entity can be walked on.""" raise NotImplementedError() @property def transparent(self) -> Array: + """The transparent attribute of the entity. The transparent attribute is a boolean array that indicates if the entity is transparent to rendering.""" raise NotImplementedError() diff --git a/navix/events.py b/navix/events.py index 02927c7..9356b36 100644 --- a/navix/events.py +++ b/navix/events.py @@ -27,18 +27,47 @@ def on_goal_reached(state: State) -> Array: + """Checks whether the goal has been reached using the `goal_reached` event. + + Args: + state (State): The current state of the game. + + Returns: + Array: A boolean array indicating whether the goal has been reached.""" return state.events.goal_reached.happened def on_lava_fall(state: State) -> Array: + """Checks whether the lava has fallen using the `lava_fall` event. + + Args: + state (State): The current state of the game. + + Returns: + Array: A boolean array indicating whether the lava has fallen.""" return state.events.lava_fall.happened def on_ball_hit(state: State) -> Array: + """Checks whether the ball has hit something using the `ball_hit` event. + + Args: + state (State): The current state of the game. + + Returns: + Array: A boolean array indicating whether the ball has hit something.""" return state.events.ball_hit.happened def on_door_done(state: State) -> Array: + """Checks whether the action `done` has been called in front of a `Door` object with the correct colour. + + Args: + state (State): The current state of the game. + + Returns: + Array: A boolean array indicating whether the action `done` has been called in front of a `Door` object with the correct colour. + """ assert ( state.mission is not None ), "Termination on door done requires the state to specify a mission." @@ -57,4 +86,11 @@ def on_door_done(state: State) -> Array: def on_wall_hit(state: State) -> Array: + """Checks whether the wall has been hit using the `wall_hit` event. + + Args: + state (State): The current state of the game. + + Returns: + Array: A boolean array indicating whether the wall has been hit.""" return state.events.wall_hit.happened diff --git a/navix/experiment.py b/navix/experiment.py index c6111e0..0e49fb6 100644 --- a/navix/experiment.py +++ b/navix/experiment.py @@ -12,6 +12,26 @@ class Experiment: + """A class to run an experiment with a given agent and environment. + + Args: + name (str): The name of the experiment. + agent (Agent): The agent to use in the experiment. + env (Environment): The environment to use in the experiment. + env_id (str): The ID of the environment. + seeds (Tuple[int, ...]): The seeds to use in the experiment. + group (str): The group to use in the experiment. + + Attributes: + name (str): The name of the experiment. + agent (Agent): The agent to use in the experiment. + env (Environment): The environment to use in the experiment. + env_id (str): The ID of the environment. + seeds (Tuple[int, ...]): The seeds to use in the experiment. + group (str): The group to use in the experiment. + + """ + def __init__( self, name: str, @@ -29,6 +49,17 @@ def __init__( self.group = group def run(self, do_log: bool = True): + """Default function to run the experiment. This function compiles the training function, trains the agent, and logs the results. + + Args: + do_log (bool): Whether to log the results to wandb. + !!! Warning + Logging to `wandb` is usually much slower than training the agent itself. + The time is linear in the number of seeds. + + Returns: + Tuple: A tuple containing the final training state and the logs. + """ print("Running experiment with the following configuration:") print(vars(self)) rng = jnp.asarray([jax.random.PRNGKey(seed) for seed in self.seeds]) @@ -74,6 +105,21 @@ def run(self, do_log: bool = True): def run_hparam_search( self, hparams_distr: Dict[str, distrax.Distribution], pop_size: int ): + """Function to run a hyperparameter search for the experiment. This function \ + samples hyperparameters from the given distributions, trains the agent, and \ + logs the results. + + Args: + hparams_distr (Dict[str, distrax.Distribution]): A dictionary of \ + hyperparameter distributions. The keys are the hyperparameter names, which \ + must exist in `self.agent.hparams`, and the values are the corresponding \ + distributions. + pop_size (int): The number of hyperparameter sets to sample. + + Returns: + Tuple: A tuple containing the final training states and the logs, batched \ + over the hyperparameter sets. + """ hparams_fields = fields(self.agent.hparams) for k in hparams_distr: member = list(filter(lambda x: x.name == k, hparams_fields)) diff --git a/navix/grid.py b/navix/grid.py index 963d15f..2b4a48b 100644 --- a/navix/grid.py +++ b/navix/grid.py @@ -34,11 +34,30 @@ def coordinates(grid: Array) -> Coordinates: + """Returns a tuple of 2D coordinates [(col, row), ...] for each cell in the grid. + A grid array of shape `i32[height, width]` will return a tuple of length (height * width), + containing two arrays, each of shape `i32[2]`. + + Args: + grid (Array): A 2D grid of shape (height, width). + + Returns: + Tuple[Array, Array]: A tuple of two arrays containing the 2D coordinates of \ + each cell in the grid. + """ return tuple(jnp.mgrid[0 : grid.shape[0], 0 : grid.shape[1]]) # type: ignore -def idx_from_coordinates(grid: Array, coordinates: Array): - """Converts a batch of 2D coordinates [(col, row), ...] into a flat index""" +def idx_from_coordinates(grid: Array, coordinates: Array) -> Array: + """Converts a batch of 2D coordinates [(col, row), ...] into a flat index + + Args: + grid (Array): A 2D grid of shape (height, width). + coordinates (Array): A batch of 2D coordinates of shape (batch_size, 2). + + Returns: + Array: A flat index of shape `i32[batch_size]` for each coordinate in the batch. + """ coordinates = coordinates.T assert coordinates.shape[0] == 2, coordinates.shape @@ -46,8 +65,16 @@ def idx_from_coordinates(grid: Array, coordinates: Array): return jnp.asarray(idx, dtype=jnp.int32) -def coordinates_from_idx(grid: Array, idx: Array): - """Converts a flat index into a 2D coordinate (col, row)""" +def coordinates_from_idx(grid: Array, idx: Array) -> Array: + """Converts a flat index of shape `i32[]` into a 2D coordinate `i32[2]` containing \ + (col, row) data. The index is calculated as `idx = row * width + col`. + + Args: + grid (Array): A 2D grid of shape (height, width). + idx (Array): A flat index of shape `i32[]`. + + Returns: + Array: A 2D coordinate of shape `i32[2]` containing the (col, row) data.""" coords = jnp.divmod(idx, grid.shape[1]) return jnp.asarray(coords, dtype=jnp.int32).T @@ -62,6 +89,16 @@ def mask_by_coordinates( Returns a mask of the same shape as `grid` where the value is 1 if the corresponding element in `grid` satisfies the `comparison_fn` with the corresponding element in `address` (col, row) and 0 otherwise. + + Args: + grid (Array): A 2D grid of shape (height, width). + address (Coordinates): A tuple of 2D coordinates (col, row). + comparison_fn (Callable[[Array, Array], Array], optional): A comparison function. \ + Defaults to `jnp.greater_equal`. + + Returns: + Array: A boolean mask of the same shape as `grid`. + """ mesh = jnp.mgrid[0 : grid.shape[0], 0 : grid.shape[1]] cond_1 = comparison_fn(mesh[0], address[0]) @@ -73,6 +110,17 @@ def mask_by_coordinates( def translate( position: Array, direction: Array, modulus: Array = jnp.asarray(1) ) -> Array: + """Translates a point in a grid by a given direction and modulus. + + Args: + position (Array): A 2D coordinate of shape `i32[2]` containing the (col, row) data. + direction (Array): A direction in the range [0, 1, 2, 3] representing the \ + cardinal directions [east, south, west, north]. + modulus (Array, optional): The modulus of the translation. Defaults to jnp.asarray(1). + + Returns: + Array: A 2D coordinate of shape `i32[2]` containing the (col, row) data. + """ moves = ( lambda position: position + jnp.asarray((0, modulus)), # east lambda position: position + jnp.asarray((modulus, 0)), # south @@ -83,22 +131,73 @@ def translate( def translate_forward(position: Array, forward_direction: Array, modulus: Array): + """Translates a point in a grid by a given forward direction and modulus. + + Args: + position (Array): A 2D coordinate of shape `i32[2]` containing the (col, row) data. + forward_direction (Array): A direction in the range [0, 1, 2, 3] representing the \ + cardinal directions [east, south, west, north]. + modulus (Array): The modulus of the translation. + + Returns: + Array: A 2D coordinate of shape `i32[2]` containing the (col, row) data.""" return translate(position, forward_direction, modulus) def translate_left(position: Array, forward_direction: Array, modulus: Array): + """Translates a point in a grid by a given left direction and modulus. + + Args: + position (Array): A 2D coordinate of shape `i32[2]` containing the (col, row) data. + forward_direction (Array): A direction in the range [0, 1, 2, 3] representing the \ + cardinal directions [east, south, west, north]. + modulus (Array): The modulus of the translation. + + Returns: + Array: A 2D coordinate of shape `i32[2]` containing the (col, row) data.""" return translate(position, (forward_direction + 3) % 4, modulus) def translate_right(position: Array, forward_direction: Array, modulus: Array): + """Translates a point in a grid by a given right direction and modulus. + + Args: + position (Array): A 2D coordinate of shape `i32[2]` containing the (col, row) data. + forward_direction (Array): A direction in the range [0, 1, 2, 3] representing the \ + cardinal directions [east, south, west, north]. + modulus (Array): The modulus of the translation. + + Returns: + Array: A 2D coordinate of shape `i32[2]` containing the (col, row) data.""" return translate(position, (forward_direction + 1) % 4, modulus) def rotate(direction: Array, spin: int) -> Array: + """Changes a direction vectory by a given number of spins. + + Args: + direction (Array): A direction vector of shape `i32[]` in the range [0, 3] \ + representing the cardinal directions [east, south, west, north]. + spin (int): The number of spins to apply. + + Returns: + Array: A direction vector of shape `i32[]` in the range [0, 3] representing \ + the cardinal directions [east, south, west, north].""" return (direction + spin) % 4 def align(patch: Array, current_direction: Array, desired_direction: Array) -> Array: + """Aligns a patch of the grid from the current direction to the desired direction. + + Args: + patch (Array): A patch of the grid. + current_direction (Array): The current direction in the range [0, 1, 2, 3] \ + representing the cardinal directions [east, south, west, north]. + desired_direction (Array): The desired direction in the range [0, 1, 2, 3] \ + representing the cardinal directions [east, south, west, north]. + + Returns: + Array: A patch of the grid aligned to the desired direction.""" return jax.lax.switch( desired_direction - current_direction, ( @@ -114,6 +213,16 @@ def align(patch: Array, current_direction: Array, desired_direction: Array) -> A def random_positions( key: Array, grid: Array, n: int = 1, exclude: Array = jnp.asarray((-1, -1)) ) -> Array: + """Generates `n` random positions in the grid, excluding the `exclude` position. + + Args: + key (Array): A random key. + grid (Array): A 2D grid of shape (height, width). + n (int, optional): The number of random positions to generate. Defaults to 1. + exclude (Array, optional): The position to exclude. Defaults to jnp.asarray((-1, -1)). + + Returns: + Array: A batch of random positions of shape `i32[n, 2]`.""" probs = grid.reshape(-1) indices = idx_from_coordinates(grid, exclude) probs = probs.at[indices].set(-1) + 1.0 @@ -123,14 +232,40 @@ def random_positions( def random_directions(key: Array, n=1) -> Array: + """Generates `n` random directions in the range [0, 1, 2, 3] representing the \ + cardinal directions [east, south, west, north]. + + Args: + key (Array): A random key. + n (int, optional): The number of random directions to generate. Defaults to 1. + + Returns: + Array: A batch of random directions of shape `i32[n]`.""" return jax.random.randint(key, (n,), 0, 4).squeeze() def random_colour(key: Array, n=1) -> Array: + """Generates `n` random colours in the range [0, 1, 2, 3, 4, 5]. + + Args: + key (Array): A random key. + n (int, optional): The number of random colours to generate. Defaults to 1. + + Returns: + Array: A batch of random colours of shape `u8[n]`.""" return jax.random.randint(key, (n,), 0, 6).squeeze() def positions_equal(a: Array, b: Array) -> Array: + """Checks if two points are equal. + + Args: + a (Array): A 2D coordinate of shape `i32[2]` containing the (col, row) data. + b (Array): A 2D coordinate of shape `i32[2]` containing the (col, row) data. + + Returns: + + """ if b.ndim == 1: b = b[None] if a.ndim == 1: @@ -141,14 +276,34 @@ def positions_equal(a: Array, b: Array) -> Array: return is_equal -def room(height: int, width: int): - """A grid of ids of size `width` x `height`, including the sorrounding walls""" +def room(height: int, width: int) -> Array: + """Creates an array representing a room of size `height` x `width`, including + a set of walls around the room. The room is represented as a 2D grid of shape + `(height, width)`, including walls, with walls set to -1 and empty tiles set to 0. + + Args: + height (int): The height of the room. + width (int): The width of the room. + + Returns: + Array: A 2D grid of shape `(height, width)` representing a room.""" grid = jnp.zeros((height - 2, width - 2), dtype=jnp.int32) return jnp.pad(grid, 1, mode="constant", constant_values=-1) def two_rooms(height: int, width: int, key: Array) -> Tuple[Array, Array]: - """Two rooms separated by a vertical wall at `width // 2`""" + """Creates a 2D grid representing two rooms of size `height` x `width`, separated + by a wall. The rooms are represented as a 2D grid of shape `(height, width)`, \ + including walls, with walls set to -1 and empty tiles set to 0. + + Args: + height (int): The height of the rooms. + width (int): The width of the rooms. + key (Array): A random key, determining the position of the wall separating the rooms. + + Returns: + Tuple[Array, Array]: A tuple containing the 2D grid representing the rooms \ + and the column index of the wall separating the rooms.""" # create room grid = jnp.zeros((height - 2, width - 2), dtype=jnp.int32) grid = jnp.pad(grid, 1, mode="constant", constant_values=-1) @@ -162,6 +317,17 @@ def two_rooms(height: int, width: int, key: Array) -> Tuple[Array, Array]: def vertical_wall( grid: Array, row_idx: int, opening_col_idx: Array | None = None ) -> Array: + """Creates a vertical wall in the grid at the given row index, with an opening at the \ + given column index. + + Args: + grid (Array): A 2D grid of shape `(height, width)`. + row_idx (int): The row index where the wall is placed. + opening_col_idx (Array, optional): The column index where the opening is placed. \ + Defaults to None. + + Returns: + Array: A 2D grid of shape `(height, width)` with a vertical wall.""" rows = jnp.arange(1, grid.shape[0] - 1) cols = jnp.asarray([row_idx] * (grid.shape[0] - 2)) positions = jnp.stack((rows, cols), axis=1) @@ -175,6 +341,17 @@ def vertical_wall( def horizontal_wall( grid: Array, col_idx: int, opening_row_idx: Array | None = None ) -> Array: + """Creates a horizontal wall in the grid at the given column index, with an opening at the \ + given row index. + + Args: + grid (Array): A 2D grid of shape `(height, width)`. + col_idx (int): The column index where the wall is placed. + opening_row_idx (Array, optional): The row index where the opening is placed. \ + Defaults to None. + + Returns: + Array: A 2D grid of shape `(height, width)` with a horizontal wall.""" rows = jnp.asarray([col_idx] * (grid.shape[1] - 2)) cols = jnp.arange(1, grid.shape[1] - 1) positions = jnp.stack((rows, cols), axis=1) @@ -188,6 +365,17 @@ def horizontal_wall( def crop( grid: Array, origin: Array, direction: Array, radius: int, padding_value: int = 0 ) -> Array: + """Crops a grid around a given origin, facing a given direction, with a given radius. + + Args: + grid (Array): A 2D grid of shape `(height, width)`. + origin (Array): The origin of the crop. + direction (Array): The direction the crop is facing. + radius (int): The radius of the crop. + padding_value (int, optional): The padding value. Defaults to 0. + + Returns: + Array: A cropped grid.""" input_shape = grid.shape # assert radius % 2, "Radius must be an odd number" # mid = jnp.asarray([g // 2 for g in grid.shape[:2]]) @@ -228,6 +416,17 @@ def crop( def view_cone(transparency_map: Array, origin: Array, radius: int) -> Array: + """Computes the view cone of a given origin in a grid with a given radius. + The view cone is a boolean map of transparent (1) and opaque (0) tiles, indicating + whether a tile is visible from the origin or not. + + Args: + transparency_map (Array): A boolean map of transparent (1) and opaque (0) tiles. + origin (Array): The origin of the view cone. + radius (int): The radius of the view cone. + + Returns: + Array: The view cone of the given origin in the grid with the given radius.""" # transparency_map is a boolean map of transparent (1) and opaque (0) tiles def fin_diff(array, _): @@ -251,6 +450,19 @@ def fin_diff(array, _): def from_ascii_map(ascii_map: str, mapping: Dict[str, int] = {}) -> Array: + """Converts an ASCII map into a 2D grid. The ASCII map is a string where each character + represents a tile in the grid. The mapping dictionary can be used to map ASCII characters + to integer values. By default, the mapping is as follows: + - `#` is mapped to -1 + - `.` is mapped to 0 + + Args: + ascii_map (str): The ASCII map. + mapping (Dict[str, int], optional): A dictionary mapping ASCII characters to integer \ + values. Defaults to {}. + + Returns: + Array: A 2D grid representing the ASCII map.""" mapping = {**{"#": -1, ".": 0}, **mapping} ascii_map = ascii_map.strip() @@ -266,6 +478,12 @@ def from_ascii_map(ascii_map: str, mapping: Dict[str, int] = {}) -> Array: class RoomsGrid(struct.PyTreeNode): + """A grid of rooms. Each room is represented as a 2D grid of shape `(room_height, room_width)`, + with walls set to -1 and empty tiles set to 0. The grid of rooms is represented as a 2D grid of + shape `(rows * (room_height + 1), cols * (room_width + 1))`, with walls set to -1 and empty tiles + set to 0. The grid of rooms is represented as a 2D grid of shape `(rows * (room_height + 1), cols * (room_width + 1))`, + with walls set to -1 and empty tiles set to 0.""" + room_starts: Array # shape (rows, cols) room_size: Tuple[int, int] @@ -273,6 +491,15 @@ class RoomsGrid(struct.PyTreeNode): def create( cls, num_rows: int, num_cols: int, room_size: Tuple[int, int] ) -> RoomsGrid: + """Creates a grid of rooms with the given number of rows and columns, and the given room size. + + Args: + num_rows (int): The number of rows. + num_cols (int): The number of columns. + room_size (Tuple[int, int]): The size of each room `(height, width)`. + + Returns: + RoomsGrid: A grid of rooms.""" # generate rooms grid height = num_rows * (room_size[0] + 1) width = num_cols * (room_size[1] + 1) @@ -286,6 +513,15 @@ def create( return cls(starts, room_size) def get_grid(self, occupied_positions: Array | None = None) -> Array: + """Computes the array representation of the grid of rooms, with walls set to \ + -1 and empty tiles set to 0. + + Args: + occupied_positions (Array, optional): A batch of extra occupied positions \ + of shape `(n, 2)`. Defaults to None. + + Returns: + Array: A 2D grid of shape `(rows * (room_height + 1), cols * (room_width + 1))`.""" room_size = self.room_size num_rows, num_cols = self.room_starts.shape[:2] grid = jnp.zeros( @@ -299,6 +535,15 @@ def get_grid(self, occupied_positions: Array | None = None) -> Array: return grid def position_in_room(self, row: Array, col: Array, *, key: Array) -> Array: + """Generates a random position in a given room. + + Args: + row (Array): The row index of the room. + col (Array): The column index of the room. + key (Array): A random key. + + Returns: + Array: A random position in the given room.""" k1, k2 = jax.random.split(key) local_row = jax.random.randint(k1, (), minval=1, maxval=self.room_size[0]) local_col = jax.random.randint(k2, (), minval=1, maxval=self.room_size[1]) @@ -308,7 +553,17 @@ def position_in_room(self, row: Array, col: Array, *, key: Array) -> Array: def position_on_border( self, row: Array, col: Array, side: int, *, key: Array ) -> Array: - """Side is 0: west, 1: east, 2: north, 3: south (like padding)""" + """Generates a random position on the border of a given room. + Side is 0: west, 1: east, 2: north, 3: south (like padding) + + Args: + row (Array): The row index of the room. + col (Array): The column index of the room. + side (int): The side of the room. + key (Array): A random key. + + Returns: + Array: A random position on the border of the given room.""" starts = self.room_starts[row, col] room_size = self.room_size if side == 0: diff --git a/navix/observations.py b/navix/observations.py index a83d202..a2577a1 100644 --- a/navix/observations.py +++ b/navix/observations.py @@ -35,10 +35,28 @@ def none(state: State) -> Array: + """An empty observation represented as an array of shape f32[0]. + Useful for testing purposes. + + Args: + state (State): The current state of the game. + + Returns: + Array: A 0-shaped array `f32[0]`.""" return jnp.asarray(()) def categorical(state: State) -> Array: + """Fully observable grid with a categorical state representation. + Each entity is represented by its unique integer tag. + + Args: + state (State): The current state of the game. + + Returns: + Array: A grid of integers, where each integer represents an entity, \ + represented as an array of shape `i32[H, W]`, where `H` and `W` are the height \ + and width of the grid.""" # get idx of entity on the set of patches indices = idx_from_coordinates(state.grid, state.get_positions()) # get tags corresponding to the entities @@ -51,6 +69,15 @@ def categorical(state: State) -> Array: def categorical_first_person(state: State) -> Array: + """Categorical state representation, but cropped to the agent's view, and aligned \ + with the agent's direction, such that the agent always points upwards. + + Args: + state (State): The current state of the game. + + Returns: + Array: A grid of integers, where each integer represents an entity, \ + represented as an array of shape `i32[2 * RADIUS + 1, 2 * RADIUS + 1]`.""" # get transparency map transparency_map = jnp.where(state.grid == 0, 1, 0) positions = state.get_positions() @@ -72,9 +99,19 @@ def categorical_first_person(state: State) -> Array: def symbolic(state: State) -> Array: - """Fully observable grid with a symbolic state representation. - The symbol is a triple of (OBJECT_TAG, COLOUR_IDX, OPEN/CLOSED/LOCKED), \ - where X and Y are the coordinates on the grid, and IDX is the id of the object.""" + """Fully observable grid with a symbolic state representation as originally \ + proposed in the MiniGrid environment. + The symbol is a triple of (OBJECT_TAG, COLOUR_IDX, OPEN/CLOSED/LOCKED). The + last layer might also contain the direction of the entity, for example, the + direction of the agent. + + Args: + state (State): The current state of the game. + + Returns: + Array: A grid of integers, where each integer represents an entity, \ + represented as an array of shape `u8[H, W, 3]`, where `H` and `W` are the height \ + and width of the grid.""" # initialise as all floors H, W = state.grid.shape obs = jnp.zeros((H, W, 3), dtype=jnp.uint8) @@ -104,9 +141,16 @@ def symbolic(state: State) -> Array: def symbolic_first_person(state: State) -> Array: - """First person view with a symbolic state representation. - The symbol is a triple of (OBJECT_TAG, COLOUR_IDX, OPEN/CLOSED/LOCKED), \ - where X and Y are the coordinates on the grid, and IDX is the id of the object.""" + """First person view with a symbolic state representation, but cropped to the \ + agent's view, and aligned with the agent's direction, such that the agent always \ + points upwards. See `symbolic` for more details. + + Args: + state (State): The current state of the game. + + Returns: + Array: A grid of integers, where each integer represents an entity, \ + represented as an array of shape `u8[2 * RADIUS + 1, 2 * RADIUS + 1, 3]`.""" # get transparency map obs = symbolic(state) @@ -131,6 +175,18 @@ def symbolic_first_person(state: State) -> Array: def rgb(state: State) -> Array: + """Fully observable grid with an RGB state representation. + Each entity is represented by its unique RGB sprite. The RGB sprites are \ + stored in a cache, and the entities are placed on the grid according to their \ + positions. + + Args: + state (State): The current state of the game. + + Returns: + Array: An RGB image of the grid, represented as an array of shape \ + `u8[H * S, W * S, 3]`, where `H` and `W` are the height and width of the grid, + and `S` is the size of the tile.""" # get idx of entity on the flat set of patches indices = idx_from_coordinates(state.grid, state.get_positions()) # get tiles corresponding to the entities @@ -149,6 +205,18 @@ def rgb(state: State) -> Array: def rgb_first_person(state: State) -> Array: + """First person view with an RGB state representation. + The image is cropped to the agent's view, and aligned with the agent's direction, \ + such that the agent always points upwards. See `rgb` for more details. + See `rgb` for more details. + + Args: + state (State): The current state of the game. + + Returns: + Array: An RGB image of the agent's view, represented as an array of shape \ + `u8[(2 * RADIUS + 1) * S, (2 * RADIUS + 1) * S, 3]`, where + `S` is the size of the tile.""" # calculate final image size # get agent's view # image_size = ( diff --git a/navix/rewards.py b/navix/rewards.py index 92cc47e..4bdbd64 100644 --- a/navix/rewards.py +++ b/navix/rewards.py @@ -31,6 +31,19 @@ def compose( *reward_functions: Callable[[State, Array, State], Array], operator: Callable = jnp.sum, ) -> Callable: + """Compose multiple reward functions into a single reward function. + The functions are called in order and the results are reduced using the `operator` \ + function. + + Args: + *reward_functions (Callable[[State, Array, State], Array]): A list of reward functions. + operator (Callable): The operator to reduce the results of the reward functions. + It must be a function that takes a list of arrays, or an array and returns an \ + array of size `f32[]`. + + Returns: + Callable: A composed reward function that applies the `operator` to the results of the \ + reward functions.""" return lambda prev_state, action, state: operator( jnp.asarray( [f(prev_state, action, state) for f in reward_functions], dtype=jnp.float32 @@ -39,23 +52,62 @@ def compose( def free(state: State) -> Array: + """A reward function that always returns 0, to simulate reward-free learning. + + Args: + state (State): The current state of the game. + + Returns: + Array: A scalar array `f32[]` with value 0.""" return jnp.asarray(0.0, dtype=jnp.float32) def on_goal_reached(prev_state: State, action: Array, state: State) -> Array: + """A reward function that returns 1 when the goal is reached, and 0 otherwise. + + Args: + state (State): The current state of the game. + + Returns: + Array: A scalar array `f32[]` with value 1 if the goal is reached, and 0 otherwise. + """ return jnp.asarray(events.on_goal_reached(state), dtype=jnp.float32) def action_cost( prev_state: State, action: Array, new_state: State, cost: float = 0.01 ) -> Array: + """A reward function that returns a negative value when an action is taken. + All actions have a cost of `cost`, except for noops. + + Args: + prev_state (State): The previous state of the game. + action (Array): The action taken. + new_state (State): The new state of the game. + cost (float): The cost of taking an action. + + Returns: + Array: A scalar array `f32[]` with value -`cost` if the action is not a noop, \ + and 0 otherwise.""" # noops are free - return -jnp.asarray(action > 0, dtype=jnp.float32) * cost + return -jnp.asarray(action != 6, dtype=jnp.float32) * cost def time_cost( prev_state: State, action: Array, new_state: State, cost: float = 0.01 ) -> Array: + """A reward function that returns a negative value as time passes, paying a cost \ + of `cost` at each time step. + + Args: + prev_state (State): The previous state of the game. + action (Array): The action taken. + new_state (State): The new state of the game. + cost (float): The cost of time passing. + + Returns: + Array: A scalar array `f32[]` with value -`cost`. + """ # time always has a cost return -jnp.asarray(cost, dtype=jnp.float32) @@ -63,13 +115,32 @@ def time_cost( def wall_hit_cost( prev_state: State, action: Array, state: State, cost: float = 0.01 ) -> Array: + """A reward function that returns a negative value when the agent hits a wall, \ + paying a cost of `cost` for each wall hit. + + Args: + state (State): The current state of the game. + cost (float): The cost of hitting a wall. + + Returns: + Array: A scalar array `f32[]` with value -`cost` if the agent hits a wall, \ + and 0 otherwise.""" return jnp.asarray(events.on_wall_hit(state), dtype=jnp.float32) * cost -def on_door_done( - prev_state: State, action: Array, state: State, cost: float = 0.01 -) -> Array: +def on_door_done(prev_state: State, action: Array, state: State) -> Array: + """A reward function that returns a positive value when the agent uses the action \ + `done` in front of a door. + + Args: + state (State): The current state of the game. + + Returns: + Array: A scalar array `f32[]` with value 1 if the agent uses the action `done` in \ + front of a door, and 0 otherwise.""" + return jnp.asarray(events.on_door_done(state), dtype=jnp.float32) DEFAULT_TASK = compose(on_goal_reached, action_cost) +"""The default task for the game, composed of the `on_goal_reached` and `action_cost` reward functions.""" diff --git a/navix/spaces.py b/navix/spaces.py index 1a6cb95..0158a5d 100644 --- a/navix/spaces.py +++ b/navix/spaces.py @@ -25,12 +25,30 @@ class Space(struct.PyTreeNode): + """Base class for all spaces in the game. Spaces define the shape and type of the \ + observations, actions and rewards in the game. + The `sample` method is used to generate random samples from the space. + + !!! note + To initialize a space, use the `create` method of the specific space class. + + TODO: + * maximum and minimum should be static objects, not arrays. + But how do we handle the case when they are not scalars? Maybe numpy arrays?""" + shape: Shape = struct.field(pytree_node=False) dtype: jnp.dtype = struct.field(pytree_node=False) minimum: Array maximum: Array def sample(self, key: Array) -> Array: + """Generate a random sample from the space. + + Args: + key (Array): A random key to generate the sample. + + Returns: + Array: A random sample from the space.""" raise NotImplementedError() @@ -39,6 +57,15 @@ class Discrete(Space): def create( cls, n_elements: int | jax.Array, shape: Shape = (), dtype=jnp.int32 ) -> Discrete: + """Create a discrete space with a given number of elements. + + Args: + n_elements (int | jax.Array): The number of elements in the space. + shape (Shape): The shape of the space. + dtype (jnp.dtype): The data type of the space. + + Returns: + Discrete: A discrete space with the given number of elements.""" return Discrete( shape=shape, dtype=dtype, @@ -47,12 +74,23 @@ def create( ) def sample(self, key: Array) -> Array: + """Generate a random sample from the space. + + Args: + key (Array): A random key to generate the sample. + + Returns: + Array: A random sample from the space.""" item = jax.random.randint(key, self.shape, self.minimum, self.maximum) # randint cannot draw jnp.uint, so we cast it later return jnp.asarray(item, dtype=self.dtype) - + @property def n(self) -> Array: + """The number of elements in the space. + + Returns: + Array: The number of elements in the space.""" return self.maximum + 1 @@ -61,9 +99,27 @@ class Continuous(Space): def create( cls, shape: Shape, minimum: Array, maximum: Array, dtype=jnp.float32 ) -> Continuous: + """Create a continuous space with a given shape, minimum and maximum values. + + Args: + shape (Shape): The shape of the space. + minimum (Array): The minimum value of the space. + maximum (Array): The maximum value of the space. + dtype (jnp.dtype): The data type of the space. + + Returns: + Continuous: A continuous space with the given shape, minimum and maximum values. + """ return Continuous(shape=shape, dtype=dtype, minimum=minimum, maximum=maximum) def sample(self, key: Array) -> Array: + """Generate a random sample from the space. + + Args: + key (Array): A random key to generate the sample. + + Returns: + Array: A random sample from the space.""" assert jnp.issubdtype(self.dtype, jnp.floating) # see: https://github.com/google/jax/issues/14003 lower = jnp.nan_to_num(self.minimum) diff --git a/navix/states.py b/navix/states.py index c815626..aaddd99 100644 --- a/navix/states.py +++ b/navix/states.py @@ -34,6 +34,8 @@ class EventType: + """Enumeration of the different types of events that can happen in the environment.""" + NONE: Array = jnp.asarray(-1, dtype=jnp.int32) REACH: Array = jnp.asarray(0, dtype=jnp.int32) HIT: Array = jnp.asarray(1, dtype=jnp.int32) @@ -44,6 +46,23 @@ class EventType: class Event(Positionable, HasColour): + """A struct representing an event that happened in the environment. It contains the + position of the event, the colour of the entity involved in the event, whether the event + happened, and the type of event that happened. + + !!! note + Notice that we need the `happened` property, which flags if an event has + happened or not, because JAX does not support variable size arrays. + This means that we cannot add an event to the list in the middle of training. + Instead, we initialise all events, and mask them out as not happened. + + + Attributes: + position (Array): The (row, column) position of the event in the grid. + colour (Array): The colour of the entity involved in the event. + happened (Array): A boolean flag indicating whether the event happened. + event_type (Array): The type of event that happened.""" + position: Array = jnp.asarray([-1, -1], dtype=jnp.int32) colour: Array = PALETTE.UNSET happened: Array = jnp.asarray(False, dtype=jnp.bool_) @@ -60,6 +79,20 @@ def __ne__(self, other: Event) -> Array: class EventsManager(struct.PyTreeNode): + """A struct that manages the events. It contains the different events that can happen + in the environment, such as the goal being reached, the player being hit by a ball, etc. + + Attributes: + goal_reached (Event): An event indicating that the goal has been reached. + ball_hit (Event): An event indicating that the player has been hit by a ball. + wall_hit (Event): An event indicating that the player has hit a wall. + lava_fall (Event): An event indicating that the lava has fallen. + key_pickup (Event): An event indicating that the player has picked up a key. + door_opening (Event): An event indicating that the player has opened a door. + door_unlock (Event): An event indicating that the player has unlocked a door. + ball_pickup (Event): An event indicating that the player has picked up a ball. + """ + goal_reached: Event = Event() ball_hit: Event = Event() wall_hit: Event = Event() @@ -70,6 +103,15 @@ class EventsManager(struct.PyTreeNode): ball_pickup: Event = Event() def record_walk_into(self, entity: Entity, position: Array) -> EventsManager: + """Flags an event when the player walks into an entity as happened and returns the + updated events manager. + + Args: + entity (Entity): The entity the player walked into. + position (Array): The position of the entity in the grid. + + Returns: + EventsManager: The updated events manager.""" if isinstance(entity, Goal): return self.record_goal_reached(entity, position) elif isinstance(entity, Wall): @@ -79,6 +121,15 @@ def record_walk_into(self, entity: Entity, position: Array) -> EventsManager: return self def record_pickup(self, entity: Entity, position: Array) -> EventsManager: + """Flags an event when the player picks up an entity as happened and returns the + updated events manager. + + Args: + entity (Entity): The entity the player picked up. + position (Array): The position of the entity in the grid. + + Returns: + EventsManager: The updated events manager.""" if isinstance(entity, Key): return self.record_key_pickup(entity, position) elif isinstance(entity, Ball): @@ -86,6 +137,15 @@ def record_pickup(self, entity: Entity, position: Array) -> EventsManager: return self def record_goal_reached(self, goal: Goal, position: Array) -> EventsManager: + """Flags an event when the player reaches the goal as happened and returns the + updated events manager. + + Args: + goal (Goal): The goal the player reached. + position (Array): The position of the goal in the grid. + + Returns: + EventsManager: The updated events manager.""" idx = jnp.where(goal.position == position, size=1)[0][0] goal = goal[idx] return self.replace( @@ -98,6 +158,14 @@ def record_goal_reached(self, goal: Goal, position: Array) -> EventsManager: ) def record_ball_hit(self, ball: Ball) -> EventsManager: + """Flags an event when the player is hit by a ball as happened and returns the + updated events manager. + + Args: + ball (Ball): The ball that hit the player. + + Returns: + EventsManager: The updated events manager.""" return self.replace( ball_hit=Event( position=ball.position, @@ -108,6 +176,15 @@ def record_ball_hit(self, ball: Ball) -> EventsManager: ) def record_wall_hit(self, wall: Wall, position: Array) -> EventsManager: + """Flags an event when the player hits a wall as happened and returns the + updated events manager. + + Args: + wall (Wall): The wall the player hit. + position (Array): The position of the wall in the grid. + + Returns: + EventsManager: The updated events manager.""" idx = jnp.where(wall.position == position, size=1)[0][0] wall = wall[idx] return self.replace( @@ -120,6 +197,14 @@ def record_wall_hit(self, wall: Wall, position: Array) -> EventsManager: ) def record_grid_hit(self, position: Array) -> EventsManager: + """Flags an event when the player hits a wall as happened and returns the + updated events manager. + + Args: + position (Array): The position of the wall in the grid. + + Returns: + EventsManager: The updated events manager.""" return self.replace( wall_hit=Event( position=position, @@ -130,6 +215,15 @@ def record_grid_hit(self, position: Array) -> EventsManager: ) def record_lava_fall(self, lava: Lava, position: Array) -> EventsManager: + """Flags an event when the lava falls as happened and returns the + updated events manager. + + Args: + lava (Lava): The lava that fell. + position (Array): The position of the lava in the grid. + + Returns: + EventsManager: The updated events manager.""" idx = jnp.where(lava.position == position, size=1)[0][0] lava = lava[idx] return self.replace( @@ -142,6 +236,15 @@ def record_lava_fall(self, lava: Lava, position: Array) -> EventsManager: ) def record_key_pickup(self, key: Key, position: Array) -> EventsManager: + """Flags an event when the player picks up a key as happened and returns the + updated events manager. + + Args: + key (Key): The key the player picked up. + position (Array): The position of the key in the grid. + + Returns: + EventsManager: The updated events manager.""" idx = jnp.where(key.position == position, size=1)[0][0] key = key[idx] return self.replace( @@ -154,6 +257,15 @@ def record_key_pickup(self, key: Key, position: Array) -> EventsManager: ) def record_door_opening(self, door: Door, position: Array) -> EventsManager: + """Flags an event when the player opens a door as happened and returns the + updated events manager. + + Args: + door (Door): The door the player opened. + position (Array): The position of the door in the grid. + + Returns: + EventsManager: The updated events manager.""" idx = jnp.where(door.position == position, size=1)[0][0] door = door[idx] return self.replace( @@ -166,6 +278,15 @@ def record_door_opening(self, door: Door, position: Array) -> EventsManager: ) def record_door_unlock(self, door: Door, position: Array) -> EventsManager: + """Flags an event when the player unlocks a door as happened and returns the + updated events manager. + + Args: + door (Door): The door the player unlocked. + position (Array): The position of the door in the grid. + + Returns: + EventsManager: The updated events manager.""" idx = jnp.where(door.position == position, size=1)[0][0] door = door[idx] return self.replace( @@ -178,6 +299,15 @@ def record_door_unlock(self, door: Door, position: Array) -> EventsManager: ) def record_ball_pickup(self, ball: Ball, position: Array) -> EventsManager: + """Flags an event when the player picks up a ball as happened and returns the + updated events manager. + + Args: + ball (Ball): The ball the player picked up. + position (Array): The position of the ball in the grid. + + Returns: + EventsManager: The updated events manager.""" idx = jnp.where(ball.position == position, size=1)[0][0] ball = ball[idx] return self.replace( @@ -208,76 +338,112 @@ class State(struct.PyTreeNode): mission: Event | None = None def get_entity(self, entity_enum: str) -> Entity: + """Get an entity from the state by its enum. + + Args: + entity_enum (str): The enum of the entity to get. + + Returns: + Entity: The entity from the state.""" return self.entities[entity_enum] def set_entity(self, entity_enum: str, entity: Entity) -> State: + """Set an entity in the state by its enum. + + Args: + entity_enum (str): The enum of the entity to set. + entity (Entity): The entity to set. + + Returns: + State: The updated state.""" self.entities[entity_enum] = entity return self def get_walls(self) -> Wall: + """Gets all the `WALL` entities from the state.""" return self.entities.get(Entities.WALL, Wall()) # type: ignore def set_walls(self, walls: Wall) -> State: + """Sets the `WALL` entities in the state.""" self.entities[Entities.WALL] = walls return self def get_player(self, idx: int = 0) -> Player: + """Gets the player entity from the state.""" return self.entities[Entities.PLAYER][idx] # type: ignore def set_player(self, player: Player, idx: int = 0) -> State: + """Sets the player entity in the state. Notice that we only support one player in the + environment for now, but this can easily be extended to multiple players.""" # TODO(epignatelli): this is a hack and won't work in multi-agent settings self.entities[Entities.PLAYER] = player[None] return self def get_goals(self) -> Goal: + """Gets the goal entity from the state.""" return self.entities[Entities.GOAL] # type: ignore def set_goals(self, goals: Goal) -> State: + """Sets the goal entity in the state.""" self.entities[Entities.GOAL] = goals return self def get_keys(self) -> Key: + """Gets the key entity from the state.""" return self.entities[Entities.KEY] # type: ignore def set_keys(self, keys: Key) -> State: + """Sets the key entity in the state.""" self.entities[Entities.KEY] = keys return self def get_doors(self) -> Door: + """Gets the door entity from the state.""" return self.entities[Entities.DOOR] # type: ignore def set_doors(self, doors: Door) -> State: + """Sets the door entity in the state.""" self.entities[Entities.DOOR] = doors return self def get_lavas(self) -> Lava: + """Gets the lava entity from the state.""" return self.entities[Entities.LAVA] # type: ignore def get_balls(self) -> Ball: + """Gets the ball entity from the state.""" return self.entities[Entities.BALL] # type: ignore def get_boxes(self) -> Ball: + """Gets the box entity from the state.""" return self.entities[Entities.BOX] # type: ignore def set_balls(self, balls: Ball) -> State: + """Sets the ball entity in the state.""" self.entities[Entities.BALL] = balls return self def set_boxes(self, boxes: Box) -> State: + """Sets the box entity in the state.""" self.entities[Entities.BOX] = boxes return self def set_events(self, events: EventsManager) -> State: + """Sets the events in the state.""" return self.replace(events=events) def get_positions(self) -> Array: + """Get the positions of all the entities in the state.""" return jnp.concatenate([self.entities[k].position for k in self.entities]) def get_tags(self) -> Array: + """Get the tags of all the entities in the state.""" return jnp.concatenate([self.entities[k].tag for k in self.entities]) def get_sprites(self) -> Array: + """Get the sprites of all the entities in the state.""" return jnp.concatenate([self.entities[k].sprite for k in self.entities]) def get_transparency(self) -> Array: + """Get the transparency of all the entities in the state.""" return jnp.concatenate([self.entities[k].transparent for k in self.entities]) diff --git a/navix/terminations.py b/navix/terminations.py index 85d6796..6345a0a 100644 --- a/navix/terminations.py +++ b/navix/terminations.py @@ -24,37 +24,91 @@ from . import events from .states import State -from .grid import translate -from .entities import Entities, Player def compose( *term_functions: Callable[[State, Array, State], Array], operator: Callable = jnp.any, ) -> Callable: + """Compose termination functions into a single termination function. + + Args: + *term_functions (Callable): List of termination functions. + operator (Callable): Operator to combine the termination functions. + + Returns: + Callable: A single termination function.""" return lambda prev_state, action, state: operator( jnp.asarray([term_f(prev_state, action, state) for term_f in term_functions]) ) def check_truncation(terminated: Array, truncated: Array) -> Array: + """Check if the episode is truncated or terminated, and returns a value + that conforms to the `StepType` enum. + + Args: + terminated (Array): A boolean array indicating whether the episode is terminated. + truncated (Array): A boolean array indicating whether the episode is truncated. + + Returns: + Array: An integer array that represents the step type.""" result = jnp.asarray(truncated + 2 * terminated, dtype=jnp.int32) return jnp.clip(result, 0, 2) def on_goal_reached(prev_state: State, action: Array, state: State) -> Array: + """Check if the goal has been reached using the `goal_reached` event. + + Args: + prev_state (State): The previous state of the game. + action (Array): The action taken by the player. + state (State): The current state of the game. + + Returns: + Array: A boolean array indicating whether the goal has been reached.""" return jnp.asarray(events.on_goal_reached(state), dtype=jnp.bool_) def on_lava_fall(prev_state: State, action: Array, state: State) -> Array: + """Check if the lava has fallen using the `lava_fall` event. + + Args: + prev_state (State): The previous state of the game. + action (Array): The action taken by the player. + state (State): The current state of the game. + + Returns: + Array: A boolean array indicating whether the lava has fallen.""" return jnp.asarray(events.on_lava_fall(state), dtype=jnp.bool_) def on_ball_hit(prev_state: State, action: Array, state: State) -> Array: + """Check if the ball has hit something using the `ball_hit` event. + + Args: + prev_state (State): The previous state of the game. + action (Array): The action taken by the player. + state (State): The current state of the game. + + Returns: + Array: A boolean array indicating whether the ball has hit something.""" return jnp.asarray(events.on_ball_hit(state), dtype=jnp.bool_) def on_door_done(prev_state: State, action: Array, state: State) -> Array: + """Check if the action `done` has been called in front of a `Door` object with the \ + correct colour. + + Args: + prev_state (State): The previous state of the game. + action (Array): The action taken by the player. + state (State): The current state of the game. + + Returns: + Array: A boolean array indicating whether the action `done` has been called in \ + front of a `Door` object with the correct colour. + """ return jnp.asarray(events.on_door_done(state), dtype=jnp.bool_) diff --git a/navix/transitions.py b/navix/transitions.py index 1a99032..b17a2ba 100644 --- a/navix/transitions.py +++ b/navix/transitions.py @@ -24,7 +24,6 @@ from jax import Array import jax import jax.numpy as jnp -import jax.tree_util as jtu from .entities import Entities, Ball from .states import EventsManager, State from .grid import positions_equal, translate @@ -33,12 +32,33 @@ def deterministic_transition( state: State, action: Array, actions_set: Tuple[Callable[[State], State], ...] ) -> State: + """Deterministic transition function. It selects the action from the set of actions + and applies it to the state. + + Args: + state (State): The current state of the game. + action (Array): The action to be taken. + actions_set (Tuple[Callable[[State], State]): A set of actions that can be taken. + + Returns: + State: The new state of the game.""" return jax.lax.switch(action, actions_set, state) def stochastic_transition( state: State, action: Array, actions_set: Tuple[Callable[[State], State], ...] ) -> State: + """Stochastic transition function. It selects the action from the set of actions + and applies it to the state, and updates entities that have stochastic transitions, + such as balls. + + Args: + state (State): The current state of the game. + action (Array): The action to be taken. + actions_set (Tuple[Callable[[State], State]): A set of actions that can be taken. + + Returns: + State: The new state of the game.""" # actions state = jax.lax.switch(action, actions_set, state) @@ -47,6 +67,14 @@ def stochastic_transition( def update_balls(state: State) -> State: + """Update the position of the balls in the game. + Balls move in a random direction if they can, otherwise they stay in place. + + Args: + state (State): The current state of the game. + + Returns: + State: The new state of the game.""" def update_one(ball: Ball, key: Array) -> Tuple[Array, EventsManager]: direction = jax.random.randint(key, (), minval=0, maxval=4) new_position = translate(ball.position, direction)