Skip to content

Commit

Permalink
Fix color features (#908)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk authored May 26, 2023
1 parent f92ecd9 commit 3960a4d
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 16 deletions.
9 changes: 4 additions & 5 deletions pgx/chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,8 +628,10 @@ def _possible_piece_positions(state):


def _observe(state: State, player_id: jnp.ndarray):
color = jax.lax.select(
state.current_player == player_id, state._turn, 1 - state._turn
)
ones = jnp.ones((1, 8, 8), dtype=jnp.float32)
color = state._turn * ones

state = jax.lax.cond(
state.current_player == player_id, lambda: state, lambda: _flip(state)
Expand All @@ -651,9 +653,7 @@ def piece_feat(p):
rep1 = ones * (rep >= 1)
return jnp.vstack([my_pieces, opp_pieces, rep0, rep1])

# color = jax.lax.select(
# state.current_player == player_id, state._turn, 1 - state._turn
# )
board_feat = jax.vmap(make)(jnp.arange(8)).reshape(-1, 8, 8)
color = color * ones
total_move_cnt = (state._step_count / MAX_TERMINATION_STEPS) * ones
my_queen_side_castling_right = ones * state._can_castle_queen_side[0]
Expand All @@ -662,7 +662,6 @@ def piece_feat(p):
opp_king_side_castling_right = ones * state._can_castle_king_side[1]
no_prog_cnt = (state._halfmove_count.astype(jnp.float32) / 100.0) * ones

board_feat = jax.vmap(make)(jnp.arange(8)).reshape(-1, 8, 8)
return jnp.vstack(
[
board_feat,
Expand Down
7 changes: 5 additions & 2 deletions pgx/gardner_chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,8 +461,10 @@ def _update_zobrist_hash(state: State, action: Action):


def _observe(state: State, player_id: jnp.ndarray):
color = jax.lax.select(
state.current_player == player_id, state._turn, 1 - state._turn
)
ones = jnp.ones((1, 5, 5), dtype=jnp.float32)
color = state._turn * ones

state = jax.lax.cond(
state.current_player == player_id, lambda: state, lambda: _flip(state)
Expand All @@ -484,10 +486,11 @@ def piece_feat(p):
rep1 = ones * (rep >= 1)
return jnp.vstack([my_pieces, opp_pieces, rep0, rep1])

board_feat = jax.vmap(make)(jnp.arange(8)).reshape(-1, 5, 5)
color = color * ones
total_move_cnt = (state._step_count / MAX_TERMINATION_STEPS) * ones
no_prog_cnt = (state._halfmove_count.astype(jnp.float32) / 100.0) * ones

board_feat = jax.vmap(make)(jnp.arange(8)).reshape(-1, 5, 5)
return jnp.vstack(
[board_feat, color, total_move_cnt, no_prog_cnt]
).transpose((1, 2, 0))
Expand Down
5 changes: 4 additions & 1 deletion pgx/go.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ def _observe(state: State, player_id, size, history_length):
So, we use player_id's color to let the agent komi information.
As long as it's called when state.current_player == player_id, this doesn't matter.
"""
my_turn = jax.lax.select(
player_id == state.current_player, state._turn, 1 - state._turn
)
current_player_color = _my_color(state) # -1 or 1
my_color, opp_color = jax.lax.cond(
player_id == state.current_player,
Expand All @@ -155,7 +158,7 @@ def _make(i):
return state._board_history[i // 2] == color

log = _make(jnp.arange(history_length * 2))
color = jnp.full_like(log[0], my_color == 1) # black=1, white=0
color = jnp.full_like(log[0], my_turn) # black=0, white=1

return jnp.vstack([log, color]).transpose().reshape((size, size, -1))

Expand Down
10 changes: 4 additions & 6 deletions pgx/hex.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,16 +160,14 @@ def _observe(state: State, player_id: jnp.ndarray, size) -> jnp.ndarray:
my_board = board * 1 > 0
opp_board = board * -1 > 0
ones = jnp.ones_like(my_board)
my_color = (
jax.lax.select(
player_id == state.current_player, state._turn, 1 - state._turn
)
* ones
color = jax.lax.select(
player_id == state.current_player, state._turn, 1 - state._turn
)
color = color * ones
can_swap = state.legal_action_mask[-1] * ones

return jnp.stack(
[my_board, opp_board, my_color, can_swap], 2, dtype=jnp.bool_
[my_board, opp_board, color, can_swap], 2, dtype=jnp.bool_
)


Expand Down
8 changes: 6 additions & 2 deletions tests/test_go.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,10 @@ def test_observe():
state = init(key=key)
assert state.current_player == 1
# player 0 is white, player 1 is black
obs = observe(state, 1) # black turn, black view
assert (obs[:, :, -1] == 0).all()
obs = observe(state, 0) # black turn, white view
assert (obs[:, :, -1] == 1).all()

state = step(state=state, action=0)
state = step(state=state, action=1)
Expand Down Expand Up @@ -306,13 +310,13 @@ def test_observe():
assert obs.shape == (5, 5, 17)
assert (obs[:, :, 0] == (curr_board == -1)).all()
assert (obs[:, :, 1] == (curr_board == 1)).all()
assert (obs[:, :, -1] == 0).all()
assert (obs[:, :, -1] == 1).all()

obs = observe(state, 1) # black
assert obs.shape == (5, 5, 17)
assert (obs[:, :, 0] == (curr_board == 1)).all()
assert (obs[:, :, 1] == (curr_board == -1)).all()
assert (obs[:, :, -1] == 1).all()
assert (obs[:, :, -1] == 0).all()


def test_legal_action():
Expand Down

0 comments on commit 3960a4d

Please sign in to comment.