-
Notifications
You must be signed in to change notification settings - Fork 1
/
run.py
73 lines (62 loc) · 2.46 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
Marked Point Process Learning via EM algorithm
'''
import sys
import arrow
import utils
import numpy as np
from vhawkes import VecMarkedMultivarHawkes
burglary_with_random_indice = list(range(0, 50))
pedrobbery_with_random_indice = list(range(3700, 3900))
others_with_random_indice = list(range(10000, 10056))
indice = burglary_with_random_indice + pedrobbery_with_random_indice + others_with_random_indice
def exp_baselines(
retrieval_range=np.linspace(100, 1000, 51).astype(np.int32), n=10056,
category='other', epoches=1, iters=5,
csv_filename='data/beats_graph.csv'):
# load dataset
t, _, m, l, u, u_set, true_labels = utils.load_police_training_data(n=n, category=category)
# only select a small set of data for category other
if category == 'other':
t = np.array([ t[idx] for idx in indice ])
u = np.array([ u[idx] for idx in indice ])
m = np.array([ m[idx] for idx in indice ])
l = [ l[idx] for idx in indice ]
# init results
precisions = []
recalls = []
# data preparation and configuration
t = np.expand_dims((t - min(t) + 1000.) / (max(t) - min(t) + 2000.), -1) # time normalization
u = np.expand_dims(u, -1)
m = m / 1000.
seq = np.concatenate([t, u, m], axis=1)
n_dim = len(np.unique(u))
T = 1.
# build model
hawkes = VecMarkedMultivarHawkes(n_dim=n_dim, T=T, seq=seq)
hawkes.em_fit(iters=2)
# experiments
for N in retrieval_range:
print('---------N = %d ----------' % N)
precision = []
recall = []
for e in range(epoches):
p, r = utils.retrieval_test(hawkes, l, true_labels=true_labels, first_N=N)
print(p, r)
precision.append(p)
recall.append(r)
precisions.append(precision)
recalls.append(recall)
# save exp results
# np.savetxt("result/newsttpp+gbrbm1k_%s_precision_N_from%dto%d.txt" % \
# (category, min(retrieval_range), max(retrieval_range)), precisions, delimiter=',')
# np.savetxt("result/newsttpp+gbrbm1k_%s_recalls_N_from%dto%d.txt" % \
# (category, min(retrieval_range), max(retrieval_range)), recalls, delimiter=',')
if __name__ == '__main__':
# np.random.seed(0)
# np.set_printoptions(suppress=True)
exp_baselines(
retrieval_range=np.linspace(100, 1000, 51).astype(np.int32),
n=10056, category='robbery', epoches=2)