-
Notifications
You must be signed in to change notification settings - Fork 29
/
4_Gibbs_Sampling.py
140 lines (115 loc) · 4.55 KB
/
4_Gibbs_Sampling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#################################
### Author: Paul Soto ###
### paul.soto@upf.edu ###
# #
# This file is a class to #######
# run Gibbs #
# sampling for Latent Dirichlet #
# Dirichlet Allocation #
#################################
import pandas as pd
from Text_Preprocessing import *
import re
from itertools import chain
import numpy as np
from collections import Counter
class Gibbs():
"""
A class for the uncollapsed Gibbs sampling on text
"""
def __init__(self, text, K=2):
"""
text: Pandas series with each row a list of words in the document
K: number of topics
"""
self.tokens = list(set(chain(*text.values)))
self.V = len(self.tokens)
self.K = K
self.text = text
### Create objects we'll need in updating parameters
self.doc_topic = text.apply(lambda x: np.random.randint(0,self.K,size=(len(x),)))
self.doc_topic_counts = self.doc_topic.apply(lambda x: Counter(x)).apply(pd.Series)
self.doc_topic_counts = self.doc_topic_counts.fillna(0)
# Fill missing columns (typically if K is too large)
if list(self.doc_topic_counts.columns)!=range(self.K):
needed = [el for el in range(self.K) if el not in self.doc_topic_counts.columns]
for col in needed:
self.doc_topic_counts[col] = 0
self.term_topic_count = pd.DataFrame(index=self.tokens,columns=range(self.K),
data=np.zeros((self.V,self.K)))
for doc_ind in range(self.text.shape[0]):
for (topic,word) in zip(self.doc_topic.ix[doc_ind],self.text.ix[doc_ind]):
self.term_topic_count.loc[word,topic]+=1
# Set priors
self.alpha = 50.0/self.K
self.beta = 200.0/self.V
self.perplexity_scores = []
def iterate(self,n=1000):
"""
Run n steps of the Gibbs sampler
Relies on two calculations:
word_given_topic: "probability" of observing a word given a topic
topic_given_doc: "probability" of observing topic j
Each is calculated by removing the current word from document
"""
for step in range(n):
if step%25==0:
print "Step %s of Gibbs Sampling Completed" % step
self.perplexity()
print self.perplexity_scores
for doc_ind,doc in enumerate(self.text):
topics = self.doc_topic.ix[doc_ind]
for word_ind,word in enumerate(doc):
# Remove current word from current calculations
self.doc_topic_counts.loc[doc_ind,topics[word_ind]]-=1
self.term_topic_count.loc[word,topics[word_ind]]-=1
# Find conditional probability
# Multiply how much a word likes a given topic by
# how much a document likes that topic
word_given_topic = (self.term_topic_count.ix[word]+self.beta)/\
(self.doc_topic_counts.sum()+self.V*self.beta)
topic_given_doc = (self.doc_topic_counts.ix[doc_ind]+self.alpha)/\
(self.doc_topic_counts.sum(1).ix[doc_ind]+self.K*self.alpha)
weights = word_given_topic*topic_given_doc
weights = weights/weights.sum()
new_topic = np.where(np.random.multinomial(1,weights)==1)[0][0]
topics[word_ind] = new_topic
# Add back the removed word to appropriate topic
self.doc_topic_counts.loc[doc_ind,new_topic]+=1
self.term_topic_count.loc[word,new_topic]+=1
self.doc_topic.ix[doc_ind] = topics
def perplexity(self):
"""
Compute perplexity scores of samples (currently insample)
"""
dt = (self.doc_topic_counts+self.alpha).apply(lambda x: x/x.sum(),1).fillna(0)
tt = (self.term_topic_count+self.beta)/(self.term_topic_count+self.beta).sum().fillna(0)
def prob(row):
word_list = row[0]
index = row['index']
doc_perp= 0
for each in word_list:
doc_perp+=np.log((tt.ix[each]*dt.ix[index]).sum())
return doc_perp
perplexity = self.text.reset_index().apply(prob,1)
perplexity = perplexity.sum()
self.perplexity_scores.append(np.exp( - np.sum(perplexity) / self.text.apply(len).sum()))
def top_n_words(self,n=10):
"""
Returns the n most frequent words from each topic
"""
for topic in range(self.K):
top_n = self.term_topic_count.sort(topic,ascending=False)[topic].head(n)
print "Top %s terms for topic %s" % (n,topic)
for word in top_n.index: print word
data = [["rugby","football","competition","ball","games"],
["macro","economics","competition","games"],
["technology","computers","apple","AAPL","internet"],
["football","score","touchdown","team"],
["keynes","macro","friedman","policy"],
["stocks","AAPL","gains","analysis"],
["playoffs","games","season","compete","ball"],
["analysis","economy","economics","government"],
["apple","team","jobs","compete","computers"]]
gibbsobj = Gibbs(pd.Series(data),K=3)
gibbsobj.iterate()