gpsa.models.vgpsa
1import torch 2import numpy as np 3import torch.nn as nn 4from sklearn.cluster import KMeans 5from gpsa import GPSA 6from ..util.util import rbf_kernel 7from collections.abc import Iterable 8 9torch.autograd.set_detect_anomaly(True) 10 11device = "cuda" if torch.cuda.is_available() else "cpu" 12 13 14class VariationalGPSA(GPSA): 15 def __init__( 16 self, 17 data_dict, 18 m_X_per_view, 19 m_G, 20 data_init=True, 21 minmax_init=False, 22 grid_init=False, 23 n_spatial_dims=2, 24 n_noise_variance_params=2, 25 kernel_func_warp=rbf_kernel, 26 kernel_func_data=rbf_kernel, 27 n_latent_gps=None, 28 mean_function="identity_fixed", 29 mean_penalty_param=0.0, 30 fixed_warp_kernel_variances=None, 31 fixed_warp_kernel_lengthscales=None, 32 fixed_data_kernel_lengthscales=None, 33 fixed_view_idx=None, 34 ): 35 super(VariationalGPSA, self).__init__( 36 data_dict, 37 data_init=True, 38 n_spatial_dims=2, 39 n_noise_variance_params=2, 40 kernel_func_warp=kernel_func_warp, 41 kernel_func_data=kernel_func_data, 42 mean_penalty_param=mean_penalty_param, 43 fixed_warp_kernel_variances=fixed_warp_kernel_variances, 44 fixed_warp_kernel_lengthscales=fixed_warp_kernel_lengthscales, 45 fixed_data_kernel_lengthscales=fixed_data_kernel_lengthscales, 46 ) 47 48 self.m_X_per_view = m_X_per_view 49 self.m_G = m_G 50 self.n_latent_gps = n_latent_gps 51 self.n_latent_outputs = {} 52 for mod in self.modality_names: 53 curr_n_latent_outputs = ( 54 self.n_latent_gps[mod] 55 if self.n_latent_gps[mod] is not None 56 else self.Ps[mod] 57 ) 58 self.n_latent_outputs[mod] = curr_n_latent_outputs 59 self.fixed_view_idx = fixed_view_idx 60 61 if data_init: 62 # Initialize inducing locations with a subset of the data 63 Xtilde = torch.zeros([self.n_views, self.m_X_per_view, self.n_spatial_dims]) 64 for ii in range(self.n_views): 65 curr_X_spatial_list = [] 66 for mod in self.modality_names: 67 curr_idx = self.view_idx[mod][ii] 68 curr_modality_and_view_spatial = data_dict[mod]["spatial_coords"][ 69 curr_idx, : 70 ] 71 curr_X_spatial_list.append(curr_modality_and_view_spatial) 72 curr_X_spatial = torch.cat(curr_X_spatial_list, dim=0) 73 74 kmeans = KMeans(n_clusters=self.m_X_per_view) 75 kmeans.fit(curr_X_spatial.detach().cpu().numpy()) 76 Xtilde[ii, :, :] = torch.tensor(kmeans.cluster_centers_) 77 78 self.Xtilde = nn.Parameter(Xtilde.clone()) 79 # self.Xtilde = Xtilde.clone() 80 81 rand_idx = np.random.choice( 82 np.arange(curr_X_spatial.shape[0]), 83 size=self.m_G, 84 replace=False, 85 ) 86 87 all_X_spatial = torch.cat( 88 [data_dict[mod]["spatial_coords"] for mod in self.modality_names] 89 ) 90 kmeans = KMeans(n_clusters=self.m_G) 91 kmeans.fit(all_X_spatial.detach().cpu().numpy()) 92 self.Gtilde = nn.Parameter(torch.tensor(kmeans.cluster_centers_)) 93 94 elif grid_init: 95 96 if self.n_spatial_dims == 2: 97 xlow, ylow = ( 98 data_dict[self.modality_names[0]]["spatial_coords"].numpy().min(0) 99 ) 100 xhigh, yhigh = ( 101 data_dict[self.modality_names[0]]["spatial_coords"].numpy().max(0) 102 ) 103 xlimits = [xlow, xhigh] 104 ylimits = [ylow, yhigh] 105 numticks = np.ceil(np.sqrt(self.m_G)).astype(int) 106 self.m_G = numticks**2 107 self.m_X_per_view = numticks**2 108 x1s = np.linspace(*xlimits, num=numticks) 109 x2s = np.linspace(*ylimits, num=numticks) 110 X1, X2 = np.meshgrid(x1s, x2s) 111 Xtilde = np.vstack([X1.ravel(), X2.ravel()]).T 112 Xtilde_torch = torch.zeros( 113 [self.n_views, Xtilde.shape[0], self.n_spatial_dims] 114 ) 115 for vv in range(self.n_views): 116 Xtilde_torch[vv] = torch.tensor(Xtilde) 117 118 # self.Xtilde = Xtilde_torch.clone() 119 # self.Gtilde = torch.tensor(Xtilde).float() 120 self.Xtilde = nn.Parameter(Xtilde_torch.clone()) 121 self.Gtilde = nn.Parameter(torch.tensor(Xtilde).float()) 122 123 else: 124 # Random initialization of inducing locations 125 self.Xtilde = nn.Parameter( 126 torch.randn([self.n_views, self.m_X_per_view, self.n_spatial_dims]) 127 ) 128 self.Gtilde = nn.Parameter(torch.randn([self.m_G, self.n_spatial_dims])) 129 130 ## Variational covariance parameters 131 Omega_sqt_G_list = torch.zeros( 132 [self.n_views * self.n_spatial_dims, self.m_X_per_view, self.m_X_per_view], 133 device=device, 134 ) 135 for ii in range(self.n_views): 136 for jj in range(self.n_spatial_dims): 137 Omega_sqt = 0.1 * torch.randn( 138 size=[self.m_X_per_view, self.m_X_per_view] 139 ) 140 # import ipdb; ipdb.set_trace() 141 # Omega_sqt_G_list[ii * self.n_views + jj, :, :] = Omega_sqt 142 Omega_sqt_G_list[jj * self.n_views + ii, :, :] = Omega_sqt 143 self.Omega_sqt_G_list = nn.Parameter(Omega_sqt_G_list) 144 145 Omega_sqt_F_dict = torch.nn.ParameterDict() 146 for mod in self.modality_names: 147 num_outputs = self.Ps[mod] 148 curr_Omega = torch.zeros([self.n_latent_outputs[mod], self.m_G, self.m_G]) 149 for jj in range(self.n_latent_outputs[mod]): 150 Omega_sqt = 0.1 * torch.randn(size=[self.m_G, self.m_G]) 151 curr_Omega[jj, :, :] = Omega_sqt 152 Omega_sqt_F_dict[mod] = nn.Parameter(curr_Omega) 153 self.Omega_sqt_F_dict = Omega_sqt_F_dict 154 155 ## Variational mean parameters 156 self.delta_G_list = nn.Parameter(self.Xtilde.clone()) 157 delta_F_dict = torch.nn.ParameterDict() 158 for mod in self.modality_names: 159 num_outputs = self.Ps[mod] 160 curr_delta = nn.Parameter( 161 torch.randn(size=[self.m_G, self.n_latent_outputs[mod]], device=device) 162 ) 163 delta_F_dict[mod] = curr_delta 164 self.delta_F_dict = delta_F_dict 165 166 ## LMC parameters 167 self.W_dict = torch.nn.ParameterDict() 168 for mod in self.modality_names: 169 if self.n_latent_gps[mod] is not None: 170 self.W_dict[mod] = nn.Parameter( 171 torch.randn([self.n_latent_gps[mod], self.Ps[mod]]) 172 ) 173 174 def compute_mean_and_var( 175 self, Kff_diag, Kuf, Kuu_chol, mu_x, mu_z, delta, Omega_tril 176 ): 177 alpha_x = torch.cholesky_solve(Kuf, Kuu_chol) 178 179 a_t_Kchol = torch.matmul(alpha_x.transpose(-1, -2), Kuu_chol) 180 aKa = torch.sum(torch.square(a_t_Kchol), dim=-1) 181 182 mu_tilde = mu_x.unsqueeze(0) + torch.matmul( 183 alpha_x.transpose(-1, -2), delta - mu_z 184 ) 185 186 if len(alpha_x.shape) == 2: 187 a_t_Omega_tril = torch.matmul( 188 alpha_x.transpose(-1, -2).unsqueeze(0), Omega_tril 189 ) 190 aOmega_a = torch.sum(torch.square(a_t_Omega_tril), dim=-1) 191 Sigma_tilde = Kff_diag - aKa + aOmega_a + self.diagonal_offset 192 else: 193 a_t_Omega_tril = torch.matmul( 194 alpha_x.transpose(-1, -2).unsqueeze(1), Omega_tril.unsqueeze(0) 195 ) 196 aOmega_a = torch.sum(torch.square(a_t_Omega_tril), dim=-1) 197 Sigma_tilde = ( 198 Kff_diag.unsqueeze(1) 199 - aKa.unsqueeze(1) 200 + aOmega_a 201 + self.diagonal_offset 202 ) 203 204 return mu_tilde, Sigma_tilde + self.diagonal_offset 205 206 def get_Omega_from_Omega_sqt(self, Omega_sqt): 207 return torch.matmul( 208 Omega_sqt, 209 torch.transpose(Omega_sqt, -1, -2), 210 ) + self.diagonal_offset * torch.eye(Omega_sqt.shape[-1], device=device) 211 212 def forward(self, X_spatial, view_idx, Ns, S=1, prediction_mode=False, G_test=None): 213 214 if prediction_mode: 215 self.eval() 216 217 self.noise_variance_pos = torch.exp(self.noise_variance) + self.diagonal_offset 218 219 self.mu_z_G = ( 220 torch.zeros( 221 [self.n_views, self.m_X_per_view, self.n_spatial_dims], device=device 222 ) 223 * np.nan 224 ) 225 for vv in range(self.n_views): 226 self.mu_z_G[vv] = ( 227 torch.mm(self.Xtilde[vv], self.mean_slopes[vv]) 228 + self.mean_intercepts[vv] 229 ) 230 if self.fixed_view_idx is not None and ( 231 vv in self.fixed_view_idx 232 if isinstance(self.fixed_view_idx, Iterable) 233 else self.fixed_view_idx == vv 234 ): 235 self.mu_z_G[vv] *= 100.0 236 237 self.Kuu_chol_list = ( 238 torch.zeros( 239 [self.n_views, self.m_X_per_view, self.m_X_per_view], device=device 240 ) 241 * np.nan 242 ) 243 G_samples = {} 244 for mod in self.modality_names: 245 G_samples[mod] = ( 246 torch.zeros([S, Ns[mod], self.n_spatial_dims], device=device) * np.nan 247 ) 248 249 G_means = {} 250 for mod in self.modality_names: 251 G_means[mod] = ( 252 torch.zeros([Ns[mod], self.n_spatial_dims], device=device) * np.nan 253 ) 254 255 curr_Omega_G = self.get_Omega_from_Omega_sqt(self.Omega_sqt_G_list) 256 257 self.curr_Omega_tril_list = torch.cholesky(curr_Omega_G) 258 259 for vv in range(self.n_views): 260 261 ## If this view is fixed (template-based alignment), then we don't need to sample for it. 262 if self.fixed_view_idx is not None and ( 263 vv in self.fixed_view_idx 264 if isinstance(self.fixed_view_idx, Iterable) 265 else self.fixed_view_idx == vv 266 ): 267 for mm, mod in enumerate(self.modality_names): 268 observed_X_spatial = X_spatial[mod][view_idx[mod][vv]] 269 G_means[mod][view_idx[mod][vv]] = observed_X_spatial 270 271 G_samples[mod][:, view_idx[mod][vv], :] = observed_X_spatial 272 273 continue 274 275 kernel_G = lambda x1, x2, diag=False: self.kernel_func_warp( 276 x1, 277 x2, 278 lengthscale_unconstrained=self.warp_kernel_lengthscales[vv], 279 output_variance_unconstrained=self.warp_kernel_variances[vv], 280 diag=diag, 281 ) 282 283 ## Collect data from all modalities for this view 284 curr_X_spatial_list = [] 285 curr_n = 0 286 curr_mod_idx = [] 287 for mod in self.modality_names: 288 curr_idx = view_idx[mod][vv] 289 curr_mod_idx.append(np.arange(curr_n, curr_n + len(curr_idx))) 290 curr_n += len(curr_idx) 291 curr_modality_and_view_spatial = X_spatial[mod][curr_idx, :] 292 curr_X_spatial_list.append(curr_modality_and_view_spatial) 293 294 curr_X_spatial = torch.cat(curr_X_spatial_list, dim=0) 295 296 if len(curr_X_spatial) == 0: 297 continue 298 299 curr_X_tilde = self.Xtilde[vv] 300 301 mu_x_G = ( 302 torch.mm(curr_X_spatial, self.mean_slopes[vv]) 303 + self.mean_intercepts[vv] 304 ) 305 306 # Kff_diag = ( 307 # kernel_G(curr_X_spatial, curr_X_spatial, diag=True) 308 # + self.diagonal_offset 309 # ) 310 Kff_diag = torch.ones((curr_X_spatial.shape[0]), device=device) * torch.exp( 311 self.warp_kernel_variances[vv] 312 ) 313 314 Kuu = kernel_G( 315 curr_X_tilde, curr_X_tilde 316 ) + self.diagonal_offset * torch.eye(self.m_X_per_view, device=device) 317 318 Kuf = kernel_G(curr_X_tilde, curr_X_spatial) 319 320 Kuu_chol = torch.cholesky(Kuu) 321 self.Kuu_chol_list[vv, :, :] = Kuu_chol 322 323 mu_tilde, Sigma_tilde = self.compute_mean_and_var( 324 Kff_diag, 325 Kuf, 326 Kuu_chol, 327 mu_x_G, 328 self.mu_z_G, 329 self.delta_G_list, 330 self.curr_Omega_tril_list, 331 ) 332 333 # Sample 334 G_marginal_dist = torch.distributions.Normal( 335 mu_tilde[vv], 336 Sigma_tilde[ 337 vv * self.n_spatial_dims : vv * self.n_spatial_dims 338 + self.n_spatial_dims 339 ].t(), 340 ) 341 342 for mm, mod in enumerate(self.modality_names): 343 curr_idx = curr_mod_idx[mm] 344 G_means[mod][view_idx[mod][vv]] = mu_tilde[vv][curr_idx] 345 346 for ss in range(S): 347 348 curr_G_sample = G_marginal_dist.rsample() 349 for mm, mod in enumerate(self.modality_names): 350 curr_idx = curr_mod_idx[mm] 351 G_samples[mod][ss, view_idx[mod][vv]] = curr_G_sample[curr_idx] 352 353 self.curr_Omega_tril_F = {} 354 for mod in self.modality_names: 355 self.curr_Omega_tril_F[mod] = torch.zeros( 356 [self.n_latent_outputs[mod], self.m_G, self.m_G], device=device 357 ) 358 359 F_samples = {} 360 self.F_latent_samples = {} 361 self.F_observed_samples = {} 362 for mod in self.modality_names: 363 F_samples[mod] = torch.zeros([S, Ns[mod], self.n_latent_outputs[mod]]) 364 self.F_latent_samples[mod] = torch.zeros( 365 [S, Ns[mod], self.n_latent_outputs[mod]], device=device 366 ) 367 self.F_observed_samples[mod] = torch.zeros([S, Ns[mod], self.Ps[mod]]) 368 369 if G_test is not None: 370 371 self.F_latent_samples_test = {} 372 self.F_observed_samples_test = {} 373 for mod in self.modality_names: 374 n_test = G_test[mod].shape[1] 375 self.F_latent_samples_test[mod] = torch.zeros( 376 [S, n_test, self.n_latent_outputs[mod]] 377 ) 378 self.F_observed_samples_test[mod] = torch.zeros( 379 [S, n_test, self.Ps[mod]] 380 ) 381 382 kernel_F = lambda x1, x2, diag=False: self.kernel_func_data( 383 x1, 384 x2, 385 lengthscale_unconstrained=self.data_kernel_lengthscale, 386 output_variance_unconstrained=self.data_kernel_variance, 387 diag=diag, 388 ) 389 390 Kuu = kernel_F(self.Gtilde, self.Gtilde) + self.diagonal_offset * torch.eye( 391 self.m_G, device=device 392 ) 393 394 self.Kuu_chol_F = torch.cholesky(Kuu) 395 396 for mod in self.modality_names: 397 398 mu_x_F = torch.zeros([Ns[mod], self.n_latent_outputs[mod]], device=device) 399 mu_z_F = torch.zeros([self.m_G, self.n_latent_outputs[mod]], device=device) 400 401 # Kff_diag = ( 402 # kernel_F(G_samples[mod], G_samples[mod], diag=True) 403 # + self.diagonal_offset 404 # ) 405 Kff_diag = torch.ones( 406 (G_samples[mod].shape[:2]), device=device 407 ) * torch.exp(self.data_kernel_variance) 408 409 Kuf = kernel_F(self.Gtilde, G_samples[mod]) 410 curr_Omega = self.get_Omega_from_Omega_sqt(self.Omega_sqt_F_dict[mod]) 411 412 self.curr_Omega_tril_F[mod] = torch.cholesky(curr_Omega) 413 mu_tilde, Sigma_tilde = self.compute_mean_and_var( 414 Kff_diag, 415 Kuf, 416 self.Kuu_chol_F, 417 mu_x_F, 418 mu_z_F, 419 self.delta_F_dict[mod], 420 self.curr_Omega_tril_F[mod], 421 ) 422 423 eps = torch.randn(mu_tilde.shape, device=device) 424 curr_F_latent_samples = ( 425 mu_tilde + torch.sqrt(torch.transpose(Sigma_tilde, 1, 2)) * eps 426 ) 427 428 if self.n_latent_gps[mod] is not None: 429 curr_W = self.W_dict[mod] 430 F_observed_mean = torch.matmul(curr_F_latent_samples, curr_W) 431 else: 432 F_observed_mean = curr_F_latent_samples 433 434 self.F_latent_samples[mod] = curr_F_latent_samples 435 self.F_observed_samples[mod] = F_observed_mean 436 437 ## For test samples 438 if G_test is not None: 439 # Kff_diag = ( 440 # kernel_F(G_samples[mod], G_samples[mod], diag=True) 441 # + self.diagonal_offset 442 # ) 443 Kff_diag = torch.ones( 444 (G_test[mod].shape[:2]), device=device 445 ) * torch.exp( 446 self.data_kernel_variance, 447 ) 448 449 mu_x_F = torch.zeros( 450 [G_test[mod].shape[1], self.n_latent_outputs[mod]], device=device 451 ) 452 453 Kuf = kernel_F(self.Gtilde, G_test[mod]) 454 455 mu_tilde, Sigma_tilde = self.compute_mean_and_var( 456 Kff_diag, 457 Kuf, 458 self.Kuu_chol_F, 459 mu_x_F, 460 mu_z_F, 461 self.delta_F_dict[mod], 462 self.curr_Omega_tril_F[mod], 463 ) 464 465 eps = torch.randn(mu_tilde.shape, device=device) 466 curr_F_latent_samples = ( 467 mu_tilde + torch.sqrt(torch.transpose(Sigma_tilde, 1, 2)) * eps 468 ) 469 470 if self.n_latent_gps[mod] is not None: 471 curr_W = self.W_dict[mod] 472 F_observed_mean = torch.matmul(curr_F_latent_samples, curr_W) 473 else: 474 F_observed_mean = curr_F_latent_samples 475 476 self.F_latent_samples_test[mod] = curr_F_latent_samples 477 self.F_observed_samples_test[mod] = F_observed_mean 478 479 if G_test is not None: 480 return ( 481 G_means, 482 G_samples, 483 self.F_latent_samples, 484 self.F_observed_samples, 485 self.F_latent_samples_test, 486 self.F_observed_samples_test, 487 ) 488 else: 489 return G_means, G_samples, self.F_latent_samples, self.F_observed_samples 490 491 def loss_fn(self, data_dict, F_samples): 492 # This computes the the negative (approximate) ELBO 493 494 # Running sum for KL terms 495 KL_div = 0 496 497 ## G 498 for vv in range(self.n_views): 499 if self.fixed_view_idx is not None and ( 500 vv in self.fixed_view_idx 501 if isinstance(self.fixed_view_idx, Iterable) 502 else self.fixed_view_idx == vv 503 ): 504 continue 505 for jj in range(self.n_spatial_dims): 506 qu = torch.distributions.MultivariateNormal( 507 loc=self.delta_G_list[vv, :, jj], 508 scale_tril=self.curr_Omega_tril_list[jj * self.n_views + vv, :, :], 509 ) 510 pu = torch.distributions.MultivariateNormal( 511 loc=self.mu_z_G[vv, :, jj], 512 scale_tril=self.Kuu_chol_list[vv, :, :], 513 ) 514 curr_KL_div = torch.distributions.kl.kl_divergence(qu, pu) 515 516 KL_div += curr_KL_div 517 518 ## F 519 LL = 0 520 pu = torch.distributions.MultivariateNormal( 521 loc=torch.zeros(self.m_G, device=device), scale_tril=self.Kuu_chol_F 522 ) 523 for mm, mod in enumerate(self.modality_names): 524 qu = torch.distributions.MultivariateNormal( 525 loc=self.delta_F_dict[mod].t(), 526 scale_tril=self.curr_Omega_tril_F[mod], 527 ) 528 529 curr_KL_div = torch.distributions.kl.kl_divergence(qu, pu) 530 KL_div += curr_KL_div.sum() 531 532 Y_distribution = torch.distributions.Normal( 533 loc=F_samples[mod], 534 scale=self.noise_variance_pos[-self.n_modalities + mm], 535 ) 536 S = F_samples[mod].shape[0] 537 538 LL += Y_distribution.log_prob(data_dict[mod]["outputs"]).sum() / S 539 540 return -LL + KL_div 541 542 543if __name__ == "__main__": 544 pass
15class VariationalGPSA(GPSA): 16 def __init__( 17 self, 18 data_dict, 19 m_X_per_view, 20 m_G, 21 data_init=True, 22 minmax_init=False, 23 grid_init=False, 24 n_spatial_dims=2, 25 n_noise_variance_params=2, 26 kernel_func_warp=rbf_kernel, 27 kernel_func_data=rbf_kernel, 28 n_latent_gps=None, 29 mean_function="identity_fixed", 30 mean_penalty_param=0.0, 31 fixed_warp_kernel_variances=None, 32 fixed_warp_kernel_lengthscales=None, 33 fixed_data_kernel_lengthscales=None, 34 fixed_view_idx=None, 35 ): 36 super(VariationalGPSA, self).__init__( 37 data_dict, 38 data_init=True, 39 n_spatial_dims=2, 40 n_noise_variance_params=2, 41 kernel_func_warp=kernel_func_warp, 42 kernel_func_data=kernel_func_data, 43 mean_penalty_param=mean_penalty_param, 44 fixed_warp_kernel_variances=fixed_warp_kernel_variances, 45 fixed_warp_kernel_lengthscales=fixed_warp_kernel_lengthscales, 46 fixed_data_kernel_lengthscales=fixed_data_kernel_lengthscales, 47 ) 48 49 self.m_X_per_view = m_X_per_view 50 self.m_G = m_G 51 self.n_latent_gps = n_latent_gps 52 self.n_latent_outputs = {} 53 for mod in self.modality_names: 54 curr_n_latent_outputs = ( 55 self.n_latent_gps[mod] 56 if self.n_latent_gps[mod] is not None 57 else self.Ps[mod] 58 ) 59 self.n_latent_outputs[mod] = curr_n_latent_outputs 60 self.fixed_view_idx = fixed_view_idx 61 62 if data_init: 63 # Initialize inducing locations with a subset of the data 64 Xtilde = torch.zeros([self.n_views, self.m_X_per_view, self.n_spatial_dims]) 65 for ii in range(self.n_views): 66 curr_X_spatial_list = [] 67 for mod in self.modality_names: 68 curr_idx = self.view_idx[mod][ii] 69 curr_modality_and_view_spatial = data_dict[mod]["spatial_coords"][ 70 curr_idx, : 71 ] 72 curr_X_spatial_list.append(curr_modality_and_view_spatial) 73 curr_X_spatial = torch.cat(curr_X_spatial_list, dim=0) 74 75 kmeans = KMeans(n_clusters=self.m_X_per_view) 76 kmeans.fit(curr_X_spatial.detach().cpu().numpy()) 77 Xtilde[ii, :, :] = torch.tensor(kmeans.cluster_centers_) 78 79 self.Xtilde = nn.Parameter(Xtilde.clone()) 80 # self.Xtilde = Xtilde.clone() 81 82 rand_idx = np.random.choice( 83 np.arange(curr_X_spatial.shape[0]), 84 size=self.m_G, 85 replace=False, 86 ) 87 88 all_X_spatial = torch.cat( 89 [data_dict[mod]["spatial_coords"] for mod in self.modality_names] 90 ) 91 kmeans = KMeans(n_clusters=self.m_G) 92 kmeans.fit(all_X_spatial.detach().cpu().numpy()) 93 self.Gtilde = nn.Parameter(torch.tensor(kmeans.cluster_centers_)) 94 95 elif grid_init: 96 97 if self.n_spatial_dims == 2: 98 xlow, ylow = ( 99 data_dict[self.modality_names[0]]["spatial_coords"].numpy().min(0) 100 ) 101 xhigh, yhigh = ( 102 data_dict[self.modality_names[0]]["spatial_coords"].numpy().max(0) 103 ) 104 xlimits = [xlow, xhigh] 105 ylimits = [ylow, yhigh] 106 numticks = np.ceil(np.sqrt(self.m_G)).astype(int) 107 self.m_G = numticks**2 108 self.m_X_per_view = numticks**2 109 x1s = np.linspace(*xlimits, num=numticks) 110 x2s = np.linspace(*ylimits, num=numticks) 111 X1, X2 = np.meshgrid(x1s, x2s) 112 Xtilde = np.vstack([X1.ravel(), X2.ravel()]).T 113 Xtilde_torch = torch.zeros( 114 [self.n_views, Xtilde.shape[0], self.n_spatial_dims] 115 ) 116 for vv in range(self.n_views): 117 Xtilde_torch[vv] = torch.tensor(Xtilde) 118 119 # self.Xtilde = Xtilde_torch.clone() 120 # self.Gtilde = torch.tensor(Xtilde).float() 121 self.Xtilde = nn.Parameter(Xtilde_torch.clone()) 122 self.Gtilde = nn.Parameter(torch.tensor(Xtilde).float()) 123 124 else: 125 # Random initialization of inducing locations 126 self.Xtilde = nn.Parameter( 127 torch.randn([self.n_views, self.m_X_per_view, self.n_spatial_dims]) 128 ) 129 self.Gtilde = nn.Parameter(torch.randn([self.m_G, self.n_spatial_dims])) 130 131 ## Variational covariance parameters 132 Omega_sqt_G_list = torch.zeros( 133 [self.n_views * self.n_spatial_dims, self.m_X_per_view, self.m_X_per_view], 134 device=device, 135 ) 136 for ii in range(self.n_views): 137 for jj in range(self.n_spatial_dims): 138 Omega_sqt = 0.1 * torch.randn( 139 size=[self.m_X_per_view, self.m_X_per_view] 140 ) 141 # import ipdb; ipdb.set_trace() 142 # Omega_sqt_G_list[ii * self.n_views + jj, :, :] = Omega_sqt 143 Omega_sqt_G_list[jj * self.n_views + ii, :, :] = Omega_sqt 144 self.Omega_sqt_G_list = nn.Parameter(Omega_sqt_G_list) 145 146 Omega_sqt_F_dict = torch.nn.ParameterDict() 147 for mod in self.modality_names: 148 num_outputs = self.Ps[mod] 149 curr_Omega = torch.zeros([self.n_latent_outputs[mod], self.m_G, self.m_G]) 150 for jj in range(self.n_latent_outputs[mod]): 151 Omega_sqt = 0.1 * torch.randn(size=[self.m_G, self.m_G]) 152 curr_Omega[jj, :, :] = Omega_sqt 153 Omega_sqt_F_dict[mod] = nn.Parameter(curr_Omega) 154 self.Omega_sqt_F_dict = Omega_sqt_F_dict 155 156 ## Variational mean parameters 157 self.delta_G_list = nn.Parameter(self.Xtilde.clone()) 158 delta_F_dict = torch.nn.ParameterDict() 159 for mod in self.modality_names: 160 num_outputs = self.Ps[mod] 161 curr_delta = nn.Parameter( 162 torch.randn(size=[self.m_G, self.n_latent_outputs[mod]], device=device) 163 ) 164 delta_F_dict[mod] = curr_delta 165 self.delta_F_dict = delta_F_dict 166 167 ## LMC parameters 168 self.W_dict = torch.nn.ParameterDict() 169 for mod in self.modality_names: 170 if self.n_latent_gps[mod] is not None: 171 self.W_dict[mod] = nn.Parameter( 172 torch.randn([self.n_latent_gps[mod], self.Ps[mod]]) 173 ) 174 175 def compute_mean_and_var( 176 self, Kff_diag, Kuf, Kuu_chol, mu_x, mu_z, delta, Omega_tril 177 ): 178 alpha_x = torch.cholesky_solve(Kuf, Kuu_chol) 179 180 a_t_Kchol = torch.matmul(alpha_x.transpose(-1, -2), Kuu_chol) 181 aKa = torch.sum(torch.square(a_t_Kchol), dim=-1) 182 183 mu_tilde = mu_x.unsqueeze(0) + torch.matmul( 184 alpha_x.transpose(-1, -2), delta - mu_z 185 ) 186 187 if len(alpha_x.shape) == 2: 188 a_t_Omega_tril = torch.matmul( 189 alpha_x.transpose(-1, -2).unsqueeze(0), Omega_tril 190 ) 191 aOmega_a = torch.sum(torch.square(a_t_Omega_tril), dim=-1) 192 Sigma_tilde = Kff_diag - aKa + aOmega_a + self.diagonal_offset 193 else: 194 a_t_Omega_tril = torch.matmul( 195 alpha_x.transpose(-1, -2).unsqueeze(1), Omega_tril.unsqueeze(0) 196 ) 197 aOmega_a = torch.sum(torch.square(a_t_Omega_tril), dim=-1) 198 Sigma_tilde = ( 199 Kff_diag.unsqueeze(1) 200 - aKa.unsqueeze(1) 201 + aOmega_a 202 + self.diagonal_offset 203 ) 204 205 return mu_tilde, Sigma_tilde + self.diagonal_offset 206 207 def get_Omega_from_Omega_sqt(self, Omega_sqt): 208 return torch.matmul( 209 Omega_sqt, 210 torch.transpose(Omega_sqt, -1, -2), 211 ) + self.diagonal_offset * torch.eye(Omega_sqt.shape[-1], device=device) 212 213 def forward(self, X_spatial, view_idx, Ns, S=1, prediction_mode=False, G_test=None): 214 215 if prediction_mode: 216 self.eval() 217 218 self.noise_variance_pos = torch.exp(self.noise_variance) + self.diagonal_offset 219 220 self.mu_z_G = ( 221 torch.zeros( 222 [self.n_views, self.m_X_per_view, self.n_spatial_dims], device=device 223 ) 224 * np.nan 225 ) 226 for vv in range(self.n_views): 227 self.mu_z_G[vv] = ( 228 torch.mm(self.Xtilde[vv], self.mean_slopes[vv]) 229 + self.mean_intercepts[vv] 230 ) 231 if self.fixed_view_idx is not None and ( 232 vv in self.fixed_view_idx 233 if isinstance(self.fixed_view_idx, Iterable) 234 else self.fixed_view_idx == vv 235 ): 236 self.mu_z_G[vv] *= 100.0 237 238 self.Kuu_chol_list = ( 239 torch.zeros( 240 [self.n_views, self.m_X_per_view, self.m_X_per_view], device=device 241 ) 242 * np.nan 243 ) 244 G_samples = {} 245 for mod in self.modality_names: 246 G_samples[mod] = ( 247 torch.zeros([S, Ns[mod], self.n_spatial_dims], device=device) * np.nan 248 ) 249 250 G_means = {} 251 for mod in self.modality_names: 252 G_means[mod] = ( 253 torch.zeros([Ns[mod], self.n_spatial_dims], device=device) * np.nan 254 ) 255 256 curr_Omega_G = self.get_Omega_from_Omega_sqt(self.Omega_sqt_G_list) 257 258 self.curr_Omega_tril_list = torch.cholesky(curr_Omega_G) 259 260 for vv in range(self.n_views): 261 262 ## If this view is fixed (template-based alignment), then we don't need to sample for it. 263 if self.fixed_view_idx is not None and ( 264 vv in self.fixed_view_idx 265 if isinstance(self.fixed_view_idx, Iterable) 266 else self.fixed_view_idx == vv 267 ): 268 for mm, mod in enumerate(self.modality_names): 269 observed_X_spatial = X_spatial[mod][view_idx[mod][vv]] 270 G_means[mod][view_idx[mod][vv]] = observed_X_spatial 271 272 G_samples[mod][:, view_idx[mod][vv], :] = observed_X_spatial 273 274 continue 275 276 kernel_G = lambda x1, x2, diag=False: self.kernel_func_warp( 277 x1, 278 x2, 279 lengthscale_unconstrained=self.warp_kernel_lengthscales[vv], 280 output_variance_unconstrained=self.warp_kernel_variances[vv], 281 diag=diag, 282 ) 283 284 ## Collect data from all modalities for this view 285 curr_X_spatial_list = [] 286 curr_n = 0 287 curr_mod_idx = [] 288 for mod in self.modality_names: 289 curr_idx = view_idx[mod][vv] 290 curr_mod_idx.append(np.arange(curr_n, curr_n + len(curr_idx))) 291 curr_n += len(curr_idx) 292 curr_modality_and_view_spatial = X_spatial[mod][curr_idx, :] 293 curr_X_spatial_list.append(curr_modality_and_view_spatial) 294 295 curr_X_spatial = torch.cat(curr_X_spatial_list, dim=0) 296 297 if len(curr_X_spatial) == 0: 298 continue 299 300 curr_X_tilde = self.Xtilde[vv] 301 302 mu_x_G = ( 303 torch.mm(curr_X_spatial, self.mean_slopes[vv]) 304 + self.mean_intercepts[vv] 305 ) 306 307 # Kff_diag = ( 308 # kernel_G(curr_X_spatial, curr_X_spatial, diag=True) 309 # + self.diagonal_offset 310 # ) 311 Kff_diag = torch.ones((curr_X_spatial.shape[0]), device=device) * torch.exp( 312 self.warp_kernel_variances[vv] 313 ) 314 315 Kuu = kernel_G( 316 curr_X_tilde, curr_X_tilde 317 ) + self.diagonal_offset * torch.eye(self.m_X_per_view, device=device) 318 319 Kuf = kernel_G(curr_X_tilde, curr_X_spatial) 320 321 Kuu_chol = torch.cholesky(Kuu) 322 self.Kuu_chol_list[vv, :, :] = Kuu_chol 323 324 mu_tilde, Sigma_tilde = self.compute_mean_and_var( 325 Kff_diag, 326 Kuf, 327 Kuu_chol, 328 mu_x_G, 329 self.mu_z_G, 330 self.delta_G_list, 331 self.curr_Omega_tril_list, 332 ) 333 334 # Sample 335 G_marginal_dist = torch.distributions.Normal( 336 mu_tilde[vv], 337 Sigma_tilde[ 338 vv * self.n_spatial_dims : vv * self.n_spatial_dims 339 + self.n_spatial_dims 340 ].t(), 341 ) 342 343 for mm, mod in enumerate(self.modality_names): 344 curr_idx = curr_mod_idx[mm] 345 G_means[mod][view_idx[mod][vv]] = mu_tilde[vv][curr_idx] 346 347 for ss in range(S): 348 349 curr_G_sample = G_marginal_dist.rsample() 350 for mm, mod in enumerate(self.modality_names): 351 curr_idx = curr_mod_idx[mm] 352 G_samples[mod][ss, view_idx[mod][vv]] = curr_G_sample[curr_idx] 353 354 self.curr_Omega_tril_F = {} 355 for mod in self.modality_names: 356 self.curr_Omega_tril_F[mod] = torch.zeros( 357 [self.n_latent_outputs[mod], self.m_G, self.m_G], device=device 358 ) 359 360 F_samples = {} 361 self.F_latent_samples = {} 362 self.F_observed_samples = {} 363 for mod in self.modality_names: 364 F_samples[mod] = torch.zeros([S, Ns[mod], self.n_latent_outputs[mod]]) 365 self.F_latent_samples[mod] = torch.zeros( 366 [S, Ns[mod], self.n_latent_outputs[mod]], device=device 367 ) 368 self.F_observed_samples[mod] = torch.zeros([S, Ns[mod], self.Ps[mod]]) 369 370 if G_test is not None: 371 372 self.F_latent_samples_test = {} 373 self.F_observed_samples_test = {} 374 for mod in self.modality_names: 375 n_test = G_test[mod].shape[1] 376 self.F_latent_samples_test[mod] = torch.zeros( 377 [S, n_test, self.n_latent_outputs[mod]] 378 ) 379 self.F_observed_samples_test[mod] = torch.zeros( 380 [S, n_test, self.Ps[mod]] 381 ) 382 383 kernel_F = lambda x1, x2, diag=False: self.kernel_func_data( 384 x1, 385 x2, 386 lengthscale_unconstrained=self.data_kernel_lengthscale, 387 output_variance_unconstrained=self.data_kernel_variance, 388 diag=diag, 389 ) 390 391 Kuu = kernel_F(self.Gtilde, self.Gtilde) + self.diagonal_offset * torch.eye( 392 self.m_G, device=device 393 ) 394 395 self.Kuu_chol_F = torch.cholesky(Kuu) 396 397 for mod in self.modality_names: 398 399 mu_x_F = torch.zeros([Ns[mod], self.n_latent_outputs[mod]], device=device) 400 mu_z_F = torch.zeros([self.m_G, self.n_latent_outputs[mod]], device=device) 401 402 # Kff_diag = ( 403 # kernel_F(G_samples[mod], G_samples[mod], diag=True) 404 # + self.diagonal_offset 405 # ) 406 Kff_diag = torch.ones( 407 (G_samples[mod].shape[:2]), device=device 408 ) * torch.exp(self.data_kernel_variance) 409 410 Kuf = kernel_F(self.Gtilde, G_samples[mod]) 411 curr_Omega = self.get_Omega_from_Omega_sqt(self.Omega_sqt_F_dict[mod]) 412 413 self.curr_Omega_tril_F[mod] = torch.cholesky(curr_Omega) 414 mu_tilde, Sigma_tilde = self.compute_mean_and_var( 415 Kff_diag, 416 Kuf, 417 self.Kuu_chol_F, 418 mu_x_F, 419 mu_z_F, 420 self.delta_F_dict[mod], 421 self.curr_Omega_tril_F[mod], 422 ) 423 424 eps = torch.randn(mu_tilde.shape, device=device) 425 curr_F_latent_samples = ( 426 mu_tilde + torch.sqrt(torch.transpose(Sigma_tilde, 1, 2)) * eps 427 ) 428 429 if self.n_latent_gps[mod] is not None: 430 curr_W = self.W_dict[mod] 431 F_observed_mean = torch.matmul(curr_F_latent_samples, curr_W) 432 else: 433 F_observed_mean = curr_F_latent_samples 434 435 self.F_latent_samples[mod] = curr_F_latent_samples 436 self.F_observed_samples[mod] = F_observed_mean 437 438 ## For test samples 439 if G_test is not None: 440 # Kff_diag = ( 441 # kernel_F(G_samples[mod], G_samples[mod], diag=True) 442 # + self.diagonal_offset 443 # ) 444 Kff_diag = torch.ones( 445 (G_test[mod].shape[:2]), device=device 446 ) * torch.exp( 447 self.data_kernel_variance, 448 ) 449 450 mu_x_F = torch.zeros( 451 [G_test[mod].shape[1], self.n_latent_outputs[mod]], device=device 452 ) 453 454 Kuf = kernel_F(self.Gtilde, G_test[mod]) 455 456 mu_tilde, Sigma_tilde = self.compute_mean_and_var( 457 Kff_diag, 458 Kuf, 459 self.Kuu_chol_F, 460 mu_x_F, 461 mu_z_F, 462 self.delta_F_dict[mod], 463 self.curr_Omega_tril_F[mod], 464 ) 465 466 eps = torch.randn(mu_tilde.shape, device=device) 467 curr_F_latent_samples = ( 468 mu_tilde + torch.sqrt(torch.transpose(Sigma_tilde, 1, 2)) * eps 469 ) 470 471 if self.n_latent_gps[mod] is not None: 472 curr_W = self.W_dict[mod] 473 F_observed_mean = torch.matmul(curr_F_latent_samples, curr_W) 474 else: 475 F_observed_mean = curr_F_latent_samples 476 477 self.F_latent_samples_test[mod] = curr_F_latent_samples 478 self.F_observed_samples_test[mod] = F_observed_mean 479 480 if G_test is not None: 481 return ( 482 G_means, 483 G_samples, 484 self.F_latent_samples, 485 self.F_observed_samples, 486 self.F_latent_samples_test, 487 self.F_observed_samples_test, 488 ) 489 else: 490 return G_means, G_samples, self.F_latent_samples, self.F_observed_samples 491 492 def loss_fn(self, data_dict, F_samples): 493 # This computes the the negative (approximate) ELBO 494 495 # Running sum for KL terms 496 KL_div = 0 497 498 ## G 499 for vv in range(self.n_views): 500 if self.fixed_view_idx is not None and ( 501 vv in self.fixed_view_idx 502 if isinstance(self.fixed_view_idx, Iterable) 503 else self.fixed_view_idx == vv 504 ): 505 continue 506 for jj in range(self.n_spatial_dims): 507 qu = torch.distributions.MultivariateNormal( 508 loc=self.delta_G_list[vv, :, jj], 509 scale_tril=self.curr_Omega_tril_list[jj * self.n_views + vv, :, :], 510 ) 511 pu = torch.distributions.MultivariateNormal( 512 loc=self.mu_z_G[vv, :, jj], 513 scale_tril=self.Kuu_chol_list[vv, :, :], 514 ) 515 curr_KL_div = torch.distributions.kl.kl_divergence(qu, pu) 516 517 KL_div += curr_KL_div 518 519 ## F 520 LL = 0 521 pu = torch.distributions.MultivariateNormal( 522 loc=torch.zeros(self.m_G, device=device), scale_tril=self.Kuu_chol_F 523 ) 524 for mm, mod in enumerate(self.modality_names): 525 qu = torch.distributions.MultivariateNormal( 526 loc=self.delta_F_dict[mod].t(), 527 scale_tril=self.curr_Omega_tril_F[mod], 528 ) 529 530 curr_KL_div = torch.distributions.kl.kl_divergence(qu, pu) 531 KL_div += curr_KL_div.sum() 532 533 Y_distribution = torch.distributions.Normal( 534 loc=F_samples[mod], 535 scale=self.noise_variance_pos[-self.n_modalities + mm], 536 ) 537 S = F_samples[mod].shape[0] 538 539 LL += Y_distribution.log_prob(data_dict[mod]["outputs"]).sum() / S 540 541 return -LL + KL_div
Args: data_dict (dict): Dictionary of data in the format {"modality": {"spatial_coords": X, "outputs": Y, "n_samples_list": n_samples_list}} data_init (bool, optional): Whether to initialize inducing locations with KMeans on data. n_spatial_dims (int, optional): Number of spatial dimensions (usually 2 or 3). n_noise_variance_params (int, optional): Number of noise variance parameters. kernel_func_warp (function, optional): Covariance function for warp GP. kernel_func_data (function, optional): Covariance function for output GP. mean_function (str, optional): Mean function for warp GP. One of ["identity_fixed", "identity_initialized", or None]. None results in a linear mean function. mean_penalty_param (float, optional): Description fixed_warp_kernel_variances (None, optional): Description fixed_warp_kernel_lengthscales (None, optional): Description fixed_data_kernel_lengthscales (None, optional): Description
16 def __init__( 17 self, 18 data_dict, 19 m_X_per_view, 20 m_G, 21 data_init=True, 22 minmax_init=False, 23 grid_init=False, 24 n_spatial_dims=2, 25 n_noise_variance_params=2, 26 kernel_func_warp=rbf_kernel, 27 kernel_func_data=rbf_kernel, 28 n_latent_gps=None, 29 mean_function="identity_fixed", 30 mean_penalty_param=0.0, 31 fixed_warp_kernel_variances=None, 32 fixed_warp_kernel_lengthscales=None, 33 fixed_data_kernel_lengthscales=None, 34 fixed_view_idx=None, 35 ): 36 super(VariationalGPSA, self).__init__( 37 data_dict, 38 data_init=True, 39 n_spatial_dims=2, 40 n_noise_variance_params=2, 41 kernel_func_warp=kernel_func_warp, 42 kernel_func_data=kernel_func_data, 43 mean_penalty_param=mean_penalty_param, 44 fixed_warp_kernel_variances=fixed_warp_kernel_variances, 45 fixed_warp_kernel_lengthscales=fixed_warp_kernel_lengthscales, 46 fixed_data_kernel_lengthscales=fixed_data_kernel_lengthscales, 47 ) 48 49 self.m_X_per_view = m_X_per_view 50 self.m_G = m_G 51 self.n_latent_gps = n_latent_gps 52 self.n_latent_outputs = {} 53 for mod in self.modality_names: 54 curr_n_latent_outputs = ( 55 self.n_latent_gps[mod] 56 if self.n_latent_gps[mod] is not None 57 else self.Ps[mod] 58 ) 59 self.n_latent_outputs[mod] = curr_n_latent_outputs 60 self.fixed_view_idx = fixed_view_idx 61 62 if data_init: 63 # Initialize inducing locations with a subset of the data 64 Xtilde = torch.zeros([self.n_views, self.m_X_per_view, self.n_spatial_dims]) 65 for ii in range(self.n_views): 66 curr_X_spatial_list = [] 67 for mod in self.modality_names: 68 curr_idx = self.view_idx[mod][ii] 69 curr_modality_and_view_spatial = data_dict[mod]["spatial_coords"][ 70 curr_idx, : 71 ] 72 curr_X_spatial_list.append(curr_modality_and_view_spatial) 73 curr_X_spatial = torch.cat(curr_X_spatial_list, dim=0) 74 75 kmeans = KMeans(n_clusters=self.m_X_per_view) 76 kmeans.fit(curr_X_spatial.detach().cpu().numpy()) 77 Xtilde[ii, :, :] = torch.tensor(kmeans.cluster_centers_) 78 79 self.Xtilde = nn.Parameter(Xtilde.clone()) 80 # self.Xtilde = Xtilde.clone() 81 82 rand_idx = np.random.choice( 83 np.arange(curr_X_spatial.shape[0]), 84 size=self.m_G, 85 replace=False, 86 ) 87 88 all_X_spatial = torch.cat( 89 [data_dict[mod]["spatial_coords"] for mod in self.modality_names] 90 ) 91 kmeans = KMeans(n_clusters=self.m_G) 92 kmeans.fit(all_X_spatial.detach().cpu().numpy()) 93 self.Gtilde = nn.Parameter(torch.tensor(kmeans.cluster_centers_)) 94 95 elif grid_init: 96 97 if self.n_spatial_dims == 2: 98 xlow, ylow = ( 99 data_dict[self.modality_names[0]]["spatial_coords"].numpy().min(0) 100 ) 101 xhigh, yhigh = ( 102 data_dict[self.modality_names[0]]["spatial_coords"].numpy().max(0) 103 ) 104 xlimits = [xlow, xhigh] 105 ylimits = [ylow, yhigh] 106 numticks = np.ceil(np.sqrt(self.m_G)).astype(int) 107 self.m_G = numticks**2 108 self.m_X_per_view = numticks**2 109 x1s = np.linspace(*xlimits, num=numticks) 110 x2s = np.linspace(*ylimits, num=numticks) 111 X1, X2 = np.meshgrid(x1s, x2s) 112 Xtilde = np.vstack([X1.ravel(), X2.ravel()]).T 113 Xtilde_torch = torch.zeros( 114 [self.n_views, Xtilde.shape[0], self.n_spatial_dims] 115 ) 116 for vv in range(self.n_views): 117 Xtilde_torch[vv] = torch.tensor(Xtilde) 118 119 # self.Xtilde = Xtilde_torch.clone() 120 # self.Gtilde = torch.tensor(Xtilde).float() 121 self.Xtilde = nn.Parameter(Xtilde_torch.clone()) 122 self.Gtilde = nn.Parameter(torch.tensor(Xtilde).float()) 123 124 else: 125 # Random initialization of inducing locations 126 self.Xtilde = nn.Parameter( 127 torch.randn([self.n_views, self.m_X_per_view, self.n_spatial_dims]) 128 ) 129 self.Gtilde = nn.Parameter(torch.randn([self.m_G, self.n_spatial_dims])) 130 131 ## Variational covariance parameters 132 Omega_sqt_G_list = torch.zeros( 133 [self.n_views * self.n_spatial_dims, self.m_X_per_view, self.m_X_per_view], 134 device=device, 135 ) 136 for ii in range(self.n_views): 137 for jj in range(self.n_spatial_dims): 138 Omega_sqt = 0.1 * torch.randn( 139 size=[self.m_X_per_view, self.m_X_per_view] 140 ) 141 # import ipdb; ipdb.set_trace() 142 # Omega_sqt_G_list[ii * self.n_views + jj, :, :] = Omega_sqt 143 Omega_sqt_G_list[jj * self.n_views + ii, :, :] = Omega_sqt 144 self.Omega_sqt_G_list = nn.Parameter(Omega_sqt_G_list) 145 146 Omega_sqt_F_dict = torch.nn.ParameterDict() 147 for mod in self.modality_names: 148 num_outputs = self.Ps[mod] 149 curr_Omega = torch.zeros([self.n_latent_outputs[mod], self.m_G, self.m_G]) 150 for jj in range(self.n_latent_outputs[mod]): 151 Omega_sqt = 0.1 * torch.randn(size=[self.m_G, self.m_G]) 152 curr_Omega[jj, :, :] = Omega_sqt 153 Omega_sqt_F_dict[mod] = nn.Parameter(curr_Omega) 154 self.Omega_sqt_F_dict = Omega_sqt_F_dict 155 156 ## Variational mean parameters 157 self.delta_G_list = nn.Parameter(self.Xtilde.clone()) 158 delta_F_dict = torch.nn.ParameterDict() 159 for mod in self.modality_names: 160 num_outputs = self.Ps[mod] 161 curr_delta = nn.Parameter( 162 torch.randn(size=[self.m_G, self.n_latent_outputs[mod]], device=device) 163 ) 164 delta_F_dict[mod] = curr_delta 165 self.delta_F_dict = delta_F_dict 166 167 ## LMC parameters 168 self.W_dict = torch.nn.ParameterDict() 169 for mod in self.modality_names: 170 if self.n_latent_gps[mod] is not None: 171 self.W_dict[mod] = nn.Parameter( 172 torch.randn([self.n_latent_gps[mod], self.Ps[mod]]) 173 )
Initializes internal Module state, shared by both nn.Module and ScriptModule.
175 def compute_mean_and_var( 176 self, Kff_diag, Kuf, Kuu_chol, mu_x, mu_z, delta, Omega_tril 177 ): 178 alpha_x = torch.cholesky_solve(Kuf, Kuu_chol) 179 180 a_t_Kchol = torch.matmul(alpha_x.transpose(-1, -2), Kuu_chol) 181 aKa = torch.sum(torch.square(a_t_Kchol), dim=-1) 182 183 mu_tilde = mu_x.unsqueeze(0) + torch.matmul( 184 alpha_x.transpose(-1, -2), delta - mu_z 185 ) 186 187 if len(alpha_x.shape) == 2: 188 a_t_Omega_tril = torch.matmul( 189 alpha_x.transpose(-1, -2).unsqueeze(0), Omega_tril 190 ) 191 aOmega_a = torch.sum(torch.square(a_t_Omega_tril), dim=-1) 192 Sigma_tilde = Kff_diag - aKa + aOmega_a + self.diagonal_offset 193 else: 194 a_t_Omega_tril = torch.matmul( 195 alpha_x.transpose(-1, -2).unsqueeze(1), Omega_tril.unsqueeze(0) 196 ) 197 aOmega_a = torch.sum(torch.square(a_t_Omega_tril), dim=-1) 198 Sigma_tilde = ( 199 Kff_diag.unsqueeze(1) 200 - aKa.unsqueeze(1) 201 + aOmega_a 202 + self.diagonal_offset 203 ) 204 205 return mu_tilde, Sigma_tilde + self.diagonal_offset
213 def forward(self, X_spatial, view_idx, Ns, S=1, prediction_mode=False, G_test=None): 214 215 if prediction_mode: 216 self.eval() 217 218 self.noise_variance_pos = torch.exp(self.noise_variance) + self.diagonal_offset 219 220 self.mu_z_G = ( 221 torch.zeros( 222 [self.n_views, self.m_X_per_view, self.n_spatial_dims], device=device 223 ) 224 * np.nan 225 ) 226 for vv in range(self.n_views): 227 self.mu_z_G[vv] = ( 228 torch.mm(self.Xtilde[vv], self.mean_slopes[vv]) 229 + self.mean_intercepts[vv] 230 ) 231 if self.fixed_view_idx is not None and ( 232 vv in self.fixed_view_idx 233 if isinstance(self.fixed_view_idx, Iterable) 234 else self.fixed_view_idx == vv 235 ): 236 self.mu_z_G[vv] *= 100.0 237 238 self.Kuu_chol_list = ( 239 torch.zeros( 240 [self.n_views, self.m_X_per_view, self.m_X_per_view], device=device 241 ) 242 * np.nan 243 ) 244 G_samples = {} 245 for mod in self.modality_names: 246 G_samples[mod] = ( 247 torch.zeros([S, Ns[mod], self.n_spatial_dims], device=device) * np.nan 248 ) 249 250 G_means = {} 251 for mod in self.modality_names: 252 G_means[mod] = ( 253 torch.zeros([Ns[mod], self.n_spatial_dims], device=device) * np.nan 254 ) 255 256 curr_Omega_G = self.get_Omega_from_Omega_sqt(self.Omega_sqt_G_list) 257 258 self.curr_Omega_tril_list = torch.cholesky(curr_Omega_G) 259 260 for vv in range(self.n_views): 261 262 ## If this view is fixed (template-based alignment), then we don't need to sample for it. 263 if self.fixed_view_idx is not None and ( 264 vv in self.fixed_view_idx 265 if isinstance(self.fixed_view_idx, Iterable) 266 else self.fixed_view_idx == vv 267 ): 268 for mm, mod in enumerate(self.modality_names): 269 observed_X_spatial = X_spatial[mod][view_idx[mod][vv]] 270 G_means[mod][view_idx[mod][vv]] = observed_X_spatial 271 272 G_samples[mod][:, view_idx[mod][vv], :] = observed_X_spatial 273 274 continue 275 276 kernel_G = lambda x1, x2, diag=False: self.kernel_func_warp( 277 x1, 278 x2, 279 lengthscale_unconstrained=self.warp_kernel_lengthscales[vv], 280 output_variance_unconstrained=self.warp_kernel_variances[vv], 281 diag=diag, 282 ) 283 284 ## Collect data from all modalities for this view 285 curr_X_spatial_list = [] 286 curr_n = 0 287 curr_mod_idx = [] 288 for mod in self.modality_names: 289 curr_idx = view_idx[mod][vv] 290 curr_mod_idx.append(np.arange(curr_n, curr_n + len(curr_idx))) 291 curr_n += len(curr_idx) 292 curr_modality_and_view_spatial = X_spatial[mod][curr_idx, :] 293 curr_X_spatial_list.append(curr_modality_and_view_spatial) 294 295 curr_X_spatial = torch.cat(curr_X_spatial_list, dim=0) 296 297 if len(curr_X_spatial) == 0: 298 continue 299 300 curr_X_tilde = self.Xtilde[vv] 301 302 mu_x_G = ( 303 torch.mm(curr_X_spatial, self.mean_slopes[vv]) 304 + self.mean_intercepts[vv] 305 ) 306 307 # Kff_diag = ( 308 # kernel_G(curr_X_spatial, curr_X_spatial, diag=True) 309 # + self.diagonal_offset 310 # ) 311 Kff_diag = torch.ones((curr_X_spatial.shape[0]), device=device) * torch.exp( 312 self.warp_kernel_variances[vv] 313 ) 314 315 Kuu = kernel_G( 316 curr_X_tilde, curr_X_tilde 317 ) + self.diagonal_offset * torch.eye(self.m_X_per_view, device=device) 318 319 Kuf = kernel_G(curr_X_tilde, curr_X_spatial) 320 321 Kuu_chol = torch.cholesky(Kuu) 322 self.Kuu_chol_list[vv, :, :] = Kuu_chol 323 324 mu_tilde, Sigma_tilde = self.compute_mean_and_var( 325 Kff_diag, 326 Kuf, 327 Kuu_chol, 328 mu_x_G, 329 self.mu_z_G, 330 self.delta_G_list, 331 self.curr_Omega_tril_list, 332 ) 333 334 # Sample 335 G_marginal_dist = torch.distributions.Normal( 336 mu_tilde[vv], 337 Sigma_tilde[ 338 vv * self.n_spatial_dims : vv * self.n_spatial_dims 339 + self.n_spatial_dims 340 ].t(), 341 ) 342 343 for mm, mod in enumerate(self.modality_names): 344 curr_idx = curr_mod_idx[mm] 345 G_means[mod][view_idx[mod][vv]] = mu_tilde[vv][curr_idx] 346 347 for ss in range(S): 348 349 curr_G_sample = G_marginal_dist.rsample() 350 for mm, mod in enumerate(self.modality_names): 351 curr_idx = curr_mod_idx[mm] 352 G_samples[mod][ss, view_idx[mod][vv]] = curr_G_sample[curr_idx] 353 354 self.curr_Omega_tril_F = {} 355 for mod in self.modality_names: 356 self.curr_Omega_tril_F[mod] = torch.zeros( 357 [self.n_latent_outputs[mod], self.m_G, self.m_G], device=device 358 ) 359 360 F_samples = {} 361 self.F_latent_samples = {} 362 self.F_observed_samples = {} 363 for mod in self.modality_names: 364 F_samples[mod] = torch.zeros([S, Ns[mod], self.n_latent_outputs[mod]]) 365 self.F_latent_samples[mod] = torch.zeros( 366 [S, Ns[mod], self.n_latent_outputs[mod]], device=device 367 ) 368 self.F_observed_samples[mod] = torch.zeros([S, Ns[mod], self.Ps[mod]]) 369 370 if G_test is not None: 371 372 self.F_latent_samples_test = {} 373 self.F_observed_samples_test = {} 374 for mod in self.modality_names: 375 n_test = G_test[mod].shape[1] 376 self.F_latent_samples_test[mod] = torch.zeros( 377 [S, n_test, self.n_latent_outputs[mod]] 378 ) 379 self.F_observed_samples_test[mod] = torch.zeros( 380 [S, n_test, self.Ps[mod]] 381 ) 382 383 kernel_F = lambda x1, x2, diag=False: self.kernel_func_data( 384 x1, 385 x2, 386 lengthscale_unconstrained=self.data_kernel_lengthscale, 387 output_variance_unconstrained=self.data_kernel_variance, 388 diag=diag, 389 ) 390 391 Kuu = kernel_F(self.Gtilde, self.Gtilde) + self.diagonal_offset * torch.eye( 392 self.m_G, device=device 393 ) 394 395 self.Kuu_chol_F = torch.cholesky(Kuu) 396 397 for mod in self.modality_names: 398 399 mu_x_F = torch.zeros([Ns[mod], self.n_latent_outputs[mod]], device=device) 400 mu_z_F = torch.zeros([self.m_G, self.n_latent_outputs[mod]], device=device) 401 402 # Kff_diag = ( 403 # kernel_F(G_samples[mod], G_samples[mod], diag=True) 404 # + self.diagonal_offset 405 # ) 406 Kff_diag = torch.ones( 407 (G_samples[mod].shape[:2]), device=device 408 ) * torch.exp(self.data_kernel_variance) 409 410 Kuf = kernel_F(self.Gtilde, G_samples[mod]) 411 curr_Omega = self.get_Omega_from_Omega_sqt(self.Omega_sqt_F_dict[mod]) 412 413 self.curr_Omega_tril_F[mod] = torch.cholesky(curr_Omega) 414 mu_tilde, Sigma_tilde = self.compute_mean_and_var( 415 Kff_diag, 416 Kuf, 417 self.Kuu_chol_F, 418 mu_x_F, 419 mu_z_F, 420 self.delta_F_dict[mod], 421 self.curr_Omega_tril_F[mod], 422 ) 423 424 eps = torch.randn(mu_tilde.shape, device=device) 425 curr_F_latent_samples = ( 426 mu_tilde + torch.sqrt(torch.transpose(Sigma_tilde, 1, 2)) * eps 427 ) 428 429 if self.n_latent_gps[mod] is not None: 430 curr_W = self.W_dict[mod] 431 F_observed_mean = torch.matmul(curr_F_latent_samples, curr_W) 432 else: 433 F_observed_mean = curr_F_latent_samples 434 435 self.F_latent_samples[mod] = curr_F_latent_samples 436 self.F_observed_samples[mod] = F_observed_mean 437 438 ## For test samples 439 if G_test is not None: 440 # Kff_diag = ( 441 # kernel_F(G_samples[mod], G_samples[mod], diag=True) 442 # + self.diagonal_offset 443 # ) 444 Kff_diag = torch.ones( 445 (G_test[mod].shape[:2]), device=device 446 ) * torch.exp( 447 self.data_kernel_variance, 448 ) 449 450 mu_x_F = torch.zeros( 451 [G_test[mod].shape[1], self.n_latent_outputs[mod]], device=device 452 ) 453 454 Kuf = kernel_F(self.Gtilde, G_test[mod]) 455 456 mu_tilde, Sigma_tilde = self.compute_mean_and_var( 457 Kff_diag, 458 Kuf, 459 self.Kuu_chol_F, 460 mu_x_F, 461 mu_z_F, 462 self.delta_F_dict[mod], 463 self.curr_Omega_tril_F[mod], 464 ) 465 466 eps = torch.randn(mu_tilde.shape, device=device) 467 curr_F_latent_samples = ( 468 mu_tilde + torch.sqrt(torch.transpose(Sigma_tilde, 1, 2)) * eps 469 ) 470 471 if self.n_latent_gps[mod] is not None: 472 curr_W = self.W_dict[mod] 473 F_observed_mean = torch.matmul(curr_F_latent_samples, curr_W) 474 else: 475 F_observed_mean = curr_F_latent_samples 476 477 self.F_latent_samples_test[mod] = curr_F_latent_samples 478 self.F_observed_samples_test[mod] = F_observed_mean 479 480 if G_test is not None: 481 return ( 482 G_means, 483 G_samples, 484 self.F_latent_samples, 485 self.F_observed_samples, 486 self.F_latent_samples_test, 487 self.F_observed_samples_test, 488 ) 489 else: 490 return G_means, G_samples, self.F_latent_samples, self.F_observed_samples
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
492 def loss_fn(self, data_dict, F_samples): 493 # This computes the the negative (approximate) ELBO 494 495 # Running sum for KL terms 496 KL_div = 0 497 498 ## G 499 for vv in range(self.n_views): 500 if self.fixed_view_idx is not None and ( 501 vv in self.fixed_view_idx 502 if isinstance(self.fixed_view_idx, Iterable) 503 else self.fixed_view_idx == vv 504 ): 505 continue 506 for jj in range(self.n_spatial_dims): 507 qu = torch.distributions.MultivariateNormal( 508 loc=self.delta_G_list[vv, :, jj], 509 scale_tril=self.curr_Omega_tril_list[jj * self.n_views + vv, :, :], 510 ) 511 pu = torch.distributions.MultivariateNormal( 512 loc=self.mu_z_G[vv, :, jj], 513 scale_tril=self.Kuu_chol_list[vv, :, :], 514 ) 515 curr_KL_div = torch.distributions.kl.kl_divergence(qu, pu) 516 517 KL_div += curr_KL_div 518 519 ## F 520 LL = 0 521 pu = torch.distributions.MultivariateNormal( 522 loc=torch.zeros(self.m_G, device=device), scale_tril=self.Kuu_chol_F 523 ) 524 for mm, mod in enumerate(self.modality_names): 525 qu = torch.distributions.MultivariateNormal( 526 loc=self.delta_F_dict[mod].t(), 527 scale_tril=self.curr_Omega_tril_F[mod], 528 ) 529 530 curr_KL_div = torch.distributions.kl.kl_divergence(qu, pu) 531 KL_div += curr_KL_div.sum() 532 533 Y_distribution = torch.distributions.Normal( 534 loc=F_samples[mod], 535 scale=self.noise_variance_pos[-self.n_modalities + mm], 536 ) 537 S = F_samples[mod].shape[0] 538 539 LL += Y_distribution.log_prob(data_dict[mod]["outputs"]).sum() / S 540 541 return -LL + KL_div
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- register_buffer
- register_parameter
- add_module
- apply
- cuda
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- state_dict
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr