-
Notifications
You must be signed in to change notification settings - Fork 1
/
interpret_with_ground_truth.py
208 lines (176 loc) · 7.57 KB
/
interpret_with_ground_truth.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import os, gc
import pandas as pd
from datetime import datetime
from os.path import join
import warnings
warnings.filterwarnings('ignore')
from captum.attr import Lime, DeepLift, IntegratedGradients, GradientShap
from explainers import MorrisSensitivty
from tint.attr import (
AugmentedOcclusion,
Occlusion,
FeatureAblation
)
from run import stringify_setting, get_parser as get_run_parser, initial_setup
from exp.config import FeatureFiles, DataConfig
from exp.exp_forecasting import Exp_Forecast
from exp.exp_interpret import initialize_explainer, explainer_name_map
from interpret_with_ground_truth import *
from utils.interpreter import *
explainer_map = {
'feature_ablation': FeatureAblation,
'occlusion': Occlusion,
'augmented_occlusion': AugmentedOcclusion,
'lime': Lime,
'deep_lift': DeepLift,
'integrated_gradients': IntegratedGradients,
'gradient_shap': GradientShap,
'morris_sensitivity': MorrisSensitivty
}
def main(args):
print(f'Experiment started at {datetime.now()}')
# only has real features and observed reals also contains past targets
features = DataConfig.static_reals + DataConfig.observed_reals
age_features = DataConfig.static_reals
# update arguments
initial_setup(args)
setting = stringify_setting(args)
# initialize experiment
exp = Exp_Forecast(args, setting) # set experiments
exp.load_model()
# get dataset and dataloader
flag = args.flag
dataset, dataloader = exp.get_data(flag)
# get ground truth
df = exp.data_map[flag]
df.sort_values(by=['Date', 'FIPS'], inplace=True)
# read ground truth and county populations
group_cases = pd.read_csv(
join(FeatureFiles.root_folder, 'Cases by age groups.csv')
)
group_cases['end_of_week'] = pd.to_datetime(group_cases['end_of_week'])
population = pd.read_csv(join(FeatureFiles.root_folder, 'Population.csv'))
population = population[['FIPS', 'POPESTIMATE']]
# weight attributions by population ratio and total count
weights = df.groupby('FIPS').first()[age_features].reset_index()
# create result folder if not present
result_folder = os.path.join(exp.output_folder, 'interpretation')
if not os.path.exists(result_folder):
os.makedirs(result_folder, exist_ok=True)
print(f'Interpretation results will be saved in {result_folder}')
for explainer_name in args.explainers:
# calculate attribute
start = datetime.now()
print(f'{explainer_name} interpretation started at {start}')
explainer = initialize_explainer(
explainer_name, exp, dataloader, args, add_x_mark=False
)
# batch x pred_len x seq_len x features
attr = batch_compute_attr(
dataloader, exp, explainer,
baseline_mode=args.baseline_mode,
add_x_mark=False # only interpret the static and dynamic features
)
# batch x pred_len x seq_len x features -> batch x pred_len x features
attr = attr.mean(axis=2)
# batch x features x pred_len
attr = attr.permute(0, 2, 1)
end = datetime.now()
print(f'{explainer_name} interpretation ended at {end}, total time {end - start}')
# taking absolute since we want the magnitude of feature importance only
attr_numpy = np.abs(attr.detach().cpu().numpy())
# align attribution to date time index
attr_df = align_interpretation(
ranges=dataset.ranges,
attr=attr_numpy,
features=features,
min_date=df['Date'].min(),
seq_len=args.seq_len, pred_len=args.pred_len
)
print('Attribution statistics')
print(attr_df.describe())
gc.collect()
# multiply the importance of age groups from each county by the corresponding population
groups = []
for FIPS, group_df in attr_df.groupby('FIPS'):
county_age_weights = weights[weights['FIPS']==FIPS][age_features].values
total_population = population[
population['FIPS']==FIPS]['POPESTIMATE'].values[0]
group_df[age_features] *= county_age_weights * total_population
groups.append(group_df)
groups = pd.concat(groups, axis=0)
weighted_attr_df = groups[['FIPS', 'Date'] + age_features].reset_index(drop=True)
weighted_attr_by_date = weighted_attr_df.groupby('Date')[
age_features].aggregate('sum').reset_index()
dates = weighted_attr_by_date['Date'].values
first_common_date = find_first_common_date(group_cases, dates)
last_common_date = find_last_common_date(group_cases, dates)
# sum of ground truth cases within common time
summed_ground_truth = group_cases[
(group_cases['end_of_week']>=first_common_date) &
(group_cases['end_of_week']<=last_common_date)
][age_features].mean(axis=0).T.reset_index()
summed_ground_truth.columns = ['age_group', 'cases']
# sum of predicted weighted age relevance score within common time
summed_weighted_attr = weighted_attr_df[
(weighted_attr_df['Date']>=(first_common_date-pd.to_timedelta(6, unit='D'))) &
(weighted_attr_df['Date']<=last_common_date)
][age_features].mean(axis=0).T.reset_index()
summed_weighted_attr.columns = ['age_group', 'attr']
# merge ground truth and predicted ranking
global_rank = summed_ground_truth.merge(
summed_weighted_attr, on='age_group', how='inner'
)
global_rank[['cases', 'attr']] = global_rank[['cases', 'attr']].div(
global_rank[['cases', 'attr']].sum(axis=0)/100, axis=1
).fillna(0) # will be null when all attributions are zero
global_rank['cases_rank'] = global_rank['cases'].rank(
axis=0, ascending=False
)
global_rank['attr_rank'] = global_rank['attr'].rank(
axis=0, ascending=False
)
print('Global rank comparison')
print(global_rank)
global_rank.to_csv(
join(
result_folder,
f'{flag}_global_rank_{explainer.get_name()}.csv'
),
index=False
)
print('\nEvaluating local ranks')
# since age group ground truth is weekly aggregated
# do the same for predicted importance
weekly_agg_scores_df = aggregate_importance_by_window(
weighted_attr_by_date, age_features, first_common_date
)
result_df = evaluate_interpretation(
group_cases, weekly_agg_scores_df, age_features
)
result_df.to_csv(
join(
result_folder,
f'{flag}_int_metrics_{explainer.get_name()}.csv'
),
index=False
)
def get_parser():
parser = get_run_parser()
parser.description = 'Interpret Timeseries Models'
parser.add_argument('--explainers', nargs='*', default=['feature_ablation'],
choices=list(explainer_name_map.keys()),
help='explaination method names')
parser.add_argument('--flag', type=str, default='test',
choices=['train', 'val', 'test', 'updated'],
help='flag for data split'
)
parser.add_argument('--baseline_mode', type=str, default='random',
choices=['random', 'aug', 'zero', 'mean'],
help='how to create the baselines for the interepretation methods'
)
return parser
if __name__ == '__main__':
parser = get_parser()
args = parser.parse_args()
main(args)