Skip to content

Commit

Permalink
Merge pull request #5 from camel-ai/twitter_simu
Browse files Browse the repository at this point in the history
add twitter gpt example
  • Loading branch information
yiyiyi0817 authored Nov 18, 2024
2 parents 2492a34 + acd0ee0 commit 0d07b6e
Show file tree
Hide file tree
Showing 6 changed files with 233 additions and 11 deletions.
14 changes: 8 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ ______________________________________________________________________
### Step 1: Clone the Repository

```bash
git clone https://github.com/camel-ai/social-simulation.git
git clone https://github.com/camel-ai/oasis.git

cd social-simulation
cd oasis
```

### Step 2: Create and Activate a Virtual Environment
Expand Down Expand Up @@ -94,23 +94,23 @@ First, you need to add your OpenAI API key to the system's environment variables
```bash
# Export your OpenAI API key
export OPENAI_API_KEY=<insert your OpenAI API key>
OPENAI_API_BASE_URL=<inert your OpenAI API BASE URL> #(Should you utilize an OpenAI proxy service, kindly specify this)
export OPENAI_API_BASE_URL=<insert your OpenAI API BASE URL> #(Should you utilize an OpenAI proxy service, kindly specify this)
```

- For Windows Command Prompt:\*\*

```cmd
REM export your OpenAI API key
set OPENAI_API_KEY=<insert your OpenAI API key>
set OPENAI_API_BASE_URL=<inert your OpenAI API BASE URL> #(Should you utilize an OpenAI proxy service, kindly specify this)
set OPENAI_API_BASE_URL=<insert your OpenAI API BASE URL> #(Should you utilize an OpenAI proxy service, kindly specify this)
```

- For Windows PowerShell:\*\*

```powershell
# Export your OpenAI API key
$env:OPENAI_API_KEY="<insert your OpenAI API key>"
$env:OPENAI_API_BASE_URL="<inert your OpenAI API BASE URL>" #(Should you utilize an OpenAI proxy service, kindly specify this)
$env:OPENAI_API_BASE_URL="<insert your OpenAI API BASE URL>" #(Should you utilize an OpenAI proxy service, kindly specify this)
```

Replace `<insert your OpenAI API key>` with your actual OpenAI API key in each case. Make sure there are no spaces around the `=` sign.
Expand All @@ -126,10 +126,12 @@ To import your own user and post data, please refer to the JSON file format loca
```bash
# For Reddit
python scripts/reddit_gpt_example/reddit_simulation_gpt.py --config_path scripts/reddit_gpt_example/gpt_example.yaml

# For Twitter
python scripts/twitter_gpt_example/twitter_simulation_large.py --config_path scripts/twitter_gpt_example/gpt_example.yaml
```

Note: without modifying the Configuration File, running this script requires approximately 14 API requests to call gpt-4, and the cost incurred is minimal. (October 29, 2024)
Note: without modifying the Configuration File, running the reddit script requires approximately 14 API requests to call gpt-4, and the cost incurred is minimal. (October 29, 2024)

## 📘 Comprehensive Guide (For Open Source Models)

