-
-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[solidago] feat: Asymetric uncertainty #1781
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,13 +4,14 @@ | |
using coordinate descent. | ||
""" | ||
import random | ||
from typing import Tuple, Callable | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from numba import njit | ||
|
||
from solidago.comparisons_to_scores.base import ComparisonsToScoresAlgorithm | ||
from solidago.solvers.optimize import brentq | ||
from solidago.solvers.optimize import brentq, SignChangeIntervalNotFoundError | ||
|
||
|
||
DEFAULT_ALPHA = 0.20 # Signal-to-noise hyperparameter | ||
|
@@ -29,6 +30,19 @@ def contributor_loss_partial_derivative(theta_a, theta_b, r_ab, alpha): | |
+ r_ab | ||
) | ||
|
||
@njit | ||
def continuous_bradley_terry_log_likelihood(theta_a, theta_b, r_ab, r_max) -> float: | ||
theta_ab = theta_a - theta_b | ||
normalized_r_ab = r_ab / r_max | ||
positive_exponential_term = np.exp((normalized_r_ab + 1) * theta_ab) | ||
negative_exponential_term = np.exp((normalized_r_ab - 1) * theta_ab) | ||
return np.where( | ||
np.abs(theta_ab) < EPSILON, | ||
1 / 2, | ||
np.log(theta_ab / (positive_exponential_term - negative_exponential_term)), | ||
).sum() | ||
|
||
|
||
|
||
@njit | ||
def Delta_theta(theta_ab): | ||
|
@@ -39,6 +53,64 @@ def Delta_theta(theta_ab): | |
).sum() ** (-0.5) | ||
|
||
|
||
HIGH_LIKELIHOOD_RANGE_THRESHOLD = 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @lenhoanglnh Do you confirm the idea of using a Likelihood lower bound to compute the uncertainty interval, rather than a more standard 90% confidence interval? |
||
|
||
|
||
@njit | ||
def translated_function(x, f, translation, args=()): | ||
"""Returns the function x => f(x) - translation""" | ||
return f(x, *args) - translation | ||
|
||
def get_high_likelihood_range( | ||
log_likelihood, | ||
maximum_a_posteriori: float, | ||
threshold: float = HIGH_LIKELIHOOD_RANGE_THRESHOLD, | ||
args=(), | ||
) -> Tuple[float, float]: | ||
""" | ||
Find a root of a function in a bracketing interval using Brent's method | ||
adapted from Scipy's brentq. | ||
Uses the classic Brent's method to find a zero of the function `f` on | ||
the sign changing interval [a , b]. | ||
|
||
Parameters | ||
---------- | ||
likelihood_function: | ||
Python function computing a log likelihood. | ||
`f` must be continuous and concave. | ||
`f` must be jitted via numba. | ||
maximum_a_posteriori: | ||
The high liklihood position selected as most likely based on the prior | ||
distribution and the observed likelihood | ||
threshold: | ||
The threshold used to compute the high likelihood range. The range will | ||
be the interval with where we have | ||
log_likelihood > log_likelihood(maximum_a_posteriori) - threshold | ||
The threshold must be strictly positive. | ||
|
||
Returns | ||
------- | ||
interval: | ||
A tuple of float representing an interval containing the | ||
maximum_a_posteriori. | ||
""" | ||
if threshold <= 0: | ||
raise ValueError("`threshold` must be strictly positive") | ||
log_likelihood_at_maximum_a_posteriori = log_likelihood(maximum_a_posteriori, *args) | ||
min_log_likelihood = log_likelihood_at_maximum_a_posteriori - threshold | ||
|
||
try: | ||
lower_bound = brentq(translated_function, a=maximum_a_posteriori-1, b=maximum_a_posteriori, search_b=False, args=(log_likelihood, min_log_likelihood, args)) | ||
except SignChangeIntervalNotFoundError: | ||
lower_bound = -np.inf | ||
try: | ||
upper_bound = brentq(translated_function, a=maximum_a_posteriori, b=maximum_a_posteriori+1, search_a=False, args=(log_likelihood, min_log_likelihood, args)) | ||
except SignChangeIntervalNotFoundError: | ||
upper_bound = np.inf | ||
|
||
return lower_bound, upper_bound | ||
|
||
|
||
@njit | ||
def coordinate_optimize(r_ab, theta_b, precision, alpha): | ||
return brentq( | ||
|
@@ -91,17 +163,17 @@ def pick_next_coordinate(): | |
unchanged.clear() | ||
return theta | ||
|
||
def compute_individual_scores(self, scores: pd.DataFrame, initial_entity_scores=None): | ||
scores = scores[["entity_a", "entity_b", "score"]] | ||
def compute_individual_scores(self, comparison_scores: pd.DataFrame, initial_entity_scores=None): | ||
comparison_scores = comparison_scores[["entity_a", "entity_b", "score"]] | ||
scores_sym = ( | ||
pd.concat( | ||
[ | ||
scores, | ||
comparison_scores, | ||
pd.DataFrame( | ||
{ | ||
"entity_a": scores.entity_b, | ||
"entity_b": scores.entity_a, | ||
"score": -1 * scores.score, | ||
"entity_a": comparison_scores.entity_b, | ||
"entity_b": comparison_scores.entity_a, | ||
"score": -1 * comparison_scores.score, | ||
} | ||
), | ||
] | ||
|
@@ -124,14 +196,25 @@ def compute_individual_scores(self, scores: pd.DataFrame, initial_entity_scores= | |
initial_scores = initial_scores.to_numpy() | ||
theta_star_numpy = self.coordinate_descent(coord_to_subset, initial_scores=initial_scores) | ||
delta_star_numpy = np.zeros(len(theta_star_numpy)) | ||
raw_score_lower_bound = np.zeros(len(theta_star_numpy)) | ||
raw_score_upper_bound = np.zeros(len(theta_star_numpy)) | ||
for idx_a in range(len(theta_star_numpy)): | ||
indices_b, _r_ab = coord_to_subset[idx_a] | ||
indices_b, r_ab = coord_to_subset[idx_a] | ||
lower_bound, upper_bound = get_high_likelihood_range( | ||
continuous_bradley_terry_log_likelihood, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @lenhoanglnh When calculating the high likelihood range, should we include only the likelihood of observed comparisons, or also the prior/regularization (something like |
||
theta_star_numpy[idx_a], | ||
args=(theta_star_numpy[indices_b], r_ab, self.r_max), | ||
) | ||
raw_score_lower_bound[idx_a] = lower_bound | ||
raw_score_upper_bound[idx_a] = upper_bound | ||
delta_star_numpy[idx_a] = Delta_theta(theta_star_numpy[idx_a] - theta_star_numpy[indices_b]) | ||
|
||
result = pd.DataFrame( | ||
{ | ||
"raw_score": theta_star_numpy, | ||
"raw_uncertainty": delta_star_numpy, | ||
"raw_score_lower_bound": raw_score_lower_bound, | ||
"raw_score_upper_bound": raw_score_upper_bound, | ||
}, | ||
index=entities_index, | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,7 @@ | |
All rights reserved. | ||
""" | ||
# pylint: skip-file | ||
from typing import Tuple | ||
from typing import Tuple, Callable, Optional | ||
|
||
import numpy as np | ||
from numba import njit | ||
|
@@ -39,9 +39,64 @@ def _bisect_interval(a, b, fa, fb) -> Tuple[float, int]: | |
|
||
return root, status | ||
|
||
class SignChangeIntervalNotFoundError(RuntimeError): | ||
pass | ||
|
||
@njit | ||
def brentq(f, args=(), xtol=_xtol, rtol=_rtol, maxiter=_iter, disp=True, a: float=-1.0, b: float=1.0) -> float: | ||
def search_sign_change_interval( | ||
f: Callable, | ||
a: float, | ||
b: float, | ||
args: Tuple = (), | ||
max_iterations: int = 32, | ||
search_a: bool = True, | ||
search_b: bool = True, | ||
): | ||
""" | ||
Searches bounds a and b of interval where `f` changes sign. This is | ||
achieved by increasing the size of the interval iteratively. | ||
Note that the method is not guaranteed to succeed for most functions | ||
and highly depends on the initial bounds. | ||
|
||
Parameters | ||
---------- | ||
f : jitted and callable | ||
Python function returning a number. `f` must be continuous. | ||
a : number | ||
One end of the bracketing interval [a,b]. | ||
b : number | ||
The other end of the bracketing interval [a,b]. | ||
args : tuple, optional(default=()) | ||
Extra arguments to be used in the function call. | ||
max_iterations: | ||
The maximum number of iteration in the search. /!\ When using a | ||
large number of iterations, bounds would become very large and | ||
functions may not be well behaved. | ||
search_a: | ||
If true, the value of `a` provided will be updated to search for an | ||
interval where `f` changes sign | ||
search_b: | ||
If true, the value of `b` provided will be updated to search for an | ||
interval where `f` changes sign | ||
|
||
Returns | ||
------- | ||
a, b: | ||
An interval on which the continuous function `f` changes sign | ||
""" | ||
if a >= b: | ||
raise ValueError(f"Initial interval bounds should be such that a < b. Found a={a} and b={b}") | ||
iteration_count = 0 | ||
while f(a, *args) * f(b, *args) > 0: | ||
if iteration_count > max_iterations: | ||
raise SignChangeIntervalNotFoundError("Could not find a sign changing interval") | ||
iteration_count+=1 | ||
a = a-(b-a) if search_a else a | ||
b = b+(b-a) if search_b else b | ||
return a, b | ||
|
||
@njit | ||
def brentq(f, args=(), xtol=_xtol, rtol=_rtol, maxiter=_iter, disp=True, a: float=-1.0, b: float=1.0, search_a: bool=True, search_b: bool = True) -> float: | ||
""" | ||
Find a root of a function in a bracketing interval using Brent's method | ||
adapted from Scipy's brentq. | ||
|
@@ -69,14 +124,23 @@ def brentq(f, args=(), xtol=_xtol, rtol=_rtol, maxiter=_iter, disp=True, a: floa | |
Maximum number of iterations. | ||
disp : bool, optional(default=True) | ||
If True, raise a RuntimeError if the algorithm didn't converge. | ||
search_a: | ||
If true, the value of `a` provided will be updated to search for an | ||
interval where `f` changes sign | ||
search_b: | ||
If true, the value of `b` provided will be updated to search for an | ||
interval where `f` changes sign | ||
Returns | ||
------- | ||
root : float | ||
""" | ||
while f(a, *args) > 0: | ||
a = a - 2 * (b-a) | ||
while f(b, *args) < 0: | ||
b = b + 2 * (b-a) | ||
a, b = search_sign_change_interval(f, a, b, args=args, search_a=search_a, search_b=search_b) | ||
if f(a, *args) == 0: | ||
return a | ||
if f(b, *args) == 0: | ||
return b | ||
if f(a, *args) * f(b, *args) > 0: | ||
raise ValueError("Function `f` should have opposite sign on bounds `a` and `b`") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. optimize by not calling f(a) and f(b) twice |
||
|
||
if xtol <= 0: | ||
raise ValueError("xtol is too small (<= 0)") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably log of 1/2 instead