diff --git a/pgx/minatar/asterix.py b/pgx/minatar/asterix.py index d7901a44f..7b5e42b89 100644 --- a/pgx/minatar/asterix.py +++ b/pgx/minatar/asterix.py @@ -174,20 +174,6 @@ def _step_det( lr, is_gold, slot, -): - return jax.lax.cond( - state._terminal, - lambda: state.replace(_last_action=action, reward=jnp.zeros_like(state.reward)), # type: ignore - lambda: _step_det_at_non_terminal(state, action, lr, is_gold, slot), - ) - - -def _step_det_at_non_terminal( - state: State, - action: jnp.ndarray, - lr: bool, - is_gold: bool, - slot: int, ): ramping: bool = True r = jnp.float32(0) diff --git a/pgx/minatar/breakout.py b/pgx/minatar/breakout.py index 361292c05..f0ba28d45 100644 --- a/pgx/minatar/breakout.py +++ b/pgx/minatar/breakout.py @@ -146,14 +146,6 @@ def _init(rng: jnp.ndarray) -> State: def _step_det(state: State, action: jnp.ndarray): - return jax.lax.cond( - state._terminal, - lambda: state.replace(_last_action=action, reward=jnp.zeros_like(state.reward)), # type: ignore - lambda: _step_det_at_non_terminal(state, action), - ) - - -def _step_det_at_non_terminal(state: State, action: jnp.ndarray): ball_y = state._ball_y ball_x = state._ball_x ball_dir = state._ball_dir diff --git a/pgx/minatar/freeway.py b/pgx/minatar/freeway.py index bbf79b507..2023686e3 100644 --- a/pgx/minatar/freeway.py +++ b/pgx/minatar/freeway.py @@ -142,20 +142,6 @@ def _step_det( speeds: jnp.ndarray, directions: jnp.ndarray, ): - return jax.lax.cond( - state._terminal, - lambda: state.replace(_last_action=action, reward=jnp.zeros_like(state.reward)), # type: ignore - lambda: _step_det_at_non_terminal(state, action, speeds, directions), - ) - - -def _step_det_at_non_terminal( - state: State, - action: jnp.ndarray, - speeds: jnp.ndarray, - directions: jnp.ndarray, -): - cars = state._cars pos = state._pos move_timer = state._move_timer diff --git a/pgx/minatar/seaquest.py b/pgx/minatar/seaquest.py index 398b655dd..9dccf709a 100644 --- a/pgx/minatar/seaquest.py +++ b/pgx/minatar/seaquest.py @@ -170,24 +170,6 @@ def _step_det( enemy_y, diver_lr, diver_y, -): - return lax.cond( - state._terminal, - lambda: state.replace(_last_action=action, reward=jnp.zeros_like(state.reward)), # type: ignore - lambda: _step_det_at_non_terminal( - state, action, enemy_lr, is_sub, enemy_y, diver_lr, diver_y - ), - ) - - -def _step_det_at_non_terminal( - state: State, - action: jnp.ndarray, - enemy_lr, - is_sub, - enemy_y, - diver_lr, - diver_y, ): ramping = TRUE diff --git a/pgx/minatar/space_invaders.py b/pgx/minatar/space_invaders.py index 46ed6608c..d3811e685 100644 --- a/pgx/minatar/space_invaders.py +++ b/pgx/minatar/space_invaders.py @@ -166,17 +166,6 @@ def _observe(state: State) -> jnp.ndarray: def _step_det( state: State, action: jnp.ndarray, -): - return lax.cond( - state._terminal, - lambda: state.replace(_last_action=action, reward=jnp.zeros_like(state.reward)), # type: ignore - lambda: _step_det_at_non_terminal(state, action), - ) - - -def _step_det_at_non_terminal( - state: State, - action: jnp.ndarray, ): r = jnp.float32(0)