強(qiáng)化學(xué)習(xí)是機(jī)器學(xué)習(xí)領(lǐng)域中一個(gè)充滿活力的分支,它研究如何使智能體在環(huán)境中通過試錯(cuò)的方式學(xué)習(xí)最優(yōu)行為策略,以最大化累積獎(jiǎng)勵(lì)。深度 Q 網(wǎng)絡(luò)(DQN)作為強(qiáng)化學(xué)習(xí)領(lǐng)域的一個(gè)重要突破,將深度學(xué)習(xí)的強(qiáng)大函數(shù)擬合能力與 Q 學(xué)習(xí)算法相結(jié)合,成功解決了高維狀態(tài)空間下的強(qiáng)化學(xué)習(xí)問題。本文將帶領(lǐng)讀者深入淺出地學(xué)習(xí)如何使用 PyTorch 在 OpenAI Gym 的 CartPole-v0 任務(wù)上訓(xùn)練 DQN 智能體,開啟強(qiáng)化學(xué)習(xí)的探索之旅。
在開始訓(xùn)練 DQN 智能體之前,我們需要先搭建好開發(fā)環(huán)境并導(dǎo)入必要的軟件包。
確保已安裝 PyTorch、OpenAI Gym 和其他所需的依賴庫(kù)??梢允褂靡韵旅钸M(jìn)行安裝:
pip install torch gym matplotlib numpy pillow
import gym
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from itertools import count
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
## 創(chuàng)建 CartPole-v0 環(huán)境
env = gym.make('CartPole-v0').unwrapped
## 設(shè)置設(shè)備(GPU 或 CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
## 設(shè)置 matplotlib 交互模式
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
from IPython import display
plt.ion()
經(jīng)驗(yàn)回放是一種通過存儲(chǔ)智能體與環(huán)境交互的經(jīng)驗(yàn),并從中隨機(jī)采樣進(jìn)行學(xué)習(xí)的方法,它可以打破數(shù)據(jù)之間的相關(guān)性,提高模型的穩(wěn)定性和收斂速度。
## 定義 Transition 命名元組
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))
## 定義經(jīng)驗(yàn)回放內(nèi)存類
class ReplayMemory(object):
def __init__(self, capacity):
self.capacity = capacity
self.memory = []
self.position = 0
def push(self, *args):
"""Saves a transition."""
if len(self.memory) < self.capacity:
self.memory.append(None)
self.memory[self.position] = Transition(*args)
self.position = (self.position + 1) % self.capacity
def sample(self, batch_size):
return random.sample(self.memory, batch_size)
def __len__(self):
return len(self.memory)
DQN 是一個(gè)卷積神經(jīng)網(wǎng)絡(luò),用于根據(jù)當(dāng)前狀態(tài)預(yù)測(cè)每個(gè)動(dòng)作的 Q 值,從而指導(dǎo)智能體選擇最優(yōu)動(dòng)作。
class DQN(nn.Module):
def __init__(self, h, w, outputs):
super(DQN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=2)
self.bn1 = nn.BatchNorm2d(16)
self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
self.bn2 = nn.BatchNorm2d(32)
self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2)
self.bn3 = nn.BatchNorm2d(32)
def conv2d_size_out(size, kernel_size=5, stride=2):
return (size - (kernel_size - 1) - 1) // stride + 1
convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w)))
convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h)))
linear_input_size = convw * convh * 32
self.head = nn.Linear(linear_input_size, outputs)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
return self.head(x.view(x.size(0), -1))
從環(huán)境中提取智能體所需的當(dāng)前狀態(tài)信息,并進(jìn)行預(yù)處理,使其適合作為神經(jīng)網(wǎng)絡(luò)的輸入。
resize = T.Compose([
T.ToPILImage(),
T.Resize(40, interpolation=Image.CUBIC),
T.ToTensor()
])
def get_cart_location(screen_width):
world_width = env.x_threshold * 2
scale = screen_width / world_width
return int(env.state[0] * scale + screen_width / 2.0)
def get_screen():
screen = env.render(mode='rgb_array').transpose((2, 0, 1))
_, screen_height, screen_width = screen.shape
screen = screen[:, int(screen_height * 0.4):int(screen_height * 0.8)]
view_width = int(screen_width * 0.6)
cart_location = get_cart_location(screen_width)
if cart_location < view_width // 2:
slice_range = slice(view_width)
elif cart_location > (screen_width - view_width // 2):
slice_range = slice(-view_width, None)
else:
slice_range = slice(cart_location - view_width // 2, cart_location + view_width // 2)
screen = screen[:, :, slice_range]
screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
screen = torch.from_numpy(screen)
return resize(screen).unsqueeze(0).to(device)
## 超參數(shù)設(shè)置
BATCH_SIZE = 128
GAMMA = 0.999
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200
TARGET_UPDATE = 10
## 獲取屏幕尺寸
init_screen = get_screen()
_, _, screen_height, screen_width = init_screen.shape
## 獲取動(dòng)作空間維度
n_actions = env.action_space.n
## 初始化策略網(wǎng)絡(luò)和目標(biāo)網(wǎng)絡(luò)
policy_net = DQN(screen_height, screen_width, n_actions).to(device)
target_net = DQN(screen_height, screen_width, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
## 定義優(yōu)化器
optimizer = optim.RMSprop(policy_net.parameters())
## 初始化經(jīng)驗(yàn)回放內(nèi)存
memory = ReplayMemory(10000)
## 初始化步驟計(jì)數(shù)
steps_done = 0
## 定義選擇動(dòng)作的函數(shù)
def select_action(state):
global steps_done
sample = random.random()
eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps_done / EPS_DECAY)
steps_done += 1
if sample > eps_threshold:
with torch.no_grad():
return policy_net(state).max(1)[1].view(1, 1)
else:
return torch.tensor([[random.randrange(n_actions)]], device=device, dtype=torch.long)
## 定義繪制訓(xùn)練曲線的函數(shù)
episode_durations = []
def plot_durations():
plt.figure(2)
plt.clf()
durations_t = torch.tensor(episode_durations, dtype=torch.float)
plt.title('Training...')
plt.xlabel('Episode')
plt.ylabel('Duration')
plt.plot(durations_t.numpy())
if len(durations_t) >= 100:
means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
means = torch.cat((torch.zeros(99), means))
plt.plot(means.numpy())
plt.pause(0.001)
if is_ipython:
display.clear_output(wait=True)
display.display(plt.gcf())
def optimize_model():
if len(memory) < BATCH_SIZE:
return
transitions = memory.sample(BATCH_SIZE)
batch = Transition(*zip(*transitions))
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool)
non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
state_batch = torch.cat(batch.state)
action_batch = torch.cat(batch.action)
reward_batch = torch.cat(batch.reward)
state_action_values = policy_net(state_batch).gather(1, action_batch)
next_state_values = torch.zeros(BATCH_SIZE, device=device)
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
expected_state_action_values = (next_state_values * GAMMA) + reward_batch
loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))
optimizer.zero_grad()
loss.backward()
for param in policy_net.parameters():
param.grad.data.clamp_(-1, 1)
optimizer.step()
## 訓(xùn)練智能體
num_episodes = 50
for i_episode in range(num_episodes):
env.reset()
last_screen = get_screen()
current_screen = get_screen()
state = current_screen - last_screen
for t in count():
action = select_action(state)
_, reward, done, _ = env.step(action.item())
reward = torch.tensor([reward], device=device)
last_screen = current_screen
current_screen = get_screen()
if not done:
next_state = current_screen - last_screen
else:
next_state = None
memory.push(state, action, next_state, reward)
state = next_state
optimize_model()
if done:
episode_durations.append(t + 1)
plot_durations()
break
if i_episode % TARGET_UPDATE == 0:
target_net.load_state_dict(policy_net.state_dict())
print('Complete')
env.render()
env.close()
plt.ioff()
plt.show()
通過本文,您已成功使用 PyTorch 實(shí)現(xiàn)了 DQN 智能體,并在 CartPole-v0 任務(wù)上進(jìn)行了訓(xùn)練。DQN 的核心思想是利用神經(jīng)網(wǎng)絡(luò)來(lái)近似 Q 函數(shù),從而解決高維狀態(tài)空間下的強(qiáng)化學(xué)習(xí)問題。在訓(xùn)練過程中,我們通過經(jīng)驗(yàn)回放機(jī)制和目標(biāo)網(wǎng)絡(luò)來(lái)穩(wěn)定學(xué)習(xí)過程,并采用 epsilon-greedy 策略來(lái)平衡探索與利用。
強(qiáng)化學(xué)習(xí)是一個(gè)廣闊而深刻的領(lǐng)域,DQN 僅是其中的一顆明珠。未來(lái),您可以進(jìn)一步探索其他強(qiáng)化學(xué)習(xí)算法,如深度確定性策略梯度(DDPG)、 proximal 策略優(yōu)化(PPO)等,以應(yīng)對(duì)更復(fù)雜的連續(xù)動(dòng)作空間和多智能體環(huán)境。編程獅將持續(xù)為您帶來(lái)更多強(qiáng)化學(xué)習(xí)和深度學(xué)習(xí)的優(yōu)質(zhì)教程,助力您在人工智能的道路上不斷前行。
更多建議: