Loading...
墨滴

廿

2021/10/27  阅读:17  主题:默认主题

KL Divergence

一、基本定义

假设给定事件 , 则我们有以下定义:

1.Probaility:
取值0~1

2.Information:
取对数,加符号得正值

概率越高,包含的信息小,因为事件越来越确定。相反,概率越低,包含的信息越多,因为事件具有很大的不确定性。

3.(Shannon)Entropy:
平均

熵是信息的平均,直观上,Shannon熵是信息在同一分布 下的平均。

4.Cross-Entropy
平均

熵是信息的平均,直观上,交叉熵是信息在不同分布下的平均。

5.KL divergence(Relative entropy/Information gain):

  1. 相对熵 = 交叉熵 - shannon熵
  2. 非对称 ,亦不满足三角不等式,故不是距离。
  3. 相对于 ,值非负,取零若 。从公式上看,就是拿 替代 后熵的变化。
  4. KL = Kullback-Leibler

二、 KL divergence 一些性质(非正式)证明

1. 非对称性

易知,当 上式不为0。故, 非对称,是不同的。(此部分侧重于说明它们不是不同的)

2. 非负性

其中,不等式部分使用了Jensen's inequality

3. 凹性

其中,不等式部分用到了log sum inequality

三、最小化KL divergence目标函数

为了方便说明,我们基于定义在某个空间 上的分布 来重写一下KL, 如下所示:

: P基于Q的KL,或从Q到P的KL,此处称为正向KL
: Q基于P的KL,或从P到Q的KL,此处称为反向KL

假设, 为真实的分布函数,我们想要用带参数 的分布函数 ,即 ,去近似 。也就是说,通过选取参数 , 让 在某种意义上具有相似性。下面,我们分别将选取正向KL和反向KL做为目标函数进行说明。为了方便,我们假设 为双峰分布, 为正太分布,故 包含均值和方差两个参数。

双峰分布

双峰分布

1. 最小化正向KL目标函数

目标函数如下:

(你也可以从信息/熵的角度去理解)从此处可以看出最小化正向KL目标函数,其实是等价于通过 进行最大似然估计。也就是说,数据 产生,基于这些数据,我们选取 让平均在 上的 似然函数最大,即:

平均概率高的地方, 概率也要高

所以我们有下图mean-seeking的结果 最小化正向KL

正向求最小

2. 最小化反向KL目标函数

目标函数如下:

此时,我们需要选取参数 ,让平均在 上的 似然函数最大;同时,让Shannon熵 也比较大,即约束 不要过于集中。总的来看,我们有:

平均概率高的地方, 概率也要高,但 不能过于集中

可以想象,如果没有 的约束,可能会调整\theta,让 集中于 最大的地方,得到的值也会比较大。所以, 起到了一个正则化(regularization)的效果。

所以我们有下图mode-seeking 的结果: Mode-seeking

反向求最小

正向最小化和反向最小化放在一起对比正反对比

正向和反向最小化
此部分代码来自3,但在调用logsumexp时,有点问题,故做了一个微小改动,代码放在最后。

选哪个方向最小化

  1. 更关注mean-seeking 还是mode-seeking
  2. 是否便于取sample进行计算,如果某个分布未知,取不了sample,则考虑用另外一个方向。
  3. 监督学习可考虑正向,强化学习可考虑反向。

相关参考资料

  1. Cover, T. M., and J. A. Thomas. "Elements of Information Theory,(2nd edn, 2006)." DOI: https://doi. org/10.1002 X 47174882 (2006).
  2. https://dibyaghosh.com/blog/probability/kldivergence.html
  3. https://www.tuananhle.co.uk/notes/reverse-forward-kl.html

正反最小化代码

import numpy as np
import scipy as sp
import scipy.stats
import matplotlib.pyplot as plt

class GaussianMixture1D:
    def __init__(self, mixture_probs, means, stds):
        self.num_mixtures = len(mixture_probs)
        self.mixture_probs = mixture_probs
        self.means = means
        self.stds = stds

    def sample(self, num_samples=1):
        mixture_ids = np.random.choice(self.num_mixtures, size=num_samples, p=self.mixture_probs)
        result = np.zeros([num_samples])
        for sample_idx in range(num_samples):
            result[sample_idx] = np.random.normal(
                loc=self.means[mixture_ids[sample_idx]],
                scale=self.stds[mixture_ids[sample_idx]]
            )
        return result

    def logpdf(self, samples):
        mixture_logpdfs = np.zeros([len(samples), self.num_mixtures])
        for mixture_idx in range(self.num_mixtures):
            mixture_logpdfs[:, mixture_idx] = scipy.stats.norm.logpdf(
                samples,
                loc=self.means[mixture_idx],
                scale=self.stds[mixture_idx]
            )
        return sp.special.logsumexp(mixture_logpdfs + np.log(self.mixture_probs), axis=1)

    def pdf(self, samples):
        return np.exp(self.logpdf(samples))


def approx_kl(gmm_1, gmm_2, xs):
    ys = gmm_1.pdf(xs) * (gmm_1.logpdf(xs) - gmm_2.logpdf(xs))
    return np.trapz(ys, xs)


def minimize_pq(p, xs, q_means, q_stds):
    q_mean_best = None
    q_std_best = None
    kl_best = np.inf
    for q_mean in q_means:
        for q_std in q_stds:
            q = GaussianMixture1D(np.array([1]), np.array([q_mean]), np.array([q_std]))
            kl = approx_kl(p, q, xs)
            if kl < kl_best:
                kl_best = kl
                q_mean_best = q_mean
                q_std_best = q_std

    q_best = GaussianMixture1D(np.array([1]), np.array([q_mean_best]), np.array([q_std_best]))
    return q_best, kl_best


def minimize_qp(p, xs, q_means, q_stds):
    q_mean_best = None
    q_std_best = None
    kl_best = np.inf
    for q_mean in q_means:
        for q_std in q_stds:
            q = GaussianMixture1D(np.array([1]), np.array([q_mean]), np.array([q_std]))
            kl = approx_kl(q, p, xs)
            if kl < kl_best:
                kl_best = kl
                q_mean_best = q_mean
                q_std_best = q_std

    q_best = GaussianMixture1D(np.array([1]), np.array([q_mean_best]), np.array([q_std_best]))
    return q_best, kl_best
    
def main():
    p_second_means_min = 0
    p_second_means_max = 10
    num_p_second_means = 5
    p_second_mean_list = np.linspace(p_second_means_min, p_second_means_max, num_p_second_means)

    p = [None] * num_p_second_means
    q_best_forward = [None] * num_p_second_means
    kl_best_forward = [None] * num_p_second_means
    q_best_reverse = [None] * num_p_second_means
    kl_best_reverse = [None] * num_p_second_means

    for p_second_mean_idx, p_second_mean in enumerate(p_second_mean_list):
        p_mixture_probs = np.array([0.50.5])
        p_means = np.array([0, p_second_mean])
        p_stds = np.array([11])
        p[p_second_mean_idx] = GaussianMixture1D(p_mixture_probs, p_means, p_stds)

        q_means_min = np.min(p_means) - 1
        q_means_max = np.max(p_means) + 1
        num_q_means = 20
        q_means = np.linspace(q_means_min, q_means_max, num_q_means)

        q_stds_min = 0.1
        q_stds_max = 5
        num_q_stds = 20
        q_stds = np.linspace(q_stds_min, q_stds_max, num_q_stds)

        trapz_xs_min = np.min(np.append(p_means, q_means_min)) - 3 * np.max(np.append(p_stds, q_stds_max))
        trapz_xs_max = np.max(np.append(p_means, q_means_min)) + 3 * np.max(np.append(p_stds, q_stds_max))
        num_trapz_points = 1000
        trapz_xs = np.linspace(trapz_xs_min, trapz_xs_max, num_trapz_points)

        q_best_forward[p_second_mean_idx], kl_best_forward[p_second_mean_idx] = minimize_pq(
            p[p_second_mean_idx], trapz_xs, q_means, q_stds
        )
        q_best_reverse[p_second_mean_idx], kl_best_reverse[p_second_mean_idx] = minimize_qp(
            p[p_second_mean_idx], trapz_xs, q_means, q_stds
        )

    # plotting
    fig, axs = plt.subplots(nrows=1, ncols=num_p_second_means, sharex=True, sharey=True)
    fig.set_size_inches(81.5)
    for p_second_mean_idx, p_second_mean in enumerate(p_second_mean_list):
        xs_min = -5
        xs_max = 15
        num_plot_points = 1000
        xs = np.linspace(xs_min, xs_max, num_plot_points)
        axs[p_second_mean_idx].plot(xs, p[p_second_mean_idx].pdf(xs), label='$p$', color='black')
        axs[p_second_mean_idx].plot(xs, q_best_forward[p_second_mean_idx].pdf(xs), label='$\mathrm{argmin}_q \,\mathrm{KL}(p || q)$', color='black', linestyle='dashed')
        axs[p_second_mean_idx].plot(xs, q_best_reverse[p_second_mean_idx].pdf(xs), label='$\mathrm{argmin}_q \,\mathrm{KL}(q || p)$', color='black', linestyle='dotted')

        axs[p_second_mean_idx].spines['right'].set_visible(False)
        axs[p_second_mean_idx].spines['top'].set_visible(False)
        axs[p_second_mean_idx].set_yticks([])
        axs[p_second_mean_idx].set_xticks([])

    axs[2].legend(ncol=3, loc='upper center', bbox_to_anchor=(0.50), fontsize='small')
    filenames = ['reverse_forward_kl.pdf''reverse_forward_kl.png']
    for filename in filenames:
        fig.savefig(filename, bbox_inches='tight', dpi=200)
        print('Saved to {}'.format(filename))


if __name__ == '__main__':
    main()

廿

2021/10/27  阅读:17  主题:默认主题

作者介绍

廿

香港城市大学-数据科学博士在读