99re热视频这里只精品,久久久天堂国产精品女人,国产av一区二区三区,久久久精品成人免费看片,99久久精品免费看国产一区二区三区

PyTorch 強(qiáng)化學(xué)習(xí)(DQN)教程

2025-06-18 17:14 更新

強(qiáng)化學(xué)習(xí)是機(jī)器學(xué)習(xí)領(lǐng)域中一個(gè)充滿活力的分支,它研究如何使智能體在環(huán)境中通過試錯(cuò)的方式學(xué)習(xí)最優(yōu)行為策略,以最大化累積獎(jiǎng)勵(lì)。深度 Q 網(wǎng)絡(luò)(DQN)作為強(qiáng)化學(xué)習(xí)領(lǐng)域的一個(gè)重要突破,將深度學(xué)習(xí)的強(qiáng)大函數(shù)擬合能力與 Q 學(xué)習(xí)算法相結(jié)合,成功解決了高維狀態(tài)空間下的強(qiáng)化學(xué)習(xí)問題。本文將帶領(lǐng)讀者深入淺出地學(xué)習(xí)如何使用 PyTorch 在 OpenAI Gym 的 CartPole-v0 任務(wù)上訓(xùn)練 DQN 智能體,開啟強(qiáng)化學(xué)習(xí)的探索之旅。

一、環(huán)境搭建與準(zhǔn)備工作

在開始訓(xùn)練 DQN 智能體之前,我們需要先搭建好開發(fā)環(huán)境并導(dǎo)入必要的軟件包。

(一)安裝依賴庫(kù)

確保已安裝 PyTorch、OpenAI Gym 和其他所需的依賴庫(kù)??梢允褂靡韵旅钸M(jìn)行安裝:

pip install torch gym matplotlib numpy pillow

(二)導(dǎo)入必要的模塊

import gym
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from itertools import count
from PIL import Image


import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

(三)初始化環(huán)境與設(shè)備

## 創(chuàng)建 CartPole-v0 環(huán)境
env = gym.make('CartPole-v0').unwrapped


## 設(shè)置設(shè)備(GPU 或 CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


## 設(shè)置 matplotlib 交互模式
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display
plt.ion()

二、經(jīng)驗(yàn)回放機(jī)制

經(jīng)驗(yàn)回放是一種通過存儲(chǔ)智能體與環(huán)境交互的經(jīng)驗(yàn),并從中隨機(jī)采樣進(jìn)行學(xué)習(xí)的方法,它可以打破數(shù)據(jù)之間的相關(guān)性,提高模型的穩(wěn)定性和收斂速度。

(一)定義 Transition 和 ReplayMemory

## 定義 Transition 命名元組
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))


## 定義經(jīng)驗(yàn)回放內(nèi)存類
class ReplayMemory(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0


    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity


    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)


    def __len__(self):
        return len(self.memory)

三、構(gòu)建 DQN 模型

DQN 是一個(gè)卷積神經(jīng)網(wǎng)絡(luò),用于根據(jù)當(dāng)前狀態(tài)預(yù)測(cè)每個(gè)動(dòng)作的 Q 值,從而指導(dǎo)智能體選擇最優(yōu)動(dòng)作。

(一)定義 DQN 網(wǎng)絡(luò)結(jié)構(gòu)

class DQN(nn.Module):
    def __init__(self, h, w, outputs):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=2)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2)
        self.bn3 = nn.BatchNorm2d(32)


        def conv2d_size_out(size, kernel_size=5, stride=2):
            return (size - (kernel_size - 1) - 1) // stride + 1


        convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w)))
        convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h)))
        linear_input_size = convw * convh * 32
        self.head = nn.Linear(linear_input_size, outputs)


    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        return self.head(x.view(x.size(0), -1))

四、輸入處理與狀態(tài)提取

從環(huán)境中提取智能體所需的當(dāng)前狀態(tài)信息,并進(jìn)行預(yù)處理,使其適合作為神經(jīng)網(wǎng)絡(luò)的輸入。

(一)定義圖像處理函數(shù)

resize = T.Compose([
    T.ToPILImage(),
    T.Resize(40, interpolation=Image.CUBIC),
    T.ToTensor()
])


def get_cart_location(screen_width):
    world_width = env.x_threshold * 2
    scale = screen_width / world_width
    return int(env.state[0] * scale + screen_width / 2.0)


