Some implementation magic.
RL Cumulative sum Cumulative sum of state-value function with discounting.
scipy.signal.lfilter 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 def discount_cumsum (x, discount ): """ magic from rllab for computing discounted cumulatice sums of vectors ========== input: vector x, [x0, x1, x2] output: [x0 + discount * x1 + discount^2 * x2, x1 + discount *x2, x2] """ return scipy.signal.lfilter([1 ], [1 , float (-discount)], x[::-1 ], axis=0 )[::-1 ]
TD(n)1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 def discount_with_dones (rewards, dones, gamma ): """ """ discounted = [] r = 0 for reward, done in zip (rewards[::-1 ], dones[::-1 ]): r = reward + gamma*r*(1. -done) discounted.append(r) return discounted[::-1 ] mb_rewards, mb_dones, _, _ = replay_buffer.sample(...) for n, (rewards, dones, value) in enumerate (zip (mb_rewards, mb_dones, last_values)): if dones[-1 ] == 0 : rewards = discount_with_dones(rewards+[value], dones+[0 ], self.gamma)[:-1 ] else : rewards = discount_with_dones(rewards, dones, self.gamma) mb_rewards[n] = rewards
Training tricks Entropy normalization 1 2 3 4 5 6 7 8 9 10 11 def entropy (logits ): """ PyTorch """ row_max, _ = torch.max (logits, -1 , True ) a0 = logits - row_max ea0 = torch.exp(a0) z0 = torch.sum (ea0, -1 , True ) p0 = ea0 / z0 return torch.sum (p0 * torch.log(z0) - a0, -1 ) ent = entropy(logits).mean()
1 2 3 4 5 6 7 def entropy (logits ): """ TensorFlow """ a0 = logits - tf.reduce_max(logits, axis=-1 , keepdims=True ) ea0 = tf.exp(a0) z0 = tf.reduce_sum(ea0, axis=-1 , keepdims=True ) p0 = ea0 / z0 return tf.reduce_sum(p0 * (tf.log(z0) - a0), axis=-1 )
all-reduce 1 2 3 4 5 6 7 8 def avg_grad (layers: tuple ): """ (PyTorch) gradients averaging """ size = float (dist.get_world_size()) for layer in layers: for param in layer.parameters(): dist.all_reduce(, op=dist.reduce_op.SUM) /= size
DL One-hot encoding 1 2 3 4 5 6 def one_hot (col: int , row: int , one_hot_index: torch.LongTensor ): """ PyTorch """ y_one_hot = torch.FloatTensor(col, row) y_one_hot.zero_() y_one_hot.scatter_(1 , one_hot_index, 1 ) return y_one_hot