Skip to content

Commit

Permalink
#TIDY: Updated doc string in gmb.py file
Browse files Browse the repository at this point in the history
  • Loading branch information
paolodelia99 committed Feb 24, 2024
1 parent 807d17e commit ff6a4f3
Showing 1 changed file with 13 additions and 18 deletions.
31 changes: 13 additions & 18 deletions jaxfin/models/gbm/gbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,12 @@ class UnivGeometricBrownianMotion:
Represent a 1-dimensional GBM
# Example usage:
params = {
's0' : 10,
'dtype' : jnp.float32,
'mean' : 0.1,
'sigma': 0.3
}
gmb_process = GeometricBrownianMotion(**params)
paths = gmb_process.simulate_paths(maturity=1.0, n=100, n_sim=100)
s0 = jnp.array([100], dtype=jnp.float32)
mean = jnp.array([0.1], dtype=jnp.float32)
sigma = jnp.array([0.3], dtype=jnp.float32)
gbm_process = UnivGeometricBrownianMotion(s0, mean, sigma, dtype=jnp.float32)
paths = gbm_process.simulate_paths(maturity=1.0, n=100, n_sim=100)
"""

def __init__(self, s0, mean, sigma, dtype):
Expand Down Expand Up @@ -89,18 +86,16 @@ class MultiGeometricBrownianMotion:
Represent a d-dimensional GBM
# Example usage:
params = {
's0' : [10, 12],
'dtype' : jnp.float32,
'mean' : 0.1,
'cov': [[0.3, 0.1], [0.1, 0.5]]
}
gmb_process = GeometricBrownianMotion(**params)
s0 = jnp.array([100, 100], dtype=jnp.float32)
mean = jnp.array([0.1, 0.1], dtype=jnp.float32)
sigma = jnp.array([0.3, 0.3], dtype=jnp.float32)
corr = jnp.array([[1.0, 0.5], [0.5, 1.0]], dtype=jnp.float32)
m_gbm = MultiGeometricBrownianMotion(s0, mean, sigma, corr, jnp.float32)
paths = gmb_process.simulate_paths(maturity=1.0, n=100, n_sim=100)
"""

def __init__(self, s0, mean, sigma, corr, dtype):
def __init__(self, s0, mean, sigma, corr, dtype=jnp.float32):
if dtype is None:
raise ValueError("dtype must not be None")

Expand Down

0 comments on commit ff6a4f3

Please sign in to comment.