【深度强化学习】9. Policy Gradient实现中核心部分torch.distributions

news/2024/5/19 0:11:15 标签: 人工智能, 强化学习, 深度学习, 神经网络

【导语】:在深度强化学习第四篇中,讲了Policy Gradient的理论。通过最终推导得到的公式,本文用PyTorch简单实现以下,并且尽可能搞清楚torch.distribution的使用方法。代码参考了LeeDeepRl-Notes中的实现。

1. 复习

θ ← θ + η ∇ R θ ˉ ∇ R θ ˉ = 1 N ∑ n = 1 N ∑ t = 1 T n R ( τ n ) ∇ l o g p θ ( a t n ∣ s t n ) \theta \leftarrow \theta+\eta \nabla \bar{R_\theta} \\\nabla \bar{R_\theta}=\frac{1}{N}\sum^N_{n=1}\sum^{T_n}_{t=1}R(\tau^n)\nabla log p_\theta(a_t^n|s_t^n) θθ+ηRθˉRθˉ=N1n=1Nt=1TnR(τn)logpθ(atnstn)

θ \theta θ代表模型的参数,第一行公式代表了模型进行更新的方法, η \eta η 代表的是学习率。

第二行是推导得到的,和CrossEntropy可以对照着理解记忆。

2. Torch.Distributions

distributions包主要是实现了参数化的概率分布和采样函数。参数化是为了让模型能够具有反向传播的能力,这样才可以用随机梯度下降的方法来进行优化。随机采样的话没办法直接反向传播,有两个方法,REINFORCE和pathwise derivative estimator。

Torch中提供两个方法,sample()和log_prob(),就可以实现REINFORCE
Δ θ = α r ∂ log ⁡ p ( a ∣ π θ ( s ) ) ∂ θ \Delta \theta=\alpha r \frac{\partial \log p\left(a \mid \pi^{\theta}(s)\right)}{\partial \theta} Δθ=αrθlogp(aπθ(s))

θ \theta θ是模型参数, α \alpha α代表的是学习率,r代表reward, p ( a ∣ π θ ( s ) ) p\left(a \mid \pi^{\theta}(s)\right) p(aπθ(s))代表在状态s下,使用策略 π θ \pi^{\theta} πθ采取a动作的概率。

2.1 REINFORCE

实现的时候,会先从网络输出构造一个分布,然后从分布中采样一个action,将action作用于环境,然后使用log_prob()函数来构建一个损失函数,代码如下(PyTorch官方提供):

probs = policy_network(state)
# Note that this is equivalent to what used to be called multinomial
m = Categorical(probs)
action = m.sample()
next_state, reward = env.step(action)
loss = -m.log_prob(action) * reward
loss.backward()

对照一下,这个-m.log_prob(action)应该对应上述公式: log ⁡ p ( a ∣ π θ ( s ) ) \log p\left(a \mid \pi^{\theta}(s)\right) logp(aπθ(s)), 加负号的原因是,在公式中应该是实现的梯度上升算法,而loss一般使用随机梯度下降的,所以加个负号保持一致性。

2.2 PathWise Derivative Estimator

这是一种重参数化技巧,主要是通过调用rsample()函数来实现的,参数化随机变量可以通过无参数随机变量的参数化确定性函数来构造。参数化以后,采样过程就变得可微分了,也就支持了网络的后向传播。实现如下(PyTorch官方实现):

params = policy_network(state)
m = Normal(*params)
# Any distribution with .has_rsample == True could work based on the application
action = m.rsample()
next_state, reward = env.step(action)  # Assuming that reward is differentiable
loss = -reward
loss.backward()

这样的话,可以直接对-reward使用随机梯度下降,因为rsample后可微分,可以后向传播。

3. 源码

主要看agent对象的实现:

