W3Cschool
恭喜您成為首批注冊用戶
獲得88經(jīng)驗值獎勵
對抗示例是機器學(xué)習(xí)領(lǐng)域中的一個重要研究方向,它揭示了模型在面對惡意攻擊時的脆弱性。本教程教你如何生成對抗示例并攻擊一個圖像分類器。通過學(xué)習(xí) FGSM 攻擊方法,你將深入了解對抗示例的原理和實現(xiàn)方式。
對抗示例是指通過在輸入數(shù)據(jù)中添加精心設(shè)計的擾動,使機器學(xué)習(xí)模型產(chǎn)生錯誤輸出的樣本。這些擾動通常很小,以至于人類無法察覺,但卻能顯著影響模型的性能。對抗示例的存在提醒我們在開發(fā)機器學(xué)習(xí)模型時,不僅要關(guān)注模型的準(zhǔn)確性,還要重視其安全性和魯棒性。
在實際應(yīng)用中,攻擊者可能對模型有不同的了解程度,這引出了白盒攻擊和黑盒攻擊的概念:
此外,根據(jù)攻擊目標(biāo)的不同,對抗示例可以分為錯誤分類和源 / 目標(biāo)錯誤分類兩種類型。
FGSM 是一種簡單而有效的對抗示例生成方法。它的核心思想是利用模型的梯度信息來構(gòu)造對抗擾動。具體來說,F(xiàn)GSM 通過計算損失函數(shù)對輸入數(shù)據(jù)的梯度,然后根據(jù)梯度的方向調(diào)整輸入數(shù)據(jù),使損失最大化,從而生成對抗示例。
FGSM 的公式可以表示為:
[ x_{\text{adv}} = x + \epsilon \cdot \text{sign}(\nabla_x J(\theta, x, y)) ]
其中,(x) 是原始輸入,(\epsilon) 是擾動的幅度,(\text{sign}) 是取符號函數(shù),(\nabla_x J(\theta, x, y)) 是損失函數(shù)對輸入 (x) 的梯度。
我們首先導(dǎo)入實現(xiàn)對抗示例生成所需的庫和模塊。
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
我們使用一個預(yù)訓(xùn)練的 MNIST 分類器作為受攻擊的模型。
## LeNet 模型定義
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
## 加載 MNIST 測試數(shù)據(jù)集
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, download=True, transform=transforms.Compose([
transforms.ToTensor(),
])),
batch_size=1, shuffle=True)
## 檢測設(shè)備并初始化模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net().to(device)
## 加載預(yù)訓(xùn)練模型權(quán)重并設(shè)置為評估模式
model.load_state_dict(torch.load("data/lenet_mnist_model.pth", map_location=device))
model.eval()
def fgsm_attack(image, epsilon, data_grad):
# 獲取數(shù)據(jù)梯度的符號
sign_data_grad = data_grad.sign()
# 生成對抗示例
perturbed_image = image + epsilon * sign_data_grad
# 將對抗示例的像素值限制在 [0, 1] 范圍內(nèi)
perturbed_image = torch.clamp(perturbed_image, 0, 1)
return perturbed_image
def test(model, device, test_loader, epsilon):
correct = 0
adv_examples = []
for data, target in test_loader:
data, target = data.to(device), target.to(device)
data.requires_grad = True
output = model(data)
init_pred = output.max(1, keepdim=True)[1]
if init_pred.item() != target.item():
continue
loss = F.nll_loss(output, target)
model.zero_grad()
loss.backward()
data_grad = data.grad.data
perturbed_data = fgsm_attack(data, epsilon, data_grad)
output = model(perturbed_data)
final_pred = output.max(1, keepdim=True)[1]
if final_pred.item() == target.item():
correct += 1
if epsilon == 0 and len(adv_examples) < 5:
adv_ex = perturbed_data.squeeze().detach().cpu().numpy()
adv_examples.append((init_pred.item(), final_pred.item(), adv_ex))
else:
if len(adv_examples) < 5:
adv_ex = perturbed_data.squeeze().detach().cpu().numpy()
adv_examples.append((init_pred.item(), final_pred.item(), adv_ex))
final_acc = correct / float(len(test_loader))
print("Epsilon: {}\tTest Accuracy = {} / {} = {}".format(epsilon, correct, len(test_loader), final_acc))
return final_acc, adv_examples
epsilons = [0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3]
accuracies = []
examples = []
for eps in epsilons:
acc, ex = test(model, device, test_loader, eps)
accuracies.append(acc)
examples.append(ex)
## 繪制精度與 epsilon 的關(guān)系圖
plt.figure(figsize=(5, 5))
plt.plot(epsilons, accuracies, "*-")
plt.yticks(np.arange(0, 1.1, step=0.1))
plt.xticks(np.arange(0, 0.35, step=0.05))
plt.title("Accuracy vs Epsilon")
plt.xlabel("Epsilon")
plt.ylabel("Accuracy")
plt.show()
## 可視化對抗示例
cnt = 0
plt.figure(figsize=(8, 10))
for i in range(len(epsilons)):
for j in range(len(examples[i])):
cnt += 1
plt.subplot(len(epsilons), len(examples[0]), cnt)
plt.xticks([], [])
plt.yticks([], [])
if j == 0:
plt.ylabel("Eps: {}".format(epsilons[i]), fontsize=14)
orig, adv, ex = examples[i][j]
plt.title("{} -> {}".format(orig, adv))
plt.imshow(ex, cmap="gray")
plt.tight_layout()
plt.show()
通過運行上述代碼,我們可以得到不同 epsilon 值下模型的測試精度以及一些成功的對抗示例。
從精度與 epsilon 的關(guān)系圖中可以看到,隨著 epsilon 的增加,模型的測試精度逐漸下降。這表明對抗示例的擾動對模型的性能產(chǎn)生了顯著影響。
對抗示例的可視化結(jié)果展示了在不同 epsilon 值下,原始圖像被錯誤分類為其他類別的示例。盡管擾動很小,但模型的預(yù)測結(jié)果發(fā)生了變化,而人類仍然能夠正確識別圖像中的數(shù)字。
Epsilon | 測試精度 |
---|---|
0 | 0.981 |
0.05 | 0.9426 |
0.1 | 0.851 |
0.15 | 0.6826 |
0.2 | 0.4301 |
0.25 | 0.2082 |
0.3 | 0.0869 |
本教程介紹了對抗示例的概念和 FGSM 攻擊方法,并通過實驗展示了如何生成對抗示例并攻擊一個 MNIST 分類器。通過學(xué)習(xí)本教程,你了解了對抗示例的原理和實現(xiàn)方式,以及它們對模型性能的影響。在編程獅(W3Cschool)網(wǎng)站上,你可以找到更多關(guān)于 PyTorch 的詳細(xì)教程和實戰(zhàn)案例,幫助你進一步提升深度學(xué)習(xí)技能,成為人工智能領(lǐng)域的編程大神。
Copyright©2021 w3cschool編程獅|閩ICP備15016281號-3|閩公網(wǎng)安備35020302033924號
違法和不良信息舉報電話:173-0602-2364|舉報郵箱:jubao@eeedong.com
掃描二維碼
下載編程獅App
編程獅公眾號
聯(lián)系方式:
更多建議: