首页 > 其他 > 详细

GMM_GMVAE学习

时间:2021-01-22 22:38:50      阅读:143      评论:0      收藏:0      [点我收藏+]

1.GMVAE具体是怎么实现的呢?

多元高斯在代码中是怎么体现出来的?

技术分享图片

定义模型,并给出各个维度的参数,给出了有多少个中心点,即有多少类。 

技术分享图片

定义了ZINB分布的参数,这里的k类dropout和为1? 所以对每一类都有一个单独的ZINB分布了?

 技术分享图片

初始化模型中的GMM参数,用的sklearn中的?迷惑:

from sklearn.mixture import GaussianMixture

 https://scikit-learn.org/stable/modules/generated/sklearn.mixture.GaussianMixture.html 这个搜了很多,反正就是讲解的不详细,

fit函数:使用EM算法估计模型参数,https://www.cnblogs.com/dahu-daqing/p/9456137.html 这个讲的还蛮好的,讲了它的计算过程,非常好。

技术分享图片

 EM算法迭代,先计算出每个样本的得分?score不太明白,predict就是计算出每个样本属于每个子高斯模型的概率:

技术分享图片

 比如上面这个4个子高斯模型,那么是233,

总之经过fit就能够得出每个样本属于哪一个子高斯分布的概率。

技术分享图片

这里的n_features应该是指z的维度吧,比如说z有10维,分3个类,那么返回的means_,size就是[3,10],每一类中对应每一个维度都有一个均值和方差,方差如果是diag那么size也是一样的。那么现在拟合这个函数了。

例子:

import torch
from sklearn.mixture import GaussianMixture
a=torch.randn(100,3)
gmm = GaussianMixture(n_components=3, covariance_type=‘diag‘)
gmm.fit(a)
print(gmm.means_)
print(gmm.covariances_)

#输出:
[[ 0.13355099 -0.87731539  0.11807917]
 [-0.02544988  0.36799893  0.93609104]
 [-0.1310032   0.52882658 -0.38371371]]
[[0.96960938 0.4109546  0.70868165]
 [0.74410499 0.72071033 0.31313133]
 [1.2172497  0.43499053 0.39583606]]

 

可以发现,最终输出的是3个子高斯分布的均值和方差,6了。

之后调用model.fit函数里直接计算损失?其实我不太明白这个fit函数有啥用?为什么不forward呢?emm应该是因为如果直接forward的话那就还要外围再写训练的for循环吧。

技术分享图片

损失函数中会直接计算, 出来损失,

技术分享图片

下面看一下encoder是如何建立的?

技术分享图片

可以发现,encoder的定义也是非常简单,但是我感觉build_mlp这个第一个参数有点问题啊,我试验了一下,h_dim只是一个数怎么能和[]直接相加呢? 

def fun(lay):
    for l in lay:
        print(l)
dim=[1,2]
[x,y]=dim
fun([x]+y)

#输出:
    fun([x]+y)
TypeError: can only concatenate list (not "int") to list

反正我大体知道代码表达的是什么意思就ok了。

下面是GaussianSample类,可以发现,它的in特征数目是隐层数目,然后对μ和方差都构建了一个线性输出层,在forward时,分别进行重采样。

技术分享图片

这里和普通的vae没什么差别啊。 

上面decoder的部分就比较单纯了, 

但是后来它是怎么把GaussianMixture 融入进去的???我迷惑了。

上面的损失函数错了,模型后来又重写了损失函数:

损失函数:

技术分享图片

 

我就很不明白,这是怎么算的呢? 

https://arxiv.org/pdf/1611.05148.pdf 这篇论文有讲啊! 

技术分享图片

 

 反正你让我推,我是不可能推出来的啦。

 

GMM_GMVAE学习

原文:https://www.cnblogs.com/BlueBlueSea/p/14315650.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!