Skip to content

Commit

Permalink
7 bugfix review on how vmap is used in the greeks calculations (#8)
Browse files Browse the repository at this point in the history
#7: Add JAX jit optimization to improve performance when calculating the Greeks, and able to use scalars
  • Loading branch information
paolodelia99 authored Feb 25, 2024
1 parent ff6a4f3 commit 4c45123
Show file tree
Hide file tree
Showing 4 changed files with 282 additions and 139 deletions.
118 changes: 81 additions & 37 deletions jaxfin/price_engine/black/black_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""
Black '76 prices for options on forwards and futures
"""
from typing import Union

import jax
import jax.numpy as jnp
from jax import grad, vmap
from jax import grad, jit, vmap

from ..common import compute_undiscounted_call_prices
from ..utils import cast_arrays
Expand Down Expand Up @@ -58,6 +60,7 @@ def black_price(
)


@jit
def _delta_black(
spots: jax.Array,
strikes: jax.Array,
Expand Down Expand Up @@ -86,6 +89,7 @@ def _delta_black(
)


@jit
def _gamma_black(
spots: jax.Array,
strikes: jax.Array,
Expand Down Expand Up @@ -115,56 +119,96 @@ def _gamma_black(


def delta_black(
spots: jax.Array,
strikes: jax.Array,
expires: jax.Array,
vols: jax.Array,
discount_rates: jax.Array,
dividend_rates: jax.Array = None,
are_calls: jax.Array = None,
spots: Union[jax.Array, float],
strikes: Union[jax.Array, float],
expires: Union[jax.Array, float],
vols: Union[jax.Array, float],
discount_rates: Union[jax.Array, float],
dividend_rates: Union[jax.Array, float] = None,
are_calls: Union[jax.Array, bool] = None,
dtype: jnp.dtype = None,
) -> jax.Array:
) -> Union[jax.Array, float]:
"""
Compute the option deltas for european options using the Black '76 model. (vectorized)
:param spots: (jax.Array): Array of current asset prices.
:param strikes: (jax.Array): Array of option strike prices.
:param expires: (jax.Array): Array of option expiration times.
:param vols: (jax.Array): Array of option volatility values.
:param discount_rates: (jax.Array): Array of risk-free interest rates. Defaults to None.
:param dividend_rates: (jax.Array): Array of dividend rates. Defaults to None.
:param are_calls: (jax.Array): Array of booleans indicating whether options are calls (True) or puts (False).
:param spots: (Union[jax.Array, float]): Current asset price or array of prices.
:param strikes: (Union[jax.Array, float]): Option strike price or array of prices.
:param expires: (Union[jax.Array, float]): Option expiration time or array of times.
:param vols: (Union[jax.Array, float]): Option volatility value or array of values.
:param discount_rates: (Union[jax.Array, float]): Risk-free interest rate or array of rates.
:param dividend_rates: (Union[jax.Array, float]): Dividend rate or array of rates. Defaults to None.
:param are_calls: (Union[jax.Array, bool]): Boolean indicating whether option is a
call or put, or array of booleans.
:param dtype: (jnp.dtype): Data type of the output. Defaults to None.
:return: (jax.Array): Array of computed option deltas.
:return: (Union[jax.Array, float]): Delta of the given option or array of deltas.
"""
return vmap(_delta_black, in_axes=(0, 0, 0, 0, 0, 0, 0, None))(
spots, strikes, expires, vols, discount_rates, dividend_rates, are_calls, dtype
if jnp.isscalar(spots) or spots.shape == ():
return _delta_black(
spots,
strikes,
expires,
vols,
discount_rates,
dividend_rates,
are_calls,
dtype,
)

return jit(vmap(_delta_black, in_axes=(0, 0, 0, 0, 0, 0, 0, None)))(
spots,
strikes,
expires,
vols,
discount_rates,
dividend_rates,
are_calls,
dtype,
)


def gamma_black(
spots: jax.Array,
strikes: jax.Array,
expires: jax.Array,
vols: jax.Array,
discount_rates: jax.Array,
dividend_rates: jax.Array = None,
are_calls: jax.Array = None,
spots: Union[jax.Array, float],
strikes: Union[jax.Array, float],
expires: Union[jax.Array, float],
vols: Union[jax.Array, float],
discount_rates: Union[jax.Array, float],
dividend_rates: Union[jax.Array, float] = None,
are_calls: Union[jax.Array, bool] = None,
dtype: jnp.dtype = None,
) -> jax.Array:
) -> Union[jax.Array, float]:
"""
Compute the option gammas for european options using the Black '76 model. (vectorized)
:param spots: (jax.Array): Array of current asset prices.
:param strikes: (jax.Array): Array of option strike prices.
:param expires: (jax.Array): Array of option expiration times.
:param vols: (jax.Array): Array of option volatility values.
:param discount_rates: (jax.Array): Array of risk-free interest rates. Defaults to None.
:param dividend_rates: (jax.Array): Array of dividend rates. Defaults to None.
:param are_calls: (jax.Array): Array of booleans indicating whether options are calls (True) or puts (False).
:param spots: (Union[jax.Array, float]): Current asset price or array of prices.
:param strikes: (Union[jax.Array, float]): Option strike price or array of prices.
:param expires: (Union[jax.Array, float]): Option expiration time or array of times.
:param vols: (Union[jax.Array, float]): Option volatility value or array of values.
:param discount_rates: (Union[jax.Array, float]): Risk-free interest rate or array of rates.
:param dividend_rates: (Union[jax.Array, float]): Dividend rate or array of rates. Defaults to None.
:param are_calls: (Union[jax.Array, bool]): Boolean indicating whether option is a
call or put, or array of booleans.
:param dtype: (jnp.dtype): Data type of the output. Defaults to None.
:return: (jax.Array): Array of computed option gammas.
:return: (Union[jax.Array, float]): Gamma of the given option or array of gammas.
"""
return vmap(_gamma_black, in_axes=(0, 0, 0, 0, 0, 0, 0, None))(
spots, strikes, expires, vols, discount_rates, dividend_rates, are_calls, dtype
if jnp.isscalar(spots) or spots.shape == ():
return _gamma_black(
spots,
strikes,
expires,
vols,
discount_rates,
dividend_rates,
are_calls,
dtype,
)

return jit(vmap(_gamma_black, in_axes=(0, 0, 0, 0, 0, 0, 0, None)))(
spots,
strikes,
expires,
vols,
discount_rates,
dividend_rates,
are_calls,
dtype,
)
Loading

0 comments on commit 4c45123

Please sign in to comment.