EM as a special case of variational inference

Expectation maximization can be seen as a special case of variational inference when the approximating distribution for the parameters $q(\theta)$ is taken to be a point mass.

We first briefly review EM and variational inference, and then we show the connection between them.

Expectation maximization

EM consists of two steps: computing the expected log-likelihood with respect to the latent variables (E step), and maximizing this expected log-likelihood with respect to the parameters (M step).

In the expectation step of EM, we compute the expected log-likelihood of the data over the latent variables. Specifically, we estimate

$\mathbb{E}_{p(z | x, \theta_t)}[\log p(X | \theta)] = \int \log p(X | \theta) p(z | x, \theta_t) dz.$

In the M step, we maximize this quantity w.r.t. $\theta$ (treating the current estimate $\theta_t$ as given):

$\text{arg} \max_{\theta} \mathbb{E}_{p(z | x, \theta_t)}[\log p(X | \theta)].$

Thus, in EM, we essentially integrate out the latent variables $z$ in order to get a point estimate of the parameters $\theta$.

Variational inference

In variational inference, both the latent variables $z$ and the parameters $\theta$ are treated as random variables. We approximate the posterior $p(z, \theta | x)$ with another, simpler family of distributions $q(z, \theta)$. Furthermore, in mean-field variational inference, we assume that this distribution factorizes as $q(z, \theta) = q_z(z) q_\theta(\theta)$.

To find the member of the families $q_z$ and $q_\theta$ that best approximate $p(z, \theta | x)$, we minimize the divergence (typically KL-divergence) between $p(z, \theta | x)$ and $q(z) q(\theta)$:

$\min D_{KL}(q || p) = \min \mathbb{E}_{q_z(z) q_\theta(\theta)}\left[\log \frac{q_z(z) q_\theta(\theta)}{p(z, \theta | x)}\right].$

Expanding this expression, we can obtain an expression for the log marginal probability of the data $p(x)$:

\begin{align} D_{KL}(q || p) &= \mathbb{E}_{q_z(z) q_\theta(\theta)}\left[\log \frac{q_z(z) q_\theta(\theta)}{p(z, \theta | x)}\right] \\ &= \mathbb{E}_{q_z(z) q_\theta(\theta)}\left[\log \frac{q_z(z) q_\theta(\theta) p(x)}{p(z, \theta, x)}\right] \\ &= \mathbb{E}_{q_z(z) q_\theta(\theta)}\left[\log \frac{q_z(z) q_\theta(\theta)}{p(z, \theta, x)} + \log p(x) \right] \\ &= \mathbb{E}_{q_z(z) q_\theta(\theta)}\left[\log \frac{q_z(z) q_\theta(\theta)}{p(z, \theta, x)}\right] + \mathbb{E}_{q_z(z) q_\theta(\theta)}[\log p(x)] \\ &= \mathbb{E}_{q_z(z) q_\theta(\theta)}\left[\log \frac{q_z(z) q_\theta(\theta)}{p(z, \theta, x)}\right] + \int \int \log p(x) q_z(z) q_\theta(\theta) dz d\theta \\ &= \mathbb{E}_{q_z(z) q_\theta(\theta)}\left[\log \frac{q_z(z) q_\theta(\theta)}{p(z, \theta, x)}\right] + \log p(x) \underbrace{\int \int q_z(z) q_\theta(\theta) dz d\theta}_{1} \\ &= \mathbb{E}_{q_z(z) q_\theta(\theta)}\left[\log \frac{q_z(z) q_\theta(\theta)}{p(z, \theta, x)}\right] + \log p(x) \\ \implies \log p(x) &= D_{KL}(q || p) - \mathbb{E}_{q_z(z) q_\theta(\theta)}\left[\log \frac{q_z(z) q_\theta(\theta)}{p(z, \theta, x)}\right] \\ \end{align}

Now, since the KL-divergence is always non-negative, we have the following lower bound on $p(x)$:

$\log p(x) \geq - \mathbb{E}_{q_z(z) q_\theta(\theta)}\left[\log \frac{q_z(z) q_\theta(\theta)}{p(z, \theta, x)}\right].$

This is known as the evidence lower bound (ELBO). By maximizing the ELBO w.r.t. the parameters of $q$, we improve a lower bound on $p(x)$.

Connection between EM and VI

Expectation maximization can be seen as a special case of variational inference if we take $q(\theta)$ to be a point mass (recall that EM provides a point estimate of the parameters).

In particular, let $\phi$ be the (unknown) parameter of the point mass $q(\theta)$, and let $q(\theta)$ be a delta function centered at $\hat{\theta}$:

$q(\theta) = \delta(\theta - \phi).$

Recall that the $\delta$ function is defined as

$\delta(x) = \begin{cases} \infty \quad& \text{if } x = 0 \\\ 0 & \text{if } x \neq 0 \end{cases}$

Then a variational approach to inference will maximize the ELBO:

$- \mathbb{E}_{q(\theta) q(z)}\left[\log \frac{q(z) q(\theta)}{p(z, \theta, x)}\right].$

Notice that the expectation with respect to $q(\theta)$ will simply be the density at $\phi$, $q(\phi)$, so this simplifies to

$-\mathbb{E}_{q(z)}\left[\log \frac{q(z) q(\phi)}{p(z, \phi, x)}\right].$

Now, if we maximize this expression w.r.t. $q(z)$, it will be maximized when $q(z) = p(z | x, \phi)$. Note that we must assume that $q(z)$ is unconstrained in this case for this to work. Now, we have

\begin{align} \max_{q(z)} -\mathbb{E}_{q(z)}\left[\log \frac{q(z) q(\phi)}{p(z, \phi, x)}\right] &= -\mathbb{E}_{p(z | x, \phi)}\left[\log \frac{p(z | x, \phi) q(\phi)}{p(z, \phi, x)}\right] \\ &= -\mathbb{E}_{p(z | x, \phi)}\left[-\log p(x | z, \phi) \right] \\ &= \mathbb{E}_{p(z | x, \phi)}\left[\log p(x | z, \phi) \right]. \\ \end{align}

Now, for a given current estimate of $\theta$ (call it $\theta_t$), we can take the expected log-likelihood using this current value:

$\mathbb{E}_{p(z | x, \theta_t)}\left[\log p(x \| z, \phi) \right].$

This is exactly the E step of EM: the last line is the expected log-likelihood of the data, averaged over the latent variables $z$.

Now, if we maximize w.r.t. $\phi$ to get a point estimate of $\theta$, this will be equivalent to the M step:

$\max_{\phi} \mathbb{E}_{p(z | x, \theta_t)}\left[\log p(x | z, \phi) \right].$