In the contemporary era of technology, deploying machine learning models as web services has become a pivotal component in production pipelines. In this tutorial, we will delve into a concise code snippet that marries FastAPI with a machine-learning model to serve predictions. So far, we have dumped the vectorizer and model that we will utilize now to do the predictions.
First, let's put the requirements in the requirements.txt file:
fastapi==0.101.1
uvicorn==0.23.2
pandas==2.1.0
scikit-learn==1.3.0
Let's create a pydantic schema that will ensure we get the input in a specified JSON format. Let's create the file named schema.py
from pydantic import BaseModel, Field
class Input(BaseModel):
input: str = Field(min_length=1)
main.py file
import joblib
import pandas as pd
from fastapi import FastAPI
from schema import Input
app = FastAPI()
model = joblib.load("ml/model.pkl")
vectorizer = joblib.load("ml/vectorizer.pkl")
@app.post("/api/v1/predict")
def get_prediction(body: Input):
vectorized_text = vectorizer.transform(pd.Series(body.input))
probability = round(model.predict_proba(vectorized_text)[0][1],2)
is_spam = True if probability > 0.3 else False
return {"isSpam":is_spam, "spamProbability": probability}
joblib
is a library popular for serializing and deserializing Python objects, especially efficient for saving and loading models from scikit-learn
. First, a new FastAPI app instance is created. Next, the pre-trained model is loaded from ml/model.pkl
. Similarly, the vectorizer is loaded, which is an essential step for text data. This Count Vectorizer ensures that textual information is converted into a numerical format, digestible for our model.
Here, an API endpoint (/api/v1/predict
) is set up to listen for POST requests. When data arrives at this endpoint, the get_prediction
function is invoked. This function takes in the incoming data, expecting it to match the Input
schema. The data, which is expected to be text (body.input
), undergoes transformation via the vectorizer
. Post transformation, the data is fed to the model to retrieve a probability score indicating its likelihood of being spam. A threshold, of 0.3 in this instance, assists in deciding if the input text qualifies as spam. The function returns a dictionary containing the binary outcome (isSpam
) and the associated probability (spamProbability
).
Now, we can start the server by executing:
docker compose up --build
And we can try a prediction at the 127.0.0.1:8000/docs url
Brige the gap between Tutorial hell and Industry. We want to bring in the culture of Clean Code, Test Driven Development.
We know, we might make it hard for you but definitely worth the efforts.
© Copyright 2022-23 Team FastAPITutorial