Skip to content

Commit

Permalink
Merge pull request #59 from epignatelli/examples
Browse files Browse the repository at this point in the history
Add examples and make sure obs match those in minigrid
  • Loading branch information
epignatelli authored Jun 6, 2024
2 parents bf22b15 + ba5b598 commit 346198e
Show file tree
Hide file tree
Showing 44 changed files with 1,746 additions and 521 deletions.
82 changes: 44 additions & 38 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,45 +14,51 @@ jobs:
os: ["ubuntu"]
continue-on-error: false
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Setup navix
run: |
pip install . -v
- name: Check code quality
run: |
pip install pylint
MESSAGE=$(pylint -ry $(git ls-files '*.py') ||:)
echo "$MESSAGE"
- name: Run unit tests with pytest
run: |
pytest
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: Setup navix
run: |
pip install . -v
- name: Check code quality
run: |
pip install pylint
MESSAGE=$(pylint -ry $(git ls-files '*.py') ||:)
echo "$MESSAGE"
- name: Run unit tests with pytest
run: |
wandb offline
pytest
- name: Run examples
run: |
for example in examples/*.py; do
python $example
done
Compliance:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
- name: PEP8 Compliance
run: |
pip install pylint
PR_BRANCH=${{ github.event.pull_request.target.ref }}
MAIN_BRANCH=origin/${{ github.event.pull_request.base.ref }}
CURRENT_DIFF=$(git diff --name-only --diff-filter=d $MAIN_BRANCH $PR_BRANCH | grep -E '\.py$' | tr '\n' ' ')
if [[ $CURRENT_DIFF == "" ]];
then MESSAGE="Diff is empty and there is nothing to pylint."
else
MESSAGE=$(pylint -ry --disable=E0401 $CURRENT_DIFF ||:)
fi
echo 'MESSAGE<<EOF' >> $GITHUB_ENV
echo "<pre><code>$MESSAGE</code></pre>" >> $GITHUB_ENV
echo 'EOF' >> $GITHUB_ENV
echo "Printing PR message: $MESSAGE"
- uses: mshick/add-pr-comment@v2
with:
issue: ${{ github.event.pull_request.number }}
message: ${{ env.MESSAGE }}
repo-token: ${{ secrets.GITHUB_TOKEN }}
- uses: actions/checkout@v3
with:
fetch-depth: 0
- name: PEP8 Compliance
run: |
pip install pylint
PR_BRANCH=${{ github.event.pull_request.target.ref }}
MAIN_BRANCH=origin/${{ github.event.pull_request.base.ref }}
CURRENT_DIFF=$(git diff --name-only --diff-filter=d $MAIN_BRANCH $PR_BRANCH | grep -E '\.py$' | tr '\n' ' ')
if [[ $CURRENT_DIFF == "" ]];
then MESSAGE="Diff is empty and there is nothing to pylint."
else
MESSAGE=$(pylint -ry --disable=E0401 $CURRENT_DIFF ||:)
fi
echo 'MESSAGE<<EOF' >> $GITHUB_ENV
echo "<pre><code>$MESSAGE</code></pre>" >> $GITHUB_ENV
echo 'EOF' >> $GITHUB_ENV
echo "Printing PR message: $MESSAGE"
- uses: mshick/add-pr-comment@v2
with:
issue: ${{ github.event.pull_request.number }}
message: ${{ env.MESSAGE }}
repo-token: ${{ secrets.GITHUB_TOKEN }}
97 changes: 97 additions & 0 deletions baselines/ppo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from dataclasses import asdict, dataclass
import time
import wandb

import jax
import numpy as np
import jax.numpy as jnp
import flax.linen as nn
from flax.linen.initializers import constant, orthogonal
import tyro
import navix as nx
from navix.environments.environment import Environment
from navix.agents import PPO, PPOHparams, ActorCritic


def FlattenObsWrapper(env: Environment):
flatten_obs_fn = lambda x: jnp.ravel(env.observation_fn(x))
flatten_obs_shape = (int(np.prod(env.observation_space.shape)),)
return env.replace(
observation_fn=flatten_obs_fn,
observation_space=env.observation_space.replace(shape=flatten_obs_shape),
)


@dataclass
class Args:
project_name = "navix-baselines"
budget: int = 10_000_000
seeds_offset: int = 0
n_seeds: int = 10


if __name__ == "__main__":
args = tyro.cli(Args)

ppo_hparams = PPOHparams(budget=args.budget)
# create environments
for env_id in nx.registry():
# init logging
config = {**vars(args), **asdict(ppo_hparams)}
wandb.init(project=args.project_name, config=config)

# init environment
env = FlattenObsWrapper(nx.make(env_id))

# create agent
network = nn.Sequential(
[
nn.Dense(
64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
),
nn.tanh,
nn.Dense(
64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
),
nn.tanh,
]
)
agent = PPO(
hparams=ppo_hparams,
network=ActorCritic(action_dim=len(env.action_set)),
env=env,
)

# train agent
seeds = range(args.seeds_offset, args.seeds_offset + args.n_seeds)
rngs = jnp.asarray([jax.random.PRNGKey(seed) for seed in seeds])
train_fn = jax.vmap(agent.train)

print("Compiling training function...")
start_time = time.time()
train_fn = jax.jit(train_fn).lower(rngs).compile()
compilation_time = time.time() - start_time
print(f"Compilation time cost: {compilation_time}")

print("Training agent...")
start_time = time.time()
train_state, logs = train_fn(rngs)
training_time = time.time() - start_time
print(f"Training time cost: {training_time}")

print("Logging final results to wandb...")
start_time = time.time()
# transpose logs tree
logs = jax.tree_map(lambda *args: jnp.stack(args), *logs)
for log in logs:
agent.log_on_train_end(log)
logging_time = time.time() - start_time
print(f"Logging time cost: {logging_time}")

print("Training complete")
print(f"Compilation time cost: {compilation_time}")
print(f"Training time cost: {training_time}")
total_time = compilation_time + training_time
print(f"Logging time cost: {logging_time}")
total_time += logging_time
print(f"Total time cost: {total_time}")
4 changes: 2 additions & 2 deletions docs/performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@


def profile_navix(seed):
env = nx.environments.Room(16, 16, 8)
env = nx.make("Navix-Empty-5x5-v0", max_steps=100)
key = jax.random.PRNGKey(seed)
timestep = env.reset(key)
timestep = env._reset(key)
actions = jax.random.randint(key, (N_TIMESTEPS,), 0, 6)

# for loop
Expand Down
60 changes: 60 additions & 0 deletions examples/ppo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from dataclasses import dataclass, field
import tyro
import numpy as np
import jax.numpy as jnp
import navix as nx
from navix import observations
from navix.agents import PPO, PPOHparams, ActorCritic
from navix.environments.environment import Environment

# set persistent compilation cache directory
# jax.config.update("jax_compilation_cache_dir", "/tmp/jax-cache/")


@dataclass
class Args:
project_name = "navix-examples"
seeds_offset: int = 0
n_seeds: int = 1
# env
env_id: str = "Navix-Empty-Random-5x5-v0"
discount: float = 0.99
# ppo
ppo_config: PPOHparams = field(default_factory=PPOHparams)


if __name__ == "__main__":
args = tyro.cli(Args)

def FlattenObsWrapper(env: Environment):
flatten_obs_fn = lambda x: jnp.ravel(env.observation_fn(x))
flatten_obs_shape = (int(np.prod(env.observation_space.shape)),)
return env.replace(
observation_fn=flatten_obs_fn,
observation_space=env.observation_space.replace(shape=flatten_obs_shape),
)

env = nx.make(
args.env_id,
observation_fn=observations.symbolic_first_person,
gamma=args.discount,
)
env = FlattenObsWrapper(env)

agent = PPO(
hparams=args.ppo_config,
network=ActorCritic(
action_dim=len(env.action_set),
),
env=env,
)

experiment = nx.Experiment(
name=args.project_name,
budget=1_000_000,
agent=agent,
env=env,
env_id=args.env_id,
seeds=tuple(range(args.seeds_offset, args.seeds_offset + args.n_seeds)),
)
train_state, logs = experiment.run()
4 changes: 4 additions & 0 deletions navix/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,8 @@
rendering,
transitions,
events,
agents,
)

from .environments.registry import make, register_env, registry
from .experiment import Experiment
2 changes: 1 addition & 1 deletion navix/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
# under the License.


__version__ = "0.5.0"
__version__ = "0.6.0"
__version_info__ = tuple(int(i) for i in __version__.split(".") if i.isdigit())
43 changes: 30 additions & 13 deletions navix/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,13 @@
import jax
from jax import Array
import jax.numpy as jnp
import jax.tree_util as jtu

from .entities import Entities
from .entities import Entities, Player
from .states import EventsManager, State
from .components import DISCARD_PILE_COORDS
from .components import DISCARD_PILE_COORDS, Pickable
from .grid import translate, rotate, positions_equal


class Directions:
EAST = jnp.asarray(0)
SOUTH = jnp.asarray(1)
WEST = jnp.asarray(2)
NORTH = jnp.asarray(3)


def _rotate(state: State, spin: int) -> State:
if Entities.PLAYER not in state.entities:
return state
Expand Down Expand Up @@ -159,11 +151,27 @@ def pickup(state: State) -> State:


def drop(state: State) -> State:
raise NotImplementedError()
"""Replaces the position in front of the player with the item in the pocket."""
player = state.get_player(idx=0)

position_in_front = translate(player.position, player.direction)

has_item = player.pocket != -1
can_drop, events = _can_walk_there(state, position_in_front)
can_drop = jnp.logical_and(can_drop, has_item)

for k in state.entities:
entity = state.entities[k]
if isinstance(entity, Pickable):
cond = jnp.logical_and(can_drop, entity.position == DISCARD_PILE_COORDS)
position = jnp.where(cond, position_in_front, entity.position)
entity = entity.replace(position=position)
state.set_entity(k, entity)
return state


def toggle(state: State) -> State:
raise NotImplementedError()
return open(state)


def open(state: State) -> State:
Expand Down Expand Up @@ -242,5 +250,14 @@ def done(state: State) -> State:
done,
)

MINIGRID_ACTION_SET = (
rotate_ccw,
rotate_cw,
forward,
pickup,
drop,
toggle,
done,
)

DEFAULT_ACTION_SET = COMPLETE_ACTION_SET
DEFAULT_ACTION_SET = MINIGRID_ACTION_SET
22 changes: 22 additions & 0 deletions navix/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright 2023 The Navix Authors.

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.


from .ppo import PPO, PPOHparams as PPOHparams
from .models import MLPEncoder, ConvEncoder, ActorCritic
20 changes: 20 additions & 0 deletions navix/agents/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from dataclasses import dataclass
from typing import Dict, Tuple
import jax
from flax import struct
from flax.training.train_state import TrainState


@dataclass
class HParams:
debug: bool = False


class Agent(struct.PyTreeNode):
hparams: HParams

def train(self, rng: jax.Array) -> Tuple[TrainState, Dict[str, jax.Array]]:
raise NotImplementedError

def log_on_train_end(self, logs: Dict[str, jax.Array]):
raise NotImplementedError
Loading

0 comments on commit 346198e

Please sign in to comment.