Skip to content

Commit

Permalink
Added a predict function to get predictions
Browse files Browse the repository at this point in the history
  • Loading branch information
Charles Pilgrim committed Dec 18, 2023
1 parent be3d610 commit c645ef1
Showing 1 changed file with 26 additions and 10 deletions.
36 changes: 26 additions & 10 deletions piecewise_regression/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,14 +867,19 @@ def plot_data(self, **kwargs):
"""
plt.scatter(self.xx, self.yy, **kwargs)

def plot_fit(self, **kwargs):
def predict(self, xx_predict):
"""
Plot the fitted model as a series of straight lines.
Passes any kwargs to the matplotlib plot function, e.g. color="red".
Predict y values from x values given the fitted model.
Returns predictions as a list of numbers.
:param xx: Data series in x-axis.
:type xx: list of floats
"""

xx_predict = validate_list_of_numbers(var=xx_predict, var_name="x values to predict", min_length=1)

if not self.best_muggeo:
print("Algorithm didn't converge. No fit to plot.")
print("Algorithm didn't converge. No model to use for prediction.")
else:
# Get the final results from the fitted model variables
# Params are in terms of [intercept, alpha, betas, gammas]
Expand All @@ -886,15 +891,26 @@ def plot_fit(self, **kwargs):
alpha_hat = final_params[1]
beta_hats = final_params[2:2 + len(breakpoints)]

xx_plot = np.linspace(min(self.xx), max(self.xx), 100)

# Build the fit plot segment by segment. Betas are defined as
# Predict y values segment by segment. Betas are defined as
# difference in gradient from previous section
yy_plot = intercept_hat + alpha_hat * xx_plot
yy_predict = intercept_hat + alpha_hat * xx_predict
for bp_count in range(len(breakpoints)):
yy_plot += beta_hats[bp_count] * \
np.maximum(xx_plot - breakpoints[bp_count], 0)
yy_predict += beta_hats[bp_count] * \
np.maximum(xx_predict - breakpoints[bp_count], 0)
return yy_predict

def plot_fit(self, **kwargs):
"""
Plot the fitted model as a series of straight lines.
Passes any kwargs to the matplotlib plot function, e.g. color="red".
"""
if not self.best_muggeo:
print("Algorithm didn't converge. No fit to plot.")
else:
# Plot model
xx_plot = np.linspace(min(self.xx), max(self.xx), 100)
yy_plot = self.predict(xx_plot)
plt.plot(xx_plot, yy_plot, **kwargs)

def plot_breakpoints(self, **kwargs):
Expand Down

0 comments on commit c645ef1

Please sign in to comment.