Loading...
墨滴

hugang

2021/10/15  阅读:46  主题:默认主题

重要性采样

重要性采样

一、理论部分

应用于 off-policy

当需要计算一个采样 p(x) 的期望时,使用

此时因为 p(x) 不易采样,所以当仍然需要计算 p(x) 的期望的时候,需要用到另一个简单的采样 q(x) 来近似替代 p(x),所以有:

所以在强化学习中,当时用旧策略的采样数据取更新新策略的网络参数的时候,需要新旧策略之间的差异性尽量的小,所以需要乘重要性权重,也就是

也就有损失函数中的


要求 p(x) 和 q(x) 之间的差异不大,虽然p(x) 和 q(x) 之间是无偏估计,但是当p(x) 和 q(x) 之间的差异变大时,为了保证无偏估计,需要进行大量采样,这样会增大两者之间的方差。

时, ;使用重要性采样,即 时,有如下表

采样次数 均值 方差
300 0.7845907879880836 12.730297162372766
1000 0.4805234602092745 3.8211458228792354
3000 0.858111538077771 79.2243886580859

二、代码:

import numpy as np
import pandas as pd

import seaborn as sns
import scipy.stats as stats
import matplotlib.pyplot as plt
%matplotlib inline

import warnings
warnings.filterwarnings('ignore')
def f_x(x):
    return 1/(1+np.exp(-x))
fig = plt.figure(figsize = (82))
ax = fig.add_subplot(111)

#绘制图形
x = np.linspace(04100)
ax.plot(x, f_x(x), 'b', label = r'$\frac{1}{1+exp(-x)} $')
ax.legend(fontsize = 17)
1
# 实验一
## 当 p(x) 和 q(x) 之间的差异不大时

#p(x)
mu_target = 3.5
sigma_target = 1
p_x = [np.random.normal(mu_target, sigma_target) for _ in range(3000)]

#q(x)
mu_appro = 3
sigma_appro = 1
q_x = [np.random.normal(mu_appro, sigma_appro) for _ in range(3000)]

fig = plt.figure(figsize = (106))
ax = fig.add_subplot(1,1,1)

#画出两个图形的分布
sns.distplot(p_x, label="distribution $p(x)$")
sns.distplot(q_x, label="distribution $q(x)$")

plt.title('Distribution', size = 16)
plt.legend()

![2](C:\Users\TR\Desktop\文档\10 深度强化学习\2.png)

# x~p(x)
np.mean([f_x(i) for i in p_x])

out = 0.9570351382988408

# x~q(x)
#stats.norm:随机变量的概率密度函数
p_pdf = stats.norm(mu_target, sigma_target)
q_pdf = stats.norm(mu_appro, sigma_appro)

print(np.mean([p_pdf.pdf(i) / q_pdf.pdf(i) * f_x(i) for i in q_x]))
print(np.var([p_pdf.pdf(i) / q_pdf.pdf(i) * f_x(i) for i in q_x]))

out = {0.96145124567806340.2937305136118007}
# 实验二
## 两个分布相差较远

#p(x)
mu_target = 3.5
sigma_target = 1
p_x = [np.random.normal(mu_target, sigma_target) for _ in range(3000)]

#q(x)
mu_appro = 1
sigma_appro = 1
q_x = [np.random.normal(mu_appro, sigma_appro) for _ in range(3000)]

fig = plt.figure(figsize = (10, 6))
ax = fig.add_subplot(1, 1, 1)

sns.distplot(p_x, label = '$p(x)$')
sns.distplot(q_x, label = '$q(x)$')

plt.title('Distributions')
plt.legend()

![3](C:\Users\TR\Desktop\文档\10 深度强化学习\3.png)

#x~p(X)
np.mean([f_x(i) for i in p_x])

out=0.9542214426246207

def get_Normal_And_Var(total):
    p_x = [np.random.normal(mu_target, sigma_target) for _ in range(total)]
    q_x = [np.random.normal(mu_appro, sigma_appro) for _ in range(total)]
    
    p_pdf = stats.norm(mu_target, sigma_target)
    q_pdf = stats.norm(mu_appro, sigma_appro)
    
    normal = np.mean([p_pdf.pdf(i) / q_pdf.pdf(i) * f_x(i) for i in q_x])
    var = np.var([p_pdf.pdf(i) / q_pdf.pdf(i) * f_x(i) for i in q_x])
    return normal, var

#当采样数量为300
normal, var = get_Normal_And_Var(300)
print(normal)
print(var)

out={0.784590787988083612.730297162372766}

#当采样数量为 1000
normal, var = get_Normal_And_Var(1000)
print(normal)
print(var)

out={0.48052346020927453.8211458228792354}

#当采样数量为 3000
normal, var = get_Normal_And_Var(3000)
print(normal)
print(var)

out={0.858111538077771,79.2243886580859}

三、链接

hugang

2021/10/15  阅读:46  主题:默认主题

作者介绍

hugang