Skip to content

Commit

Permalink
[MinAtar] Remove unused termination condition (#816)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk authored Apr 26, 2023
1 parent e268d72 commit 98a0ada
Show file tree
Hide file tree
Showing 5 changed files with 0 additions and 65 deletions.
14 changes: 0 additions & 14 deletions pgx/minatar/asterix.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,20 +174,6 @@ def _step_det(
lr,
is_gold,
slot,
):
return jax.lax.cond(
state._terminal,
lambda: state.replace(_last_action=action, reward=jnp.zeros_like(state.reward)), # type: ignore
lambda: _step_det_at_non_terminal(state, action, lr, is_gold, slot),
)


def _step_det_at_non_terminal(
state: State,
action: jnp.ndarray,
lr: bool,
is_gold: bool,
slot: int,
):
ramping: bool = True
r = jnp.float32(0)
Expand Down
8 changes: 0 additions & 8 deletions pgx/minatar/breakout.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,14 +146,6 @@ def _init(rng: jnp.ndarray) -> State:


def _step_det(state: State, action: jnp.ndarray):
return jax.lax.cond(
state._terminal,
lambda: state.replace(_last_action=action, reward=jnp.zeros_like(state.reward)), # type: ignore
lambda: _step_det_at_non_terminal(state, action),
)


def _step_det_at_non_terminal(state: State, action: jnp.ndarray):
ball_y = state._ball_y
ball_x = state._ball_x
ball_dir = state._ball_dir
Expand Down
14 changes: 0 additions & 14 deletions pgx/minatar/freeway.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,20 +142,6 @@ def _step_det(
speeds: jnp.ndarray,
directions: jnp.ndarray,
):
return jax.lax.cond(
state._terminal,
lambda: state.replace(_last_action=action, reward=jnp.zeros_like(state.reward)), # type: ignore
lambda: _step_det_at_non_terminal(state, action, speeds, directions),
)


def _step_det_at_non_terminal(
state: State,
action: jnp.ndarray,
speeds: jnp.ndarray,
directions: jnp.ndarray,
):

cars = state._cars
pos = state._pos
move_timer = state._move_timer
Expand Down
18 changes: 0 additions & 18 deletions pgx/minatar/seaquest.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,24 +170,6 @@ def _step_det(
enemy_y,
diver_lr,
diver_y,
):
return lax.cond(
state._terminal,
lambda: state.replace(_last_action=action, reward=jnp.zeros_like(state.reward)), # type: ignore
lambda: _step_det_at_non_terminal(
state, action, enemy_lr, is_sub, enemy_y, diver_lr, diver_y
),
)


def _step_det_at_non_terminal(
state: State,
action: jnp.ndarray,
enemy_lr,
is_sub,
enemy_y,
diver_lr,
diver_y,
):
ramping = TRUE

Expand Down
11 changes: 0 additions & 11 deletions pgx/minatar/space_invaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,17 +166,6 @@ def _observe(state: State) -> jnp.ndarray:
def _step_det(
state: State,
action: jnp.ndarray,
):
return lax.cond(
state._terminal,
lambda: state.replace(_last_action=action, reward=jnp.zeros_like(state.reward)), # type: ignore
lambda: _step_det_at_non_terminal(state, action),
)


def _step_det_at_non_terminal(
state: State,
action: jnp.ndarray,
):
r = jnp.float32(0)

Expand Down

0 comments on commit 98a0ada

Please sign in to comment.