Skip to content

Commit

Permalink
Added faster KDE and CDF generation using upsampling
Browse files Browse the repository at this point in the history
  • Loading branch information
KulikDM committed Aug 18, 2024
1 parent 0798d36 commit c1072f7
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ repos:
name: Sort imports

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.5.6
rev: v0.6.1
hooks:
- id: ruff
args: [--exit-non-zero-on-fix, --fix, --line-length=180]
Expand Down
34 changes: 30 additions & 4 deletions pythresh/thresholds/thresh_utility.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import scipy.stats as stats
from scipy.interpolate import interp1d
from scipy.special import ndtr
from sklearn.decomposition import TruncatedSVD
from sklearn.utils import check_array
Expand All @@ -20,24 +21,49 @@ def cut(decision, limit):
return labels


def gen_interp(x, y):

interpolator = interp1d(x, y, kind='cubic',
fill_value='extrapolate')

return interpolator


def gen_kde(data, lower, upper, size):

insize = min(size, 5000)

# Create a KDE of the data
kde = stats.gaussian_kde(data)
dat_range = np.linspace(lower, upper, size)
dat_range = np.linspace(lower, upper, insize)
dat_eval = np.linspace(lower, upper, size)

return kde(dat_range), dat_range
# Use interpolation for fast KDE upsampling
if size > insize:
interpolator = gen_interp(dat_range, kde(dat_range))
return interpolator(dat_eval), dat_eval

return kde(dat_eval), dat_eval


def gen_cdf(data, lower, upper, size):

insize = min(size, 5000)

# Create a KDE & CDF of the data
kde = stats.gaussian_kde(data)
dat_range = np.linspace(lower, upper, size)
dat_range = np.linspace(lower, upper, insize)
dat_eval = np.linspace(lower, upper, size)

cdf = np.array(tuple(ndtr(np.ravel(item - kde.dataset) / kde.factor).mean()
for item in dat_range))

return cdf, dat_range
# Use interpolation for fast CDF upsampling
if size > insize:
interpolator = gen_interp(dat_range, cdf)
return interpolator(dat_eval), dat_eval

return cdf, dat_eval


def check_scores(decision, random_state=1234):
Expand Down

0 comments on commit c1072f7

Please sign in to comment.