Skip to content

Commit

Permalink
add docstring on njit_brentq
Browse files Browse the repository at this point in the history
  • Loading branch information
amatissart committed Jun 1, 2024
1 parent 6004536 commit d059bb2
Showing 1 changed file with 34 additions and 14 deletions.
48 changes: 34 additions & 14 deletions solidago/src/solidago/solvers/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Copyright © 2013-2021 Thomas J. Sargent and John Stachurski: BSD-3
All rights reserved.
"""
# pylint: skip-file

from typing import Callable, Tuple, Literal

import numpy as np
Expand Down Expand Up @@ -47,29 +47,49 @@ def njit_brentq(
xtol=_xtol,
rtol=_rtol,
maxiter=_iter,
a: float=-1.0,
b: float=1.0,
a: float = -1.0,
b: float = 1.0,
extend_bounds: Literal["ascending", "descending", "no"] = "ascending",
) -> float:
""" `Accelerated brentq. Requires f to be itself jitted via numba.
"""Accelerated brentq. Requires f to be itself jitted via numba.
Essentially, numba optimizes the execution by running an optimized compilation
of the function when it is first called, and by then running the compiled function.
Parameters
----------
f : jitted and callable
Python function returning a number. `f` must be continuous.
args : tuple, optional(default=())
Extra arguments to be used in the function call.
xtol : number, optional(default=2e-12)
The computed root ``x0`` will satisfy ``np.allclose(x, x0,
atol=xtol, rtol=rtol)``, where ``x`` is the exact root. The
parameter must be nonnegative.
rtol : number, optional(default=`4*np.finfo(float).eps`)
The computed root ``x0`` will satisfy ``np.allclose(x, x0,
atol=xtol, rtol=rtol)``, where ``x`` is the exact root.
maxiter : number, optional(default=100)
Maximum number of iterations.
a : number
One end of the bracketing interval [a,b].
b : number
The other end of the bracketing interval [a,b].
extend_bounds: default: "ascending",
Whether to extend the interval [a,b] to find a root.
('no': to keep the bounds [a, b],
'ascending': extend the bounds assuming `f` is ascending,
'descending': extend the bounds assuming `f` is descending)
"""
if extend_bounds == "ascending":
while f(a, *args) > 0:
a = a - 2 * (b-a)
a = a - 2 * (b - a)
while f(b, *args) < 0:
b = b + 2 * (b-a)
b = b + 2 * (b - a)
elif extend_bounds == "descending":
while f(a, *args) < 0:
a = a - 2 * (b-a)
a = a - 2 * (b - a)
while f(b, *args) > 0:
b = b + 2 * (b-a)
b = b + 2 * (b - a)

if xtol <= 0:
raise ValueError("xtol is too small (<= 0)")
Expand Down Expand Up @@ -157,13 +177,13 @@ def njit_brentq(
def coordinate_descent(
update_coordinate_function: Callable[[Tuple, float], float],
get_args: Callable[[int, np.ndarray], Tuple],
initialization: np.ndarray,
initialization: np.ndarray,
updated_coordinates: list[int],
error: float = 1e-5
error: float = 1e-5,
):
""" Minimize a loss function with coordinate descent,
"""Minimize a loss function with coordinate descent,
by leveraging the partial derivatives of the loss
Parameters
----------
loss_partial_derivative: callable
Expand All @@ -174,7 +194,7 @@ def coordinate_descent(
Initialization point of the coordinate descent
error: float
Tolerated error
Returns
-------
out: stationary point of the loss
Expand Down

0 comments on commit d059bb2

Please sign in to comment.