W3Cschool
恭喜您成為首批注冊用戶
獲得88經(jīng)驗(yàn)值獎勵
將 PyTorch 模型從 Python 環(huán)境成功移植到 C++ 環(huán)境,能夠充分發(fā)揮 C++ 在高性能計(jì)算和生產(chǎn)部署中的優(yōu)勢,同時保留 PyTorch 模型的靈活性和強(qiáng)大功能。本文將詳細(xì)指導(dǎo)您完成這一過程,包括模型轉(zhuǎn)換、序列化、C++ 應(yīng)用程序開發(fā)以及模型加載與執(zhí)行等關(guān)鍵步驟。
對于大多數(shù)模型,尤其是控制流較為簡單的模型,使用追蹤方法可以輕松地將其轉(zhuǎn)換為 TorchScript 格式:
import torch
import torchvision
## 定義模型
model = torchvision.models.resnet18()
## 準(zhǔn)備示例輸入
example_input = torch.rand(1, 3, 224, 224)
## 使用追蹤生成 TorchScript 模型
traced_script_module = torch.jit.trace(model, example_input)
若模型包含復(fù)雜的控制流,需要使用注釋方法進(jìn)行轉(zhuǎn)換:
import torch
class MyModule(torch.nn.Module):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
def forward(self, input):
if input.sum() > 0:
output = self.weight.mv(input)
else:
output = self.weight + input
return output
my_module = MyModule(10, 20)
scripted_module = torch.jit.script(my_module)
將生成的 TorchScript 模型保存到文件,以便在 C++ 應(yīng)用程序中加載和使用:
## 序列化追蹤得到的模型
traced_script_module.save("traced_resnet_model.pt")
## 序列化注釋得到的模型
scripted_module.save("scripted_my_module.pt")
編寫一個簡單的 C++ 程序來加載和執(zhí)行 TorchScript 模型:
#include <torch/script.h>
#include <iostream>
#include <memory>
int main(int argc, const char* argv[]) {
if (argc != 2) {
std::cerr << "Usage: " << argv[0] << " <path-to-model>\n";
return -1;
}
torch::jit::script::Module module;
try {
module = torch::jit::load(argv[1]);
} catch (const c10::Error& e) {
std::cerr << "Error loading the model: " << e.what() << std::endl;
return -1;
}
std::cout << "Model loaded successfully!\n";
// 準(zhǔn)備輸入數(shù)據(jù)
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({1, 3, 224, 224}));
// 執(zhí)行模型
at::Tensor output = module.forward(inputs).toTensor();
// 輸出結(jié)果
std::cout << "Model output:\n" << output.slice(1, 0, 5) << std::endl;
return 0;
}
使用 CMake 構(gòu)建上述 C++ 應(yīng)用程序:
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(PyTorchCppExample)
find_package(Torch REQUIRED)
add_executable(example-app example-app.cpp)
target_link_libraries(example-app "${TORCH_LIBRARIES}")
set_property(TARGET example-app PROPERTY CXX_STANDARD 14)
mkdir build
cd build
cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
cmake --build . --config Release
將序列化后的模型文件路徑作為參數(shù)傳遞給構(gòu)建好的可執(zhí)行文件:
./example-app traced_resnet_model.pt
預(yù)期輸出示例:
Model loaded successfully!
Model output:
-0.2698 -0.0381 0.4023 -0.3010 -0.0448
[ Variable[CPUFloatType]{1,5} ]
在加載模型后,可以對模型進(jìn)行進(jìn)一步優(yōu)化以提升性能:
module.to(torch::kCUDA); // 將模型移動到 GPU
inputs[0] = inputs[0].to(torch::kCUDA); // 確保輸入數(shù)據(jù)也在 GPU 上
如果需要在 C++ 中實(shí)現(xiàn)自定義運(yùn)算符,可以參考 PyTorch C++ API 文檔進(jìn)行開發(fā)和集成。
利用 C++ 的多線程庫和 PyTorch 提供的并行計(jì)算功能,進(jìn)一步提升模型推理速度。
對模型進(jìn)行量化處理,減少模型大小并加速推理過程,適合在資源受限的環(huán)境中部署。
通過本文,您已經(jīng)學(xué)習(xí)了如何將 PyTorch 模型轉(zhuǎn)換為 TorchScript 格式,并在 C++ 應(yīng)用程序中加載和執(zhí)行。這一過程使您能夠在高性能、低延遲的生產(chǎn)環(huán)境中充分利用 PyTorch 模型的強(qiáng)大功能。未來,您可以探索更多高級特性,如在 C++ 中實(shí)現(xiàn)自定義運(yùn)算符、優(yōu)化模型性能以及與其他 C++ 框架和庫進(jìn)行集成等。編程獅將持續(xù)為您更新更多深度學(xué)習(xí)跨語言部署的實(shí)用教程,助力您的技術(shù)成長。
Copyright©2021 w3cschool編程獅|閩ICP備15016281號-3|閩公網(wǎng)安備35020302033924號
違法和不良信息舉報(bào)電話:173-0602-2364|舉報(bào)郵箱:jubao@eeedong.com
掃描二維碼
下載編程獅App
編程獅公眾號
聯(lián)系方式:
更多建議: