-
Notifications
You must be signed in to change notification settings - Fork 1
/
create_conversational_dataset.py
89 lines (77 loc) · 3.08 KB
/
create_conversational_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import pandas as pd
import csv
import mysql.connector
from tqdm import tqdm
from multiprocessing import Pool, current_process
import multiprocessing
import glob
from config import *
NUMBER_OF_PROCESSES = NUMBER_OF_PROCESSES_OVERRIDE or multiprocessing.cpu_count()
NUMBER_OF_CONTEXTS = NUMBER_OF_CONTEXTS_OVERRIDE or 11
# conneting to sql db
reddit_db = mysql.connector.connect(
host=HOST, user=USER, password=PASSWORD, database=DATABASE_NAME
)
# loading sql db posts
cursor = reddit_db.cursor()
cursor.execute("select * from posts")
posts = cursor.fetchall()
posts = pd.DataFrame(
posts, columns=["post_id", "post_title", "post_body", "subreddit_name", "timestamp"]
).drop(columns=["timestamp"])
# loading sql db comments
cursor.execute("select * from comments")
comments = cursor.fetchall()
comments = pd.DataFrame(
comments, columns=["id", "comment_body", "post_id", "comment_timestamp", "parent_comment"]
)
# creating different files to write conversations on
for i in range(NUMBER_OF_PROCESSES):
with open(f"data/{SUBREDDIT_NAME}/conversations{str(i).zfill(2)}.csv", "w") as f:
wr = csv.writer(f, quoting=csv.QUOTE_ALL)
wr.writerow(
[
"id",
"response",
"context",
"context/0",
"context/1",
"context/2",
"context/3",
"context/4",
"context/5",
"context/6",
"context/7",
"context/8",
"context/9"
]
)
def generate_comment_chain(comment_id, conversation):
comment = comments[comments.id == comment_id]
if len(comment) == 0 or len(conversation) > (NUMBER_OF_CONTEXTS + 2):
return
if comment.parent_comment.values[0] is None:
conversation.append(comment.comment_body.values[0])
conversation.append(posts[posts.post_id == comment.post_id.values[0]].post_body.values[0])
else:
conversation.append(comment.comment_body.values[0])
parent_comment_id = comment.parent_comment.values[0]
generate_comment_chain(parent_comment_id, conversation)
def generate_post_conversations(post_id):
for comment_id in comments[(comments.post_id == post_id)].id:
conversation = []
generate_comment_chain(comment_id, conversation)
conversation.insert(0, comment_id)
with open(f"data/infj/conversations{str(int(current_process().name[16:])%NUMBER_OF_PROCESSES).zfill(2)}.csv", 'a') as f:
writer = csv.writer(f)
writer.writerow(conversation[:NUMBER_OF_CONTEXTS+2])
# adding conversations in the conversational db in parallel
post_ids = posts.post_id
with Pool(NUMBER_OF_PROCESSES) as pool:
pool.map(generate_post_conversations, post_ids)
# concatenate different CSVs
local_path = f"data/{SUBREDDIT_NAME}"
filenames = glob.glob(local_path + "/*.csv")
parallel_conversations = [pd.read_csv(filename) for filename in filenames]
conversations = pd.concat(parallel_conversations, ignore_index=True)
conversations.to_pickle(f"data/{SUBREDDIT_NAME}_conversations.pickle")