Expand Down
6 changes: 4 additions & 2 deletions oasis/social_agent/agents_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ async def generate_agents(
model_random_seed: int = 42,
cfgs: list[Any] | None = None,
neo4j_config: Neo4jConfig | None = None,
is_openai_model: bool = False,
) -> AgentGraph:
"""Generate and return a dictionary of agents from the agent
information CSV file. Each agent is added to the database and
Expand Down Expand Up @@ -134,16 +135,17 @@ class instances.
model_type=model_type,
agent_graph=agent_graph,
action_space_prompt=action_space_prompt,
is_openai_model=is_openai_model,
)

agent_graph.add_agent(agent)
num_followings = 0
num_followers = 0
# print('agent_info["following_count"]', agent_info["following_count"])
if not agent_info["following_count"].empty:
num_followings = agent_info["following_count"][agent_id]
num_followings = int(agent_info["following_count"][agent_id])
if not agent_info["followers_count"].empty:
num_followers = agent_info["followers_count"][agent_id]
num_followers = int(agent_info["followers_count"][agent_id])

sign_up_list.append((
agent_id,
Expand Down
7 changes: 4 additions & 3 deletions oasis/social_platform/recsys.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,10 +467,9 @@ def rec_sys_personalized_twh(
fans_score.append(
np.log(u_items[post['user_id']] + 1) / np.log(1000))
except Exception as e:
print(e)
print(f"Error on fan score calculating: {e}")
import pdb

# pdb.set_trace()
pdb.set_trace()

date_score_np = np.array(date_score)
# fan_score [1, 2.x]
Expand All @@ -487,6 +486,8 @@ def rec_sys_personalized_twh(
ActionType.LIKE_POST.value,
trace_table)
like_post_ids_all.append(like_post_ids)
# enable fans_score when the broadcasting effect of superuser should be
# taken in count
# ßscores = date_score_np * fans_score_np
scores = date_score_np
new_rec_matrix = []
Expand Down
11 changes: 11 additions & 0 deletions scripts/twitter_gpt_example/action_space_prompt.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# OBJECTIVE
You're a Twitter user, and I'll present you with some posts. After you see the posts, choose some actions from the following functions.
Suppose you are a real Twitter user. Please simulate real behavior.

- do_nothing: Most of the time, you just don't feel like reposting or liking a post, and you just want to look at it. In such cases, choose this action "do_nothing"
- repost: Repost a post.
- Arguments: "post_id" (integer) - The ID of the post to be reposted. You can `repost` when you want to spread it.
- like_post: Likes a specified post.
- Arguments: "post_id" (integer) - The ID of the tweet to be liked. You can `like` when you feel something interesting or you agree with.
- follow: Follow a user specified by 'followee_id'. You can `follow' when you respect someone, love someone, or care about someone.
- Arguments: "followee_id" (integer) - The ID of the user to be followed.
21 changes: 21 additions & 0 deletions scripts/twitter_gpt_example/gpt_example.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
---
data:
db_path: data/simu_db/yaml_gpt/False_Business_0.db
csv_path: data/twitter_dataset/anonymous_topic_200_1h/False_Business_0.csv
simulation:
num_timesteps: 10
clock_factor: 60
recsys_type: twhin-bert
model:
num_agents: 111
model_random_seed: 42
cfgs:
- model_type: gpt-4o-mini
num: 111
server_url: null
model_path: null
stop_tokens: null
temperature: null
inference:
model_type: gpt-4o-mini # Name of the OpenAI model
is_openai_model: true # Whether it is an OpenAI model
185 changes: 185 additions & 0 deletions scripts/twitter_gpt_example/twitter_simulation_large.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from __future__ import annotations

import argparse
import asyncio
import logging
import os
import random
from datetime import datetime
from pathlib import Path
from typing import Any

import pandas as pd
from colorama import Back
from yaml import safe_load

from oasis.clock.clock import Clock
from oasis.social_agent.agents_generator import generate_agents
from oasis.social_platform.channel import Channel
from oasis.social_platform.platform import Platform
from oasis.social_platform.typing import ActionType

social_log = logging.getLogger(name="social")
social_log.setLevel("DEBUG")

file_handler = logging.FileHandler("social.log")
file_handler.setLevel("DEBUG")
file_handler.setFormatter(
logging.Formatter("%(levelname)s - %(asctime)s - %(name)s - %(message)s"))
social_log.addHandler(file_handler)
stream_handler = logging.StreamHandler()
stream_handler.setLevel("DEBUG")
stream_handler.setFormatter(
logging.Formatter("%(levelname)s - %(asctime)s - %(name)s - %(message)s"))
social_log.addHandler(stream_handler)

parser = argparse.ArgumentParser(description="Arguments for script.")
parser.add_argument(
"--config_path",
type=str,
help="Path to the YAML config file.",
required=False,
default="",
)

DATA_DIR = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"data/twitter_dataset/anonymous_topic_200_1h",
)
DEFAULT_DB_PATH = ":memory:"
DEFAULT_CSV_PATH = os.path.join(DATA_DIR, "False_Business_0.csv")


async def running(
db_path: str | None = DEFAULT_DB_PATH,
csv_path: str | None = DEFAULT_CSV_PATH,
num_timesteps: int = 3,
clock_factor: int = 60,
recsys_type: str = "twhin-bert",
model_configs: dict[str, Any] | None = None,
inference_configs: dict[str, Any] | None = None,
action_space_file_path: str = None,
) -> None:
db_path = DEFAULT_DB_PATH if db_path is None else db_path
csv_path = DEFAULT_CSV_PATH if csv_path is None else csv_path
if os.path.exists(db_path):
os.remove(db_path)
Path(db_path).parent.mkdir(parents=True, exist_ok=True)

if recsys_type == "reddit":
start_time = datetime.now()
else:
start_time = 0
social_log.info(f"Start time: {start_time}")
clock = Clock(k=clock_factor)
twitter_channel = Channel()
infra = Platform(
db_path,
twitter_channel,
clock,
start_time,
recsys_type=recsys_type,
refresh_rec_post_count=2,
max_rec_post_len=2,
following_post_count=3,
)
inference_channel = Channel()
twitter_task = asyncio.create_task(infra.running())
if inference_configs["model_type"][:3] == "gpt":
is_openai_model = True

try:
all_topic_df = pd.read_csv("data/twitter_dataset/all_topics.csv")
if "False" in csv_path or "True" in csv_path:
if "-" not in csv_path:
topic_name = csv_path.split("/")[-1].split(".")[0]
else:
topic_name = csv_path.split("/")[-1].split(".")[0].split(
"-")[0]
source_post_time = (
all_topic_df[all_topic_df["topic_name"] ==
topic_name]["start_time"].item().split(" ")[1])
start_hour = int(source_post_time.split(":")[0]) + float(
int(source_post_time.split(":")[1]) / 60)
except Exception:
print("No real-world data, let start_hour be 1PM")
start_hour = 13

model_configs = model_configs or {}
with open(action_space_file_path, "r", encoding="utf-8") as file:
action_space = file.read()
agent_graph = await generate_agents(
agent_info_path=csv_path,
twitter_channel=twitter_channel,
inference_channel=inference_channel,
start_time=start_time,
recsys_type=recsys_type,
twitter=infra,
action_space_prompt=action_space,
is_openai_model=is_openai_model,
**model_configs,
)
# agent_graph.visualize("initial_social_graph.png")

for timestep in range(1, num_timesteps + 1):
os.environ["SANDBOX_TIME"] = str(timestep * 3)
social_log.info(f"timestep:{timestep}")
db_file = db_path.split("/")[-1]
print(Back.GREEN + f"DB:{db_file} timestep:{timestep}" + Back.RESET)
# if you want to disable recsys, please comment this line
await infra.update_rec_table()

# 0.05 * timestep here means 3 minutes / timestep
simulation_time_hour = start_hour + 0.05 * timestep
tasks = []
for node_id, agent in agent_graph.get_agents():
if agent.user_info.is_controllable is False:
agent_ac_prob = random.random()
threshold = agent.user_info.profile["other_info"][
"active_threshold"][int(simulation_time_hour % 24)]
if agent_ac_prob < threshold:
tasks.append(agent.perform_action_by_llm())
else:
await agent.perform_action_by_hci()

await asyncio.gather(*tasks)
# agent_graph.visualize(f"timestep_{timestep}_social_graph.png")

await twitter_channel.write_to_receive_queue((None, None, ActionType.EXIT))
await twitter_task


if __name__ == "__main__":
args = parser.parse_args()
os.environ["SANDBOX_TIME"] = str(0)
if os.path.exists(args.config_path):
with open(args.config_path, "r") as f:
cfg = safe_load(f)
data_params = cfg.get("data")
simulation_params = cfg.get("simulation")
model_configs = cfg.get("model")
inference_configs = cfg.get("inference")

asyncio.run(
running(**data_params,
**simulation_params,
model_configs=model_configs,
inference_configs=inference_configs,
action_space_file_path=("scripts/twitter_gpt_example/"
"action_space_prompt.txt")))
else:
asyncio.run(running())
social_log.info("Simulation finished.")

0 comments on commit 0d07b6e

Please sign in to comment.