Skip to content

Commit

Permalink
[Hex] Fix feature (#857)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk authored May 9, 2023
1 parent 25befb5 commit 71c2911
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 13 deletions.
7 changes: 4 additions & 3 deletions docs/hex.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ env = Hex()
| Version | `v0` |
| Number of players | `2` |
| Number of actions | `121 (= 11 x 11)` |
| Observation shape | `(11, 11, 2)` |
| Observation shape | `(11, 11, 3)` |
| Observation type | `bool` |
| Rewards | `{-1, 1}` |

Expand All @@ -51,8 +51,9 @@ env = Hex()

| Index | Description |
|:---:|:----|
| `[:, :, 0]` | represents `(11, 11)` cells filled by the current player |
| `[:, :, 1]` | represents `(11, 11)` cells filled by the opponent player of current player |
| `[:, :, 0]` | represents `(11, 11)` cells filled by `player_ix` |
| `[:, :, 1]` | represents `(11, 11)` cells filled by the opponent player of `player_id` |
| `[:, :, 2]` | represents whether `player_id` is black or white|

## Action
Each action represents the cell index to be filled.
Expand Down
17 changes: 10 additions & 7 deletions pgx/hex.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class State(v1.State):
# .
# .
# [110, 111, 112, ..., 119, 120]]
_board: jnp.ndarray = -jnp.zeros(
_board: jnp.ndarray = jnp.zeros(
11 * 11, jnp.int32
) # <0(oppo), 0(empty), 0<(self)

Expand Down Expand Up @@ -125,16 +125,19 @@ def merge(i, b):


def _observe(state: State, player_id: jnp.ndarray, size) -> jnp.ndarray:
board = jax.lax.cond(
board = jax.lax.select(
player_id == state.current_player,
lambda: state._board.reshape((size, size)),
lambda: (state._board * -1).reshape((size, size)),
state._board.reshape((size, size)),
-state._board.reshape((size, size)),
)

def make(color):
return board * color > 0
my_board = board * 1 > 0
opp_board = board * -1 > 0
my_color = jax.lax.select(
player_id == state.current_player, state._turn, 1 - state._turn
) * jnp.ones_like(my_board)

return jnp.stack(jax.vmap(make)(jnp.int8([1, -1])), 2)
return jnp.stack([my_board, opp_board, my_color], 2, dtype=jnp.bool_)


def _neighbour(xy, size):
Expand Down
8 changes: 5 additions & 3 deletions tests/test_hex.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,17 @@ def test_observe():
key = jax.random.PRNGKey(0)
state = init(key=key)
assert state.current_player == 0
assert (jnp.zeros((11, 11, 2)) == observe(state, 0)).all()
assert (jnp.zeros((11, 11, 3)) == observe(state, 0)).all()
state = step(state, 0)
assert (observe(state, 0)[:, :, 2] == 0).all()
assert (observe(state, 1)[:, :, 2] == 1).all()
state = step(state, 1)
assert (
jnp.zeros((11, 11, 2)).at[0, 0, 0].set(1).at[0, 1, 1].set(1)
jnp.zeros((11, 11, 3), dtype=jnp.bool_).at[0, 0, 0].set(True).at[0, 1, 1].set(True)
== observe(state, 0)
).all()
assert (
jnp.zeros((11, 11, 2)).at[0, 1, 0].set(1).at[0, 0, 1].set(1)
jnp.zeros((11, 11, 3), dtype=jnp.bool_).at[0, 1, 0].set(True).at[0, 0, 1].set(True).at[:, :, 2].set(True)
== observe(state, 1)
).all()

Expand Down

0 comments on commit 71c2911

Please sign in to comment.