Skip to content

Commit

Permalink
Update plots in dist_select
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander März committed Aug 24, 2023
1 parent f267cdd commit 5a5e45f
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 103 deletions.
40 changes: 14 additions & 26 deletions docs/examples/How_To_Select_A_Multivariate_Distribution.ipynb

Large diffs are not rendered by default.

54 changes: 21 additions & 33 deletions docs/examples/How_To_Select_A_Univariate_Distribution.ipynb

Large diffs are not rendered by default.

32 changes: 10 additions & 22 deletions xgboostlss/distributions/distribution_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from tqdm import tqdm

from typing import Any, Dict, Optional, List, Tuple
from plotnine import *
import matplotlib.pyplot as plt
import seaborn as sns
import warnings


Expand Down Expand Up @@ -566,7 +567,6 @@ def dist_select(self,
target: np.ndarray,
candidate_distributions: List,
max_iter: int = 100,
n_samples: int = 1000,
plot: bool = False,
figure_size: tuple = (10, 5),
) -> pd.DataFrame:
Expand All @@ -582,8 +582,6 @@ def dist_select(self,
List of candidate distributions.
max_iter: int
Maximum number of iterations for the optimization.
n_samples: int
Number of samples to draw from the fitted distribution.
plot: bool
If True, a density plot of the actual and fitted distribution is created.
figure_size: tuple
Expand Down Expand Up @@ -643,29 +641,19 @@ def dist_select(self,
axis=1,
)
fitted_params = pd.DataFrame(fitted_params, columns=best_dist_sel.param_dict.keys())
fitted_params.columns = best_dist_sel.param_dict.keys()
n_samples = np.max([10000, target.shape[0]])
n_samples = np.where(n_samples > 500000, 100000, n_samples)
dist_samples = best_dist_sel.draw_samples(fitted_params,
n_samples=n_samples,
seed=123).values

# Plot actual and fitted distribution
plot_df_actual = pd.DataFrame({"y": target.reshape(-1,), "type": "Actual"})
plot_df_fitted = pd.DataFrame({"y": dist_samples.reshape(-1,),
"type": f"Best-Fit: {best_dist['distribution'].values[0]}"})
plot_df = pd.concat([plot_df_actual, plot_df_fitted])

print(
ggplot(plot_df,
aes(x="y",
color="type")) +
geom_density(alpha=0.5) +
theme_bw(base_size=15) +
theme(figure_size=figure_size,
legend_position="right",
legend_title=element_blank(),
plot_title=element_text(hjust=0.5)) +
labs(title=f"Actual vs. Fitted Density")
)
plt.figure(figsize=figure_size)
sns.kdeplot(target.reshape(-1, ), label="Actual")
sns.kdeplot(dist_samples.reshape(-1, ), label=f"Best-Fit: {best_dist['distribution'].values[0]}")
plt.legend()
plt.title("Actual vs. Best-Fit Density")
plt.show()

fit_df.drop(columns=["rank", "params"], inplace=True)

Expand Down
64 changes: 42 additions & 22 deletions xgboostlss/distributions/multivariate_distribution_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from tqdm import tqdm

from typing import Any, Dict, Optional, List, Tuple, Callable
from plotnine import *
import seaborn as sns
import warnings


Expand Down Expand Up @@ -506,10 +506,11 @@ def dist_select(self,
target: np.ndarray,
candidate_distributions: List,
max_iter: int = 100,
n_samples: int = 1000,
plot: bool = False,
ncol: int = 3,
figure_size: tuple = (10, 5),
height: float = 4,
sharex: bool = True,
sharey: bool = True,
) -> pd.DataFrame:
"""
Function that selects the most suitable distribution among the candidate_distributions for the target variable,
Expand All @@ -523,14 +524,16 @@ def dist_select(self,
List of candidate distributions.
max_iter: int
Maximum number of iterations for the optimization.
n_samples: int
Number of samples drawn from the fitted distribution.
plot: bool
If True, a density plot of the actual and fitted distribution is created.
ncol: int
Number of columns for the facetting of the density plots.
figure_size: tuple
Figure size of the density plot.
height: Float
Height (in inches) of each facet.
sharex: bool
Whether to share the x-axis across the facets.
sharey: bool
Whether to share the y-axis across the facets.
Returns
-------
Expand Down Expand Up @@ -572,6 +575,7 @@ def dist_select(self,
pbar.set_description(f"Fitting of candidate distributions completed")

if plot:
warnings.simplefilter(action='ignore', category=UserWarning)
# Select distribution
best_dist = fit_df[fit_df["rank"] == 1].reset_index(drop=True)
for dist in candidate_distributions:
Expand All @@ -597,6 +601,8 @@ def dist_select(self,
else:
dist_kwargs = dict(zip(best_dist_sel.distribution_arg_names, dist_params))
dist_fit = best_dist_sel.distribution(**dist_kwargs)
n_samples = np.max([1000, target.shape[0]])
n_samples = np.where(n_samples > 10000, 1000, n_samples)
df_samples = best_dist_sel.draw_samples(dist_fit, n_samples=n_samples, seed=123)

# Plot actual and fitted distribution
Expand All @@ -610,21 +616,35 @@ def dist_select(self,

plot_df = pd.concat([df_actual, df_samples])

print(
ggplot(plot_df,
aes(x="value",
color="type")) +
geom_density(alpha=0.5) +
facet_wrap("target",
scales="free",
ncol=ncol) +
theme_bw(base_size=15) +
theme(figure_size=figure_size,
legend_position="right",
legend_title=element_blank(),
plot_title=element_text(hjust=0.5)) +
labs(title=f"Actual vs. Fitted Density")
)
g = sns.FacetGrid(plot_df,
col="target",
hue="type",
col_wrap=ncol,
height=height,
sharex=sharex,
sharey=sharey,
)
g.map(sns.kdeplot, "value", lw=2.5)
handles, labels = g.axes[0].get_legend_handles_labels()
g.fig.legend(handles, labels, loc='upper center', ncol=len(labels), title="", bbox_to_anchor=(0.5, 0.92))
g.fig.suptitle("Actual vs. Best-Fit Density", weight="bold", fontsize=16)
g.fig.tight_layout(rect=[0, 0, 1, 0.9])

# print(
# ggplot(plot_df,
# aes(x="value",
# color="type")) +
# geom_density(alpha=0.5) +
# facet_wrap("target",
# scales="free",
# ncol=ncol) +
# theme_bw(base_size=15) +
# theme(figure_size=figure_size,
# legend_position="right",
# legend_title=element_blank(),
# plot_title=element_text(hjust=0.5)) +
# labs(title=f"Actual vs. Fitted Density")
# )

fit_df.drop(columns=["rank", "params"], inplace=True)

Expand Down

0 comments on commit 5a5e45f

Please sign in to comment.