全部代码
https://github.com/ColinFred/Reinforce_Learning_Pytorch/tree/main/RL/DQN
一、优先回放
在经验回放中是利用均匀分布采样,而这种方式看上去并不高效,对于智能体而言,这些数据的重要程度并不一样,因此提出优先回放(Prioritized Replay)的方法。优先回放的基本思想就是打破均匀采样,赋予学习效率高的样本以更大的采样权重。
一个理想的标准是智能体学习的效率越高,权重越大。符合该标准的一个选择是TD偏差δ。TD偏差越大,说明该状态处的值函数与TD目标的差距越大,智能体的更新量越大,因此该处的学习效率越高。
简而言之,就是在原来的replay buffer中给每个Transition增加了抽样的优先级(priority)
优先回放DQN主要有三点改变:
1, 为了方便优先回放存储与及采样,采用sumTree树来存储;
原文有两种方法计算样本抽样概率:proportional priority和rank-based priority。proportional priority就是样本被sample到的概率是正比于TD偏差的priority;rank-based priority就是概率正比于Transition priority的排序(rank)。这里考虑proportional priority,Transition被抽到的概率与TD偏差成正比。
并且,为保证每一个存入的Transition都能被sample到,新Transition会被赋予一个很大的priority。
2, 目标函数在计算时根据样本的TD偏差添加了权重(权重和TD偏差有关,偏差越大,权重越大):
1 m ∑ j = 1 m w j ( y j − Q ( s j , a j , w ) ) 2 \frac{1}{m}\sum\limits_{j=1}^m w_j (y_j-Q(s_j, a_j, w))^2 m1j=1∑mwj(yj−Q(sj,aj,w))2
3,每次更新Q网络参数时,都需要重新计算TD误差 δ j = y j − Q ( s j , a j , w ) \delta_j = y_j- Q(s_j, a_j, w) δj=yj−Q(sj,aj,w)
二、代码
Prioritized experience replay 结合之前的 Double DQN 和 Dueling DQN
SumTree和ReplayMemory_Per
SumTree主要实现:add()添加experience;get()按priority抽样;update()更新某个Transition的priority。
ReplayMemory_Per主要实现:push()插入新experience;sample()按priority抽样Transition;update()更新已有经验的priority
class SumTree:
write = 0
def __init__(self, capacity):
self.capacity = capacity
self.tree = np.zeros(2 * capacity - 1)
self.data = np.zeros(capacity, dtype=object)
self.n_entries = 0
# update to the root node
def _propagate(self, idx, change):
parent = (idx - 1) // 2
self.tree[parent] += change
if parent != 0:
self._propagate(parent, change)
# find sample on leaf node
def _retrieve(self, idx, s):
left = 2 * idx + 1
right = left + 1
if left >= len(self.tree):
return idx
if s <= self.tree[left]:
return self._retrieve(left, s)
else:
return self._retrieve(right, s - self.tree[left])
def total(self):
return self.tree[0]
# store priority and sample
def add(self, p, data):
idx = self.write + self.capacity - 1
self.data[self.write] = data
self.update(idx, p)
self.write += 1
if self.write >= self.capacity:
self.write = 0
if self.n_entries < self.capacity:
self.n_entries += 1
# update priority
def update(self, idx, p):
change = p - self.tree[idx]
self.tree[idx] = p
self._propagate(idx, change)
# get priority and sample
def get(self, s):
idx = self._retrieve(0, s)
dataIdx = idx - self.capacity + 1
return (idx, self.tree[idx], self.data[dataIdx])
class ReplayMemory_Per(object):
# stored as ( s, a, r, s_ ) in SumTree
def __init__(self, capacity=1000, a=0.6, e=0.01):
self.tree = SumTree(capacity)
self.memory_size = capacity
self.prio_max = 0.1
self.a = a
self.e = e
def push(self, *args):
data = Transition(*args)
p = (np.abs(self.prio_max) + self.e) ** self.a # proportional priority
self.tree.add(p, data)
def sample(self, batch_size):
idxs = []
segment = self.tree.total() / batch_size
sample_datas = []
for i in range(batch_size):
a = segment * i
b = segment * (i + 1)
s = uniform(a, b)
idx, p, data = self.tree.get(s)
sample_datas.append(data)
idxs.append(idx)
return idxs, sample_datas
def update(self, idxs, errors):
self.prio_max = max(self.prio_max, max(np.abs(errors)))
for i, idx in enumerate(idxs):
p = (np.abs(errors[i]) + self.e) ** self.a
self.tree.update(idx, p)
def size(self):
return self.tree.n_entries
每次更新Q网络参数时,都需要重新计算TD误差,并且更新SumTree。
关于目标函数在计算时根据样本的TD偏差添加了权重这一点并未采用
class PerDQN:
def __init__(self, n_action, n_state, learning_rate):
self.n_action = n_action
self.n_state = n_state
self.memory = ReplayMemory_Per(capacity=100)
self.memory_counter = 0
self.model_policy = DNN(self.n_state, self.n_action)
self.model_target = DNN(self.n_state, self.n_action)
self.model_target.load_state_dict(self.model_policy.state_dict())
self.model_target.eval()
self.optimizer = optim.Adam(self.model_policy.parameters(), lr=learning_rate)
def store_transition(self, s, a, r, s_):
state = torch.FloatTensor([s])
action = torch.LongTensor([a])
reward = torch.FloatTensor([r])
next_state = torch.FloatTensor([s_])
self.memory.push(state, action, next_state, reward)
def choose_action(self, state):
state = torch.FloatTensor(state)
if np.random.randn() <= EPISILO: # greedy policy
with torch.no_grad():
q_value = self.model_policy(state)
action = q_value.max(0)[1].view(1, 1).item()
else: # random policy
action = torch.tensor([randrange(self.n_action)], dtype=torch.long).item()
return action
def learn(self):
if self.memory.size() < BATCH_SIZE:
return
idxs, transitions = self.memory.sample(BATCH_SIZE)
batch = Transition(*zip(*transitions))
state_batch = torch.cat(batch.state)
action_batch = torch.cat(batch.action).unsqueeze(1)
reward_batch = torch.cat(batch.reward)
next_state_batch = torch.cat(batch.next_state)
state_action_values = self.model_policy(state_batch).gather(1, action_batch)
next_action_batch = torch.unsqueeze(self.model_policy(next_state_batch).max(1)[1], 1)
next_state_values = self.model_target(next_state_batch).gather(1, next_action_batch)
expected_state_action_values = (next_state_values * GAMMA) + reward_batch.unsqueeze(1)
td_errors = (state_action_values - expected_state_action_values).detach().squeeze().tolist()
self.memory.update(idxs, td_errors) # update td error
loss = F.mse_loss(state_action_values, expected_state_action_values)
self.optimizer.zero_grad()
loss.backward()
for param in self.model_policy.parameters():
param.grad.data.clamp_(-1, 1)
self.optimizer.step()
def update_target_network(self):
self.model_target.load_state_dict(self.model_policy.state_dict())
参考
- https://zhuanlan.zhihu.com/p/128176891
- https://www.cnblogs.com/jiangxinyang/p/10112381.html