Skip to content

Commit

Permalink
add seeding to random choices in substrate configurations
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 688547963
Change-Id: I732368f402801f8ce6b679a88545530219ff4c64
  • Loading branch information
vezhnick authored and copybara-github committed Oct 22, 2024
1 parent 5ede0bd commit f6256d7
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 12 deletions.
18 changes: 13 additions & 5 deletions examples/modular/environment/modules/player_traits_and_styles.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,30 +177,38 @@
)


def get_trait(flowery: bool = False) -> str:
def get_trait(flowery: bool = False,
rng: random.Random | None = None) -> str:
"""Get a random personality trait from a preset list of traits.
Args:
flowery: if True then use complex and flowery traits, if false then use
single word traits.
rng: a random number generator.
Returns:
trait: a string
"""
if rng is None:
rng = random.Random()
if flowery:
return random.choice(FLOWERY_TRAITS)
return rng.choice(FLOWERY_TRAITS)
else:
return random.choice(TRAITS)
return rng.choice(TRAITS)


def get_conversation_style(player_name: str) -> str:
def get_conversation_style(player_name: str,
rng: random.Random | None = None) -> str:
"""Get a random conversation style from a preset list of styles.
Args:
player_name: name of the player who will be said to have the sampled style
of conversation.
rng: a random number generator.
Returns:
style: a string
"""
return random.choice(CONVERSATION_STYLES).format(player_name=player_name)
if rng is None:
rng = random.Random()
return rng.choice(CONVERSATION_STYLES).format(player_name=player_name)
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,12 @@

def sample_parameters(seed: int | None = None):
"""Samples a set of parameters for the world configuration."""
pubs = random.sample(list(PUB_PREFERENCES.keys()), NUM_PUBS)
pub_preferences = {k: PUB_PREFERENCES[k] for k in pubs}

seed = seed if seed is not None else random.getrandbits(63)
rng = random.Random(seed)

pubs = rng.sample(list(PUB_PREFERENCES.keys()), NUM_PUBS)
pub_preferences = {k: PUB_PREFERENCES[k] for k in pubs}

config = pub_coordination.WorldConfig(
year=YEAR,
Expand All @@ -212,8 +215,6 @@ def sample_parameters(seed: int | None = None):
random_seed=seed,
)

rng = random.Random(config.random_seed)

all_names = list(MALE_NAMES) + list(FEMALE_NAMES)

rng.shuffle(all_names)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,11 @@

def sample_parameters(seed: int | None = None):
"""Samples a set of parameters for the world configuration."""
pubs = random.sample(list(PUB_PREFERENCES.keys()), NUM_PUBS)
pub_preferences = {k: PUB_PREFERENCES[k] for k in pubs}
seed = seed if seed is not None else random.getrandbits(63)
rng = random.Random(seed)

pubs = rng.sample(list(PUB_PREFERENCES.keys()), NUM_PUBS)
pub_preferences = {k: PUB_PREFERENCES[k] for k in pubs}

config = pub_coordination.WorldConfig(
year=YEAR,
Expand All @@ -294,7 +296,6 @@ def sample_parameters(seed: int | None = None):
social_context=SOCIAL_CONTEXT,
random_seed=seed,
)
rng = random.Random(config.random_seed)

all_names = list(MALE_NAMES) + list(FEMALE_NAMES)

Expand Down

0 comments on commit f6256d7

Please sign in to comment.