Skip to content

Commit

Permalink
Merge pull request #5 from samueledelia/3-autodiff-bs-greeks
Browse files Browse the repository at this point in the history
Computed greeks leveraging automatic differentation
  • Loading branch information
paolodelia99 authored Feb 19, 2024
2 parents 11083b4 + 1fc3132 commit 8e93b38
Show file tree
Hide file tree
Showing 10 changed files with 580 additions and 57 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ jaxfin/examples/requirements.txt
test_report.xml
test_report.html

## Coverage results
.coverage
htmlcov/

# Version file
jaxfin.VERSION
version.txt
4 changes: 3 additions & 1 deletion jaxfin/price_engine/black/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Option price computed with the Black'76 model"""
from jaxfin.price_engine.black.black_model import black_price
from jaxfin.price_engine.black.black_model import black_price, delta_black, gamma_black

future_option_price = black_price
future_option_delta = delta_black
future_option_gamma = gamma_black
122 changes: 117 additions & 5 deletions jaxfin/price_engine/black/black_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
import jax
import jax.numpy as jnp
from jax import grad, vmap

from ..common import compute_undiscounted_call_prices
from ..utils import cast_arrays
Expand All @@ -13,7 +14,7 @@ def black_price(
strikes: jax.Array,
expires: jax.Array,
vols: jax.Array,
discount_rates: jax.Array = None,
discount_rates: jax.Array,
dividend_rates: jax.Array = None,
are_calls: jax.Array = None,
dtype: jnp.dtype = None,
Expand All @@ -37,9 +38,6 @@ def black_price(
[spots, strikes, expires, vols], dtype
)

if discount_rates is None:
discount_rates = jnp.zeros(shape, dtype=dtype)

if dividend_rates is None:
dividend_rates = jnp.zeros(shape, dtype=dtype)

Expand All @@ -55,4 +53,118 @@ def black_price(

undiscounted_forwards = forwards - strikes
undiscouted_puts = undiscounted_calls - undiscounted_forwards
return discount_factors * jnp.where(are_calls, undiscounted_calls, undiscouted_puts)
return jnp.exp((-1 * discount_rates) * expires) * jnp.where(
are_calls, undiscounted_calls, undiscouted_puts
)


def _delta_black(
spots: jax.Array,
strikes: jax.Array,
expires: jax.Array,
vols: jax.Array,
discount_rates: jax.Array,
dividend_rates: jax.Array,
are_calls: jax.Array = None,
dtype: jnp.dtype = None,
) -> jax.Array:
"""
Compute the option deltas for european options using the Black '76 model.
:param spots: (jax.Array): Array of current asset prices.
:param strikes: (jax.Array): Array of option strike prices.
:param expires: (jax.Array): Array of option expiration times.
:param vols: (jax.Array): Array of option volatility values.
:param discount_rates: (jax.Array): Array of risk-free interest rates. Defaults to None.
:param dividend_rates: (jax.Array): Array of dividend rates. Defaults to None.
:param are_calls: (jax.Array): Array of booleans indicating whether options are calls (True) or puts (False).
:param dtype: (jnp.dtype): Data type of the output. Defaults to None.
:return: (jax.Array): Array of computed option deltas.
"""
return grad(black_price, argnums=0)(
spots, strikes, expires, vols, discount_rates, dividend_rates, are_calls, dtype
)


def _gamma_black(
spots: jax.Array,
strikes: jax.Array,
expires: jax.Array,
vols: jax.Array,
discount_rates: jax.Array,
dividend_rates: jax.Array,
are_calls: jax.Array = None,
dtype: jnp.dtype = None,
) -> jax.Array:
"""
Compute the option gammas for european options using the Black '76 model.
:param spots: (jax.Array): Array of current asset prices.
:param strikes: (jax.Array): Array of option strike prices.
:param expires: (jax.Array): Array of option expiration times.
:param vols: (jax.Array): Array of option volatility values.
:param discount_rates: (jax.Array): Array of risk-free interest rates. Defaults to None.
:param dividend_rates: (jax.Array): Array of dividend rates. Defaults to None.
:param are_calls: (jax.Array): Array of booleans indicating whether options are calls (True) or puts (False).
:param dtype: (jnp.dtype): Data type of the output. Defaults to None.
:return: (jax.Array): Array of computed option gammas.
"""
return grad(grad(black_price, argnums=0), argnums=0)(
spots, strikes, expires, vols, discount_rates, dividend_rates, are_calls, dtype
)


def delta_black(
spots: jax.Array,
strikes: jax.Array,
expires: jax.Array,
vols: jax.Array,
discount_rates: jax.Array,
dividend_rates: jax.Array = None,
are_calls: jax.Array = None,
dtype: jnp.dtype = None,
) -> jax.Array:
"""
Compute the option deltas for european options using the Black '76 model. (vectorized)
:param spots: (jax.Array): Array of current asset prices.
:param strikes: (jax.Array): Array of option strike prices.
:param expires: (jax.Array): Array of option expiration times.
:param vols: (jax.Array): Array of option volatility values.
:param discount_rates: (jax.Array): Array of risk-free interest rates. Defaults to None.
:param dividend_rates: (jax.Array): Array of dividend rates. Defaults to None.
:param are_calls: (jax.Array): Array of booleans indicating whether options are calls (True) or puts (False).
:param dtype: (jnp.dtype): Data type of the output. Defaults to None.
:return: (jax.Array): Array of computed option deltas.
"""
return vmap(_delta_black, in_axes=(0, 0, 0, 0, 0, 0, 0, None))(
spots, strikes, expires, vols, discount_rates, dividend_rates, are_calls, dtype
)


def gamma_black(
spots: jax.Array,
strikes: jax.Array,
expires: jax.Array,
vols: jax.Array,
discount_rates: jax.Array,
dividend_rates: jax.Array = None,
are_calls: jax.Array = None,
dtype: jnp.dtype = None,
) -> jax.Array:
"""
Compute the option gammas for european options using the Black '76 model. (vectorized)
:param spots: (jax.Array): Array of current asset prices.
:param strikes: (jax.Array): Array of option strike prices.
:param expires: (jax.Array): Array of option expiration times.
:param vols: (jax.Array): Array of option volatility values.
:param discount_rates: (jax.Array): Array of risk-free interest rates. Defaults to None.
:param dividend_rates: (jax.Array): Array of dividend rates. Defaults to None.
:param are_calls: (jax.Array): Array of booleans indicating whether options are calls (True) or puts (False).
:param dtype: (jnp.dtype): Data type of the output. Defaults to None.
:return: (jax.Array): Array of computed option gammas.
"""
return vmap(_gamma_black, in_axes=(0, 0, 0, 0, 0, 0, 0, None))(
spots, strikes, expires, vols, discount_rates, dividend_rates, are_calls, dtype
)
6 changes: 6 additions & 0 deletions jaxfin/price_engine/black_scholes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,14 @@
bs_price,
delta_vanilla,
gamma_vanilla,
rho_vanilla,
theta_vanilla,
vega_vanilla,
)

european_price = bs_price
delta_european = delta_vanilla
gamma_european = gamma_vanilla
theta_european = theta_vanilla
rho_european = rho_vanilla
vega_european = vega_vanilla
Loading

0 comments on commit 8e93b38

Please sign in to comment.