-
Notifications
You must be signed in to change notification settings - Fork 0
/
EMalgorithm_run.py
40 lines (30 loc) · 1.1 KB
/
EMalgorithm_run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import numpy as np
import EMalgorithm
import matplotlib.pyplot as plt
y = np.loadtxt('y.csv',delimiter=',',dtype='float')
X = np.loadtxt('X.csv',delimiter=',',dtype='float')
Z = np.loadtxt('Z.csv',delimiter=',',dtype='float')
beta = np.loadtxt('beta.csv',delimiter=',',dtype='float')
print('The dimension of y is:'+str(y.shape))
print('The dimension of X is:'+str(X.shape))
print('The dimension of Z is:'+str(Z.shape))
EM = EMalgorithm.EM(maxItr=100)
EM.fit(y=y,X=X,Z=Z)
beta_history = EM.beta_history
loglikelihood_history = EM.loglikelihood_history
print(loglikelihood_history)
MSE=[]
for i in beta_history:
MSE.append(0.5*np.linalg.norm(i.reshape(2,1)-beta.reshape(2,1))**2)
fig = plt.figure()
ax1 = fig.add_subplot(111)
line1, =ax1.plot(MSE,'b', label="MSE")
ax1.set_ylabel('Mean Squre Error')
ax1.set_title("MSE and log-likelihood")
ax2 = ax1.twinx()
line2, =ax2.plot(loglikelihood_history, 'r', label = "log-likelihood")
ax2.set_ylabel('log-likelihood')
ax2.set_xlabel('Number of Iteration Steps')
first_legend = plt.legend(handles=[line1,line2], loc=1)
ax = plt.gca().add_artist(first_legend)
plt.show()