W3Cschool
恭喜您成為首批注冊用戶
獲得88經(jīng)驗(yàn)值獎(jiǎng)勵(lì)
在數(shù)據(jù)科學(xué)和機(jī)器學(xué)習(xí)領(lǐng)域,張量操作是構(gòu)建復(fù)雜模型的核心。PyTorch 作為一種廣泛使用的深度學(xué)習(xí)框架,不斷引入新特性以提升開發(fā)效率和代碼可讀性。命名張量(Named Tensor)作為 PyTorch 的一項(xiàng)實(shí)驗(yàn)性功能,旨在通過為張量維度賦予 meaningful 的名稱,簡化張量操作,減少因維度順序錯(cuò)誤導(dǎo)致的 bug,并提升代碼的可維護(hù)性。本文將深入探討 PyTorch 中命名張量的使用方法、優(yōu)勢以及在實(shí)際項(xiàng)目中的應(yīng)用,幫助讀者快速掌握這一強(qiáng)大工具。
在 PyTorch 中,可以通過在創(chuàng)建張量時(shí)指定 names
參數(shù)來賦予張量維度名稱。以下示例展示了如何創(chuàng)建具有命名維度的張量:
import torch
## 創(chuàng)建具有命名維度的張量
imgs = torch.randn(1, 2, 2, 3, names=('N', 'C', 'H', 'W'))
print(imgs.names)
命名張量的維度名稱并非一成不變,可以通過以下方式對維度進(jìn)行重命名或刪除名稱:
## 方法一:直接設(shè)置 .names 屬性(原地操作)
imgs.names = ['batch', 'channel', 'width', 'height']
print(imgs.names)
## 方法二:指定新名稱(創(chuàng)建新張量)
imgs = imgs.rename(channel='C', width='W', height='H')
print(imgs.names)
## 刪除名稱
imgs = imgs.rename(None)
print(imgs.names)
命名張量與未命名張量可以共存。若只想為部分維度指定名稱,其余維度保持未命名狀態(tài),可以通過以下方式實(shí)現(xiàn):
## 創(chuàng)建部分維度命名的張量
imgs = torch.randn(3, 1, 1, 2, names=('N', None, None, None))
print(imgs.names)
大多數(shù)張量操作(如 .abs()
)會(huì)保留維度名稱,使得操作結(jié)果的可讀性得以保持:
## 基本操作后名稱傳播
print(imgs.abs().names)
可以通過維度名稱進(jìn)行索引和規(guī)約操作,使代碼更具語義化:
## 按名稱進(jìn)行求和操作
output = imgs.sum('C')
print(output.names)
## 按名稱選擇特定維度數(shù)據(jù)
img0 = imgs.select('N', 0)
print(img0.names)
在張量操作過程中,PyTorch 會(huì)根據(jù)名稱推斷規(guī)則對輸出張量的維度名稱進(jìn)行推斷。這包括檢查輸入張量的名稱是否匹配,并傳播合適的名稱至輸出張量。
命名張量在廣播操作中會(huì)檢查維度名稱是否匹配,避免因維度對齊錯(cuò)誤導(dǎo)致的意外結(jié)果:
## 廣播操作中的名稱檢查
imgs = torch.randn(2, 2, 2, 2, names=('N', 'C', 'H', 'W'))
per_batch_scale = torch.rand(2, names=('N',))
## 嘗試進(jìn)行廣播操作
try:
imgs * per_batch_scale
except RuntimeError as e:
print("錯(cuò)誤信息:", e)
在矩陣乘法操作中,PyTorch 會(huì)根據(jù)輸入張量的維度名稱推斷輸出張量的維度名稱:
## 矩陣乘法中的名稱傳播
markov_states = torch.randn(128, 5, names=('batch', 'D'))
transition_matrix = torch.randn(5, 5, names=('in', 'out'))
new_state = markov_states @ transition_matrix
print(new_state.names)
命名張量支持通過 align_as
或 align_to
方法進(jìn)行顯式廣播,使張量對齊操作更加直觀:
## 按名稱顯式廣播
imgs = imgs.refine_names('N', 'C', 'H', 'W')
per_batch_scale = per_batch_scale.refine_names('N')
named_result = imgs * per_batch_scale.align_as(imgs)
命名張量提供了 flatten
和 unflatten
方法,支持按名稱對維度進(jìn)行展平和展開操作:
## 按名稱展平維度
imgs = imgs.flatten(['C', 'H', 'W'], 'features')
print(imgs.names)
## 按名稱展開維度
imgs = imgs.unflatten('features', (('C', 2), ('H', 2), ('W', 2)))
print(imgs.names)
為了展示命名張量在實(shí)際項(xiàng)目中的優(yōu)勢,以下是一個(gè)使用命名張量實(shí)現(xiàn)多頭注意力模塊的示例:
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
def __init__(self, n_heads, dim, dropout=0):
super(MultiHeadAttention, self).__init__()
self.n_heads = n_heads
self.dim = dim
self.attn_dropout = nn.Dropout(p=dropout)
self.q_lin = nn.Linear(dim, dim)
self.k_lin = nn.Linear(dim, dim)
self.v_lin = nn.Linear(dim, dim)
self.out_lin = nn.Linear(dim, dim)
def forward(self, query, key=None, value=None, mask=None):
query = query.refine_names(..., 'T', 'D')
self_attn = key is None and value is None
if self_attn:
mask = mask.refine_names(..., 'T')
else:
mask = mask.refine_names(..., 'T', 'T_key')
dim = query.size('D')
n_heads = self.n_heads
dim_per_head = dim // n_heads
scale = math.sqrt(dim_per_head)
def prepare_head(tensor):
tensor = tensor.refine_names(..., 'T', 'D')
return (tensor.unflatten('D', [('H', n_heads), ('D_head', dim_per_head)])
.align_to(..., 'H', 'T', 'D_head'))
if self_attn:
key = value = query
elif value is None:
key = key.refine_names(..., 'T', 'D')
value = key
k = prepare_head(self.k_lin(key)).rename(T='T_key')
v = prepare_head(self.v_lin(value)).rename(T='T_key')
q = prepare_head(self.q_lin(query))
dot_prod = q.div_(scale).matmul(k.align_to(..., 'D_head', 'T_key'))
dot_prod.refine_names(..., 'H', 'T', 'T_key')
attn_mask = (mask == 0).align_as(dot_prod)
dot_prod.masked_fill_(attn_mask, -float(1e20))
attn_weights = self.attn_dropout(F.softmax(dot_prod, dim='T_key'))
attentioned = (attn_weights.matmul(v).refine_names(..., 'H', 'T', 'D_head')
.align_to(..., 'T', 'H', 'D_head')
.flatten(['H', 'D_head'], 'D'))
return self.out_lin(attentioned).refine_names(..., 'T', 'D')
命名張量作為 PyTorch 的一項(xiàng)創(chuàng)新特性,通過為張量維度賦予名稱,極大地提升了代碼的可讀性和可維護(hù)性,減少了因維度順序錯(cuò)誤導(dǎo)致的 bug。本文詳細(xì)介紹了命名張量的創(chuàng)建、操作、名稱傳播機(jī)制以及在多頭注意力模塊中的應(yīng)用。盡管命名張量目前仍處于實(shí)驗(yàn)階段,但其在提升開發(fā)效率和代碼質(zhì)量方面的潛力不容小覷。隨著 PyTorch 的不斷發(fā)展,命名張量有望成為深度學(xué)習(xí)開發(fā)中的標(biāo)準(zhǔn)工具之一。編程獅將持續(xù)關(guān)注 PyTorch 的最新動(dòng)態(tài),并為讀者帶來更多實(shí)用的深度學(xué)習(xí)技術(shù)教程。
Copyright©2021 w3cschool編程獅|閩ICP備15016281號-3|閩公網(wǎng)安備35020302033924號
違法和不良信息舉報(bào)電話:173-0602-2364|舉報(bào)郵箱:jubao@eeedong.com
掃描二維碼
下載編程獅App
編程獅公眾號
聯(lián)系方式:
更多建議: