Skip to content

Commit

Permalink
feat: add goto door envs
Browse files Browse the repository at this point in the history
  • Loading branch information
epignatelli committed Mar 12, 2024
1 parent 3e47b4c commit 00749e9
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 0 deletions.
1 change: 1 addition & 0 deletions navix/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@
from .crossings import Crossings
from .dynamic_obstacles import DynamicObstacles
from.dist_shift import DistShift
from .go_to_door import GoToDoor
116 changes: 116 additions & 0 deletions navix/environments/go_to_door.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright 2023 The Navix Authors.

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.


from __future__ import annotations
from typing import Union

import jax
import jax.numpy as jnp
from jax import Array
from flax import struct

from ..components import EMPTY_POCKET_ID
from ..entities import Entities, Goal, Door, Player, State
from ..grid import random_colour, random_positions, random_directions, room
from ..rendering.cache import RenderingCache
from .environment import Environment, Timestep
from .registry import register_env


class GoToDoor(Environment):
split_lava: bool = struct.field(pytree_node=False, default=False)

def reset(self, key: Array, cache: Union[RenderingCache, None] = None) -> Timestep:
# map
grid = jnp.zeros((self.height, self.width), dtype=jnp.int32)

k1, k2, k3, k4, k5 = jax.random.split(key, num=5)
room_height = jax.random.randint(k1, (), minval=5, maxval=self.height)
room_width = jax.random.randint(k1, (), minval=5, maxval=self.width)

# set wall on grid
grid = grid.at[jnp.asarray([0, room_height - 1])].set(-1)
grid = grid.at[:, jnp.asarray([0, room_width - 1])].set(-1)

# goal and player
player_row = jax.random.randint(k2, (), minval=1, maxval=room_height - 1)
player_col = jax.random.randint(k3, (), minval=1, maxval=room_width - 1)
player_pos = jnp.asarray([player_row, player_col])
direction = random_directions(k4)
player = Player(
position=player_pos,
direction=direction,
pocket=EMPTY_POCKET_ID,
)

# doors
k6, k7 = jax.random.split(k5, num=2)
rows = jax.random.randint(k6, (2,), minval=2, maxval=room_height - 1)
cols = jax.random.randint(k7, (2,), minval=2, maxval=room_width - 1)
positions = jnp.asarray(
[
[rows[0], room_width - 1],
[room_height - 1, cols[0]],
[rows[1], 0],
[0, cols[1]],
]
)
colours = random_colour(key, n=4)
open = jnp.asarray([0] * 4)
requires = jnp.asarray([-1] * 4)
doors = Door.create(
position=positions, requires=requires, colour=colours, open=open
)

entities = {
Entities.PLAYER: player[None],
Entities.LAVA: doors,
}

# systems
state = State(
key=key,
grid=grid,
cache=cache or RenderingCache.init(grid),
entities=entities,
)

return Timestep(
t=jnp.asarray(0, dtype=jnp.int32),
observation=self.observation(state),
action=jnp.asarray(0, dtype=jnp.int32),
reward=jnp.asarray(0.0, dtype=jnp.float32),
step_type=jnp.asarray(0, dtype=jnp.int32),
state=state,
)


register_env(
"Navix-GoToDoor-5x5-v0",
lambda *args, **kwargs: GoToDoor(height=5, width=5, *args, **kwargs),
)
register_env(
"Navix-GoToDoor-6x6-v0",
lambda *args, **kwargs: GoToDoor(height=6, width=6, *args, **kwargs),
)
register_env(
"Navix-GoToDoor-8x8-v0",
lambda *args, **kwargs: GoToDoor(height=8, width=8, *args, **kwargs),
)

0 comments on commit 00749e9

Please sign in to comment.