From c645ef1665d57eef9950738a7d3a138df9d431fd Mon Sep 17 00:00:00 2001 From: Charles Pilgrim Date: Mon, 18 Dec 2023 12:20:38 +0000 Subject: [PATCH] Added a predict function to get predictions --- piecewise_regression/main.py | 36 ++++++++++++++++++++++++++---------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/piecewise_regression/main.py b/piecewise_regression/main.py index 36083f4..384eafa 100644 --- a/piecewise_regression/main.py +++ b/piecewise_regression/main.py @@ -867,14 +867,19 @@ def plot_data(self, **kwargs): """ plt.scatter(self.xx, self.yy, **kwargs) - def plot_fit(self, **kwargs): + def predict(self, xx_predict): """ - Plot the fitted model as a series of straight lines. - Passes any kwargs to the matplotlib plot function, e.g. color="red". + Predict y values from x values given the fitted model. + Returns predictions as a list of numbers. + :param xx: Data series in x-axis. + :type xx: list of floats """ + + xx_predict = validate_list_of_numbers(var=xx_predict, var_name="x values to predict", min_length=1) + if not self.best_muggeo: - print("Algorithm didn't converge. No fit to plot.") + print("Algorithm didn't converge. No model to use for prediction.") else: # Get the final results from the fitted model variables # Params are in terms of [intercept, alpha, betas, gammas] @@ -886,15 +891,26 @@ def plot_fit(self, **kwargs): alpha_hat = final_params[1] beta_hats = final_params[2:2 + len(breakpoints)] - xx_plot = np.linspace(min(self.xx), max(self.xx), 100) - - # Build the fit plot segment by segment. Betas are defined as + # Predict y values segment by segment. Betas are defined as # difference in gradient from previous section - yy_plot = intercept_hat + alpha_hat * xx_plot + yy_predict = intercept_hat + alpha_hat * xx_predict for bp_count in range(len(breakpoints)): - yy_plot += beta_hats[bp_count] * \ - np.maximum(xx_plot - breakpoints[bp_count], 0) + yy_predict += beta_hats[bp_count] * \ + np.maximum(xx_predict - breakpoints[bp_count], 0) + return yy_predict + + def plot_fit(self, **kwargs): + """ + Plot the fitted model as a series of straight lines. + Passes any kwargs to the matplotlib plot function, e.g. color="red". + """ + if not self.best_muggeo: + print("Algorithm didn't converge. No fit to plot.") + else: + # Plot model + xx_plot = np.linspace(min(self.xx), max(self.xx), 100) + yy_plot = self.predict(xx_plot) plt.plot(xx_plot, yy_plot, **kwargs) def plot_breakpoints(self, **kwargs):