Combine SAC with RNN (part1)

2019 年 6 月 28 日 CreateAMind

With the combination of sac and rnn. we can solve POMDP problem theoretically, but in practice, we face a lot problem.

One of the most important problem is what kind of structure should we use? There are batch of valid choice, for example we can use the full length episode to feed rnn, or we can use a fixed length.

With some investigate in this area, we choose a fixed length rnn. Another crutial problem is how to deal with different length of episodes. Me choice is discard any invalid sequence.

We will release more implement details, stay tuned.

class ReplayBuffer:    """    A simple FIFO experience replay buffer for SAC agents.    """
def __init__(self, obs_dim, act_dim, size, h_size, seq_length, flag="single"): self.flag = flag self.sequence_length = seq_length self.ptr, self.size, self.max_size = 0, 0, size self.obs_dim = obs_dim size += seq_length # in case index is out of range self.obs1_buf = np.zeros([size, obs_dim], dtype=np.float32) self.hidden_buf = np.zeros([size, h_size], dtype=np.float32) self.acts_buf = np.zeros([size, act_dim], dtype=np.float32) self.rews_buf = np.zeros([size, 1], dtype=np.float32) self.done_buf = np.zeros([size, 1], dtype=np.float32) self.target_done_ratio = 0
def store(self, obs, s_t_0, act, rew, done): self.obs1_buf[self.ptr] = obs self.hidden_buf[self.ptr] = s_t_0 self.acts_buf[self.ptr] = act self.rews_buf[self.ptr] = rew self.done_buf[self.ptr] = done self.ptr = (self.ptr + 1) % self.max_size # 1%20=1 2%20=2 21%20=1 self.size = min(self.size + 1, self.max_size) # use self.size to control sample range self.target_done_ratio = np.sum(self.done_buf) / self.size
def sample_batch(self, batch_size=32): """ :param batch_size: :return: s a r s' d """
idxs_c = np.empty([batch_size, self.sequence_length]) # N T+1
for i in range(batch_size): end = False while not end: ind = np.random.randint(0, self.size - 5) # random sample a starting point in current buffer idxs = np.arange(ind, ind + self.sequence_length) # extend seq from starting point is_valid_pos = True if sum(self.done_buf[idxs]) == 0 else (self.sequence_length - np.where(self.done_buf[idxs] == 1)[0][ 0]) == 2
end = True if is_valid_pos else False
idxs_c[i] = idxs
np.random.shuffle(idxs_c) idxs = idxs_c.astype(int) # print(self.target_done_ratio, np.sum(self.done_buf[idxs]) / batch_size) data = dict(obs1=self.obs1_buf[idxs], s_t_0=self.hidden_buf[idxs][:, 0, :], # slide N T H to N H acts=self.acts_buf[idxs], rews=self.rews_buf[idxs], done=self.done_buf[idxs]) return data

登录查看更多
3

相关内容

SAC:Selected Areas in Cryptography。 Explanation:密码术的选择区。 Publisher:Springer。 SIT:http://dblp.uni-trier.de/db/conf/sacrypt/
一份循环神经网络RNNs简明教程,37页ppt
专知会员服务
173+阅读 · 2020年5月6日
Stabilizing Transformers for Reinforcement Learning
专知会员服务
60+阅读 · 2019年10月17日
强化学习最新教程,17页pdf
专知会员服务
182+阅读 · 2019年10月11日
2019年机器学习框架回顾
专知会员服务
36+阅读 · 2019年10月11日
PLANET+SAC代码实现和解读
CreateAMind
3+阅读 · 2019年7月24日
误差反向传播——RNN
统计学习与视觉计算组
18+阅读 · 2018年9月6日
tensorflow LSTM + CTC实现端到端OCR
机器学习研究会
26+阅读 · 2017年11月16日
Auto-Encoding GAN
CreateAMind
7+阅读 · 2017年8月4日
强化学习 cartpole_a3c
CreateAMind
9+阅读 · 2017年7月21日
Arxiv
5+阅读 · 2018年1月29日
VIP会员
相关VIP内容
一份循环神经网络RNNs简明教程,37页ppt
专知会员服务
173+阅读 · 2020年5月6日
Stabilizing Transformers for Reinforcement Learning
专知会员服务
60+阅读 · 2019年10月17日
强化学习最新教程,17页pdf
专知会员服务
182+阅读 · 2019年10月11日
2019年机器学习框架回顾
专知会员服务
36+阅读 · 2019年10月11日
相关资讯
PLANET+SAC代码实现和解读
CreateAMind
3+阅读 · 2019年7月24日
误差反向传播——RNN
统计学习与视觉计算组
18+阅读 · 2018年9月6日
tensorflow LSTM + CTC实现端到端OCR
机器学习研究会
26+阅读 · 2017年11月16日
Auto-Encoding GAN
CreateAMind
7+阅读 · 2017年8月4日
强化学习 cartpole_a3c
CreateAMind
9+阅读 · 2017年7月21日
Top
微信扫码咨询专知VIP会员