-
Notifications
You must be signed in to change notification settings - Fork 1
/
bland_altman.py
124 lines (103 loc) · 4.22 KB
/
bland_altman.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def bland_altman_plot(m1, m2,
sd_limit=1.96,
ax=None,
scatter_kwds=None,
mean_line_kwds=None,
limit_lines_kwds=None):
"""
Bland-Altman Plot.
A Bland-Altman plot is a graphical method to analyze the differences
between two methods of measurement. The mean of the measures is plotted
against their difference.
Parameters
----------
m1, m2: pandas Series or array-like
sd_limit : float, default 1.96
The limit of agreements expressed in terms of the standard deviation of
the differences. If `md` is the mean of the differences, and `sd` is
the standard deviation of those differences, then the limits of
agreement that will be plotted will be
md - sd_limit * sd, md + sd_limit * sd
The default of 1.96 will produce 95% confidence intervals for the means
of the differences.
If sd_limit = 0, no limits will be plotted, and the ylimit of the plot
defaults to 3 standard deviatons on either side of the mean.
ax: matplotlib.axis, optional
matplotlib axis object to plot on.
scatter_kwargs: keywords
Options to to style the scatter plot. Accepts any keywords for the
matplotlib Axes.scatter plotting method
mean_line_kwds: keywords
Options to to style the scatter plot. Accepts any keywords for the
matplotlib Axes.axhline plotting method
limit_lines_kwds: keywords
Options to to style the scatter plot. Accepts any keywords for the
matplotlib Axes.axhline plotting method
Returns
-------
ax: matplotlib Axis object
"""
import numpy as np
import matplotlib.pyplot as plt
if len(m1) != len(m2):
raise ValueError('m1 does not have the same length as m2.')
if sd_limit < 0:
raise ValueError('sd_limit ({}) is less than 0.'.format(sd_limit))
means = np.mean([m1, m2], axis=0)
diffs = m1 - m2
mean_diff = np.mean(diffs)
std_diff = np.std(diffs, axis=0)
if ax is None:
ax = plt.gca()
scatter_kwds = scatter_kwds or {}
if 's' not in scatter_kwds:
scatter_kwds['s'] = 20
mean_line_kwds = mean_line_kwds or {}
limit_lines_kwds = limit_lines_kwds or {}
for kwds in [mean_line_kwds, limit_lines_kwds]:
if 'color' not in kwds:
kwds['color'] = 'gray'
if 'linewidth' not in kwds:
kwds['linewidth'] = 1
if 'linestyle' not in mean_line_kwds:
kwds['linestyle'] = '--'
if 'linestyle' not in limit_lines_kwds:
kwds['linestyle'] = ':'
ax.scatter(means, diffs, **scatter_kwds)
ax.axhline(mean_diff, **mean_line_kwds) # draw mean line.
# Annotate mean line with mean difference.
ax.annotate('mean diff:\n{}'.format(np.round(mean_diff, 2)),
xy=(0.99, 0.5),
horizontalalignment='right',
verticalalignment='center',
fontsize=14,
xycoords='axes fraction')
if sd_limit > 0:
half_ylim = (1.5 * sd_limit) * std_diff
ax.set_ylim(mean_diff - half_ylim,
mean_diff + half_ylim)
limit_of_agreement = sd_limit * std_diff
lower = mean_diff - limit_of_agreement
upper = mean_diff + limit_of_agreement
for j, lim in enumerate([lower, upper]):
ax.axhline(lim, **limit_lines_kwds)
ax.annotate('-SD{}: {}'.format(sd_limit, np.round(lower, 2)),
xy=(0.99, 0.07),
horizontalalignment='right',
verticalalignment='bottom',
fontsize=14,
xycoords='axes fraction')
ax.annotate('+SD{}: {}'.format(sd_limit, np.round(upper, 2)),
xy=(0.99, 0.92),
horizontalalignment='right',
fontsize=14,
xycoords='axes fraction')
elif sd_limit == 0:
half_ylim = 3 * std_diff
ax.set_ylim(mean_diff - half_ylim,
mean_diff + half_ylim)
ax.set_ylabel('Difference', fontsize=15)
ax.set_xlabel('Means', fontsize=15)
ax.tick_params(labelsize=13)
plt.tight_layout()
return ax