gpsa.models.vgpsa

  1import torch
  2import numpy as np
  3import torch.nn as nn
  4from sklearn.cluster import KMeans
  5from gpsa import GPSA
  6from ..util.util import rbf_kernel
  7from collections.abc import Iterable
  8
  9torch.autograd.set_detect_anomaly(True)
 10
 11device = "cuda" if torch.cuda.is_available() else "cpu"
 12
 13
 14class VariationalGPSA(GPSA):
 15    def __init__(
 16        self,
 17        data_dict,
 18        m_X_per_view,
 19        m_G,
 20        data_init=True,
 21        minmax_init=False,
 22        grid_init=False,
 23        n_spatial_dims=2,
 24        n_noise_variance_params=2,
 25        kernel_func_warp=rbf_kernel,
 26        kernel_func_data=rbf_kernel,
 27        n_latent_gps=None,
 28        mean_function="identity_fixed",
 29        mean_penalty_param=0.0,
 30        fixed_warp_kernel_variances=None,
 31        fixed_warp_kernel_lengthscales=None,
 32        fixed_data_kernel_lengthscales=None,
 33        fixed_view_idx=None,
 34    ):
 35        super(VariationalGPSA, self).__init__(
 36            data_dict,
 37            data_init=True,
 38            n_spatial_dims=2,
 39            n_noise_variance_params=2,
 40            kernel_func_warp=kernel_func_warp,
 41            kernel_func_data=kernel_func_data,
 42            mean_penalty_param=mean_penalty_param,
 43            fixed_warp_kernel_variances=fixed_warp_kernel_variances,
 44            fixed_warp_kernel_lengthscales=fixed_warp_kernel_lengthscales,
 45            fixed_data_kernel_lengthscales=fixed_data_kernel_lengthscales,
 46        )
 47
 48        self.m_X_per_view = m_X_per_view
 49        self.m_G = m_G
 50        self.n_latent_gps = n_latent_gps
 51        self.n_latent_outputs = {}
 52        for mod in self.modality_names:
 53            curr_n_latent_outputs = (
 54                self.n_latent_gps[mod]
 55                if self.n_latent_gps[mod] is not None
 56                else self.Ps[mod]
 57            )
 58            self.n_latent_outputs[mod] = curr_n_latent_outputs
 59        self.fixed_view_idx = fixed_view_idx
 60
 61        if data_init:
 62            # Initialize inducing locations with a subset of the data
 63            Xtilde = torch.zeros([self.n_views, self.m_X_per_view, self.n_spatial_dims])
 64            for ii in range(self.n_views):
 65                curr_X_spatial_list = []
 66                for mod in self.modality_names:
 67                    curr_idx = self.view_idx[mod][ii]
 68                    curr_modality_and_view_spatial = data_dict[mod]["spatial_coords"][
 69                        curr_idx, :
 70                    ]
 71                    curr_X_spatial_list.append(curr_modality_and_view_spatial)
 72                curr_X_spatial = torch.cat(curr_X_spatial_list, dim=0)
 73
 74                kmeans = KMeans(n_clusters=self.m_X_per_view)
 75                kmeans.fit(curr_X_spatial.detach().cpu().numpy())
 76                Xtilde[ii, :, :] = torch.tensor(kmeans.cluster_centers_)
 77
 78            self.Xtilde = nn.Parameter(Xtilde.clone())
 79            # self.Xtilde = Xtilde.clone()
 80
 81            rand_idx = np.random.choice(
 82                np.arange(curr_X_spatial.shape[0]),
 83                size=self.m_G,
 84                replace=False,
 85            )
 86
 87            all_X_spatial = torch.cat(
 88                [data_dict[mod]["spatial_coords"] for mod in self.modality_names]
 89            )
 90            kmeans = KMeans(n_clusters=self.m_G)
 91            kmeans.fit(all_X_spatial.detach().cpu().numpy())
 92            self.Gtilde = nn.Parameter(torch.tensor(kmeans.cluster_centers_))
 93
 94        elif grid_init:
 95
 96            if self.n_spatial_dims == 2:
 97                xlow, ylow = (
 98                    data_dict[self.modality_names[0]]["spatial_coords"].numpy().min(0)
 99                )
100                xhigh, yhigh = (
101                    data_dict[self.modality_names[0]]["spatial_coords"].numpy().max(0)
102                )
103                xlimits = [xlow, xhigh]
104                ylimits = [ylow, yhigh]
105                numticks = np.ceil(np.sqrt(self.m_G)).astype(int)
106                self.m_G = numticks**2
107                self.m_X_per_view = numticks**2
108                x1s = np.linspace(*xlimits, num=numticks)
109                x2s = np.linspace(*ylimits, num=numticks)
110                X1, X2 = np.meshgrid(x1s, x2s)
111                Xtilde = np.vstack([X1.ravel(), X2.ravel()]).T
112                Xtilde_torch = torch.zeros(
113                    [self.n_views, Xtilde.shape[0], self.n_spatial_dims]
114                )
115                for vv in range(self.n_views):
116                    Xtilde_torch[vv] = torch.tensor(Xtilde)
117
118                # self.Xtilde = Xtilde_torch.clone()
119                # self.Gtilde = torch.tensor(Xtilde).float()
120                self.Xtilde = nn.Parameter(Xtilde_torch.clone())
121                self.Gtilde = nn.Parameter(torch.tensor(Xtilde).float())
122
123        else:
124            # Random initialization of inducing locations
125            self.Xtilde = nn.Parameter(
126                torch.randn([self.n_views, self.m_X_per_view, self.n_spatial_dims])
127            )
128            self.Gtilde = nn.Parameter(torch.randn([self.m_G, self.n_spatial_dims]))
129
130        ## Variational covariance parameters
131        Omega_sqt_G_list = torch.zeros(
132            [self.n_views * self.n_spatial_dims, self.m_X_per_view, self.m_X_per_view],
133            device=device,
134        )
135        for ii in range(self.n_views):
136            for jj in range(self.n_spatial_dims):
137                Omega_sqt = 0.1 * torch.randn(
138                    size=[self.m_X_per_view, self.m_X_per_view]
139                )
140                # import ipdb; ipdb.set_trace()
141                # Omega_sqt_G_list[ii * self.n_views + jj, :, :] = Omega_sqt
142                Omega_sqt_G_list[jj * self.n_views + ii, :, :] = Omega_sqt
143        self.Omega_sqt_G_list = nn.Parameter(Omega_sqt_G_list)
144
145        Omega_sqt_F_dict = torch.nn.ParameterDict()
146        for mod in self.modality_names:
147            num_outputs = self.Ps[mod]
148            curr_Omega = torch.zeros([self.n_latent_outputs[mod], self.m_G, self.m_G])
149            for jj in range(self.n_latent_outputs[mod]):
150                Omega_sqt = 0.1 * torch.randn(size=[self.m_G, self.m_G])
151                curr_Omega[jj, :, :] = Omega_sqt
152            Omega_sqt_F_dict[mod] = nn.Parameter(curr_Omega)
153        self.Omega_sqt_F_dict = Omega_sqt_F_dict
154
155        ## Variational mean parameters
156        self.delta_G_list = nn.Parameter(self.Xtilde.clone())
157        delta_F_dict = torch.nn.ParameterDict()
158        for mod in self.modality_names:
159            num_outputs = self.Ps[mod]
160            curr_delta = nn.Parameter(
161                torch.randn(size=[self.m_G, self.n_latent_outputs[mod]], device=device)
162            )
163            delta_F_dict[mod] = curr_delta
164        self.delta_F_dict = delta_F_dict
165
166        ## LMC parameters
167        self.W_dict = torch.nn.ParameterDict()
168        for mod in self.modality_names:
169            if self.n_latent_gps[mod] is not None:
170                self.W_dict[mod] = nn.Parameter(
171                    torch.randn([self.n_latent_gps[mod], self.Ps[mod]])
172                )
173
174    def compute_mean_and_var(
175        self, Kff_diag, Kuf, Kuu_chol, mu_x, mu_z, delta, Omega_tril
176    ):
177        alpha_x = torch.cholesky_solve(Kuf, Kuu_chol)
178
179        a_t_Kchol = torch.matmul(alpha_x.transpose(-1, -2), Kuu_chol)
180        aKa = torch.sum(torch.square(a_t_Kchol), dim=-1)
181
182        mu_tilde = mu_x.unsqueeze(0) + torch.matmul(
183            alpha_x.transpose(-1, -2), delta - mu_z
184        )
185
186        if len(alpha_x.shape) == 2:
187            a_t_Omega_tril = torch.matmul(
188                alpha_x.transpose(-1, -2).unsqueeze(0), Omega_tril
189            )
190            aOmega_a = torch.sum(torch.square(a_t_Omega_tril), dim=-1)
191            Sigma_tilde = Kff_diag - aKa + aOmega_a + self.diagonal_offset
192        else:
193            a_t_Omega_tril = torch.matmul(
194                alpha_x.transpose(-1, -2).unsqueeze(1), Omega_tril.unsqueeze(0)
195            )
196            aOmega_a = torch.sum(torch.square(a_t_Omega_tril), dim=-1)
197            Sigma_tilde = (
198                Kff_diag.unsqueeze(1)
199                - aKa.unsqueeze(1)
200                + aOmega_a
201                + self.diagonal_offset
202            )
203
204        return mu_tilde, Sigma_tilde + self.diagonal_offset
205
206    def get_Omega_from_Omega_sqt(self, Omega_sqt):
207        return torch.matmul(
208            Omega_sqt,
209            torch.transpose(Omega_sqt, -1, -2),
210        ) + self.diagonal_offset * torch.eye(Omega_sqt.shape[-1], device=device)
211
212    def forward(self, X_spatial, view_idx, Ns, S=1, prediction_mode=False, G_test=None):
213
214        if prediction_mode:
215            self.eval()
216
217        self.noise_variance_pos = torch.exp(self.noise_variance) + self.diagonal_offset
218
219        self.mu_z_G = (
220            torch.zeros(
221                [self.n_views, self.m_X_per_view, self.n_spatial_dims], device=device
222            )
223            * np.nan
224        )
225        for vv in range(self.n_views):
226            self.mu_z_G[vv] = (
227                torch.mm(self.Xtilde[vv], self.mean_slopes[vv])
228                + self.mean_intercepts[vv]
229            )
230            if self.fixed_view_idx is not None and (
231                vv in self.fixed_view_idx
232                if isinstance(self.fixed_view_idx, Iterable)
233                else self.fixed_view_idx == vv
234            ):
235                self.mu_z_G[vv] *= 100.0
236
237        self.Kuu_chol_list = (
238            torch.zeros(
239                [self.n_views, self.m_X_per_view, self.m_X_per_view], device=device
240            )
241            * np.nan
242        )
243        G_samples = {}
244        for mod in self.modality_names:
245            G_samples[mod] = (
246                torch.zeros([S, Ns[mod], self.n_spatial_dims], device=device) * np.nan
247            )
248
249        G_means = {}
250        for mod in self.modality_names:
251            G_means[mod] = (
252                torch.zeros([Ns[mod], self.n_spatial_dims], device=device) * np.nan
253            )
254
255        curr_Omega_G = self.get_Omega_from_Omega_sqt(self.Omega_sqt_G_list)
256
257        self.curr_Omega_tril_list = torch.cholesky(curr_Omega_G)
258
259        for vv in range(self.n_views):
260
261            ## If this view is fixed (template-based alignment), then we don't need to sample for it.
262            if self.fixed_view_idx is not None and (
263                vv in self.fixed_view_idx
264                if isinstance(self.fixed_view_idx, Iterable)
265                else self.fixed_view_idx == vv
266            ):
267                for mm, mod in enumerate(self.modality_names):
268                    observed_X_spatial = X_spatial[mod][view_idx[mod][vv]]
269                    G_means[mod][view_idx[mod][vv]] = observed_X_spatial
270
271                    G_samples[mod][:, view_idx[mod][vv], :] = observed_X_spatial
272
273                continue
274
275            kernel_G = lambda x1, x2, diag=False: self.kernel_func_warp(
276                x1,
277                x2,
278                lengthscale_unconstrained=self.warp_kernel_lengthscales[vv],
279                output_variance_unconstrained=self.warp_kernel_variances[vv],
280                diag=diag,
281            )
282
283            ## Collect data from all modalities for this view
284            curr_X_spatial_list = []
285            curr_n = 0
286            curr_mod_idx = []
287            for mod in self.modality_names:
288                curr_idx = view_idx[mod][vv]
289                curr_mod_idx.append(np.arange(curr_n, curr_n + len(curr_idx)))
290                curr_n += len(curr_idx)
291                curr_modality_and_view_spatial = X_spatial[mod][curr_idx, :]
292                curr_X_spatial_list.append(curr_modality_and_view_spatial)
293
294            curr_X_spatial = torch.cat(curr_X_spatial_list, dim=0)
295
296            if len(curr_X_spatial) == 0:
297                continue
298
299            curr_X_tilde = self.Xtilde[vv]
300
301            mu_x_G = (
302                torch.mm(curr_X_spatial, self.mean_slopes[vv])
303                + self.mean_intercepts[vv]
304            )
305
306            # Kff_diag = (
307            #     kernel_G(curr_X_spatial, curr_X_spatial, diag=True)
308            #     + self.diagonal_offset
309            # )
310            Kff_diag = torch.ones((curr_X_spatial.shape[0]), device=device) * torch.exp(
311                self.warp_kernel_variances[vv]
312            )
313
314            Kuu = kernel_G(
315                curr_X_tilde, curr_X_tilde
316            ) + self.diagonal_offset * torch.eye(self.m_X_per_view, device=device)
317
318            Kuf = kernel_G(curr_X_tilde, curr_X_spatial)
319
320            Kuu_chol = torch.cholesky(Kuu)
321            self.Kuu_chol_list[vv, :, :] = Kuu_chol
322
323            mu_tilde, Sigma_tilde = self.compute_mean_and_var(
324                Kff_diag,
325                Kuf,
326                Kuu_chol,
327                mu_x_G,
328                self.mu_z_G,
329                self.delta_G_list,
330                self.curr_Omega_tril_list,
331            )
332
333            # Sample
334            G_marginal_dist = torch.distributions.Normal(
335                mu_tilde[vv],
336                Sigma_tilde[
337                    vv * self.n_spatial_dims : vv * self.n_spatial_dims
338                    + self.n_spatial_dims
339                ].t(),
340            )
341
342            for mm, mod in enumerate(self.modality_names):
343                curr_idx = curr_mod_idx[mm]
344                G_means[mod][view_idx[mod][vv]] = mu_tilde[vv][curr_idx]
345
346            for ss in range(S):
347
348                curr_G_sample = G_marginal_dist.rsample()
349                for mm, mod in enumerate(self.modality_names):
350                    curr_idx = curr_mod_idx[mm]
351                    G_samples[mod][ss, view_idx[mod][vv]] = curr_G_sample[curr_idx]
352
353        self.curr_Omega_tril_F = {}
354        for mod in self.modality_names:
355            self.curr_Omega_tril_F[mod] = torch.zeros(
356                [self.n_latent_outputs[mod], self.m_G, self.m_G], device=device
357            )
358
359        F_samples = {}
360        self.F_latent_samples = {}
361        self.F_observed_samples = {}
362        for mod in self.modality_names:
363            F_samples[mod] = torch.zeros([S, Ns[mod], self.n_latent_outputs[mod]])
364            self.F_latent_samples[mod] = torch.zeros(
365                [S, Ns[mod], self.n_latent_outputs[mod]], device=device
366            )
367            self.F_observed_samples[mod] = torch.zeros([S, Ns[mod], self.Ps[mod]])
368
369        if G_test is not None:
370
371            self.F_latent_samples_test = {}
372            self.F_observed_samples_test = {}
373            for mod in self.modality_names:
374                n_test = G_test[mod].shape[1]
375                self.F_latent_samples_test[mod] = torch.zeros(
376                    [S, n_test, self.n_latent_outputs[mod]]
377                )
378                self.F_observed_samples_test[mod] = torch.zeros(
379                    [S, n_test, self.Ps[mod]]
380                )
381
382        kernel_F = lambda x1, x2, diag=False: self.kernel_func_data(
383            x1,
384            x2,
385            lengthscale_unconstrained=self.data_kernel_lengthscale,
386            output_variance_unconstrained=self.data_kernel_variance,
387            diag=diag,
388        )
389
390        Kuu = kernel_F(self.Gtilde, self.Gtilde) + self.diagonal_offset * torch.eye(
391            self.m_G, device=device
392        )
393
394        self.Kuu_chol_F = torch.cholesky(Kuu)
395
396        for mod in self.modality_names:
397
398            mu_x_F = torch.zeros([Ns[mod], self.n_latent_outputs[mod]], device=device)
399            mu_z_F = torch.zeros([self.m_G, self.n_latent_outputs[mod]], device=device)
400
401            # Kff_diag = (
402            #     kernel_F(G_samples[mod], G_samples[mod], diag=True)
403            #     + self.diagonal_offset
404            # )
405            Kff_diag = torch.ones(
406                (G_samples[mod].shape[:2]), device=device
407            ) * torch.exp(self.data_kernel_variance)
408
409            Kuf = kernel_F(self.Gtilde, G_samples[mod])
410            curr_Omega = self.get_Omega_from_Omega_sqt(self.Omega_sqt_F_dict[mod])
411
412            self.curr_Omega_tril_F[mod] = torch.cholesky(curr_Omega)
413            mu_tilde, Sigma_tilde = self.compute_mean_and_var(
414                Kff_diag,
415                Kuf,
416                self.Kuu_chol_F,
417                mu_x_F,
418                mu_z_F,
419                self.delta_F_dict[mod],
420                self.curr_Omega_tril_F[mod],
421            )
422
423            eps = torch.randn(mu_tilde.shape, device=device)
424            curr_F_latent_samples = (
425                mu_tilde + torch.sqrt(torch.transpose(Sigma_tilde, 1, 2)) * eps
426            )
427
428            if self.n_latent_gps[mod] is not None:
429                curr_W = self.W_dict[mod]
430                F_observed_mean = torch.matmul(curr_F_latent_samples, curr_W)
431            else:
432                F_observed_mean = curr_F_latent_samples
433
434            self.F_latent_samples[mod] = curr_F_latent_samples
435            self.F_observed_samples[mod] = F_observed_mean
436
437            ## For test samples
438            if G_test is not None:
439                # Kff_diag = (
440                #     kernel_F(G_samples[mod], G_samples[mod], diag=True)
441                #     + self.diagonal_offset
442                # )
443                Kff_diag = torch.ones(
444                    (G_test[mod].shape[:2]), device=device
445                ) * torch.exp(
446                    self.data_kernel_variance,
447                )
448
449                mu_x_F = torch.zeros(
450                    [G_test[mod].shape[1], self.n_latent_outputs[mod]], device=device
451                )
452
453                Kuf = kernel_F(self.Gtilde, G_test[mod])
454
455                mu_tilde, Sigma_tilde = self.compute_mean_and_var(
456                    Kff_diag,
457                    Kuf,
458                    self.Kuu_chol_F,
459                    mu_x_F,
460                    mu_z_F,
461                    self.delta_F_dict[mod],
462                    self.curr_Omega_tril_F[mod],
463                )
464
465                eps = torch.randn(mu_tilde.shape, device=device)
466                curr_F_latent_samples = (
467                    mu_tilde + torch.sqrt(torch.transpose(Sigma_tilde, 1, 2)) * eps
468                )
469
470                if self.n_latent_gps[mod] is not None:
471                    curr_W = self.W_dict[mod]
472                    F_observed_mean = torch.matmul(curr_F_latent_samples, curr_W)
473                else:
474                    F_observed_mean = curr_F_latent_samples
475
476                self.F_latent_samples_test[mod] = curr_F_latent_samples
477                self.F_observed_samples_test[mod] = F_observed_mean
478
479        if G_test is not None:
480            return (
481                G_means,
482                G_samples,
483                self.F_latent_samples,
484                self.F_observed_samples,
485                self.F_latent_samples_test,
486                self.F_observed_samples_test,
487            )
488        else:
489            return G_means, G_samples, self.F_latent_samples, self.F_observed_samples
490
491    def loss_fn(self, data_dict, F_samples):
492        # This computes the the negative (approximate) ELBO
493
494        # Running sum for KL terms
495        KL_div = 0
496
497        ## G
498        for vv in range(self.n_views):
499            if self.fixed_view_idx is not None and (
500                vv in self.fixed_view_idx
501                if isinstance(self.fixed_view_idx, Iterable)
502                else self.fixed_view_idx == vv
503            ):
504                continue
505            for jj in range(self.n_spatial_dims):
506                qu = torch.distributions.MultivariateNormal(
507                    loc=self.delta_G_list[vv, :, jj],
508                    scale_tril=self.curr_Omega_tril_list[jj * self.n_views + vv, :, :],
509                )
510                pu = torch.distributions.MultivariateNormal(
511                    loc=self.mu_z_G[vv, :, jj],
512                    scale_tril=self.Kuu_chol_list[vv, :, :],
513                )
514                curr_KL_div = torch.distributions.kl.kl_divergence(qu, pu)
515
516                KL_div += curr_KL_div
517
518        ## F
519        LL = 0
520        pu = torch.distributions.MultivariateNormal(
521            loc=torch.zeros(self.m_G, device=device), scale_tril=self.Kuu_chol_F
522        )
523        for mm, mod in enumerate(self.modality_names):
524            qu = torch.distributions.MultivariateNormal(
525                loc=self.delta_F_dict[mod].t(),
526                scale_tril=self.curr_Omega_tril_F[mod],
527            )
528
529            curr_KL_div = torch.distributions.kl.kl_divergence(qu, pu)
530            KL_div += curr_KL_div.sum()
531
532            Y_distribution = torch.distributions.Normal(
533                loc=F_samples[mod],
534                scale=self.noise_variance_pos[-self.n_modalities + mm],
535            )
536            S = F_samples[mod].shape[0]
537
538            LL += Y_distribution.log_prob(data_dict[mod]["outputs"]).sum() / S
539
540        return -LL + KL_div
541
542
543if __name__ == "__main__":
544    pass
class VariationalGPSA(gpsa.models.gpsa.GPSA):
 15class VariationalGPSA(GPSA):
 16    def __init__(
 17        self,
 18        data_dict,
 19        m_X_per_view,
 20        m_G,
 21        data_init=True,
 22        minmax_init=False,
 23        grid_init=False,
 24        n_spatial_dims=2,
 25        n_noise_variance_params=2,
 26        kernel_func_warp=rbf_kernel,
 27        kernel_func_data=rbf_kernel,
 28        n_latent_gps=None,
 29        mean_function="identity_fixed",
 30        mean_penalty_param=0.0,
 31        fixed_warp_kernel_variances=None,
 32        fixed_warp_kernel_lengthscales=None,
 33        fixed_data_kernel_lengthscales=None,
 34        fixed_view_idx=None,
 35    ):
 36        super(VariationalGPSA, self).__init__(
 37            data_dict,
 38            data_init=True,
 39            n_spatial_dims=2,
 40            n_noise_variance_params=2,
 41            kernel_func_warp=kernel_func_warp,
 42            kernel_func_data=kernel_func_data,
 43            mean_penalty_param=mean_penalty_param,
 44            fixed_warp_kernel_variances=fixed_warp_kernel_variances,
 45            fixed_warp_kernel_lengthscales=fixed_warp_kernel_lengthscales,
 46            fixed_data_kernel_lengthscales=fixed_data_kernel_lengthscales,
 47        )
 48
 49        self.m_X_per_view = m_X_per_view
 50        self.m_G = m_G
 51        self.n_latent_gps = n_latent_gps
 52        self.n_latent_outputs = {}
 53        for mod in self.modality_names:
 54            curr_n_latent_outputs = (
 55                self.n_latent_gps[mod]
 56                if self.n_latent_gps[mod] is not None
 57                else self.Ps[mod]
 58            )
 59            self.n_latent_outputs[mod] = curr_n_latent_outputs
 60        self.fixed_view_idx = fixed_view_idx
 61
 62        if data_init:
 63            # Initialize inducing locations with a subset of the data
 64            Xtilde = torch.zeros([self.n_views, self.m_X_per_view, self.n_spatial_dims])
 65            for ii in range(self.n_views):
 66                curr_X_spatial_list = []
 67                for mod in self.modality_names:
 68                    curr_idx = self.view_idx[mod][ii]
 69                    curr_modality_and_view_spatial = data_dict[mod]["spatial_coords"][
 70                        curr_idx, :
 71                    ]
 72                    curr_X_spatial_list.append(curr_modality_and_view_spatial)
 73                curr_X_spatial = torch.cat(curr_X_spatial_list, dim=0)
 74
 75                kmeans = KMeans(n_clusters=self.m_X_per_view)
 76                kmeans.fit(curr_X_spatial.detach().cpu().numpy())
 77                Xtilde[ii, :, :] = torch.tensor(kmeans.cluster_centers_)
 78
 79            self.Xtilde = nn.Parameter(Xtilde.clone())
 80            # self.Xtilde = Xtilde.clone()
 81
 82            rand_idx = np.random.choice(
 83                np.arange(curr_X_spatial.shape[0]),
 84                size=self.m_G,
 85                replace=False,
 86            )
 87
 88            all_X_spatial = torch.cat(
 89                [data_dict[mod]["spatial_coords"] for mod in self.modality_names]
 90            )
 91            kmeans = KMeans(n_clusters=self.m_G)
 92            kmeans.fit(all_X_spatial.detach().cpu().numpy())
 93            self.Gtilde = nn.Parameter(torch.tensor(kmeans.cluster_centers_))
 94
 95        elif grid_init:
 96
 97            if self.n_spatial_dims == 2:
 98                xlow, ylow = (
 99                    data_dict[self.modality_names[0]]["spatial_coords"].numpy().min(0)
100                )
101                xhigh, yhigh = (
102                    data_dict[self.modality_names[0]]["spatial_coords"].numpy().max(0)
103                )
104                xlimits = [xlow, xhigh]
105                ylimits = [ylow, yhigh]
106                numticks = np.ceil(np.sqrt(self.m_G)).astype(int)
107                self.m_G = numticks**2
108                self.m_X_per_view = numticks**2
109                x1s = np.linspace(*xlimits, num=numticks)
110                x2s = np.linspace(*ylimits, num=numticks)
111                X1, X2 = np.meshgrid(x1s, x2s)
112                Xtilde = np.vstack([X1.ravel(), X2.ravel()]).T
113                Xtilde_torch = torch.zeros(
114                    [self.n_views, Xtilde.shape[0], self.n_spatial_dims]
115                )
116                for vv in range(self.n_views):
117                    Xtilde_torch[vv] = torch.tensor(Xtilde)
118
119                # self.Xtilde = Xtilde_torch.clone()
120                # self.Gtilde = torch.tensor(Xtilde).float()
121                self.Xtilde = nn.Parameter(Xtilde_torch.clone())
122                self.Gtilde = nn.Parameter(torch.tensor(Xtilde).float())
123
124        else:
125            # Random initialization of inducing locations
126            self.Xtilde = nn.Parameter(
127                torch.randn([self.n_views, self.m_X_per_view, self.n_spatial_dims])
128            )
129            self.Gtilde = nn.Parameter(torch.randn([self.m_G, self.n_spatial_dims]))
130
131        ## Variational covariance parameters
132        Omega_sqt_G_list = torch.zeros(
133            [self.n_views * self.n_spatial_dims, self.m_X_per_view, self.m_X_per_view],
134            device=device,
135        )
136        for ii in range(self.n_views):
137            for jj in range(self.n_spatial_dims):
138                Omega_sqt = 0.1 * torch.randn(
139                    size=[self.m_X_per_view, self.m_X_per_view]
140                )
141                # import ipdb; ipdb.set_trace()
142                # Omega_sqt_G_list[ii * self.n_views + jj, :, :] = Omega_sqt
143                Omega_sqt_G_list[jj * self.n_views + ii, :, :] = Omega_sqt
144        self.Omega_sqt_G_list = nn.Parameter(Omega_sqt_G_list)
145
146        Omega_sqt_F_dict = torch.nn.ParameterDict()
147        for mod in self.modality_names:
148            num_outputs = self.Ps[mod]
149            curr_Omega = torch.zeros([self.n_latent_outputs[mod], self.m_G, self.m_G])
150            for jj in range(self.n_latent_outputs[mod]):
151                Omega_sqt = 0.1 * torch.randn(size=[self.m_G, self.m_G])
152                curr_Omega[jj, :, :] = Omega_sqt
153            Omega_sqt_F_dict[mod] = nn.Parameter(curr_Omega)
154        self.Omega_sqt_F_dict = Omega_sqt_F_dict
155
156        ## Variational mean parameters
157        self.delta_G_list = nn.Parameter(self.Xtilde.clone())
158        delta_F_dict = torch.nn.ParameterDict()
159        for mod in self.modality_names:
160            num_outputs = self.Ps[mod]
161            curr_delta = nn.Parameter(
162                torch.randn(size=[self.m_G, self.n_latent_outputs[mod]], device=device)
163            )
164            delta_F_dict[mod] = curr_delta
165        self.delta_F_dict = delta_F_dict
166
167        ## LMC parameters
168        self.W_dict = torch.nn.ParameterDict()
169        for mod in self.modality_names:
170            if self.n_latent_gps[mod] is not None:
171                self.W_dict[mod] = nn.Parameter(
172                    torch.randn([self.n_latent_gps[mod], self.Ps[mod]])
173                )
174
175    def compute_mean_and_var(
176        self, Kff_diag, Kuf, Kuu_chol, mu_x, mu_z, delta, Omega_tril
177    ):
178        alpha_x = torch.cholesky_solve(Kuf, Kuu_chol)
179
180        a_t_Kchol = torch.matmul(alpha_x.transpose(-1, -2), Kuu_chol)
181        aKa = torch.sum(torch.square(a_t_Kchol), dim=-1)
182
183        mu_tilde = mu_x.unsqueeze(0) + torch.matmul(
184            alpha_x.transpose(-1, -2), delta - mu_z
185        )
186
187        if len(alpha_x.shape) == 2:
188            a_t_Omega_tril = torch.matmul(
189                alpha_x.transpose(-1, -2).unsqueeze(0), Omega_tril
190            )
191            aOmega_a = torch.sum(torch.square(a_t_Omega_tril), dim=-1)
192            Sigma_tilde = Kff_diag - aKa + aOmega_a + self.diagonal_offset
193        else:
194            a_t_Omega_tril = torch.matmul(
195                alpha_x.transpose(-1, -2).unsqueeze(1), Omega_tril.unsqueeze(0)
196            )
197            aOmega_a = torch.sum(torch.square(a_t_Omega_tril), dim=-1)
198            Sigma_tilde = (
199                Kff_diag.unsqueeze(1)
200                - aKa.unsqueeze(1)
201                + aOmega_a
202                + self.diagonal_offset
203            )
204
205        return mu_tilde, Sigma_tilde + self.diagonal_offset
206
207    def get_Omega_from_Omega_sqt(self, Omega_sqt):
208        return torch.matmul(
209            Omega_sqt,
210            torch.transpose(Omega_sqt, -1, -2),
211        ) + self.diagonal_offset * torch.eye(Omega_sqt.shape[-1], device=device)
212
213    def forward(self, X_spatial, view_idx, Ns, S=1, prediction_mode=False, G_test=None):
214
215        if prediction_mode:
216            self.eval()
217
218        self.noise_variance_pos = torch.exp(self.noise_variance) + self.diagonal_offset
219
220        self.mu_z_G = (
221            torch.zeros(
222                [self.n_views, self.m_X_per_view, self.n_spatial_dims], device=device
223            )
224            * np.nan
225        )
226        for vv in range(self.n_views):
227            self.mu_z_G[vv] = (
228                torch.mm(self.Xtilde[vv], self.mean_slopes[vv])
229                + self.mean_intercepts[vv]
230            )
231            if self.fixed_view_idx is not None and (
232                vv in self.fixed_view_idx
233                if isinstance(self.fixed_view_idx, Iterable)
234                else self.fixed_view_idx == vv
235            ):
236                self.mu_z_G[vv] *= 100.0
237
238        self.Kuu_chol_list = (
239            torch.zeros(
240                [self.n_views, self.m_X_per_view, self.m_X_per_view], device=device
241            )
242            * np.nan
243        )
244        G_samples = {}
245        for mod in self.modality_names:
246            G_samples[mod] = (
247                torch.zeros([S, Ns[mod], self.n_spatial_dims], device=device) * np.nan
248            )
249
250        G_means = {}
251        for mod in self.modality_names:
252            G_means[mod] = (
253                torch.zeros([Ns[mod], self.n_spatial_dims], device=device) * np.nan
254            )
255
256        curr_Omega_G = self.get_Omega_from_Omega_sqt(self.Omega_sqt_G_list)
257
258        self.curr_Omega_tril_list = torch.cholesky(curr_Omega_G)
259
260        for vv in range(self.n_views):
261
262            ## If this view is fixed (template-based alignment), then we don't need to sample for it.
263            if self.fixed_view_idx is not None and (
264                vv in self.fixed_view_idx
265                if isinstance(self.fixed_view_idx, Iterable)
266                else self.fixed_view_idx == vv
267            ):
268                for mm, mod in enumerate(self.modality_names):
269                    observed_X_spatial = X_spatial[mod][view_idx[mod][vv]]
270                    G_means[mod][view_idx[mod][vv]] = observed_X_spatial
271
272                    G_samples[mod][:, view_idx[mod][vv], :] = observed_X_spatial
273
274                continue
275
276            kernel_G = lambda x1, x2, diag=False: self.kernel_func_warp(
277                x1,
278                x2,
279                lengthscale_unconstrained=self.warp_kernel_lengthscales[vv],
280                output_variance_unconstrained=self.warp_kernel_variances[vv],
281                diag=diag,
282            )
283
284            ## Collect data from all modalities for this view
285            curr_X_spatial_list = []
286            curr_n = 0
287            curr_mod_idx = []
288            for mod in self.modality_names:
289                curr_idx = view_idx[mod][vv]
290                curr_mod_idx.append(np.arange(curr_n, curr_n + len(curr_idx)))
291                curr_n += len(curr_idx)
292                curr_modality_and_view_spatial = X_spatial[mod][curr_idx, :]
293                curr_X_spatial_list.append(curr_modality_and_view_spatial)
294
295            curr_X_spatial = torch.cat(curr_X_spatial_list, dim=0)
296
297            if len(curr_X_spatial) == 0:
298                continue
299
300            curr_X_tilde = self.Xtilde[vv]
301
302            mu_x_G = (
303                torch.mm(curr_X_spatial, self.mean_slopes[vv])
304                + self.mean_intercepts[vv]
305            )
306
307            # Kff_diag = (
308            #     kernel_G(curr_X_spatial, curr_X_spatial, diag=True)
309            #     + self.diagonal_offset
310            # )
311            Kff_diag = torch.ones((curr_X_spatial.shape[0]), device=device) * torch.exp(
312                self.warp_kernel_variances[vv]
313            )
314
315            Kuu = kernel_G(
316                curr_X_tilde, curr_X_tilde
317            ) + self.diagonal_offset * torch.eye(self.m_X_per_view, device=device)
318
319            Kuf = kernel_G(curr_X_tilde, curr_X_spatial)
320
321            Kuu_chol = torch.cholesky(Kuu)
322            self.Kuu_chol_list[vv, :, :] = Kuu_chol
323
324            mu_tilde, Sigma_tilde = self.compute_mean_and_var(
325                Kff_diag,
326                Kuf,
327                Kuu_chol,
328                mu_x_G,
329                self.mu_z_G,
330                self.delta_G_list,
331                self.curr_Omega_tril_list,
332            )
333
334            # Sample
335            G_marginal_dist = torch.distributions.Normal(
336                mu_tilde[vv],
337                Sigma_tilde[
338                    vv * self.n_spatial_dims : vv * self.n_spatial_dims
339                    + self.n_spatial_dims
340                ].t(),
341            )
342
343            for mm, mod in enumerate(self.modality_names):
344                curr_idx = curr_mod_idx[mm]
345                G_means[mod][view_idx[mod][vv]] = mu_tilde[vv][curr_idx]
346
347            for ss in range(S):
348
349                curr_G_sample = G_marginal_dist.rsample()
350                for mm, mod in enumerate(self.modality_names):
351                    curr_idx = curr_mod_idx[mm]
352                    G_samples[mod][ss, view_idx[mod][vv]] = curr_G_sample[curr_idx]
353
354        self.curr_Omega_tril_F = {}
355        for mod in self.modality_names:
356            self.curr_Omega_tril_F[mod] = torch.zeros(
357                [self.n_latent_outputs[mod], self.m_G, self.m_G], device=device
358            )
359
360        F_samples = {}
361        self.F_latent_samples = {}
362        self.F_observed_samples = {}
363        for mod in self.modality_names:
364            F_samples[mod] = torch.zeros([S, Ns[mod], self.n_latent_outputs[mod]])
365            self.F_latent_samples[mod] = torch.zeros(
366                [S, Ns[mod], self.n_latent_outputs[mod]], device=device
367            )
368            self.F_observed_samples[mod] = torch.zeros([S, Ns[mod], self.Ps[mod]])
369
370        if G_test is not None:
371
372            self.F_latent_samples_test = {}
373            self.F_observed_samples_test = {}
374            for mod in self.modality_names:
375                n_test = G_test[mod].shape[1]
376                self.F_latent_samples_test[mod] = torch.zeros(
377                    [S, n_test, self.n_latent_outputs[mod]]
378                )
379                self.F_observed_samples_test[mod] = torch.zeros(
380                    [S, n_test, self.Ps[mod]]
381                )
382
383        kernel_F = lambda x1, x2, diag=False: self.kernel_func_data(
384            x1,
385            x2,
386            lengthscale_unconstrained=self.data_kernel_lengthscale,
387            output_variance_unconstrained=self.data_kernel_variance,
388            diag=diag,
389        )
390
391        Kuu = kernel_F(self.Gtilde, self.Gtilde) + self.diagonal_offset * torch.eye(
392            self.m_G, device=device
393        )
394
395        self.Kuu_chol_F = torch.cholesky(Kuu)
396
397        for mod in self.modality_names:
398
399            mu_x_F = torch.zeros([Ns[mod], self.n_latent_outputs[mod]], device=device)
400            mu_z_F = torch.zeros([self.m_G, self.n_latent_outputs[mod]], device=device)
401
402            # Kff_diag = (
403            #     kernel_F(G_samples[mod], G_samples[mod], diag=True)
404            #     + self.diagonal_offset
405            # )
406            Kff_diag = torch.ones(
407                (G_samples[mod].shape[:2]), device=device
408            ) * torch.exp(self.data_kernel_variance)
409
410            Kuf = kernel_F(self.Gtilde, G_samples[mod])
411            curr_Omega = self.get_Omega_from_Omega_sqt(self.Omega_sqt_F_dict[mod])
412
413            self.curr_Omega_tril_F[mod] = torch.cholesky(curr_Omega)
414            mu_tilde, Sigma_tilde = self.compute_mean_and_var(
415                Kff_diag,
416                Kuf,
417                self.Kuu_chol_F,
418                mu_x_F,
419                mu_z_F,
420                self.delta_F_dict[mod],
421                self.curr_Omega_tril_F[mod],
422            )
423
424            eps = torch.randn(mu_tilde.shape, device=device)
425            curr_F_latent_samples = (
426                mu_tilde + torch.sqrt(torch.transpose(Sigma_tilde, 1, 2)) * eps
427            )
428
429            if self.n_latent_gps[mod] is not None:
430                curr_W = self.W_dict[mod]
431                F_observed_mean = torch.matmul(curr_F_latent_samples, curr_W)
432            else:
433                F_observed_mean = curr_F_latent_samples
434
435            self.F_latent_samples[mod] = curr_F_latent_samples
436            self.F_observed_samples[mod] = F_observed_mean
437
438            ## For test samples
439            if G_test is not None:
440                # Kff_diag = (
441                #     kernel_F(G_samples[mod], G_samples[mod], diag=True)
442                #     + self.diagonal_offset
443                # )
444                Kff_diag = torch.ones(
445                    (G_test[mod].shape[:2]), device=device
446                ) * torch.exp(
447                    self.data_kernel_variance,
448                )
449
450                mu_x_F = torch.zeros(
451                    [G_test[mod].shape[1], self.n_latent_outputs[mod]], device=device
452                )
453
454                Kuf = kernel_F(self.Gtilde, G_test[mod])
455
456                mu_tilde, Sigma_tilde = self.compute_mean_and_var(
457                    Kff_diag,
458                    Kuf,
459                    self.Kuu_chol_F,
460                    mu_x_F,
461                    mu_z_F,
462                    self.delta_F_dict[mod],
463                    self.curr_Omega_tril_F[mod],
464                )
465
466                eps = torch.randn(mu_tilde.shape, device=device)
467                curr_F_latent_samples = (
468                    mu_tilde + torch.sqrt(torch.transpose(Sigma_tilde, 1, 2)) * eps
469                )
470
471                if self.n_latent_gps[mod] is not None:
472                    curr_W = self.W_dict[mod]
473                    F_observed_mean = torch.matmul(curr_F_latent_samples, curr_W)
474                else:
475                    F_observed_mean = curr_F_latent_samples
476
477                self.F_latent_samples_test[mod] = curr_F_latent_samples
478                self.F_observed_samples_test[mod] = F_observed_mean
479
480        if G_test is not None:
481            return (
482                G_means,
483                G_samples,
484                self.F_latent_samples,
485                self.F_observed_samples,
486                self.F_latent_samples_test,
487                self.F_observed_samples_test,
488            )
489        else:
490            return G_means, G_samples, self.F_latent_samples, self.F_observed_samples
491
492    def loss_fn(self, data_dict, F_samples):
493        # This computes the the negative (approximate) ELBO
494
495        # Running sum for KL terms
496        KL_div = 0
497
498        ## G
499        for vv in range(self.n_views):
500            if self.fixed_view_idx is not None and (
501                vv in self.fixed_view_idx
502                if isinstance(self.fixed_view_idx, Iterable)
503                else self.fixed_view_idx == vv
504            ):
505                continue
506            for jj in range(self.n_spatial_dims):
507                qu = torch.distributions.MultivariateNormal(
508                    loc=self.delta_G_list[vv, :, jj],
509                    scale_tril=self.curr_Omega_tril_list[jj * self.n_views + vv, :, :],
510                )
511                pu = torch.distributions.MultivariateNormal(
512                    loc=self.mu_z_G[vv, :, jj],
513                    scale_tril=self.Kuu_chol_list[vv, :, :],
514                )
515                curr_KL_div = torch.distributions.kl.kl_divergence(qu, pu)
516
517                KL_div += curr_KL_div
518
519        ## F
520        LL = 0
521        pu = torch.distributions.MultivariateNormal(
522            loc=torch.zeros(self.m_G, device=device), scale_tril=self.Kuu_chol_F
523        )
524        for mm, mod in enumerate(self.modality_names):
525            qu = torch.distributions.MultivariateNormal(
526                loc=self.delta_F_dict[mod].t(),
527                scale_tril=self.curr_Omega_tril_F[mod],
528            )
529
530            curr_KL_div = torch.distributions.kl.kl_divergence(qu, pu)
531            KL_div += curr_KL_div.sum()
532
533            Y_distribution = torch.distributions.Normal(
534                loc=F_samples[mod],
535                scale=self.noise_variance_pos[-self.n_modalities + mm],
536            )
537            S = F_samples[mod].shape[0]
538
539            LL += Y_distribution.log_prob(data_dict[mod]["outputs"]).sum() / S
540
541        return -LL + KL_div

