From 0978d9eeddd1c892c9379e7c1379fd610e45202a Mon Sep 17 00:00:00 2001 From: John Agapiou Date: Fri, 1 Dec 2023 15:18:43 -0800 Subject: [PATCH 01/14] Drop support for Python 3.10 PiperOrigin-RevId: 587137040 Change-Id: Ib536e639ca8fda954f33781381d8238c6e29a49d --- pyproject.toml | 2 +- setup.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 02ad3de..d7f758a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ addopts = "-n auto" testpaths = ["concordia", "examples"] [tool.pytype] -python_version = "3.10" +python_version = "3.11" inputs = ["concordia", "examples"] # Keep going past errors to analyze as many files as possible. keep_going = true diff --git a/setup.py b/setup.py index a9d9547..d5f953b 100644 --- a/setup.py +++ b/setup.py @@ -42,15 +42,15 @@ 'Operating System :: POSIX :: Linux', 'Operating System :: MacOS :: MacOS X', 'Programming Language :: Python :: 3 :: Only', - 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', 'Topic :: Scientific/Engineering :: Artificial Intelligence', ], package_dir={ 'concordia': 'concordia', }, package_data={}, - python_requires='>=3.10', + python_requires='>=3.11', install_requires=[ # TODO: b/312199199 - remove some requirements. 'absl-py', From 36b1b6ea07467058d569b5ae104813f8146aa085 Mon Sep 17 00:00:00 2001 From: John Agapiou Date: Mon, 4 Dec 2023 03:33:10 -0800 Subject: [PATCH 02/14] Remove nonexistant .devcontainer from dependabot config PiperOrigin-RevId: 587660226 Change-Id: I6b37b7ce0d790f1a649c22b2fcc5b5047ae2b882 --- .github/dependabot.yml | 8 -------- 1 file changed, 8 deletions(-) diff --git a/.github/dependabot.yml b/.github/dependabot.yml index e933562..9ab3cc0 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -9,11 +9,3 @@ updates: directory: / schedule: interval: monthly - - - package-ecosystem: docker - directory: /.devcontainer - schedule: - interval: monthly - ignore: - - dependency-name: "vscode/devcontainers/python" - versions: [">= 3.11"] From a9620033953b53efb94c0ee896affd0225202624 Mon Sep 17 00:00:00 2001 From: John Agapiou Date: Mon, 4 Dec 2023 03:34:04 -0800 Subject: [PATCH 03/14] Update pypi package name PiperOrigin-RevId: 587660445 Change-Id: Ie1203e09470e7481ecbd24a24a2b6ab7f20df526 --- .github/workflows/pypi-publish.yml | 2 +- .github/workflows/pypi-test.yml | 2 +- README.md | 4 ++-- examples/requirements.txt | 2 +- setup.py | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/pypi-publish.yml b/.github/workflows/pypi-publish.yml index 484d9de..c97c064 100644 --- a/.github/workflows/pypi-publish.yml +++ b/.github/workflows/pypi-publish.yml @@ -28,7 +28,7 @@ jobs: runs-on: ubuntu-latest environment: name: pypi - url: https://pypi.org/p/dm-concordia + url: https://pypi.org/p/gdm-concordia permissions: id-token: write timeout-minutes: 90 diff --git a/.github/workflows/pypi-test.yml b/.github/workflows/pypi-test.yml index 6c2d01a..896b65c 100644 --- a/.github/workflows/pypi-test.yml +++ b/.github/workflows/pypi-test.yml @@ -56,7 +56,7 @@ jobs: - name: Install from PyPI run: | - pip -vvv install dm-concordia + pip -vvv install gdm-concordia pip list - name: Test installation diff --git a/README.md b/README.md index 0be1360..8d50e90 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,8 @@ *A library for generative social simulation* -[![Python](https://img.shields.io/pypi/pyversions/dm-concordia.svg)](https://pypi.python.org/pypi/dm-concordia) -[![PyPI version](https://img.shields.io/pypi/v/dm-concordia.svg)](https://pypi.python.org/pypi/dm-concordia) +[![Python](https://img.shields.io/pypi/pyversions/gdm-concordia.svg)](https://pypi.python.org/pypi/gdm-concordia) +[![PyPI version](https://img.shields.io/pypi/v/gdm-concordia.svg)](https://pypi.python.org/pypi/gdm-concordia) [![PyPI tests](../../actions/workflows/pypi-test.yml/badge.svg)](../../actions/workflows/pypi-test.yml) [![Tests](../../actions/workflows/test-concordia.yml/badge.svg)](../../actions/workflows/test-concordia.yml) [![Examples](../../actions/workflows/test-examples.yml/badge.svg)](../../actions/workflows/test-examples.yml) diff --git a/examples/requirements.txt b/examples/requirements.txt index a7d8842..1f00ced 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -1,2 +1,2 @@ -dm-concordia +gdm-concordia termcolor diff --git a/setup.py b/setup.py index d5f953b..e5a4765 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ import setuptools setuptools.setup( - name='dm-concordia', + name='gdm-concordia', version='1.0.0.dev.0', license='Apache 2.0', license_files=['LICENSE'], From 1b1fe98d48f7b96a715abb80fe05a0f81702b7ac Mon Sep 17 00:00:00 2001 From: Sasha Vezhnevets Date: Mon, 4 Dec 2023 05:38:31 -0800 Subject: [PATCH 04/14] Explaining the meaning of None returned by the component state. PiperOrigin-RevId: 587686913 Change-Id: I6f59f30d39b5bd35e3b5ab00bae257de434dd0cd --- concordia/typing/component.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/concordia/typing/component.py b/concordia/typing/component.py index e33786c..cea31f2 100644 --- a/concordia/typing/component.py +++ b/concordia/typing/component.py @@ -37,14 +37,28 @@ def name( def state( self, ) -> str | None: - """Returns the current state of the component.""" + """Returns the current state of the component. + + Returns: + state of the component or None. If none is returned, then the component + will be omitted while forming the context of action. + """ pass def partial_state( self, player_name: str, ) -> str | None: - """Returns the specified player's view of the component's current state.""" + """Returns the specified player's view of the component's current state. + + Args: + player_name: the name of the player for which the view is generated. + + Returns: + specified player's view of the component's current state or None. If none + is returned, then the component will not be sent to the player. + """ + del player_name return None From d2d8c4ecd3e8be1adcb82a4e902500305bfa2a15 Mon Sep 17 00:00:00 2001 From: John Agapiou Date: Mon, 4 Dec 2023 05:59:29 -0800 Subject: [PATCH 05/14] Remove dependency on rwlock PiperOrigin-RevId: 587691171 Change-Id: I0337fd6c25edf7f44237f363823e3a20c740b06a --- concordia/associative_memory/associative_memory.py | 14 +++++++------- setup.py | 1 - 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/concordia/associative_memory/associative_memory.py b/concordia/associative_memory/associative_memory.py index 71aeb6b..5ead51a 100644 --- a/concordia/associative_memory/associative_memory.py +++ b/concordia/associative_memory/associative_memory.py @@ -21,10 +21,10 @@ """ from collections.abc import Callable, Iterable import datetime +import threading import numpy as np import pandas as pd -import rwlock class AssociativeMemory: @@ -46,7 +46,7 @@ def __init__( clock_step_size: sets the step size of the clock. If None, assumes precise time """ - self._memory_bank_lock = rwlock.ReadWriteLock() + self._memory_bank_lock = threading.Lock() self._embedder = sentence_embedder self._importance = importance @@ -92,7 +92,7 @@ def add( .T ) - with self._memory_bank_lock.AcquireWrite(): + with self._memory_bank_lock: self._memory_bank = pd.concat( [self._memory_bank, new_df], ignore_index=True ) @@ -112,7 +112,7 @@ def extend( self.add(text, **kwargs) def get_data_frame(self): - with self._memory_bank_lock.AcquireRead(): + with self._memory_bank_lock: return self._memory_bank.copy() def _get_top_k_cosine(self, x: np.ndarray, k: int): @@ -125,7 +125,7 @@ def _get_top_k_cosine(self, x: np.ndarray, k: int): Returns: Rows, sorted by cosine similarity in descending order. """ - with self._memory_bank_lock.AcquireRead(): + with self._memory_bank_lock: cosine_similarities = self._memory_bank['embedding'].apply( lambda y: np.dot(x, y) ) @@ -150,7 +150,7 @@ def _get_top_k_similar_rows( Returns: Rows, sorted by cosine similarity in descending order. """ - with self._memory_bank_lock.AcquireRead(): + with self._memory_bank_lock: cosine_similarities = self._memory_bank['embedding'].apply( lambda y: np.dot(x, y) ) @@ -175,7 +175,7 @@ def _get_top_k_similar_rows( return self._memory_bank.iloc[similarity_score.head(k).index] def _get_k_recent(self, k: int): - with self._memory_bank_lock.AcquireRead(): + with self._memory_bank_lock: recency = self._memory_bank['time'].sort_values(ascending=False) return self._memory_bank.iloc[recency.head(k).index] diff --git a/setup.py b/setup.py index e5a4765..c17cf8b 100644 --- a/setup.py +++ b/setup.py @@ -61,7 +61,6 @@ 'python_dateutil', 'reactivex', 'retry', - 'rwlock', 'saxml', 'scipy', 'tensorflow', From 9b75546d2866b3f194a95869973c56e1415e1de4 Mon Sep 17 00:00:00 2001 From: Sasha Vezhnevets Date: Mon, 4 Dec 2023 06:27:13 -0800 Subject: [PATCH 06/14] add lock on reading/writing agents state PiperOrigin-RevId: 587697951 Change-Id: Icdfc1e0ffc64b443935d58574c31297eaf440909 --- concordia/agents/basic_agent.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/concordia/agents/basic_agent.py b/concordia/agents/basic_agent.py index 7201ace..22b88be 100644 --- a/concordia/agents/basic_agent.py +++ b/concordia/agents/basic_agent.py @@ -26,7 +26,7 @@ import contextlib import copy import datetime - +import threading from concordia.associative_memory import associative_memory from concordia.document import interactive_document from concordia.language_model import language_model @@ -86,6 +86,7 @@ def __init__( self._update_interval = update_interval self._under_interrogation = False + self._state_lock = threading.Lock() self._components = {} for comp in components: @@ -166,10 +167,8 @@ def get_last_log(self): return self._last_chain_of_thought def state(self): - return '\n'.join( - f"{self._agent_name}'s " + (comp.name() + ':\n' + comp.state()) - for comp in self._components.values() - ) + with self._state_lock: + return self._state def _maybe_update(self): next_update = self._last_update + self._update_interval @@ -181,6 +180,11 @@ def update(self): with concurrent.futures.ThreadPoolExecutor() as executor: for comp in self._components.values(): executor.submit(comp.update) + with self._state_lock: + self._state = '\n'.join( + f"{self._agent_name}'s " + (comp.name() + ':\n' + comp.state()) + for comp in self._components.values() + ) def observe(self, observation: str): if observation and not self._under_interrogation: From 46be54109e0f4d933756dc24bb76f236a99782f0 Mon Sep 17 00:00:00 2001 From: Sasha Vezhnevets Date: Mon, 4 Dec 2023 06:40:49 -0800 Subject: [PATCH 07/14] making sure clock_now is passed when clock is not needed to avoid side-effects PiperOrigin-RevId: 587700999 Change-Id: I76b684860677605828d0e19db93a4a4ad5f113d1 --- concordia/agents/components/characteristic.py | 23 ++++++++-------- .../agents/components/person_by_situation.py | 27 +++++++++---------- concordia/agents/components/report_state.py | 7 +++-- .../agents/components/self_perception.py | 13 ++++----- .../agents/components/situation_perception.py | 13 ++++----- concordia/agents/components/somatic_state.py | 22 +++++++-------- .../environment/components/conversation.py | 13 +++++---- concordia/environment/components/schedule.py | 10 +++---- concordia/examples/phone/calendar.ipynb | 2 +- concordia/examples/three_key_questions.ipynb | 6 ++--- concordia/tests/concordia_integration_test.py | 2 +- examples/village/riverbend_elections.ipynb | 4 +-- 12 files changed, 70 insertions(+), 72 deletions(-) diff --git a/concordia/agents/components/characteristic.py b/concordia/agents/components/characteristic.py index 85415da..8f82be4 100644 --- a/concordia/agents/components/characteristic.py +++ b/concordia/agents/components/characteristic.py @@ -14,11 +14,12 @@ """Agent characteristic component.""" +import datetime +from typing import Callable from concordia.associative_memory import associative_memory from concordia.document import interactive_document from concordia.language_model import language_model -from concordia.typing import clock as game_clock from concordia.typing import component import termcolor @@ -50,7 +51,7 @@ def __init__( memory: associative_memory.AssociativeMemory, agent_name: str, characteristic_name: str, - state_clock: game_clock.GameClock | None = None, + state_clock_now: Callable[[], datetime.datetime] | None = None, extra_instructions: str = '', num_memories_to_retrieve: int = 25, verbose: bool = False, @@ -62,7 +63,7 @@ def __init__( memory: an associative memory agent_name: the name of the agent characteristic_name: the string to use in similarity search of memory - state_clock: if None then consider this component as representing a + state_clock_now: if None then consider this component as representing a `trait`. If a clock is used then consider this component to represent a `state`. A state is temporary whereas a trait is meant to endure. extra_instructions: append additional instructions when asking the model @@ -77,7 +78,7 @@ def __init__( self._characteristic_name = characteristic_name self._agent_name = agent_name self._extra_instructions = extra_instructions - self._clock = state_clock + self._clock_now = state_clock_now self._num_memories_to_retrieve = num_memories_to_retrieve def name(self) -> str: @@ -88,13 +89,13 @@ def state(self) -> str: def update(self) -> None: query = f"{self._agent_name}'s {self._characteristic_name}" - if self._clock is not None: - query = f'[{self._clock.now()}] {query}' + if self._clock_now is not None: + query = f'[{self._clock_now()}] {query}' mems = '\n'.join( - self._memory.retrieve_associative(query, - self._num_memories_to_retrieve, - add_time=True) + self._memory.retrieve_associative( + query, self._num_memories_to_retrieve, add_time=True + ) ) prompt = interactive_document.InteractiveDocument(self._model) @@ -105,8 +106,8 @@ def update(self) -> None: f'{self._extra_instructions}' f'Start the answer with "{self._agent_name} is"' ) - if self._clock is not None: - question = f'Current time: {self._clock.now()}.\n{question}' + if self._clock_now is not None: + question = f'Current time: {self._clock_now()}.\n{question}' self._cache = prompt.open_question( '\n'.join([question, f'Statements:\n{mems}']), diff --git a/concordia/agents/components/person_by_situation.py b/concordia/agents/components/person_by_situation.py index 19dd7d0..355358a 100644 --- a/concordia/agents/components/person_by_situation.py +++ b/concordia/agents/components/person_by_situation.py @@ -13,12 +13,13 @@ # limitations under the License. """Agent component for self perception.""" - +import datetime +from typing import Callable from typing import Sequence + from concordia.associative_memory import associative_memory from concordia.document import interactive_document from concordia.language_model import language_model -from concordia.typing import clock from concordia.typing import component import termcolor @@ -33,7 +34,7 @@ def __init__( memory: associative_memory.AssociativeMemory, agent_name: str, components=Sequence[component.Component] | None, - state_clock: clock.GameClock | None = None, + clock_now: Callable[[], datetime.datetime] | None = None, num_memories_to_retrieve: int = 25, verbose: bool = False, ): @@ -45,7 +46,7 @@ def __init__( memory: The memory to use. agent_name: The name of the agent. components: The components to condition the answer on. - state_clock: The clock to use. + clock_now: time callback to use for the state. num_memories_to_retrieve: The number of memories to retrieve. verbose: Whether to print the state of the component. """ @@ -56,7 +57,7 @@ def __init__( self._state = '' self._components = components or [] self._agent_name = agent_name - self._clock = state_clock + self._clock_now = clock_now self._num_memories_to_retrieve = num_memories_to_retrieve self._name = name @@ -77,21 +78,19 @@ def update(self) -> None: prompt.statement(f'Memories of {self._agent_name}:\n{mems}') - component_states = '\n'.join( - [ - f"{self._agent_name}'s " - + (construct.name() + ':\n' + construct.state()) - for construct in self._components - ] - ) + component_states = '\n'.join([ + f"{self._agent_name}'s " + + (construct.name() + ':\n' + construct.state()) + for construct in self._components + ]) prompt.statement(component_states) question = ( f'What would a person like {self._agent_name} do in a situation like' ' this?' ) - if self._clock is not None: - question = f'Current time: {self._clock.now()}.\n{question}' + if self._clock_now is not None: + question = f'Current time: {self._clock_now()}.\n{question}' self._state = prompt.open_question( question, diff --git a/concordia/agents/components/report_state.py b/concordia/agents/components/report_state.py index 6e15347..ce6e3a9 100644 --- a/concordia/agents/components/report_state.py +++ b/concordia/agents/components/report_state.py @@ -15,11 +15,10 @@ """This components report what the get_state returns at the moment. -For example, can be used for reporting current time +For example, can be used for reporting current time current_time_component = ReportState( - 'Current time', + 'Current time', get_state=clock.current_time_interval_str) - """ from typing import Callable @@ -27,7 +26,7 @@ class ReportState(component.Component): - """A component that shows the current time interval.""" + """A component that reports what the get_state returns at the moment.""" def __init__(self, get_state: Callable[[], str], name: str = 'State'): """Initializes the component. diff --git a/concordia/agents/components/self_perception.py b/concordia/agents/components/self_perception.py index 8d05363..22df178 100644 --- a/concordia/agents/components/self_perception.py +++ b/concordia/agents/components/self_perception.py @@ -13,11 +13,12 @@ # limitations under the License. """Agent component for self perception.""" +import datetime +from typing import Callable from concordia.associative_memory import associative_memory from concordia.document import interactive_document from concordia.language_model import language_model -from concordia.typing import clock from concordia.typing import component import termcolor @@ -31,7 +32,7 @@ def __init__( model: language_model.LanguageModel, memory: associative_memory.AssociativeMemory, agent_name: str, - state_clock: clock.GameClock | None = None, + clock_now: Callable[[], datetime.datetime] | None = None, num_memories_to_retrieve: int = 100, verbose: bool = False, ): @@ -42,7 +43,7 @@ def __init__( model: Language model. memory: Associative memory. agent_name: Name of the agent. - state_clock: Clock to use for the state. + clock_now: time callback to use for the state. num_memories_to_retrieve: Number of memories to retrieve. verbose: Whether to print the state. """ @@ -52,7 +53,7 @@ def __init__( self._memory = memory self._state = '' self._agent_name = agent_name - self._clock = state_clock + self._clock_now = clock_now self._num_memories_to_retrieve = num_memories_to_retrieve self._name = name @@ -72,8 +73,8 @@ def update(self) -> None: prompt = interactive_document.InteractiveDocument(self._model) prompt.statement(f'Memories of {self._agent_name}:\n{mems}') - if self._clock is not None: - prompt.statement(f'Current time: {self._clock.now()}.\n') + if self._clock_now is not None: + prompt.statement(f'Current time: {self._clock_now()}.\n') question = ( f'Given the memories above, what kind of person is {self._agent_name}?' diff --git a/concordia/agents/components/situation_perception.py b/concordia/agents/components/situation_perception.py index 93111de..63ef5df 100644 --- a/concordia/agents/components/situation_perception.py +++ b/concordia/agents/components/situation_perception.py @@ -13,11 +13,12 @@ # limitations under the License. """Agent component for situation perception.""" +import datetime +from typing import Callable from concordia.associative_memory import associative_memory from concordia.document import interactive_document from concordia.language_model import language_model -from concordia.typing import clock from concordia.typing import component import termcolor @@ -31,7 +32,7 @@ def __init__( model: language_model.LanguageModel, memory: associative_memory.AssociativeMemory, agent_name: str, - state_clock: clock.GameClock | None = None, + clock_now: Callable[[], datetime.datetime] | None = None, num_memories_to_retrieve: int = 25, verbose: bool = False, ): @@ -42,7 +43,7 @@ def __init__( model: The language model to use. memory: The memory to use. agent_name: The name of the agent. - state_clock: The clock to use. + clock_now: time callback to use for the state. num_memories_to_retrieve: The number of memories to retrieve. verbose: Whether to print the last chain. """ @@ -51,7 +52,7 @@ def __init__( self._memory = memory self._state = '' self._agent_name = agent_name - self._clock = state_clock + self._clock_now = clock_now self._num_memories_to_retrieve = num_memories_to_retrieve self._name = name @@ -71,8 +72,8 @@ def update(self) -> None: prompt = interactive_document.InteractiveDocument(self._model) prompt.statement(f'Memories of {self._agent_name}:\n{mems}') - if self._clock is not None: - prompt.statement(f'Current time: {self._clock.now()}.\n') + if self._clock_now is not None: + prompt.statement(f'Current time: {self._clock_now()}.\n') question = ( 'Given the memories above, what kind of situation is' diff --git a/concordia/agents/components/somatic_state.py b/concordia/agents/components/somatic_state.py index f994cdf..cb8ec2d 100644 --- a/concordia/agents/components/somatic_state.py +++ b/concordia/agents/components/somatic_state.py @@ -14,12 +14,12 @@ """Agent component for tracking the somatic state.""" - import concurrent +import datetime +from typing import Callable from concordia.agents.components import characteristic from concordia.associative_memory import associative_memory from concordia.language_model import language_model -from concordia.typing import clock as game_clock from concordia.typing import component @@ -35,7 +35,7 @@ def __init__( model: language_model.LanguageModel, memory: associative_memory.AssociativeMemory, agent_name: str, - clock: game_clock.GameClock, + clock_now: Callable[[], datetime.datetime] | None = None, summarize: bool = True, ): """Initialize somatic state component. @@ -44,7 +44,7 @@ def __init__( model: a language model memory: an associative memory agent_name: the name of the agent - clock: the game clock is needed to know when is the current time + clock_now: time callback to use for the state. summarize: if True, the resulting state will be a one sentence summary, otherwise state it would be a concatentation of five separate characteristics @@ -53,7 +53,7 @@ def __init__( self._memory = memory self._state = '' self._agent_name = agent_name - self._clock = clock + self._clock_now = clock_now self._summarize = summarize self._characteristic_names = [ @@ -81,7 +81,7 @@ def __init__( memory=self._memory, agent_name=self._agent_name, characteristic_name=characteristic_name, - state_clock=self._clock, + state_clock_now=self._clock_now, extra_instructions=extra_instructions, ) ) @@ -97,12 +97,10 @@ def update(self): for c in self._characteristics: executor.submit(c.update) - self._state = '\n'.join( - [ - f"{self._agent_name}'s {c.name()}: " + c.state() - for c in self._characteristics - ] - ) + self._state = '\n'.join([ + f"{self._agent_name}'s {c.name()}: " + c.state() + for c in self._characteristics + ]) if self._summarize: prompt = ( f'Summarize the somatic state of {self._agent_name} in one' diff --git a/concordia/environment/components/conversation.py b/concordia/environment/components/conversation.py index e175a81..2c84aff 100644 --- a/concordia/environment/components/conversation.py +++ b/concordia/environment/components/conversation.py @@ -57,7 +57,8 @@ def __init__( players: A list of players to generate conversations for. model: A language model to use for generating utterances. memory: GM memory, used to add the summary of the conversation - clock: multi intercal game clock. + clock: multi interval game clock. If conversation happens, the clock will + advance in higher gear during the conversation scene. burner_memory_factory: a memory factory to create temporary memory for npcs and conversation gm cap_nonplayer_characters: The maximum number of non-player characters @@ -188,12 +189,10 @@ def _who_talked( who_talked = ( who_talked + 'Also present: ' - + ', '.join( - [ - npc_conversant.name - for npc_conversant in nonplayers_in_conversation - ] - ) + + ', '.join([ + npc_conversant.name + for npc_conversant in nonplayers_in_conversation + ]) + '.' ) return who_talked diff --git a/concordia/environment/components/schedule.py b/concordia/environment/components/schedule.py index ade853e..1648bfe 100644 --- a/concordia/environment/components/schedule.py +++ b/concordia/environment/components/schedule.py @@ -15,9 +15,9 @@ """This construct implements scheduled events.""" -from collections.abc import Callable import dataclasses import datetime +from typing import Callable from typing import Optional from concordia.typing import component @@ -44,10 +44,10 @@ class Schedule(component.Component): def __init__( self, - clock, - schedule, + clock_now: Callable[[], datetime.datetime], + schedule: dict[str, EventData], ): - self._clock = clock + self._clock_now = clock_now self._schedule = schedule self._state = None @@ -58,7 +58,7 @@ def state(self) -> str | None: return self._state def update(self) -> None: - now = self._clock.now() + now = self._clock_now() events = [] for _, event_data in self._schedule.items(): if now == event_data.time: diff --git a/concordia/examples/phone/calendar.ipynb b/concordia/examples/phone/calendar.ipynb index 5f22d76..104036d 100644 --- a/concordia/examples/phone/calendar.ipynb +++ b/concordia/examples/phone/calendar.ipynb @@ -226,7 +226,7 @@ " get_state=clock.current_time_interval_str)\n", "\n", " somatic_state = components.somatic_state.SomaticState(\n", - " model, mem, agent_config.name, clock\n", + " model, mem, agent_config.name, clock.now\n", " )\n", " identity = components.identity.SimIdentity(model, mem, agent_config.name)\n", " goal_component = components.constant.ConstantConstruct(state=agent_config.goal)\n", diff --git a/concordia/examples/three_key_questions.ipynb b/concordia/examples/three_key_questions.ipynb index dc032ad..ce9359e 100644 --- a/concordia/examples/three_key_questions.ipynb +++ b/concordia/examples/three_key_questions.ipynb @@ -292,7 +292,7 @@ " model=model,\n", " memory=mem,\n", " agent_name=agent_config.name,\n", - " state_clock=clock,\n", + " clock_now=clock.now,\n", " verbose=True,\n", " )\n", " situation_perception = components.situation_perception.SituationPerception(\n", @@ -300,7 +300,7 @@ " model=model,\n", " memory=mem,\n", " agent_name=agent_config.name,\n", - " state_clock=clock,\n", + " clock_now=clock.now,\n", " verbose=True,\n", " )\n", " person_by_situation = components.person_by_situation.PersonBySituation(\n", @@ -308,7 +308,7 @@ " model=model,\n", " memory=mem,\n", " agent_name=agent_config.name,\n", - " state_clock=clock,\n", + " clock_now=clock.now,\n", " components=[self_perception, situation_perception],\n", " verbose=True,\n", " )\n", diff --git a/concordia/tests/concordia_integration_test.py b/concordia/tests/concordia_integration_test.py index f682b2b..b1397e2 100644 --- a/concordia/tests/concordia_integration_test.py +++ b/concordia/tests/concordia_integration_test.py @@ -137,7 +137,7 @@ def _make_environment( } schedule_construct = gm_components.schedule.Schedule( - clock=clock, schedule=schedule + clock_now=clock.now, schedule=schedule ) player_goals = {'Alice': 'win', 'Bob': 'win'} goal_metric = goal_achievement.GoalAchievementMetric( diff --git a/examples/village/riverbend_elections.ipynb b/examples/village/riverbend_elections.ipynb index 35b9803..38ba55a 100644 --- a/examples/village/riverbend_elections.ipynb +++ b/examples/village/riverbend_elections.ipynb @@ -244,7 +244,7 @@ " get_state=clock.current_time_interval_str)\n", "\n", " somatic_state = components.somatic_state.SomaticState(\n", - " model, mem, agent_config.name, clock\n", + " model, mem, agent_config.name, clock.now\n", " )\n", " identity = components.identity.SimIdentity(model, mem, agent_config.name)\n", " goal_component = components.constant.ConstantConstruct(state=agent_config.goal)\n", @@ -603,7 +603,7 @@ " trigger=election_externality.declare_winner)\n", "}\n", "\n", - "schedule_construct = gm_components.schedule.Schedule(clock=clock, schedule=schedule)\n" + "schedule_construct = gm_components.schedule.Schedule(clock_now=clock.now, schedule=schedule)\n" ] }, { From e0456841ea098f2019136d01438e15ad008c94ca Mon Sep 17 00:00:00 2001 From: Sasha Vezhnevets Date: Mon, 4 Dec 2023 07:38:04 -0800 Subject: [PATCH 08/14] Lock the clock PiperOrigin-RevId: 587715714 Change-Id: Iab97b7ba8e7a6b0e727d4d819cfc4637f63529d9 --- concordia/clocks/game_clock.py | 96 ++++++++++++++++++++-------------- 1 file changed, 57 insertions(+), 39 deletions(-) diff --git a/concordia/clocks/game_clock.py b/concordia/clocks/game_clock.py index b2d2845..65dae2e 100644 --- a/concordia/clocks/game_clock.py +++ b/concordia/clocks/game_clock.py @@ -18,6 +18,7 @@ from collections.abc import Sequence import contextlib import datetime +import threading from concordia.typing import clock @@ -45,21 +46,27 @@ def __init__( self._step_size = step_size self._step = 0 + self._step_lock = threading.Lock() + def advance(self): """Advances time by step_size.""" - self._step += 1 + with self._step_lock: + self._step += 1 def set(self, time: datetime.datetime): - self._step = (time - self._start) // self._step_size + with self._step_lock: + self._step = (time - self._start) // self._step_size def now(self) -> datetime.datetime: - return self._start + self._step * self._step_size + with self._step_lock: + return self._start + self._step * self._step_size def get_step_size(self) -> datetime.timedelta: return self._step_size def get_step(self) -> int: - return self._step + with self._step_lock: + return self._step def current_time_interval_str(self) -> str: this_time = self.now() @@ -108,56 +115,67 @@ def __init__( self._steps = [0] * len(step_sizes) self._current_gear = 0 - def gear_up(self) -> None: - if self._current_gear + 1 >= len(self._step_sizes): - raise RuntimeError('Already in highest gear.') - self._current_gear += 1 + self._step_lock = threading.RLock() + + def _gear_up(self) -> None: + with self._step_lock: + if self._current_gear + 1 >= len(self._step_sizes): + raise RuntimeError('Already in highest gear.') + self._current_gear += 1 - def gear_down(self) -> None: - if self._current_gear == 0: - raise RuntimeError('Already in lowest gear.') - self._current_gear -= 1 + def _gear_down(self) -> None: + with self._step_lock: + if self._current_gear == 0: + raise RuntimeError('Already in lowest gear.') + self._current_gear -= 1 @contextlib.contextmanager def higher_gear(self): - self.gear_up() - try: - yield - finally: - self.gear_down() + with self._step_lock: + self._gear_up() + try: + yield + finally: + self._gear_down() def advance(self): """Advances time by step_size.""" - self._steps[self._current_gear] += 1 - for gear in range(self._current_gear + 1, len(self._step_sizes)): - self._steps[gear] = 0 - self.set(self.now()) # resolve the higher gear running over the lower + with self._step_lock: + self._steps[self._current_gear] += 1 + for gear in range(self._current_gear + 1, len(self._step_sizes)): + self._steps[gear] = 0 + self.set(self.now()) # resolve the higher gear running over the lower def set(self, time: datetime.datetime): - remainder = time - self._start - for gear, step_size in enumerate(self._step_sizes): - self._steps[gear] = remainder // step_size - remainder -= step_size * self._steps[gear] + with self._step_lock: + remainder = time - self._start + for gear, step_size in enumerate(self._step_sizes): + self._steps[gear] = remainder // step_size + remainder -= step_size * self._steps[gear] def now(self) -> datetime.datetime: - output = self._start - for gear, step_size in enumerate(self._step_sizes): - output += self._steps[gear] * step_size - return output + with self._step_lock: + output = self._start + for gear, step_size in enumerate(self._step_sizes): + output += self._steps[gear] * step_size + return output def get_step_size(self) -> datetime.timedelta: - return self._step_sizes[self._current_gear] + with self._step_lock: + return self._step_sizes[self._current_gear] def get_step(self) -> int: """Returns the current step in the lowest gear.""" - # this is used for logging, so makes sense to use lowest gear - return self._steps[0] + with self._step_lock: + # this is used for logging, so makes sense to use lowest gear + return self._steps[0] def current_time_interval_str(self) -> str: - this_time = self.now() - next_time = this_time + self._step_sizes[self._current_gear] - - time_string = this_time.strftime( - ' %d %b %Y [%H:%M - ' - ) + next_time.strftime('%H:%M]') - return time_string + with self._step_lock: + this_time = self.now() + next_time = this_time + self._step_sizes[self._current_gear] + + time_string = this_time.strftime( + ' %d %b %Y [%H:%M - ' + ) + next_time.strftime('%H:%M]') + return time_string From 1bf328154527ed5e97db13c37855e6b21e5a65a7 Mon Sep 17 00:00:00 2001 From: "Joel Z. Leibo" Date: Mon, 4 Dec 2023 07:42:01 -0800 Subject: [PATCH 09/14] allow formative memories to include specific memories as part of the player config PiperOrigin-RevId: 587716912 Change-Id: Ia013dd77fe7d1bda5a633e4b465079de67d04318 --- concordia/associative_memory/formative_memories.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/concordia/associative_memory/formative_memories.py b/concordia/associative_memory/formative_memories.py index a74bcf5..e31e5e0 100644 --- a/concordia/associative_memory/formative_memories.py +++ b/concordia/associative_memory/formative_memories.py @@ -42,6 +42,8 @@ class AgentConfig: traits: any traits to use while generating formative memories. For example, big five. context: agent formative memories will be generated with this context + specific_memories: inject these specific memories. Split memories at newline + characters. Can be left blank if not used. goal: defines agents goal. Can be left blank if not used. date_of_birth: the date of birth for the agent. formative_ages: ages at which the formative episodes will be created @@ -54,6 +56,7 @@ class AgentConfig: gender: str traits: str context: str = '' + specific_memories: str = '' goal: str = '' date_of_birth: datetime.datetime = DEFAULT_DOB formative_ages: Iterable[int] = DEFAULT_FORMATIVE_AGES @@ -152,6 +155,12 @@ def make_memories( if item: mem.add(item, importance=1.0) + if agent_config.specific_memories: + specific_memories = agent_config.specific_memories.split('\n') + for item in specific_memories: + if item: + mem.add(item, importance=1.0) + return mem def add_memories( From 4b4f2f3120e6ccdbda8be418fb7a8b383cd90705 Mon Sep 17 00:00:00 2001 From: "Joel Z. Leibo" Date: Mon, 4 Dec 2023 07:52:07 -0800 Subject: [PATCH 10/14] Fix bug which was causing agents to repeat a single phrase over and over PiperOrigin-RevId: 587719414 Change-Id: If284d44d7c9ae93d40299537326bf2afc37259d1 --- concordia/agents/basic_agent.py | 13 +++++-------- concordia/typing/agent.py | 2 +- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/concordia/agents/basic_agent.py b/concordia/agents/basic_agent.py index 22b88be..d1ae536 100644 --- a/concordia/agents/basic_agent.py +++ b/concordia/agents/basic_agent.py @@ -70,7 +70,7 @@ def __init__( update_interval: how often to update components. In game time according to the clock argument. verbose: whether to print chains of thought or not - user_controlled: if True, would query user input for speach and action + user_controlled: if True, would query user input for speech and action print_colour: which colour to use for printing """ self._verbose = verbose @@ -261,19 +261,16 @@ def say(self, conversation: str) -> str: f'{self._agent_name} is in the following' f' conversation:\n{conversation}\n' ) - call_to_speach = ( - f'Given the above, what should {self._agent_name} say next? Respond in' - f' the format `{self._agent_name} says: "..."` For example, ' - 'Cristina says: "Hello! Mighty fine weather today, right?" ' - 'or Ichabod says: "I wonder if the alfalfa is ready to harvest.\n' + call_to_speech = agent.DEFAULT_CALL_TO_SPEECH.format( + agent_name=self._agent_name, ) if self._user_controlled: utterance = self._ask_for_input( - convo_context + call_to_speach, f'{self._agent_name}:' + convo_context + call_to_speech, f'{self._agent_name}:' ) else: utterance = self.act( - action_spec=agent.ActionSpec(convo_context + call_to_speach, 'FREE'), + action_spec=agent.ActionSpec(convo_context + call_to_speech, 'FREE'), ) return utterance diff --git a/concordia/typing/agent.py b/concordia/typing/agent.py index 1d4e497..39257b7 100644 --- a/concordia/typing/agent.py +++ b/concordia/typing/agent.py @@ -47,7 +47,7 @@ class ActionSpec: OUTPUT_TYPES = ['FREE', 'CHOICE', 'FLOAT'] DEFAULT_CALL_TO_SPEECH = ( - 'Given the above, what did {agent_name} say? Respond in' + 'Given the above, what is {agent_name} likely to say next? Respond in' ' the format `{agent_name} says: "..."` For example, ' 'Cristina says: "Hello! Mighty fine weather today, right?" ' 'or Ichabod says: "I wonder if the alfalfa is ready to harvest.\n' From 3f48d0f33f981d9007beec42241c1672f82b5b85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edgar=20Du=C3=A9=C3=B1ez-Guzm=C3=A1n?= Date: Mon, 4 Dec 2023 08:48:43 -0800 Subject: [PATCH 11/14] Declare basic agent's state so the attribute is known to pytype. PiperOrigin-RevId: 587734720 Change-Id: Ia27b46ec286ab1671a14e2449d2851492638d6e8 --- concordia/agents/basic_agent.py | 1 + 1 file changed, 1 insertion(+) diff --git a/concordia/agents/basic_agent.py b/concordia/agents/basic_agent.py index d1ae536..47a5ccf 100644 --- a/concordia/agents/basic_agent.py +++ b/concordia/agents/basic_agent.py @@ -87,6 +87,7 @@ def __init__( self._under_interrogation = False self._state_lock = threading.Lock() + self._state: str | None self._components = {} for comp in components: From 08507428e7c4360a1d4a01e4716c1464c2e6f6d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edgar=20Du=C3=A9=C3=B1ez-Guzm=C3=A1n?= Date: Mon, 4 Dec 2023 08:51:13 -0800 Subject: [PATCH 12/14] Make the single agent opinion function not a closure so it can be used when debugging. PiperOrigin-RevId: 587735488 Change-Id: I67c12bc68a9d822cbfe2b48fa69febfd9acba215 --- concordia/metrics/opinion_of_others.py | 81 +++++++++++++------------- 1 file changed, 41 insertions(+), 40 deletions(-) diff --git a/concordia/metrics/opinion_of_others.py b/concordia/metrics/opinion_of_others.py index b933213..2f30b8a 100644 --- a/concordia/metrics/opinion_of_others.py +++ b/concordia/metrics/opinion_of_others.py @@ -107,52 +107,53 @@ def name( """Returns the name of the measurement.""" return self._name - def update(self) -> None: - """See base class.""" - def get_opinion(of_player: str) -> None: - if of_player == self._player_name: - return # No self opinions. - - prompt = interactive_document.InteractiveDocument(self._model) - parent_state = self._context_fn() - prompt.statement(parent_state) - - question = self._question.format( - opining_player=self._player_name, - of_player=of_player, + def _get_opinion(self, of_player: str) -> None: + if of_player == self._player_name: + return # No self opinions. + + prompt = interactive_document.InteractiveDocument(self._model) + parent_state = self._context_fn() + prompt.statement(parent_state) + + question = self._question.format( + opining_player=self._player_name, + of_player=of_player, + ) + + answer = prompt.multiple_choice_question( + question=question, answers=self._scale, + ) + answer_str = self._scale[answer] + + answer_float = float(answer) / float(len(self._scale) - 1) + datum = { + 'time_str': self._clock.now().strftime('%H:%M:%S'), + 'clock_step': self._clock.get_step(), + 'timestep': self._timestep, + 'value_float': answer_float, + 'value_str': answer_str, + 'opining_player': self._player_name, + 'of_player': of_player, + } + if self._measurements: + self._measurements.publish_datum(self._channel, datum) + + datum['time'] = self._clock.now() + if self._verbose: + print( + f'{self._name} of {of_player} as viewed by ' + f'{self._player_name}: {answer_str}' ) - answer = prompt.multiple_choice_question( - question=question, answers=self._scale, - ) - answer_str = self._scale[answer] - - answer_float = float(answer) / float(len(self._scale) - 1) - datum = { - 'time_str': self._clock.now().strftime('%H:%M:%S'), - 'clock_step': self._clock.get_step(), - 'timestep': self._timestep, - 'value_float': answer_float, - 'value_str': answer_str, - 'opining_player': self._player_name, - 'of_player': of_player, - } - if self._measurements: - self._measurements.publish_datum(self._channel, datum) - - datum['time'] = self._clock.now() - if self._verbose: - print( - f'{self._name} of {of_player} as viewed by ' - f'{self._player_name}: {answer_str}' - ) - - return + return + + def update(self) -> None: + """See base class.""" with concurrent.futures.ThreadPoolExecutor( max_workers=len(self._player_names) ) as executor: - executor.map(get_opinion, self._player_names) + executor.map(self._get_opinion, self._player_names) self._timestep += 1 def state( From 7668e1b988861bac2205c25464bddcaa2b5c6812 Mon Sep 17 00:00:00 2001 From: John Agapiou Date: Mon, 4 Dec 2023 08:59:30 -0800 Subject: [PATCH 13/14] Pin some dependencies PiperOrigin-RevId: 587738092 Change-Id: I84c56129d84b11cb86753868205ee164dc4be9e6 --- examples/requirements.txt | 3 ++- setup.py | 20 +++++++++++--------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/examples/requirements.txt b/examples/requirements.txt index 1f00ced..d94ac8f 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -1,2 +1,3 @@ +docstring-parser~=0.12.0 gdm-concordia -termcolor +termcolor~=1.1.0 diff --git a/setup.py b/setup.py index c17cf8b..567188a 100644 --- a/setup.py +++ b/setup.py @@ -54,18 +54,20 @@ install_requires=[ # TODO: b/312199199 - remove some requirements. 'absl-py', - 'ipython', - 'matplotlib', - 'numpy', - 'pandas<=1.5.3', - 'python_dateutil', - 'reactivex', - 'retry', + 'google-cloud-aiplatform', + 'ipython~=3.2.3', + 'matplotlib~=3.6.1', + 'numpy~=1.26.2', + 'pandas~=1.5.3', + 'python-dateutil~=2.8.2', + 'reactivex~=4.0.4', + 'retry~=0.9.2', 'saxml', - 'scipy', + 'scipy~=1.9.3', 'tensorflow', 'tensorflow_hub', - 'termcolor', + 'termcolor~=1.1.0', + 'typing-extensions~=4.5.0', ], extras_require={ # Used in development. From 6520408173406061bb17512accf362f8457245bb Mon Sep 17 00:00:00 2001 From: John Agapiou Date: Mon, 4 Dec 2023 10:41:36 -0800 Subject: [PATCH 14/14] Fix GCloud model and ensure consistent API for LanguageModel PiperOrigin-RevId: 587772666 Change-Id: I6ecf64887c4b0221d9be5ada27272b2645359745 --- concordia/language_model/gcloud_model.py | 37 ++++++++++---------- concordia/language_model/language_model.py | 10 ++++-- concordia/language_model/retry_wrapper.py | 17 ++++----- concordia/language_model/sax_model.py | 40 +++++----------------- concordia/tests/mock_model.py | 20 +++++------ 5 files changed, 54 insertions(+), 70 deletions(-) diff --git a/concordia/language_model/gcloud_model.py b/concordia/language_model/gcloud_model.py index 6894689..d7d2db2 100644 --- a/concordia/language_model/gcloud_model.py +++ b/concordia/language_model/gcloud_model.py @@ -14,15 +14,14 @@ """Google Cloud Language Model.""" from collections.abc import Collection, Sequence -import sys from concordia.language_model import language_model from concordia.utils import text from google import auth +from typing_extensions import override import vertexai from vertexai.preview import language_models as vertex_models -DEFAULT_MAX_TOKENS = 50 MAX_MULTIPLE_CHOICE_ATTEMPTS = 20 @@ -34,7 +33,7 @@ def __init__( project_id: str, model_name: str = 'text-bison@001', location: str = 'us-central1', - credentials: auth.credentials.Credentials = None + credentials: auth.credentials.Credentials | None = None, ) -> None: """Initializes a model instance using the Google Cloud language model API. @@ -45,26 +44,25 @@ def __init__( credentials: Custom credentials to use when making API calls. If not provided credentials will be ascertained from the environment. """ - if not credentials: - credentials = auth.default()[0] + if credentials is None: + credentials, _ = auth.default() vertexai.init( - project=project_id, location=location, credentials=credentials) + project=project_id, location=location, credentials=credentials + ) self._model = vertex_models.TextGenerationModel.from_pretrained(model_name) + @override def sample_text( self, prompt: str, *, - timeout: float = None, - max_tokens: int = DEFAULT_MAX_TOKENS, - max_characters: int = sys.maxsize, - terminators: Collection[str] = (), - temperature: float = 0.5, + max_tokens: int = language_model.DEFAULT_MAX_TOKENS, + max_characters: int = language_model.DEFAULT_MAX_CHARACTERS, + terminators: Collection[str] = language_model.DEFAULT_TERMINATORS, + temperature: float = language_model.DEFAULT_TEMPERATURE, + timeout: float = language_model.DEFAULT_TIMEOUT_SECONDS, seed: int | None = None, ) -> str: - """See base class.""" - if timeout is not None: - raise NotImplementedError('Unclear how to set timeout for cloud models.') if seed is not None: raise NotImplementedError('Unclear how to set seed for cloud models.') @@ -72,11 +70,13 @@ def sample_text( sample = self._model.predict( prompt, temperature=temperature, - max_output_tokens=max_tokens,) + max_output_tokens=max_tokens, + ) return text.truncate( sample.text, max_length=max_characters, delimiters=terminators ) + @override def sample_choice( self, prompt: str, @@ -84,7 +84,6 @@ def sample_choice( *, seed: int | None = None, ) -> tuple[int, str, dict[str, float]]: - """See base class.""" max_characters = max([len(response) for response in responses]) for _ in range(MAX_MULTIPLE_CHOICE_ATTEMPTS): @@ -93,7 +92,8 @@ def sample_choice( max_tokens=1, max_characters=max_characters, temperature=0.0, - seed=seed) + seed=seed, + ) try: idx = responses.index(sample) except ValueError: @@ -103,4 +103,5 @@ def sample_choice( return idx, responses[idx], debug raise language_model.InvalidResponseError( - 'Too many multiple choice attempts.') + 'Too many multiple choice attempts.' + ) diff --git a/concordia/language_model/language_model.py b/concordia/language_model/language_model.py index c0810fc..d2509ec 100644 --- a/concordia/language_model/language_model.py +++ b/concordia/language_model/language_model.py @@ -20,10 +20,11 @@ import sys from typing import Any -DEFAULT_MAX_TOKENS = 50 DEFAULT_TEMPERATURE = 0.5 -DEFAULT_MAX_CHARACTERS = sys.maxsize DEFAULT_TERMINATORS = () +DEFAULT_TIMEOUT_SECONDS = 60 +DEFAULT_MAX_CHARACTERS = sys.maxsize +DEFAULT_MAX_TOKENS = 50 class InvalidResponseError(Exception): @@ -43,6 +44,7 @@ def sample_text( max_characters: int = DEFAULT_MAX_CHARACTERS, terminators: Collection[str] = DEFAULT_TERMINATORS, temperature: float = DEFAULT_TEMPERATURE, + timeout: float = DEFAULT_TIMEOUT_SECONDS, seed: int | None = None, ) -> str: """Samples text from the model. @@ -57,10 +59,14 @@ def sample_text( terminators: the response will be terminated before any of these characters. temperature: temperature for the model. + timeout: timeout for the request. seed: optional seed for the sampling. If None a random seed will be used. Returns: The sampled response (i.e. does not iclude the prompt). + + Raises: + TimeoutError: if the operation times out. """ raise NotImplementedError diff --git a/concordia/language_model/retry_wrapper.py b/concordia/language_model/retry_wrapper.py index ad23401..f8f1b1f 100644 --- a/concordia/language_model/retry_wrapper.py +++ b/concordia/language_model/retry_wrapper.py @@ -14,12 +14,12 @@ """Wrapper to retry calls to an underlying language model.""" -from collections.abc import Collection, Sequence -import copy -from typing import Any, Mapping, Tuple, Type +from collections.abc import Collection, Sequence, Mapping +from typing import Any, Type from concordia.language_model import language_model import retry +from typing_extensions import override class RetryLanguageModel(language_model.LanguageModel): @@ -29,9 +29,9 @@ def __init__( self, model: language_model.LanguageModel, retry_on_exceptions: Collection[Type[Exception]] = (Exception,), - retry_tries: float = 3., + retry_tries: int = 3, retry_delay: float = 2., - jitter: Tuple[float, float] = (0.0, 1.0), + jitter: tuple[float, float] = (0.0, 1.0), ) -> None: """Wrap the underlying language model with retries on given exceptions. @@ -43,11 +43,12 @@ def __init__( jitter: tuple of minimum and maximum jitter to add to the retry. """ self._model = model - self._retry_on_exceptions = copy.deepcopy(retry_on_exceptions) + self._retry_on_exceptions = tuple(retry_on_exceptions) self._retry_tries = retry_tries self._retry_delay = retry_delay self._jitter = jitter + @override def sample_text( self, prompt: str, @@ -56,9 +57,9 @@ def sample_text( max_characters: int = language_model.DEFAULT_MAX_CHARACTERS, terminators: Collection[str] = language_model.DEFAULT_TERMINATORS, temperature: float = language_model.DEFAULT_TEMPERATURE, + timeout: float = language_model.DEFAULT_TIMEOUT_SECONDS, seed: int | None = None, ) -> str: - """See base class.""" @retry.retry(self._retry_on_exceptions, tries=self._retry_tries, delay=self._retry_delay, jitter=self._jitter) def _sample_text(model, prompt, *, max_tokens=max_tokens, @@ -72,6 +73,7 @@ def _sample_text(model, prompt, *, max_tokens=max_tokens, max_characters=max_characters, terminators=terminators, temperature=temperature, seed=seed) + @override def sample_choice( self, prompt: str, @@ -79,7 +81,6 @@ def sample_choice( *, seed: int | None = None, ) -> tuple[int, str, Mapping[str, Any]]: - """See base class.""" @retry.retry(self._retry_on_exceptions, tries=self._retry_tries, delay=self._retry_delay, jitter=self._jitter) def _sample_choice(model, prompt, responses, *, seed): diff --git a/concordia/language_model/sax_model.py b/concordia/language_model/sax_model.py index a6e7a1d..4e5713e 100644 --- a/concordia/language_model/sax_model.py +++ b/concordia/language_model/sax_model.py @@ -20,16 +20,14 @@ from collections.abc import Collection, Sequence import concurrent.futures -import sys from concordia.language_model import language_model from concordia.utils import text import numpy as np from saxml.client.python import sax from scipy import special +from typing_extensions import override -DEFAULT_MAX_TOKENS = 50 -DEFAULT_TIMEOUT_SECONDS = 60 DEFAULT_NUM_CONNECTIONS = 3 @@ -55,31 +53,18 @@ def __init__( self._model = sax.Model(path, options).LM() self._deterministic_multiple_choice = deterministic_multiple_choice + @override def sample_text( self, prompt: str, *, - timeout: float = DEFAULT_TIMEOUT_SECONDS, - max_tokens: int = DEFAULT_MAX_TOKENS, - max_characters: int = sys.maxsize, - terminators: Collection[str] = (), - temperature: float = 0.5, + max_tokens: int = language_model.DEFAULT_MAX_TOKENS, + max_characters: int = language_model.DEFAULT_MAX_CHARACTERS, + terminators: Collection[str] = language_model.DEFAULT_TERMINATORS, + temperature: float = language_model.DEFAULT_TEMPERATURE, + timeout: float = language_model.DEFAULT_TIMEOUT_SECONDS, seed: int | None = None, ) -> str: - """Samples a string from the model. - - Args: - prompt: the prompt to generate a response for. - timeout: timeout for the request. - max_tokens: maximum number of tokens to generate. - max_characters: maximum number of characters to generate. - terminators: delimiters to use in the generated response. - temperature: temperature for the model. - seed: seed for the random number generator. - - Returns: - A string of the generated response. - """ if seed is not None: raise NotImplementedError('Unclear how to set seed for sax models.') max_tokens = min(max_tokens, max_characters) @@ -92,6 +77,7 @@ def sample_text( sample, max_length=max_characters, delimiters=terminators ) + @override def sample_choice( self, prompt: str, @@ -99,16 +85,6 @@ def sample_choice( *, seed: int | None = None, ) -> tuple[int, str, dict[str, float]]: - """Samples a response from the model. - - Args: - prompt: the prompt to generate a response for. - responses: the responses to sample. - seed: seed for the random number generator. - - Returns: - A tuple of (index, response, debug). - """ scores = self._score_responses(prompt, responses) probs = special.softmax(scores) entropy = probs @ np.log(probs) diff --git a/concordia/tests/mock_model.py b/concordia/tests/mock_model.py index 24521ff..c9f93ec 100644 --- a/concordia/tests/mock_model.py +++ b/concordia/tests/mock_model.py @@ -14,9 +14,9 @@ """A mock Language Model.""" from collections.abc import Collection, Sequence -import sys from concordia.language_model import language_model +from typing_extensions import override class MockModel(language_model.LanguageModel): @@ -32,29 +32,30 @@ def __init__( """ self._response = response + @override def sample_text( self, prompt: str, *, - timeout: float = 0, - max_tokens: int = 0, - max_characters: int = sys.maxsize, - terminators: Collection[str] = (), - temperature: float = 0.5, + max_tokens: int = language_model.DEFAULT_MAX_TOKENS, + max_characters: int = language_model.DEFAULT_MAX_CHARACTERS, + terminators: Collection[str] = language_model.DEFAULT_TERMINATORS, + temperature: float = language_model.DEFAULT_TEMPERATURE, + timeout: float = language_model.DEFAULT_TIMEOUT_SECONDS, seed: int | None = None, ) -> str: - """See base class.""" del ( prompt, - timeout, max_tokens, max_characters, terminators, - seed, temperature, + timeout, + seed, ) return self._response + @override def sample_choice( self, prompt: str, @@ -62,6 +63,5 @@ def sample_choice( *, seed: int | None = None, ) -> tuple[int, str, dict[str, float]]: - """See base class.""" del prompt, seed return 0, responses[0], {}