Skip to content

Commit

Permalink
add fastapi for ml
Browse files Browse the repository at this point in the history
  • Loading branch information
khuyentran1401 committed Aug 6, 2024
1 parent 7b9948d commit 2c19720
Show file tree
Hide file tree
Showing 6 changed files with 630 additions and 2 deletions.
3 changes: 2 additions & 1 deletion Chapter5/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
.env
delta_lake*
employees
mlruns
mlruns
*.joblib
226 changes: 226 additions & 0 deletions Chapter5/machine_learning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2826,6 +2826,232 @@
"source": [
"[Learn more about MLFlow Models](https://bit.ly/46y6gpF)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8acadfca",
"metadata": {
"tags": [
"hide-cell"
]
},
"outputs": [],
"source": [
"!pip3 install joblib \"fastapi[standard]\""
]
},
{
"cell_type": "markdown",
"id": "e815f01b",
"metadata": {},
"source": [
"Imagine this scenario: You have just built a machine learning (ML) model with great performance, and you want to share this model with your team members so that they can develop a web application on top of your model.\n",
"\n",
"One way to share the model with your team members is to save the model to a file (e.g., using pickle, joblib, or framework-specific methods) and share the file directly\n",
"\n",
"\n",
"```python\n",
"import joblib\n",
"\n",
"model = ...\n",
"\n",
"# Save model\n",
"joblib.dump(model, \"model.joblib\")\n",
"\n",
"# Load model\n",
"model = joblib.load(model)\n",
"```\n",
"\n",
"However, this approach requires the same environment and dependencies, and it can pose potential security risks.\n"
]
},
{
"cell_type": "markdown",
"id": "b364f9fc",
"metadata": {},
"source": [
"An alternative is creating an API for your ML model. APIs define how software components interact, allowing:\n",
"\n",
"1. Access from various programming languages and platforms\n",
"2. Easier integration for developers unfamiliar with ML or Python\n",
"3. Versatile use across different applications (web, mobile, etc.)\n",
"\n",
"This approach simplifies model sharing and usage, making it more accessible for diverse development needs.\n",
"\n",
"Let's learn how to create an ML API with FastAPI, a modern and fast web framework for building APIs with Python. \n",
"\n",
"Before we begin constructing an API for a machine learning model, let's first develop a basic model that our API will use. In this example, we'll create a model that predicts the median house price in California."
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "d7ea435d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean squared error: 0.56\n"
]
},
{
"data": {
"text/plain": [
"['lr.joblib']"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.datasets import fetch_california_housing\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.linear_model import LinearRegression\n",
"from sklearn.metrics import mean_squared_error\n",
"import joblib\n",
"\n",
"# Load dataset\n",
"X, y = fetch_california_housing(as_frame=True, return_X_y=True)\n",
"\n",
"# Split dataset into training and test sets\n",
"X_train, X_test, y_train, y_test = train_test_split(\n",
" X, y, test_size=0.2, random_state=42\n",
")\n",
"\n",
"# Initialize and train the logistic regression model\n",
"model = LinearRegression()\n",
"model.fit(X_train, y_train)\n",
"\n",
"# Predict and evaluate the model\n",
"y_pred = model.predict(X_test)\n",
"mse = mean_squared_error(y_test, y_pred)\n",
"print(f\"Mean squared error: {mse:.2f}\")\n",
"\n",
"# Save model\n",
"joblib.dump(model, \"lr.joblib\")"
]
},
{
"cell_type": "markdown",
"id": "a5aaad8e",
"metadata": {},
"source": [
"Once we have our model, we can create an API for it using FastAPI. We'll define a POST endpoint for making predictions and use the model to make predictions.\n",
"\n",
"Here's an example of how to create an API for a machine learning model using FastAPI:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "581f789f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Overwriting ml_app.py\n"
]
}
],
"source": [
"%%writefile ml_app.py\n",
"from fastapi import FastAPI\n",
"import joblib\n",
"import pandas as pd \n",
"\n",
"# Create a FastAPI application instance\n",
"app = FastAPI()\n",
"\n",
"# Load the pre-trained machine learning model\n",
"model = joblib.load(\"lr.joblib\")\n",
"\n",
"# Define a POST endpoint for making predictions\n",
"@app.post(\"/predict/\")\n",
"def predict(data: list[float]):\n",
" # Define the column names for the input features\n",
" columns = [\n",
" \"MedInc\",\n",
" \"HouseAge\",\n",
" \"AveRooms\",\n",
" \"AveBedrms\",\n",
" \"Population\",\n",
" \"AveOccup\",\n",
" \"Latitude\",\n",
" \"Longitude\",\n",
" ]\n",
" \n",
" # Create a pandas DataFrame from the input data\n",
" features = pd.DataFrame([data], columns=columns)\n",
" \n",
" # Use the model to make a prediction\n",
" prediction = model.predict(features)[0]\n",
" \n",
" # Return the prediction as a JSON object, rounding to 2 decimal places\n",
" return {\"price\": round(prediction, 2)}"
]
},
{
"cell_type": "markdown",
"id": "34aba2f2",
"metadata": {},
"source": [
"To run your FastAPI app for development, use the `fastapi dev` command:\n",
"```bash\n",
"$ fastapi dev ml_app.py\n",
"``` "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "375b4bce",
"metadata": {
"tags": [
"remove-cell"
]
},
"outputs": [],
"source": [
"!fastapi dev ml_app.py"
]
},
{
"cell_type": "markdown",
"id": "3fff7352",
"metadata": {},
"source": [
"This will start the development server and open the API documentation in your default browser.\n",
"\n",
"You can now use the API to make predictions by sending a POST request to the `/predict/` endpoint with the input data. For example:"
]
},
{
"cell_type": "markdown",
"id": "8fb49c7c",
"metadata": {},
"source": [
"Running this cURL command on your terminal:\n",
"```bash\n",
"curl -X 'POST' \\\n",
" 'http://127.0.0.1:8000/predict/' \\\n",
" -H 'accept: application/json' \\\n",
" -H 'Content-Type: application/json' \\\n",
" -d '[\n",
" 1.68, 25, 4, 2, 1400, 3, 36.06, -119.01\n",
"]'\n",
"```\n",
"This will return the predicted price as a JSON object, rounded to 2 decimal places:\n",
"```python\n",
"{\"price\":1.51}\n",
"```"
]
}
],
"metadata": {
Expand Down
33 changes: 33 additions & 0 deletions Chapter5/ml_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from fastapi import FastAPI
import joblib
import pandas as pd

# Create a FastAPI application instance
app = FastAPI()

# Load the pre-trained machine learning model
model = joblib.load("lr.joblib")

# Define a POST endpoint for making predictions
@app.post("/predict/")
def predict(data: list[float]):
# Define the column names for the input features
columns = [
"MedInc",
"HouseAge",
"AveRooms",
"AveBedrms",
"Population",
"AveOccup",
"Latitude",
"Longitude",
]

# Create a pandas DataFrame from the input data
features = pd.DataFrame([data], columns=columns)

# Use the model to make a prediction
prediction = model.predict(features)[0]

# Return the prediction as a JSON object, rounding to 2 decimal places
return {"price": round(prediction, 2)}
Loading

0 comments on commit 2c19720

Please sign in to comment.