Latent Variable Generative Models


Main 3 Bayesian Networks (BNs) for an observation \(x\rightarrow p(x)\)

Here will discuss how the model is linked to clustering.

To sample from these networks, for example, GMM, sample a random value from uniform distribution \(\mathbb{U}(0,1)\). According to the range, which cluster \(m\) is picked to generate an observation from a Gaussian \(\mathcal{N}(\mu_m,\Sigma_m)\). Repeat this process 10000 times, you can have multiple clusters in the graph. In the way of reverse engineering, initialize multiple gaussians, collect any datapoints relevant to the gaussian, then update the new parameters. This is basically how the EM is working.

Consider the gaussian mixture model, to fit the generative model requires to maximize the following marginal likelihood,

\begin{align*} p(x)&=\sum^M_{m=1}P(c_m)p(x|c_m)=\sum^M_{m=1}P(c_m)\mathcal{N}(x;\mu_m,\Sigma_m) \\ \hat{\theta}&=\arg\max_\theta \mathcal{L}(\theta)=\arg\max_\theta\sum^N_{i=1}\log p(x_i|\theta) \end{align*}

However, there is no closed-form solution to find the parameters which maximize the likelihood. One of the common alternatives is to iteratively update the parameters to maximize the likelihood,

The update must satisfy the following requirements, i.e. increase log likelihood iteratively,

\begin{align*} &\mathcal{L}(\theta^{(k+1)})-\mathcal{L}(\theta^{(k)})=\sum^N_{i=1}\big(\log p(x_i|\theta^{(k+1)})-\log p(x_i|\theta^{(k)})\big)=\sum^N_{i=1}\log\frac{p(x_i|\theta^{(k+1)})}{p(x_i|\theta^{(k)})}\geq 0 \\ &=\sum^N_{i=1}\log\frac{1}{p(x_i|\theta^{(k)})}\sum^M_{m=1}p(x_i,c_m|\theta^{(k+1)}) = \sum^N_{i=1}\log\frac{1}{p(x_i|\theta^{(k)})}\sum^M_{m=1}\frac{P(c_m|x_i,\theta^{(k)})p(x_i,c_m|\theta^{(k+1)})}{P(c_m|x_i,\theta^{(k)})}\\ &=\sum^N_{i=1}\log\sum^M_{m=1}P(c_m|x_i,\theta^{(k)})\frac{p(x_i,c_m|\theta^{(k+1)})}{p(x_i,c_m|\theta^{(k)})}\geq\sum^N_{i=1}\sum^M_{m=1}P(c_m|x_i,\theta^{(k)})\log\frac{p(x_i,c_m|\theta^{(k+1)})}{p(x_i,c_m|\theta^{(k)})} \\ &= \mathcal{Q}(\theta^{(k)},\theta^{(k+1)})-\mathcal{Q}(\theta^{(k)},\theta^{(k)}) \end{align*}

Therefore, to optimize the parameters can be done by auxiliary function,

\begin{align*} \mathcal{L}(\theta^{(k+1)})-\mathcal{L}(\theta^{(k)})\geq \mathcal{Q}(\theta^{(k)},\theta^{(k+1)})-\mathcal{Q}(\theta^{(k)},\theta^{(k)})\geq 0 \end{align*}

But it is difficult to estimate the component prior, rederive the log-likelihood in the form of KL divergence and Free Energy,

\begin{align*} &\mathcal{L}(\theta)=\log p(x|\theta)=\int_{dz}p(z|x;\theta)\log p(x|\theta)=\int_{dz}p(z|x;\theta)\log\frac{p(x,z|\theta)}{p(z|x;\theta)}=\bigg\langle\log\frac{p(x,z|\theta)}{p(z|x;\theta)}\bigg\rangle_{p(z|x;\theta)}\approx\bigg\langle\log\frac{p(x,z|\theta)}{p(z|x;\theta)}\bigg\rangle_{q(z|\hat{\theta})}=\bigg\langle\log\frac{p(x,z|\theta)}{q(z|\hat{\theta})}\bigg\rangle_{q(z|\hat{\theta})}+\bigg\langle\log\frac{q(z|\hat{\theta})}{p(z|x;\theta)}\bigg\rangle_{q(z|\hat{\theta})}=\mathcal{F}(q(z|\hat{\theta}),\theta)+\mathcal{KL}(q(z|\hat{\theta})\|p(z|x,\theta))\\ \end{align*}