class PolicyGradient:
    def __init__(self, state_dim, device='cpu', gamma=0.99, lr=0.01, batch_size=5):
        self.gamma = gamma
        self.policy_net = FCN(state_dim)
        self.optimizer = torch.optim.RMSprop(
            self.policy_net.parameters(), lr=lr)
        self.batch_size = batch_size

    def choose_action(self, state):
        state = torch.from_numpy(state).float()
        state = Variable(state)
        probs = self.policy_net(state)
        m = Bernoulli(probs)
        action = m.sample()

        action = action.data.numpy().astype(int)[0]  # 转为标量
        return action

    def update(self, reward_pool, state_pool, action_pool):
        # Discount reward
        running_add = 0 # 就是那个有discount的公式

        for i in reversed(range(len(reward_pool))): # 倒数
            if reward_pool[i] == 0:
                running_add = 0
            else:
                running_add = running_add * self.gamma + reward_pool[i]
                reward_pool[i] = running_add
        # 得到G

        # Normalize reward
        reward_mean = np.mean(reward_pool)
        reward_std = np.std(reward_pool)
        for i in range(len(reward_pool)):
            reward_pool[i] = (reward_pool[i] - reward_mean) / reward_std
        # 归一化

        # Gradient Desent
        self.optimizer.zero_grad()

        for i in range(len(reward_pool)): # 从前往后
            state = state_pool[i] 
            action = Variable(torch.FloatTensor([action_pool[i]]))
            reward = reward_pool[i]

            state = Variable(torch.from_numpy(state).float())
            probs = self.policy_net(state)
            m = Bernoulli(probs)
            # Negtive score function x reward
            loss = -m.log_prob(action) * reward # 核心
            # print(loss)
            loss.backward()

        self.optimizer.step()

    def save_model(self, path):
        torch.save(self.policy_net.state_dict(), path)

    def load_model(self, path):
        self.policy_net.load_state_dict(torch.load(path))

可以看到核心实现是以下几句:

state = Variable(torch.from_numpy(state).float())
probs = self.policy_net(state)
m = Bernoulli(probs)
# Negtive score function x reward
loss = -m.log_prob(action) * reward # 核心
# print(loss)
loss.backward()

这里采用的是伯努利分布,二项分布,举个例子:

Example::
        >>> m = Bernoulli(torch.tensor([0.3]))
        >>> m.sample()  # 30% chance 1; 70% chance 0
        tensor([ 0.])

采样结果是0或者1,1对应的概率是p,0对应概率是1-p。

为神马要用这个伯努利分布呢?因为这个这个问题是CartPole-v0 ,其动作空间只有0或1,所以这里采用了Bernoulli,其他情况要使用不同的分布才能满足要求。

得到了采样结果以后,就是用了第二节提到的REINFORCE的方法计算loss,进行loss反向传播。

4. 总结

简单介绍了以下如何使用,但并没有深究背后的原理,这个系列会继续更新,同时我也会继续加强我的数学功底。


http://www.niftyadmin.cn/n/1003391.html

相关文章

如何最快的融入游戏

刚才看阿俊的博客,他提到一句话“无论做什么事业,如果你想成功,你要跟两个人玩,第一个就是懂游戏规则的人、第二就是和顾问一起玩”。感觉很有道理,和大家分享一下! 这句话给我感触比较深,…

Android提高第五篇之Service

上次介绍了 Activity以及Intent的使用 , 这次就介绍Service,如果把Activity比喻为前台程序,那么Service就是后台程序,Service的整个生命周期都只会在后台执行。 Service跟Activity一样也由Intent调用。在工程里想要添加一个Ser…

超外差和超再生模块有何区别?

http://www.xie-gang.com/df.htm 一、超再生接收电路 超再生解调电路也称超再生检波电路,它实际上是工作在间歇振荡状态下的再生检波电路。一般再生检波电路在中波段工作时灵敏度很高,所以常用来制作简易晶体管收音机。对于工作于短波段的无线遥控或通…

遇到的PyTorch API

文章目录1. torch.chunk2. nn.GroupNorm3. torch.permute1. torch.chunk torch.chunk(input, chunks, dim0) → List of Tensors 将input tensor划分成特定的块数,每个块都是input tensor的一个视图,最后一个块可能会小一点,因为不能被dim整…

转:IPhone之ASIFormDataRequest POST操作架构设计/ 处理网络超时问题

//开启iphone网络开关 [UIApplication sharedApplication].networkActivityIndicatorVisible YES;ASIFormDataRequest *request [[ASIFormDataRequest alloc] initWithURL:[NSURLURLWithString:host]]; //超时时间request.timeOutSeconds 30;//定义异步方法[request setDele…

Android提高第六篇之BroadcastReceiver

前面分别讨论了Activity 和Service , 这次就轮到BroastcastReceiver,Broastcast是应用程序间通信的手段。BroastcastReceiver也是跟 Intent紧密相连的,动态/静态注册了BroastcastReceiver之后,使用sendBroadcast把Intent发送之后&…

MongoDB之基础篇

MongoDB基础篇一、走进MongoDBMongoDB 是一个高性能,开源,面向集合,无模式的文档型数据库。它在许多场景下可用于替代传统的关系型数据库或键-值存储方式,MongoDB 使用C开发。1.1、初识MongoDBMongoDB 是一个介于关系数据库和非关…

Linux - date命令

一、设置时间1、只修改日期:#date -s 2007-08-032、只修改时间:#date -s 14:15:003、同时修改日期和时间(加双引号,日期与时间之间加空格):#date -s "2007-08-03 14:15:00"4、以 root 身分更改了系统时间之后&#xff0…