-
Notifications
You must be signed in to change notification settings - Fork 1
/
correlation.py
64 lines (46 loc) · 2.1 KB
/
correlation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from scipy import stats
from scipy.optimize import curve_fit
import numpy as np
def fit_curve(x, y, curve_type='logistic_4params'):
r'''Fit the scale of predict scores to MOS scores using logistic regression suggested by VQEG.
The function with 4 params is more commonly used.
The 5 params function takes from DBCNN:
- https://github.com/zwx8981/DBCNN/blob/master/dbcnn/tools/verify_performance.m
'''
assert curve_type in [
'logistic_4params', 'logistic_5params'], f'curve type should be in [logistic_4params, logistic_5params], but got {curve_type}.'
betas_init_4params = [np.max(y), np.min(y), np.mean(x), np.std(x) / 4.]
def logistic_4params(x, beta1, beta2, beta3, beta4):
yhat = (beta1 - beta2) / (1 + np.exp(- (x - beta3) / beta4)) + beta2
return yhat
betas_init_5params = [10, 0, np.mean(y), 0.1, 0.1]
def logistic_5params(x, beta1, beta2, beta3, beta4, beta5):
logistic_part = 0.5 - 1. / (1 + np.exp(beta2 * (x - beta3)))
yhat = beta1 * logistic_part + beta4 * x + beta5
return yhat
if curve_type == 'logistic_4params':
logistic = logistic_4params
betas_init = betas_init_4params
elif curve_type == 'logistic_5params':
logistic = logistic_5params
betas_init = betas_init_5params
betas, _ = curve_fit(logistic, x, y, p0=betas_init, maxfev=10000)
yhat = logistic(x, *betas)
return yhat
def calculate_rmse(x, y, fit_scale=None, eps=1e-8):
if fit_scale is not None:
x = fit_curve(x, y, fit_scale)
return np.sqrt(np.mean((x - y) ** 2) + eps)
def calculate_plcc(x, y, fit_scale=None):
if fit_scale is not None:
x = fit_curve(x, y, fit_scale)
return stats.pearsonr(x, y)[0]
def calculate_srcc(x, y):
return stats.spearmanr(x, y)[0]
def calculate_krcc(x, y):
return stats.kendalltau(x, y)[0]
def cal_plcc_srcc_rmse(x, y):
cc = calculate_plcc(x, y, 'logistic_4params')
srcc = calculate_srcc(x, y)
rmse = calculate_rmse(x, y, 'logistic_4params')
return cc, srcc, rmse