Expectation Maximization Algorithm (EM algorithm)

Clustering probability

\begin{align*} &p(x_n|z_n=k,\theta)=\mathcal{N}(x_n|\mu_k,\sigma^2_k) \\ &p(z_n=k|\theta)=\pi_k \\ &p(x_n,z_n=k|\theta)=\pi_k\mathcal{N}(x_n|\mu_k,\sigma^2_k)\\ \end{align*}

E step: evaluate the posterior probability of observation generated by latent state,

\begin{align*} &q'(z|\hat{\theta})=\arg\max_q\mathcal{F}(q(z|\hat{\theta}),\theta)\\ &q(z_n=k|\hat{\theta})=p(z_n=k|x_n,\theta)=\frac{p(x_n,z_n=k|\theta)}{p(x_n|\theta)}=\frac{p(x_n,z_n=k|\theta)}{\sum_kp(x_n,z_n=k|\theta)}=\frac{u_{nk}}{\sum_ku_{nk}}=\frac{u_{nk}}{u_k} = r_{nk} \end{align*}

M step: obtain the maximum likelihood estimate of the model parameters using the complete dataset from e step.

\begin{align*} &\theta'=\arg\max_\theta\sum_zq(z|\hat{\theta})\log p(x,z|\theta)\\ &\log p(x|\theta)=\log\prod^N_{n=1}p(x_n|\theta)=\sum^N_{n=1}\log p(x_n|\theta)=\sum^N_{n=1}\log\sum^K_{k=1} p(x_n,z_n=k|\theta)\\ &=\sum^N_{n=1}\log\sum^K_{k=1}q(z_n=k)\frac{p(x_n,z_n=k|\theta)}{q(z_n=k)}\geq \sum^N_{n=1}\sum^K_{k=1}q(z_n=k)\log\frac{p(x_n,z_n=k|\theta)}{q(z_n=k)}\\ &=\sum^N_{n=1}\sum^K_{k=1}q(z_n=k)\log p(x_n,z_n=k|\theta)-\sum^N_{n=1}\sum^K_{k=1}q(z_n=k)\log q(z_n=k) \\ &=\sum^N_{n=1}\sum^K_{k=1}q(z_n=k)\log\frac{\pi_k}{\sqrt{2\pi\sigma^2_k}}e^{-\frac{1}{2\sigma^2_k}(x_n-u_k)^2}+C\\ &\mathcal{F}(q(z),\theta)=\sum^N_{n=1}\sum^K_{k=1}q(z_n=k)\bigg(\log\pi_k-\frac{1}{2}\log\sigma_k^2-\frac{1}{2\sigma^2_k}(x_n-u_k)^2\bigg)+C\\ &\frac{\partial\mathcal{F}}{\partial\mu_j}=\\ &\mu_j = \frac{\sum^N_{n=1}q(z_n=j)x_n}{\sum^N_{n=1}q(z_n=j)}\\ &\frac{\partial\mathcal{F}}{\partial\sigma^2_j}=\\ &\sigma^2_j = \frac{\sum^N_{n=1}q(z_n=j)(x_n-\mu_j)^2}{\sum^N_{n=1}q(z_n=j)}\\ &\frac{\partial}{\partial\pi_j}\big(\mathcal{F}+\lambda(1-\sum_k\pi_k)\big)=\\ &\pi_j =\frac{1}{N} \sum^N_{n=1}q(z_n=j)\\ \end{align*}

Other EM related algorithm: Baum-Welch Estimation for HMM.

References