W3Cschool
恭喜您成為首批注冊用戶
獲得88經(jīng)驗值獎勵
在機器學(xué)習(xí)項目的實際應(yīng)用中,將訓(xùn)練好的模型部署為服務(wù),使其能夠接收外部請求并返回預(yù)測結(jié)果,是實現(xiàn)模型價值的關(guān)鍵一步。Flask 作為 Python 的輕量級 Web 框架,憑借其簡潔易用的特性,成為部署 PyTorch 模型的理想選擇之一。本文將詳細(xì)指導(dǎo)您如何使用 Flask 將 PyTorch 模型部署為 REST API 服務(wù),以預(yù)訓(xùn)練的 DenseNet 121 模型為例,實現(xiàn)圖像分類功能。
在開始部署之前,確保已安裝所需的依賴庫。運行以下命令以安裝 Flask 和 torchvision:
pip install Flask torchvision
首先,我們創(chuàng)建一個基本的 Flask Web 服務(wù)器,后續(xù)將在此基礎(chǔ)上添加模型推理功能。
from flask import Flask
app = Flask(__name__)
@app.route('/')
def hello():
return 'Hello World!'
if __name__ == '__main__':
app.run()
保存上述代碼為 app.py
,運行 Flask 開發(fā)服務(wù)器:
FLASK_ENV=development FLASK_APP=app.py flask run
訪問 http://localhost:5000/
,您將看到 "Hello World!" 文字,這表明服務(wù)器已成功啟動。
我們將定義一個 /predict
端點,用于接收包含圖像文件的 HTTP POST 請求,并返回預(yù)測結(jié)果。
DenseNet 121 模型要求輸入圖像為 224 x 224 的 3 通道 RGB 圖像,且需進(jìn)行歸一化處理。我們使用 torchvision.transforms
構(gòu)建圖像預(yù)處理管道:
import io
import torchvision.transforms as transforms
from PIL import Image
def transform_image(image_bytes):
my_transforms = transforms.Compose([
transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image = Image.open(io.BytesIO(image_bytes))
return my_transforms(image).unsqueeze(0)
加載預(yù)訓(xùn)練的 DenseNet 121 模型,并設(shè)置為評估模式:
from torchvision import models
model = models.densenet121(pretrained=True)
model.eval()
編寫函數(shù)以獲取圖像的預(yù)測類別:
import json
imagenet_class_index = json.load(open('imagenet_class_index.json')) # 請?zhí)鎿Q為實際文件路徑
def get_prediction(image_bytes):
tensor = transform_image(image_bytes)
outputs = model.forward(tensor)
_, y_hat = outputs.max(1)
predicted_idx = str(y_hat.item())
return imagenet_class_index[predicted_idx]
將模型推理功能整合到 Flask 服務(wù)器中,完成 API 的定義:
from flask import jsonify, request
@app.route('/predict', methods=['POST'])
def predict():
if request.method == 'POST':
file = request.files['file']
img_bytes = file.read()
class_id, class_name = get_prediction(image_bytes=img_bytes)
return jsonify({'class_id': class_id, 'class_name': class_name})
if __name__ == '__main__':
app.run()
完整的 app.py
文件如下:
import io
import json
from torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request
app = Flask(__name__)
imagenet_class_index = json.load(open('imagenet_class_index.json')) # 請?zhí)鎿Q為實際文件路徑
model = models.densenet121(pretrained=True)
model.eval()
def transform_image(image_bytes):
my_transforms = transforms.Compose([
transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image = Image.open(io.BytesIO(image_bytes))
return my_transforms(image).unsqueeze(0)
def get_prediction(image_bytes):
tensor = transform_image(image_bytes)
outputs = model.forward(tensor)
_, y_hat = outputs.max(1)
predicted_idx = str(y_hat.item())
return imagenet_class_index[predicted_idx]
@app.route('/predict', methods=['POST'])
def predict():
if request.method == 'POST':
file = request.files['file']
img_bytes = file.read()
class_id, class_name = get_prediction(image_bytes=img_bytes)
return jsonify({'class_id': class_id, 'class_name': class_name})
if __name__ == '__main__':
app.run()
運行 Flask 服務(wù)器:
FLASK_ENV=development FLASK_APP=app.py flask run
使用 requests
庫發(fā)送 POST 請求進(jìn)行測試:
import requests
resp = requests.post("http://localhost:5000/predict", files={"file": open('cat.jpg', 'rb')}) # 請?zhí)鎿Q為實際圖像文件路徑
print(resp.json())
成功返回結(jié)果示例:
{"class_id": "n02124075", "class_name": "Egyptian_cat"}
通過本文,您已成功使用 Flask 部署了一個 PyTorch 模型,并通過 REST API 提供圖像分類服務(wù)。然而,當(dāng)前的實現(xiàn)較為基礎(chǔ),對于生產(chǎn)環(huán)境,您可以考慮以下優(yōu)化措施:
模型部署是連接模型開發(fā)與實際應(yīng)用的橋梁,掌握這一技能,能夠使您的模型真正發(fā)揮價值,解決實際問題。編程獅將持續(xù)為您帶來更多模型部署與應(yīng)用開發(fā)的實用教程。
Copyright©2021 w3cschool編程獅|閩ICP備15016281號-3|閩公網(wǎng)安備35020302033924號
違法和不良信息舉報電話:173-0602-2364|舉報郵箱:jubao@eeedong.com
掃描二維碼
下載編程獅App
編程獅公眾號
聯(lián)系方式:
更多建議: