Skip to content

Commit

Permalink
Merge pull request #6 from samueledelia/1-multivariategbm
Browse files Browse the repository at this point in the history
Multivariate GBM implementation
  • Loading branch information
paolodelia99 authored Feb 20, 2024
2 parents 8e93b38 + 899542d commit 807d17e
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion .idea/py-exopricer.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 5 additions & 2 deletions jaxfin/models/gbm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""
Geometric Brownian motion module
"""
from jaxfin.models.gbm.gbm import UnivGeometricBrownianMotion
from jaxfin.models.gbm.gbm import (
MultiGeometricBrownianMotion,
UnivGeometricBrownianMotion,
)

__all__ = ["UnivGeometricBrownianMotion"]
__all__ = ["UnivGeometricBrownianMotion", "MultiGeometricBrownianMotion"]
121 changes: 121 additions & 0 deletions jaxfin/models/gbm/gbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,124 @@ def simulate_paths(self, seed: int, maturity, n: int, n_sim: int) -> jax.Array:
Xt = jnp.vstack([jnp.ones(n_sim), Xt])

return self._s0 * Xt.cumprod(axis=0)


class MultiGeometricBrownianMotion:
"""
Geometric Brownian Motion
Represent a d-dimensional GBM
# Example usage:
params = {
's0' : [10, 12],
'dtype' : jnp.float32,
'mean' : 0.1,
'cov': [[0.3, 0.1], [0.1, 0.5]]
}
gmb_process = GeometricBrownianMotion(**params)
paths = gmb_process.simulate_paths(maturity=1.0, n=100, n_sim=100)
"""

def __init__(self, s0, mean, sigma, corr, dtype):
if dtype is None:
raise ValueError("dtype must not be None")

if not _check_symmetric(corr, 1e-8):
raise ValueError("Correlation matrix must be symmetric")

if not jnp.array_equal(jnp.diag(corr), jnp.ones(corr.shape[0])):
raise ValueError("Correlation matrix must have ones as diagonal elements")

self._dtype = dtype
self._s0 = jnp.asarray(s0, dtype=dtype)
self._mean = jnp.asarray(mean, dtype=dtype)
self._sigma = jnp.asarray(sigma, dtype=dtype)
self._corr = jnp.asarray(corr, dtype=dtype)
self._dim = self._s0.shape[0]

@property
def mean(self):
"""
:return: Returns the mean of the GBM
"""
return self._mean

@property
def sigma(self):
"""
:return: Returns the standard deviation of the GBM
"""
return self._sigma

@property
def corr(self):
"""
:return: Returns the correlation matrix of the Weiner processes
"""
return self._corr

@property
def s0(self):
"""
:return: Returns the initial value of the GBM
"""
return self._s0

@property
def dtype(self):
"""
:return: Returns the underlying dtype of the GBM
"""
return self._dtype

@property
def dimension(self):
"""
:return: Returns the dimension of the GBM
"""
return self._dim

def simulate_paths(self, seed: int, maturity, n: int, n_sim: int) -> jax.Array:
"""
Simulate a sample of paths from the GBM
:param maturity: time in years
:param n: (int): number of steps
:return: (jax.Array): Array containing the sample paths
"""
key = random.PRNGKey(seed)

dt = maturity / n

normal_draw = random.normal(key, shape=(n_sim, n * self._dim))
normal_draw = jnp.reshape(normal_draw, (n_sim, n, self._dim)).transpose(
(1, 0, 2)
)

cholesky_matrix = jnp.linalg.cholesky(self._corr)

stochastic_increments = normal_draw @ cholesky_matrix

log_increments = (self._mean - self._sigma**2 / 2) * dt + jnp.sqrt(
dt
) * self._sigma * stochastic_increments

once = jnp.ones([n, n], dtype=self._dtype)
lower_triangular = jnp.tril(once, k=-1)
cumsum = log_increments.transpose() @ lower_triangular
cumsum = cumsum.transpose((1, 2, 0))
samples = self._s0 * jnp.exp(cumsum)
return samples.transpose(1, 0, 2)[::-1, :, :]


def _check_symmetric(a, tol=1e-8):
"""
Check if a matrix is symmetric
:param a: (jax.Array): Matrix to check
:param tol: (float): Tolerance for the check
:return: (bool): True if the matrix is symmetric
"""
return jnp.all(jnp.abs(a - a.T) < tol)
44 changes: 43 additions & 1 deletion tests/models/test_gbm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import jax.numpy as jnp

from jaxfin.models.gbm import UnivGeometricBrownianMotion
from jaxfin.models.gbm import UnivGeometricBrownianMotion, MultiGeometricBrownianMotion

SEED: int = 42

class TestUnivGBM:

Expand All @@ -16,3 +17,44 @@ def test_init(self):
assert gbm.sigma == sigma
assert gbm.dtype == dtype
assert gbm.s0 == s0

def test_sim_paths_shape(self):
s0 = 10
mean = 0.1
sigma = 0.3
dtype = jnp.float32
gbm = UnivGeometricBrownianMotion(s0, mean, sigma, dtype)

stock_paths = gbm.simulate_paths(SEED, 1.0, 52, 100)

assert stock_paths.shape == (52, 100)


class TestMultiGBM:

def test_init(self):
s0 = jnp.array([10, 12])
mean = jnp.array([0.1, 0.0])
sigma = jnp.array([0.3, 0.5])
corr = jnp.array([[1, 0.1], [0.1, 1]])
dtype = jnp.float32
gbm = MultiGeometricBrownianMotion(s0, mean, sigma, corr, dtype)

assert jnp.array_equal(gbm.mean, mean)
assert jnp.array_equal(gbm.sigma, sigma)
assert jnp.array_equal(gbm.corr, corr)
assert gbm.dtype == dtype
assert jnp.array_equal(gbm.s0, s0)
assert gbm.dimension == 2


def test_sample_path(self):
s0 = jnp.array([10, 12])
mean = jnp.array([0.1, 0.0])
sigma = jnp.array([0.3, 0.5])
corr = jnp.array([[1, 0.1], [0.1, 1]])
dtype = jnp.float32
gbm = MultiGeometricBrownianMotion(s0, mean, sigma, corr, dtype)
sample_path = gbm.simulate_paths(SEED, 1.0, 52, 100)

assert sample_path.shape == (52, 100, 2)

0 comments on commit 807d17e

Please sign in to comment.