W3Cschool
恭喜您成為首批注冊用戶
獲得88經(jīng)驗(yàn)值獎(jiǎng)勵(lì)
在深度學(xué)習(xí)領(lǐng)域,對象檢測是一項(xiàng)關(guān)鍵技術(shù),它不僅可以識別圖像中的物體類別,還能精確定位它們的位置。PyTorch 作為一款功能強(qiáng)大的開源機(jī)器學(xué)習(xí)框架,在對象檢測任務(wù)中表現(xiàn)卓越。本教程將教你如何利用 PyTorch TorchVision 微調(diào)預(yù)訓(xùn)練模型進(jìn)行對象檢測。
在 PyTorch 中,定義數(shù)據(jù)集是進(jìn)行模型訓(xùn)練的第一步。我們需要?jiǎng)?chuàng)建一個(gè)自定義數(shù)據(jù)集類,繼承自 torch.utils.data.Dataset
。這個(gè)類要實(shí)現(xiàn) __len__
和 __getitem__
方法。
import os
import numpy as np
import torch
from PIL import Image
class PennFudanDataset(torch.utils.data.Dataset):
def __init__(self, root, transforms):
self.root = root
self.transforms = transforms
self.imgs = list(sorted(os.listdir(os.path.join(root, "PNGImages"))))
self.masks = list(sorted(os.listdir(os.path.join(root, "PedMasks"))))
def __getitem__(self, idx):
img_path = os.path.join(self.root, "PNGImages", self.imgs[idx])
mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])
img = Image.open(img_path).convert("RGB")
mask = Image.open(mask_path)
mask = np.array(mask)
obj_ids = np.unique(mask)
obj_ids = obj_ids[1:]
masks = mask == obj_ids[:, None, None]
num_objs = len(obj_ids)
boxes = []
for i in range(num_objs):
pos = np.where(masks[i])
xmin = np.min(pos[1])
xmax = np.max(pos[1])
ymin = np.min(pos[0])
ymax = np.max(pos[0])
boxes.append([xmin, ymin, xmax, ymax])
boxes = torch.as_tensor(boxes, dtype=torch.float32)
labels = torch.ones((num_objs,), dtype=torch.int64)
masks = torch.as_tensor(masks, dtype=torch.uint8)
image_id = torch.tensor([idx])
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
iscrowd = torch.zeros((num_objs,), dtype=torch.int64)
target = {}
target["boxes"] = boxes
target["labels"] = labels
target["masks"] = masks
target["image_id"] = image_id
target["area"] = area
target["iscrowd"] = iscrowd
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target
def __len__(self):
return len(self.imgs)
上面的代碼定義了一個(gè)賓夕法尼亞復(fù)旦數(shù)據(jù)集(PennFudan Dataset)的數(shù)據(jù)集類。__getitem__
方法根據(jù)索引返回圖像和目標(biāo)信息,目標(biāo)包括邊界框、標(biāo)簽、掩碼等。
接下來,我們需要定義用于對象檢測的模型。這里我們使用 Mask R-CNN,它是一種在 Faster R-CNN 基礎(chǔ)上擴(kuò)展而來的實(shí)例分割模型,能夠同時(shí)進(jìn)行對象檢測和分割。
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
def get_model_instance_segmentation(num_classes):
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256
model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
return model
上面的代碼首先加載了在 COCO 數(shù)據(jù)集上預(yù)訓(xùn)練的 Mask R-CNN 模型,然后替換了模型的分類器和掩碼預(yù)測器,使其適應(yīng)我們的自定義數(shù)據(jù)集。
現(xiàn)在我們已經(jīng)定義了數(shù)據(jù)集和模型,接下來需要將它們整合起來進(jìn)行訓(xùn)練。
from engine import train_one_epoch, evaluate
import utils
import transforms as T
def get_transform(train):
transforms = []
transforms.append(T.ToTensor())
if train:
transforms.append(T.RandomHorizontalFlip(0.5))
return T.Compose(transforms)
def main():
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
num_classes = 2
dataset = PennFudanDataset('PennFudanPed', get_transform(train=True))
dataset_test = PennFudanDataset('PennFudanPed', get_transform(train=False))
indices = torch.randperm(len(dataset)).tolist()
dataset = torch.utils.data.Subset(dataset, indices[:-50])
dataset_test = torch.utils.data.Subset(dataset_test, indices[-50:])
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=2, shuffle=True, num_workers=4,
collate_fn=utils.collate_fn)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=1, shuffle=False, num_workers=4,
collate_fn=utils.collate_fn)
model = get_model_instance_segmentation(num_classes)
model.to(device)
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
num_epochs = 10
for epoch in range(num_epochs):
train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
lr_scheduler.step()
evaluate(model, data_loader_test, device=device)
print("That's it!")
if __name__ == "__main__":
main()
上面的代碼首先定義了數(shù)據(jù)轉(zhuǎn)換函數(shù) get_transform
,用于將圖像轉(zhuǎn)換為張量,并在訓(xùn)練時(shí)進(jìn)行隨機(jī)水平翻轉(zhuǎn)數(shù)據(jù)增強(qiáng)。main
函數(shù)中,我們設(shè)置了訓(xùn)練設(shè)備(GPU 或 CPU),創(chuàng)建了數(shù)據(jù)集和數(shù)據(jù)加載器,定義了模型、優(yōu)化器和學(xué)習(xí)率調(diào)度器,然后進(jìn)行了模型的訓(xùn)練和評估。
恭喜你!通過以上步驟,你已經(jīng)成功地在 PyTorch 中利用 TorchVision 微調(diào)了一個(gè)預(yù)訓(xùn)練模型進(jìn)行對象檢測。在編程獅(W3Cschool)上,你可以找到更多關(guān)于 PyTorch 的詳細(xì)教程和實(shí)戰(zhàn)案例,幫助你進(jìn)一步提升深度學(xué)習(xí)技能,成為人工智能領(lǐng)域的編程大神。
Copyright©2021 w3cschool編程獅|閩ICP備15016281號-3|閩公網(wǎng)安備35020302033924號
違法和不良信息舉報(bào)電話:173-0602-2364|舉報(bào)郵箱:jubao@eeedong.com
掃描二維碼
下載編程獅App
編程獅公眾號
聯(lián)系方式:
更多建議: