-
Notifications
You must be signed in to change notification settings - Fork 0
/
analyse_results.py
80 lines (68 loc) · 2.48 KB
/
analyse_results.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import src.utils.plot as plt
import src.utils.stats as sts
import src.utils.results as rpck
# if warnings are disturbing the presentation, uncomment the lines bellow
import warnings
warnings.filterwarnings("ignore")
# 1. Defining analysis settings
NEXP = 50
PATH = './results/'
SAVE = True
SHOW_SUMMARY = True
SHOW_PVALUE = True
PLOT = True
PLOT_TYPE = 'cumlines'
# select the target data
target_data = ['reward']#,'time','nrollouts','nsimulations'
ylabel = {
'lines':{
'reward':'Average Reward',
'time':'Average Time (s)'},
'cumlines':{
'reward':'Cumulative Reward',
'time':'Cumulative Time (s)'},
'bars':{
'reward':'Average Reward',
'time':'Average Time (s)'},
}
# select the target environments
envs = ['TigerEnv0',
'MazeEnv0','MazeEnv1','MazeEnv2','MazeEnv3',
'RockSampleEnv0','RockSampleEnv1','RockSampleEnv2','RockSampleEnv3',
'LevelForagingEnv0','LevelForagingEnv1','LevelForagingEnv2','LevelForagingEnv3', 'LevelForagingEnv4',
'TagEnv0','LaserTagEnv0'
]
# select the target methods
methods_dict = {
'pomcp':'POMCP',
'prpomcp':'PR-POMCP',
'iucbpomcp':'IUCB-POMCP',
'ibpomcp':'IB-POMCP',
'tbrhopomcp':'TB ρ-POMCP',
'rhopomcp':'ρ-POMCP',
}
methods = [name for name in methods_dict]
for env in envs:
print('>',env)
results = {}
for method in methods:
results[methods_dict[method]] = \
rpck.read(nexp=NEXP,method=method,path=PATH,env=env)
# 2. Analysing via plot and pvalues
for td in target_data:
if SHOW_SUMMARY:
sts.summary(results=results,target_data=td,LaTeX=True)
if SHOW_PVALUE:
#sts.pvalues(results=results,target_data=td,by_='iteration')
sts.pvalues(results=results,target_data='reward',by_='experiment')
if PLOT:
if PLOT_TYPE == 'lines':
plt.lines(results=results,target_data=td,ylabel=ylabel[PLOT_TYPE][td],
xlabel='Iteration',save=SAVE,savepath='./plots/',env_name=env)
elif PLOT_TYPE == 'cumlines':
plt.cumlines(results=results,target_data=td,
ylabel=ylabel[PLOT_TYPE][td],xlabel='Iteration',
save=SAVE,savepath='./plots/',env_name=env)
elif PLOT_TYPE == 'bars':
plt.bars(results=results,target_data=td,ylabel=ylabel[PLOT_TYPE][td],
save=SAVE,savepath='./plots/',env_name=env)