def get_screen():
    screen = env.render(mode='rgb_array').transpose((2, 0, 1))
    _, screen_height, screen_width = screen.shape
    screen = screen[:, int(screen_height * 0.4):int(screen_height * 0.8)]
    view_width = int(screen_width * 0.6)
    cart_location = get_cart_location(screen_width)
    if cart_location < view_width // 2:
        slice_range = slice(view_width)
    elif cart_location > (screen_width - view_width // 2):
        slice_range = slice(-view_width, None)
    else:
        slice_range = slice(cart_location - view_width // 2, cart_location + view_width // 2)
    screen = screen[:, :, slice_range]
    screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
    screen = torch.from_numpy(screen)
    return resize(screen).unsqueeze(0).to(device)

五、訓(xùn)練 DQN 智能體

(一)設(shè)置超參數(shù)與初始化模型

## 超參數(shù)設(shè)置
BATCH_SIZE = 128
GAMMA = 0.999
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200
TARGET_UPDATE = 10


## 獲取屏幕尺寸
init_screen = get_screen()
_, _, screen_height, screen_width = init_screen.shape


## 獲取動(dòng)作空間維度
n_actions = env.action_space.n


## 初始化策略網(wǎng)絡(luò)和目標(biāo)網(wǎng)絡(luò)
policy_net = DQN(screen_height, screen_width, n_actions).to(device)
target_net = DQN(screen_height, screen_width, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()


## 定義優(yōu)化器
optimizer = optim.RMSprop(policy_net.parameters())


## 初始化經(jīng)驗(yàn)回放內(nèi)存
memory = ReplayMemory(10000)


## 初始化步驟計(jì)數(shù)
steps_done = 0


## 定義選擇動(dòng)作的函數(shù)
def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            return policy_net(state).max(1)[1].view(1, 1)
    else:
        return torch.tensor([[random.randrange(n_actions)]], device=device, dtype=torch.long)


## 定義繪制訓(xùn)練曲線的函數(shù)
episode_durations = []


def plot_durations():
    plt.figure(2)
    plt.clf()
    durations_t = torch.tensor(episode_durations, dtype=torch.float)
    plt.title('Training...')
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(durations_t.numpy())
    if len(durations_t) >= 100:
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        plt.plot(means.numpy())
    plt.pause(0.001)
    if is_ipython:
        display.clear_output(wait=True)
        display.display(plt.gcf())

(二)定義優(yōu)化模型函數(shù)

def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)
    state_action_values = policy_net(state_batch).gather(1, action_batch)
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch
    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))
    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()

(三)執(zhí)行訓(xùn)練循環(huán)

## 訓(xùn)練智能體
num_episodes = 50


for i_episode in range(num_episodes):
    env.reset()
    last_screen = get_screen()
    current_screen = get_screen()
    state = current_screen - last_screen


    for t in count():
        action = select_action(state)
        _, reward, done, _ = env.step(action.item())
        reward = torch.tensor([reward], device=device)
        last_screen = current_screen
        current_screen = get_screen()
        if not done:
            next_state = current_screen - last_screen
        else:
            next_state = None
        memory.push(state, action, next_state, reward)
        state = next_state
        optimize_model()
        if done:
            episode_durations.append(t + 1)
            plot_durations()
            break


    if i_episode % TARGET_UPDATE == 0:
        target_net.load_state_dict(policy_net.state_dict())


print('Complete')
env.render()
env.close()
plt.ioff()
plt.show()

六、總結(jié)與展望

通過本文,您已成功使用 PyTorch 實(shí)現(xiàn)了 DQN 智能體,并在 CartPole-v0 任務(wù)上進(jìn)行了訓(xùn)練。DQN 的核心思想是利用神經(jīng)網(wǎng)絡(luò)來(lái)近似 Q 函數(shù),從而解決高維狀態(tài)空間下的強(qiáng)化學(xué)習(xí)問題。在訓(xùn)練過程中,我們通過經(jīng)驗(yàn)回放機(jī)制和目標(biāo)網(wǎng)絡(luò)來(lái)穩(wěn)定學(xué)習(xí)過程,并采用 epsilon-greedy 策略來(lái)平衡探索與利用。

強(qiáng)化學(xué)習(xí)是一個(gè)廣闊而深刻的領(lǐng)域,DQN 僅是其中的一顆明珠。未來(lái),您可以進(jìn)一步探索其他強(qiáng)化學(xué)習(xí)算法,如深度確定性策略梯度(DDPG)、 proximal 策略優(yōu)化(PPO)等,以應(yīng)對(duì)更復(fù)雜的連續(xù)動(dòng)作空間和多智能體環(huán)境。編程獅將持續(xù)為您帶來(lái)更多強(qiáng)化學(xué)習(xí)和深度學(xué)習(xí)的優(yōu)質(zhì)教程,助力您在人工智能的道路上不斷前行。

以上內(nèi)容是否對(duì)您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號(hào)
微信公眾號(hào)

編程獅公眾號(hào)