Args: data_dict (dict): Dictionary of data in the format {"modality": {"spatial_coords": X, "outputs": Y, "n_samples_list": n_samples_list}} data_init (bool, optional): Whether to initialize inducing locations with KMeans on data. n_spatial_dims (int, optional): Number of spatial dimensions (usually 2 or 3). n_noise_variance_params (int, optional): Number of noise variance parameters. kernel_func_warp (function, optional): Covariance function for warp GP. kernel_func_data (function, optional): Covariance function for output GP. mean_function (str, optional): Mean function for warp GP. One of ["identity_fixed", "identity_initialized", or None]. None results in a linear mean function. mean_penalty_param (float, optional): Description fixed_warp_kernel_variances (None, optional): Description fixed_warp_kernel_lengthscales (None, optional): Description fixed_data_kernel_lengthscales (None, optional): Description

VariationalGPSA( data_dict, m_X_per_view, m_G, data_init=True, minmax_init=False, grid_init=False, n_spatial_dims=2, n_noise_variance_params=2, kernel_func_warp=<function rbf_kernel>, kernel_func_data=<function rbf_kernel>, n_latent_gps=None, mean_function='identity_fixed', mean_penalty_param=0.0, fixed_warp_kernel_variances=None, fixed_warp_kernel_lengthscales=None, fixed_data_kernel_lengthscales=None, fixed_view_idx=None)
 16    def __init__(
 17        self,
 18        data_dict,
 19        m_X_per_view,
 20        m_G,
 21        data_init=True,
 22        minmax_init=False,
 23        grid_init=False,
 24        n_spatial_dims=2,
 25        n_noise_variance_params=2,
 26        kernel_func_warp=rbf_kernel,
 27        kernel_func_data=rbf_kernel,
 28        n_latent_gps=None,
 29        mean_function="identity_fixed",
 30        mean_penalty_param=0.0,
 31        fixed_warp_kernel_variances=None,
 32        fixed_warp_kernel_lengthscales=None,
 33        fixed_data_kernel_lengthscales=None,
 34        fixed_view_idx=None,
 35    ):
 36        super(VariationalGPSA, self).__init__(
 37            data_dict,
 38            data_init=True,
 39            n_spatial_dims=2,
 40            n_noise_variance_params=2,
 41            kernel_func_warp=kernel_func_warp,
 42            kernel_func_data=kernel_func_data,
 43            mean_penalty_param=mean_penalty_param,
 44            fixed_warp_kernel_variances=fixed_warp_kernel_variances,
 45            fixed_warp_kernel_lengthscales=fixed_warp_kernel_lengthscales,
 46            fixed_data_kernel_lengthscales=fixed_data_kernel_lengthscales,
 47        )
 48
 49        self.m_X_per_view = m_X_per_view
 50        self.m_G = m_G
 51        self.n_latent_gps = n_latent_gps
 52        self.n_latent_outputs = {}
 53        for mod in self.modality_names:
 54            curr_n_latent_outputs = (
 55                self.n_latent_gps[mod]
 56                if self.n_latent_gps[mod] is not None
 57                else self.Ps[mod]
 58            )
 59            self.n_latent_outputs[mod] = curr_n_latent_outputs
 60        self.fixed_view_idx = fixed_view_idx
 61
 62        if data_init:
 63            # Initialize inducing locations with a subset of the data
 64            Xtilde = torch.zeros([self.n_views, self.m_X_per_view, self.n_spatial_dims])
 65            for ii in range(self.n_views):
 66                curr_X_spatial_list = []
 67                for mod in self.modality_names:
 68                    curr_idx = self.view_idx[mod][ii]
 69                    curr_modality_and_view_spatial = data_dict[mod]["spatial_coords"][
 70                        curr_idx, :
 71                    ]
 72                    curr_X_spatial_list.append(curr_modality_and_view_spatial)
 73                curr_X_spatial = torch.cat(curr_X_spatial_list, dim=0)
 74
 75                kmeans = KMeans(n_clusters=self.m_X_per_view)
 76                kmeans.fit(curr_X_spatial.detach().cpu().numpy())
 77                Xtilde[ii, :, :] = torch.tensor(kmeans.cluster_centers_)
 78
 79            self.Xtilde = nn.Parameter(Xtilde.clone())
 80            # self.Xtilde = Xtilde.clone()
 81
 82            rand_idx = np.random.choice(
 83                np.arange(curr_X_spatial.shape[0]),
 84                size=self.m_G,
 85                replace=False,
 86            )
 87
 88            all_X_spatial = torch.cat(
 89                [data_dict[mod]["spatial_coords"] for mod in self.modality_names]
 90            )
 91            kmeans = KMeans(n_clusters=self.m_G)
 92            kmeans.fit(all_X_spatial.detach().cpu().numpy())
 93            self.Gtilde = nn.Parameter(torch.tensor(kmeans.cluster_centers_))
 94
 95        elif grid_init:
 96
 97            if self.n_spatial_dims == 2:
 98                xlow, ylow = (
 99                    data_dict[self.modality_names[0]]["spatial_coords"].numpy().min(0)
100                )
101                xhigh, yhigh = (
102                    data_dict[self.modality_names[0]]["spatial_coords"].numpy().max(0)
103                )
104                xlimits = [xlow, xhigh]
105                ylimits = [ylow, yhigh]
106                numticks = np.ceil(np.sqrt(self.m_G)).astype(int)
107                self.m_G = numticks**2
108                self.m_X_per_view = numticks**2
109                x1s = np.linspace(*xlimits, num=numticks)
110                x2s = np.linspace(*ylimits, num=numticks)
111                X1, X2 = np.meshgrid(x1s, x2s)
112                Xtilde = np.vstack([X1.ravel(), X2.ravel()]).T
113                Xtilde_torch = torch.zeros(
114                    [self.n_views, Xtilde.shape[0], self.n_spatial_dims]
115                )
116                for vv in range(self.n_views):
117                    Xtilde_torch[vv] = torch.tensor(Xtilde)
118
119                # self.Xtilde = Xtilde_torch.clone()
120                # self.Gtilde = torch.tensor(Xtilde).float()
121                self.Xtilde = nn.Parameter(Xtilde_torch.clone())
122                self.Gtilde = nn.Parameter(torch.tensor(Xtilde).float())
123
124        else:
125            # Random initialization of inducing locations
126            self.Xtilde = nn.Parameter(
127                torch.randn([self.n_views, self.m_X_per_view, self.n_spatial_dims])
128            )
129            self.Gtilde = nn.Parameter(torch.randn([self.m_G, self.n_spatial_dims]))
130
131        ## Variational covariance parameters
132        Omega_sqt_G_list = torch.zeros(
133            [self.n_views * self.n_spatial_dims, self.m_X_per_view, self.m_X_per_view],
134            device=device,
135        )
136        for ii in range(self.n_views):
137            for jj in range(self.n_spatial_dims):
138                Omega_sqt = 0.1 * torch.randn(
139                    size=[self.m_X_per_view, self.m_X_per_view]
140                )
141                # import ipdb; ipdb.set_trace()
142                # Omega_sqt_G_list[ii * self.n_views + jj, :, :] = Omega_sqt
143                Omega_sqt_G_list[jj * self.n_views + ii, :, :] = Omega_sqt
144        self.Omega_sqt_G_list = nn.Parameter(Omega_sqt_G_list)
145
146        Omega_sqt_F_dict = torch.nn.ParameterDict()
147        for mod in self.modality_names:
148            num_outputs = self.Ps[mod]
149            curr_Omega = torch.zeros([self.n_latent_outputs[mod], self.m_G, self.m_G])
150            for jj in range(self.n_latent_outputs[mod]):
151                Omega_sqt = 0.1 * torch.randn(size=[self.m_G, self.m_G])
152                curr_Omega[jj, :, :] = Omega_sqt
153            Omega_sqt_F_dict[mod] = nn.Parameter(curr_Omega)
154        self.Omega_sqt_F_dict = Omega_sqt_F_dict
155
156        ## Variational mean parameters
157        self.delta_G_list = nn.Parameter(self.Xtilde.clone())
158        delta_F_dict = torch.nn.ParameterDict()
159        for mod in self.modality_names:
160            num_outputs = self.Ps[mod]
161            curr_delta = nn.Parameter(
162                torch.randn(size=[self.m_G, self.n_latent_outputs[mod]], device=device)
163            )
164            delta_F_dict[mod] = curr_delta
165        self.delta_F_dict = delta_F_dict
166
167        ## LMC parameters
168        self.W_dict = torch.nn.ParameterDict()
169        for mod in self.modality_names:
170            if self.n_latent_gps[mod] is not None:
171                self.W_dict[mod] = nn.Parameter(
172                    torch.randn([self.n_latent_gps[mod], self.Ps[mod]])
173                )

Initializes internal Module state, shared by both nn.Module and ScriptModule.

def compute_mean_and_var(self, Kff_diag, Kuf, Kuu_chol, mu_x, mu_z, delta, Omega_tril):
175    def compute_mean_and_var(
176        self, Kff_diag, Kuf, Kuu_chol, mu_x, mu_z, delta, Omega_tril
177    ):
178        alpha_x = torch.cholesky_solve(Kuf, Kuu_chol)
179
180        a_t_Kchol = torch.matmul(alpha_x.transpose(-1, -2), Kuu_chol)
181        aKa = torch.sum(torch.square(a_t_Kchol), dim=-1)
182
183        mu_tilde = mu_x.unsqueeze(0) + torch.matmul(
184            alpha_x.transpose(-1, -2), delta - mu_z
185        )
186
187        if len(alpha_x.shape) == 2:
188            a_t_Omega_tril = torch.matmul(
189                alpha_x.transpose(-1, -2).unsqueeze(0), Omega_tril
190            )
191            aOmega_a = torch.sum(torch.square(a_t_Omega_tril), dim=-1)
192            Sigma_tilde = Kff_diag - aKa + aOmega_a + self.diagonal_offset
193        else:
194            a_t_Omega_tril = torch.matmul(
195                alpha_x.transpose(-1, -2).unsqueeze(1), Omega_tril.unsqueeze(0)
196            )
197            aOmega_a = torch.sum(torch.square(a_t_Omega_tril), dim=-1)
198            Sigma_tilde = (
199                Kff_diag.unsqueeze(1)
200                - aKa.unsqueeze(1)
201                + aOmega_a
202                + self.diagonal_offset
203            )
204
205        return mu_tilde, Sigma_tilde + self.diagonal_offset
def get_Omega_from_Omega_sqt(self, Omega_sqt):
207    def get_Omega_from_Omega_sqt(self, Omega_sqt):
208        return torch.matmul(
209            Omega_sqt,
210            torch.transpose(Omega_sqt, -1, -2),
211        ) + self.diagonal_offset * torch.eye(Omega_sqt.shape[-1], device=device)
def forward( self, X_spatial, view_idx, Ns, S=1, prediction_mode=False, G_test=None):
213    def forward(self, X_spatial, view_idx, Ns, S=1, prediction_mode=False, G_test=None):
214
215        if prediction_mode:
216            self.eval()
217
218        self.noise_variance_pos = torch.exp(self.noise_variance) + self.diagonal_offset
219
220        self.mu_z_G = (
221            torch.zeros(
222                [self.n_views, self.m_X_per_view, self.n_spatial_dims], device=device
223            )
224            * np.nan
225        )
226        for vv in range(self.n_views):
227            self.mu_z_G[vv] = (
228                torch.mm(self.Xtilde[vv], self.mean_slopes[vv])
229                + self.mean_intercepts[vv]
230            )
231            if self.fixed_view_idx is not None and (
232                vv in self.fixed_view_idx
233                if isinstance(self.fixed_view_idx, Iterable)
234                else self.fixed_view_idx == vv
235            ):
236                self.mu_z_G[vv] *= 100.0
237
238        self.Kuu_chol_list = (
239            torch.zeros(
240                [self.n_views, self.m_X_per_view, self.m_X_per_view], device=device
241            )
242            * np.nan
243        )
244        G_samples = {}
245        for mod in self.modality_names:
246            G_samples[mod] = (
247                torch.zeros([S, Ns[mod], self.n_spatial_dims], device=device) * np.nan
248            )
249
250        G_means = {}
251        for mod in self.modality_names:
252            G_means[mod] = (
253                torch.zeros([Ns[mod], self.n_spatial_dims], device=device) * np.nan
254            )
255
256        curr_Omega_G = self.get_Omega_from_Omega_sqt(self.Omega_sqt_G_list)
257
258        self.curr_Omega_tril_list = torch.cholesky(curr_Omega_G)
259
260        for vv in range(self.n_views):
261
262            ## If this view is fixed (template-based alignment), then we don't need to sample for it.
263            if self.fixed_view_idx is not None and (
264                vv in self.fixed_view_idx
265                if isinstance(self.fixed_view_idx, Iterable)
266                else self.fixed_view_idx == vv
267            ):
268                for mm, mod in enumerate(self.modality_names):
269                    observed_X_spatial = X_spatial[mod][view_idx[mod][vv]]
270                    G_means[mod][view_idx[mod][vv]] = observed_X_spatial
271
272                    G_samples[mod][:, view_idx[mod][vv], :] = observed_X_spatial
273
274                continue
275
276            kernel_G = lambda x1, x2, diag=False: self.kernel_func_warp(
277                x1,
278                x2,
279                lengthscale_unconstrained=self.warp_kernel_lengthscales[vv],
280                output_variance_unconstrained=self.warp_kernel_variances[vv],
281                diag=diag,
282            )
283
284            ## Collect data from all modalities for this view
285            curr_X_spatial_list = []
286            curr_n = 0
287            curr_mod_idx = []
288            for mod in self.modality_names:
289                curr_idx = view_idx[mod][vv]
290                curr_mod_idx.append(np.arange(curr_n, curr_n + len(curr_idx)))
291                curr_n += len(curr_idx)
292                curr_modality_and_view_spatial = X_spatial[mod][curr_idx, :]
293                curr_X_spatial_list.append(curr_modality_and_view_spatial)
294
295            curr_X_spatial = torch.cat(curr_X_spatial_list, dim=0)
296
297            if len(curr_X_spatial) == 0:
298                continue
299
300            curr_X_tilde = self.Xtilde[vv]
301
302            mu_x_G = (
303                torch.mm(curr_X_spatial, self.mean_slopes[vv])
304                + self.mean_intercepts[vv]
305            )
306
307            # Kff_diag = (
308            #     kernel_G(curr_X_spatial, curr_X_spatial, diag=True)
309            #     + self.diagonal_offset
310            # )
311            Kff_diag = torch.ones((curr_X_spatial.shape[0]), device=device) * torch.exp(
312                self.warp_kernel_variances[vv]
313            )
314
315            Kuu = kernel_G(
316                curr_X_tilde, curr_X_tilde
317            ) + self.diagonal_offset * torch.eye(self.m_X_per_view, device=device)
318
319            Kuf = kernel_G(curr_X_tilde, curr_X_spatial)
320
321            Kuu_chol = torch.cholesky(Kuu)
322            self.Kuu_chol_list[vv, :, :] = Kuu_chol
323
324            mu_tilde, Sigma_tilde = self.compute_mean_and_var(
325                Kff_diag,
326                Kuf,
327                Kuu_chol,
328                mu_x_G,
329                self.mu_z_G,
330                self.delta_G_list,
331                self.curr_Omega_tril_list,
332            )
333
334            # Sample
335            G_marginal_dist = torch.distributions.Normal(
336                mu_tilde[vv],
337                Sigma_tilde[
338                    vv * self.n_spatial_dims : vv * self.n_spatial_dims
339                    + self.n_spatial_dims
340                ].t(),
341            )
342
343            for mm, mod in enumerate(self.modality_names):
344                curr_idx = curr_mod_idx[mm]
345                G_means[mod][view_idx[mod][vv]] = mu_tilde[vv][curr_idx]
346
347            for ss in range(S):
348
349                curr_G_sample = G_marginal_dist.rsample()
350                for mm, mod in enumerate(self.modality_names):
351                    curr_idx = curr_mod_idx[mm]
352                    G_samples[mod][ss, view_idx[mod][vv]] = curr_G_sample[curr_idx]
353
354        self.curr_Omega_tril_F = {}
355        for mod in self.modality_names:
356            self.curr_Omega_tril_F[mod] = torch.zeros(
357                [self.n_latent_outputs[mod], self.m_G, self.m_G], device=device
358            )
359
360        F_samples = {}
361        self.F_latent_samples = {}
362        self.F_observed_samples = {}
363        for mod in self.modality_names:
364            F_samples[mod] = torch.zeros([S, Ns[mod], self.n_latent_outputs[mod]])
365            self.F_latent_samples[mod] = torch.zeros(
366                [S, Ns[mod], self.n_latent_outputs[mod]], device=device
367            )
368            self.F_observed_samples[mod] = torch.zeros([S, Ns[mod], self.Ps[mod]])
369
370        if G_test is not None:
371
372            self.F_latent_samples_test = {}
373            self.F_observed_samples_test = {}
374            for mod in self.modality_names:
375                n_test = G_test[mod].shape[1]
376                self.F_latent_samples_test[mod] = torch.zeros(
377                    [S, n_test, self.n_latent_outputs[mod]]
378                )
379                self.F_observed_samples_test[mod] = torch.zeros(
380                    [S, n_test, self.Ps[mod]]
381                )
382
383        kernel_F = lambda x1, x2, diag=False: self.kernel_func_data(
384            x1,
385            x2,
386            lengthscale_unconstrained=self.data_kernel_lengthscale,
387            output_variance_unconstrained=self.data_kernel_variance,
388            diag=diag,
389        )
390
391        Kuu = kernel_F(self.Gtilde, self.Gtilde) + self.diagonal_offset * torch.eye(
392            self.m_G, device=device
393        )
394
395        self.Kuu_chol_F = torch.cholesky(Kuu)
396
397        for mod in self.modality_names:
398
399            mu_x_F = torch.zeros([Ns[mod], self.n_latent_outputs[mod]], device=device)
400            mu_z_F = torch.zeros([self.m_G, self.n_latent_outputs[mod]], device=device)
401
402            # Kff_diag = (
403            #     kernel_F(G_samples[mod], G_samples[mod], diag=True)
404            #     + self.diagonal_offset
405            # )
406            Kff_diag = torch.ones(
407                (G_samples[mod].shape[:2]), device=device
408            ) * torch.exp(self.data_kernel_variance)
409
410            Kuf = kernel_F(self.Gtilde, G_samples[mod])
411            curr_Omega = self.get_Omega_from_Omega_sqt(self.Omega_sqt_F_dict[mod])
412
413            self.curr_Omega_tril_F[mod] = torch.cholesky(curr_Omega)
414            mu_tilde, Sigma_tilde = self.compute_mean_and_var(
415                Kff_diag,
416                Kuf,
417                self.Kuu_chol_F,
418                mu_x_F,
419                mu_z_F,
420                self.delta_F_dict[mod],
421                self.curr_Omega_tril_F[mod],
422            )
423
424            eps = torch.randn(mu_tilde.shape, device=device)
425            curr_F_latent_samples = (
426                mu_tilde + torch.sqrt(torch.transpose(Sigma_tilde, 1, 2)) * eps
427            )
428
429            if self.n_latent_gps[mod] is not None:
430                curr_W = self.W_dict[mod]
431                F_observed_mean = torch.matmul(curr_F_latent_samples, curr_W)
432            else:
433                F_observed_mean = curr_F_latent_samples
434
435            self.F_latent_samples[mod] = curr_F_latent_samples
436            self.F_observed_samples[mod] = F_observed_mean
437
438            ## For test samples
439            if G_test is not None:
440                # Kff_diag = (
441                #     kernel_F(G_samples[mod], G_samples[mod], diag=True)
442                #     + self.diagonal_offset
443                # )
444                Kff_diag = torch.ones(
445                    (G_test[mod].shape[:2]), device=device
446                ) * torch.exp(
447                    self.data_kernel_variance,
448                )
449
450                mu_x_F = torch.zeros(
451                    [G_test[mod].shape[1], self.n_latent_outputs[mod]], device=device
452                )
453
454                Kuf = kernel_F(self.Gtilde, G_test[mod])
455
456                mu_tilde, Sigma_tilde = self.compute_mean_and_var(
457                    Kff_diag,
458                    Kuf,
459                    self.Kuu_chol_F,
460                    mu_x_F,
461                    mu_z_F,
462                    self.delta_F_dict[mod],
463                    self.curr_Omega_tril_F[mod],
464                )
465
466                eps = torch.randn(mu_tilde.shape, device=device)
467                curr_F_latent_samples = (
468                    mu_tilde + torch.sqrt(torch.transpose(Sigma_tilde, 1, 2)) * eps
469                )
470
471                if self.n_latent_gps[mod] is not None:
472                    curr_W = self.W_dict[mod]
473                    F_observed_mean = torch.matmul(curr_F_latent_samples, curr_W)
474                else:
475                    F_observed_mean = curr_F_latent_samples
476
477                self.F_latent_samples_test[mod] = curr_F_latent_samples
478                self.F_observed_samples_test[mod] = F_observed_mean
479
480        if G_test is not None:
481            return (
482                G_means,
483                G_samples,
484                self.F_latent_samples,
485                self.F_observed_samples,
486                self.F_latent_samples_test,
487                self.F_observed_samples_test,
488            )
489        else:
490            return G_means, G_samples, self.F_latent_samples, self.F_observed_samples

Defines the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

def loss_fn(self, data_dict, F_samples):
492    def loss_fn(self, data_dict, F_samples):
493        # This computes the the negative (approximate) ELBO
494
495        # Running sum for KL terms
496        KL_div = 0
497
498        ## G
499        for vv in range(self.n_views):
500            if self.fixed_view_idx is not None and (
501                vv in self.fixed_view_idx
502                if isinstance(self.fixed_view_idx, Iterable)
503                else self.fixed_view_idx == vv
504            ):
505                continue
506            for jj in range(self.n_spatial_dims):
507                qu = torch.distributions.MultivariateNormal(
508                    loc=self.delta_G_list[vv, :, jj],
509                    scale_tril=self.curr_Omega_tril_list[jj * self.n_views + vv, :, :],
510                )
511                pu = torch.distributions.MultivariateNormal(
512                    loc=self.mu_z_G[vv, :, jj],
513                    scale_tril=self.Kuu_chol_list[vv, :, :],
514                )
515                curr_KL_div = torch.distributions.kl.kl_divergence(qu, pu)
516
517                KL_div += curr_KL_div
518
519        ## F
520        LL = 0
521        pu = torch.distributions.MultivariateNormal(
522            loc=torch.zeros(self.m_G, device=device), scale_tril=self.Kuu_chol_F
523        )
524        for mm, mod in enumerate(self.modality_names):
525            qu = torch.distributions.MultivariateNormal(
526                loc=self.delta_F_dict[mod].t(),
527                scale_tril=self.curr_Omega_tril_F[mod],
528            )
529
530            curr_KL_div = torch.distributions.kl.kl_divergence(qu, pu)
531            KL_div += curr_KL_div.sum()
532
533            Y_distribution = torch.distributions.Normal(
534                loc=F_samples[mod],
535                scale=self.noise_variance_pos[-self.n_modalities + mm],
536            )
537            S = F_samples[mod].shape[0]
538
539            LL += Y_distribution.log_prob(data_dict[mod]["outputs"]).sum() / S
540
541        return -LL + KL_div
Inherited Members
gpsa.models.gpsa.GPSA
create_view_idx_dict
compute_mean_penalty
torch.nn.modules.module.Module
dump_patches
register_buffer
register_parameter
add_module
apply
cuda
xpu
cpu
type
float
double
half
bfloat16
to
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
state_dict
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr