gpsa.models.gpsa
1import torch 2import numpy as np 3import torch.nn as nn 4from ..util.util import rbf_kernel 5 6device = "cuda" if torch.cuda.is_available() else "cpu" 7 8# Define model 9class GPSA(nn.Module): 10 """ 11 Args: 12 data_dict (dict): Dictionary of data in the format {"modality": {"spatial_coords": X, "outputs": Y, "n_samples_list": n_samples_list}} 13 data_init (bool, optional): Whether to initialize inducing locations with KMeans on data. 14 n_spatial_dims (int, optional): Number of spatial dimensions (usually 2 or 3). 15 n_noise_variance_params (int, optional): Number of noise variance parameters. 16 kernel_func_warp (function, optional): Covariance function for warp GP. 17 kernel_func_data (function, optional): Covariance function for output GP. 18 mean_function (str, optional): Mean function for warp GP. One of ["identity_fixed", "identity_initialized", or None]. None results in a linear mean function. 19 mean_penalty_param (float, optional): Description 20 fixed_warp_kernel_variances (None, optional): Description 21 fixed_warp_kernel_lengthscales (None, optional): Description 22 fixed_data_kernel_lengthscales (None, optional): Description 23 """ 24 25 def __init__( 26 self, 27 data_dict, 28 data_init=True, 29 n_spatial_dims=2, 30 n_noise_variance_params=2, 31 kernel_func_warp=rbf_kernel, 32 kernel_func_data=rbf_kernel, 33 mean_function="identity_fixed", 34 mean_penalty_param=0.0, 35 fixed_warp_kernel_variances=None, 36 fixed_warp_kernel_lengthscales=None, 37 fixed_data_kernel_lengthscales=None, 38 ): 39 # Constructor 40 super(GPSA, self).__init__() 41 self.modality_names = list(data_dict.keys()) 42 self.n_modalities = len(self.modality_names) 43 self.mean_penalty_param = mean_penalty_param 44 45 ## Make sure all modalities have the same number of "views" 46 n_views = np.unique( 47 np.array( 48 [len(data_dict[mod]["n_samples_list"]) for mod in self.modality_names] 49 ) 50 ) 51 if len(n_views) != 1: 52 raise ValueError("Each modality must have the same number of views.") 53 self.n_views = n_views[0] 54 55 ## Make sure all modalities have the same domain for the spatial coordinates 56 n_spatial_dims = np.unique( 57 np.array( 58 [ 59 data_dict[mod]["spatial_coords"].shape[1] 60 for mod in self.modality_names 61 ] 62 ) 63 ) 64 if len(n_spatial_dims) != 1: 65 raise ValueError( 66 "Each modality must have the same number of spatial dimensions." 67 ) 68 self.n_spatial_dims = n_spatial_dims[0] 69 70 view_idx, Ns, Ps, n_total = self.create_view_idx_dict(data_dict) 71 self.view_idx = view_idx 72 self.Ns = Ns 73 self.Ps = Ps 74 self.n_total = n_total 75 # import ipdb; ipdb.set_trace() 76 77 ## Number of kernel parameters: 78 ## - 2 parameters for each view for warp GP (lengthscale and variance) 79 ## - 2 parameters for observation GP (lengthscale and variance) 80 self.n_kernel_params = 2 * self.n_views + 2 81 self.n_noise_variance_params = n_noise_variance_params 82 self.kernel_func_warp = kernel_func_warp 83 self.kernel_func_data = kernel_func_data 84 85 ## Parameters 86 self.noise_variance = nn.Parameter( 87 torch.randn([self.n_noise_variance_params]) - 1 88 ) 89 # self.noise_variance = torch.log(torch.ones(2) * 0.001) 90 91 if fixed_warp_kernel_variances is None: 92 # self.warp_kernel_variances = nn.Parameter( 93 # torch.randn([self.n_kernel_params // 2 - 1]) - 1 94 # ) 95 self.warp_kernel_variances = nn.Parameter( 96 torch.zeros(self.n_kernel_params // 2 - 1) 97 ) 98 else: 99 self.warp_kernel_variances = torch.log( 100 torch.tensor(fixed_warp_kernel_variances) 101 ) 102 103 if fixed_warp_kernel_lengthscales is None: 104 # self.warp_kernel_lengthscales = nn.Parameter( 105 # torch.randn([self.n_kernel_params // 2 - 1]) + 3 106 # ) 107 self.warp_kernel_lengthscales = nn.Parameter( 108 torch.zeros(self.n_kernel_params // 2 - 1) + np.log(10) 109 ) 110 else: 111 self.warp_kernel_lengthscales = torch.log( 112 torch.tensor(fixed_warp_kernel_lengthscales) 113 ) 114 115 if fixed_data_kernel_lengthscales is None: 116 self.data_kernel_lengthscale = nn.Parameter( 117 torch.log(torch.exp(torch.randn(1))) 118 ) 119 else: 120 self.data_kernel_lengthscale = torch.log( 121 torch.tensor(fixed_data_kernel_lengthscales).float() 122 ) 123 124 self.data_kernel_variance = nn.Parameter(torch.randn(1)) 125 # self.data_kernel_variance = nn.Parameter(torch.zeros(1)) 126 # self.data_kernel_variance = torch.tensor(0.).float() 127 128 if mean_function == "identity_fixed": 129 self.mean_slopes = ( 130 torch.eye(self.n_spatial_dims, device=device) 131 .unsqueeze(0) 132 .repeat(self.n_views, 1, 1) 133 ) 134 self.mean_intercepts = torch.zeros( 135 [self.n_views, self.n_spatial_dims], device=device 136 ) 137 elif mean_function == "identity_initialized": 138 self.mean_slopes = nn.Parameter( 139 torch.randn([self.n_views, self.n_spatial_dims, self.n_spatial_dims]) 140 ) 141 self.mean_intercepts = nn.Parameter( 142 torch.zeros([self.n_views, self.n_spatial_dims]) 143 ) 144 else: 145 self.mean_slopes = nn.Parameter( 146 torch.eye(self.n_spatial_dims).unsqueeze(0).repeat(self.n_views, 1, 1) 147 ) 148 self.mean_intercepts = nn.Parameter( 149 torch.randn([self.n_views, self.n_spatial_dims]) * 0.1 150 ) 151 152 # self.mean_intercepts = torch.zeros([self.n_views, self.n_spatial_dims]) 153 self.diagonal_offset = 1e-5 154 155 def create_view_idx_dict(self, data_dict): 156 """Summary 157 158 Args: 159 data_dict (TYPE): Description 160 161 Returns: 162 TYPE: Description 163 """ 164 view_idx, Ns, Ps = {}, {}, {} 165 n_total = 0 166 for mod in self.modality_names: 167 n_samples_list = data_dict[mod]["n_samples_list"] 168 # self.n_samples_lists[mod] = n_samples_list 169 curr_N = np.sum(n_samples_list) 170 Ns[mod] = curr_N 171 n_total += curr_N 172 Ps[mod] = data_dict[mod]["outputs"].shape[1] 173 174 # Compute the indices of each view for each modality 175 cumulative_sums = np.cumsum(n_samples_list) 176 cumulative_sums = np.insert(cumulative_sums, 0, 0) 177 curr_view_idx = [ 178 np.arange(cumulative_sums[ii], cumulative_sums[ii + 1]) 179 for ii in range(self.n_views) 180 ] 181 view_idx[mod] = curr_view_idx 182 183 return view_idx, Ns, Ps, n_total 184 185 def compute_mean_penalty(self): 186 return self.mean_penalty_param * torch.mean( 187 torch.square( 188 self.mean_slopes 189 - torch.eye(self.n_spatial_dims).unsqueeze(0).repeat(self.n_views, 1, 1) 190 ) 191 ) 192 193 def forward(self, X_spatial): 194 raise (NotImplementedError) 195 196 def loss_fn(self, data_dict, Gs, means_G_list, covs_G_list, means_Y, covs_Y): 197 raise (NotImplementedError) 198 199 200def distance_matrix(X, Y): 201 """Compute distances between samples (rows) of two matrices 202 203 Args: 204 X (array): n x D matrix of spatial locations 205 Y (array): m x D matrix of spatial locations 206 207 Returns: 208 array: n x m matrix whose ij'th elementh is the Euclidean distance between i'th row of X and j'th row of Y 209 """ 210 squared_diffs = torch.square(torch.unsqueeze(X, 0) - torch.unsqueeze(Y, 1)) 211 squared_distances = torch.sum(squared_diffs, dim=2) 212 return squared_distances 213 214 215if __name__ == "__main__": 216 217 pass
10class GPSA(nn.Module): 11 """ 12 Args: 13 data_dict (dict): Dictionary of data in the format {"modality": {"spatial_coords": X, "outputs": Y, "n_samples_list": n_samples_list}} 14 data_init (bool, optional): Whether to initialize inducing locations with KMeans on data. 15 n_spatial_dims (int, optional): Number of spatial dimensions (usually 2 or 3). 16 n_noise_variance_params (int, optional): Number of noise variance parameters. 17 kernel_func_warp (function, optional): Covariance function for warp GP. 18 kernel_func_data (function, optional): Covariance function for output GP. 19 mean_function (str, optional): Mean function for warp GP. One of ["identity_fixed", "identity_initialized", or None]. None results in a linear mean function. 20 mean_penalty_param (float, optional): Description 21 fixed_warp_kernel_variances (None, optional): Description 22 fixed_warp_kernel_lengthscales (None, optional): Description 23 fixed_data_kernel_lengthscales (None, optional): Description 24 """ 25 26 def __init__( 27 self, 28 data_dict, 29 data_init=True, 30 n_spatial_dims=2, 31 n_noise_variance_params=2, 32 kernel_func_warp=rbf_kernel, 33 kernel_func_data=rbf_kernel, 34 mean_function="identity_fixed", 35 mean_penalty_param=0.0, 36 fixed_warp_kernel_variances=None, 37 fixed_warp_kernel_lengthscales=None, 38 fixed_data_kernel_lengthscales=None, 39 ): 40 # Constructor 41 super(GPSA, self).__init__() 42 self.modality_names = list(data_dict.keys()) 43 self.n_modalities = len(self.modality_names) 44 self.mean_penalty_param = mean_penalty_param 45 46 ## Make sure all modalities have the same number of "views" 47 n_views = np.unique( 48 np.array( 49 [len(data_dict[mod]["n_samples_list"]) for mod in self.modality_names] 50 ) 51 ) 52 if len(n_views) != 1: 53 raise ValueError("Each modality must have the same number of views.") 54 self.n_views = n_views[0] 55 56 ## Make sure all modalities have the same domain for the spatial coordinates 57 n_spatial_dims = np.unique( 58 np.array( 59 [ 60 data_dict[mod]["spatial_coords"].shape[1] 61 for mod in self.modality_names 62 ] 63 ) 64 ) 65 if len(n_spatial_dims) != 1: 66 raise ValueError( 67 "Each modality must have the same number of spatial dimensions." 68 ) 69 self.n_spatial_dims = n_spatial_dims[0] 70 71 view_idx, Ns, Ps, n_total = self.create_view_idx_dict(data_dict) 72 self.view_idx = view_idx 73 self.Ns = Ns 74 self.Ps = Ps 75 self.n_total = n_total 76 # import ipdb; ipdb.set_trace() 77 78 ## Number of kernel parameters: 79 ## - 2 parameters for each view for warp GP (lengthscale and variance) 80 ## - 2 parameters for observation GP (lengthscale and variance) 81 self.n_kernel_params = 2 * self.n_views + 2 82 self.n_noise_variance_params = n_noise_variance_params 83 self.kernel_func_warp = kernel_func_warp 84 self.kernel_func_data = kernel_func_data 85 86 ## Parameters 87 self.noise_variance = nn.Parameter( 88 torch.randn([self.n_noise_variance_params]) - 1 89 ) 90 # self.noise_variance = torch.log(torch.ones(2) * 0.001) 91 92 if fixed_warp_kernel_variances is None: 93 # self.warp_kernel_variances = nn.Parameter( 94 # torch.randn([self.n_kernel_params // 2 - 1]) - 1 95 # ) 96 self.warp_kernel_variances = nn.Parameter( 97 torch.zeros(self.n_kernel_params // 2 - 1) 98 ) 99 else: 100 self.warp_kernel_variances = torch.log( 101 torch.tensor(fixed_warp_kernel_variances) 102 ) 103 104 if fixed_warp_kernel_lengthscales is None: 105 # self.warp_kernel_lengthscales = nn.Parameter( 106 # torch.randn([self.n_kernel_params // 2 - 1]) + 3 107 # ) 108 self.warp_kernel_lengthscales = nn.Parameter( 109 torch.zeros(self.n_kernel_params // 2 - 1) + np.log(10) 110 ) 111 else: 112 self.warp_kernel_lengthscales = torch.log( 113 torch.tensor(fixed_warp_kernel_lengthscales) 114 ) 115 116 if fixed_data_kernel_lengthscales is None: 117 self.data_kernel_lengthscale = nn.Parameter( 118 torch.log(torch.exp(torch.randn(1))) 119 ) 120 else: 121 self.data_kernel_lengthscale = torch.log( 122 torch.tensor(fixed_data_kernel_lengthscales).float() 123 ) 124 125 self.data_kernel_variance = nn.Parameter(torch.randn(1)) 126 # self.data_kernel_variance = nn.Parameter(torch.zeros(1)) 127 # self.data_kernel_variance = torch.tensor(0.).float() 128 129 if mean_function == "identity_fixed": 130 self.mean_slopes = ( 131 torch.eye(self.n_spatial_dims, device=device) 132 .unsqueeze(0) 133 .repeat(self.n_views, 1, 1) 134 ) 135 self.mean_intercepts = torch.zeros( 136 [self.n_views, self.n_spatial_dims], device=device 137 ) 138 elif mean_function == "identity_initialized": 139 self.mean_slopes = nn.Parameter( 140 torch.randn([self.n_views, self.n_spatial_dims, self.n_spatial_dims]) 141 ) 142 self.mean_intercepts = nn.Parameter( 143 torch.zeros([self.n_views, self.n_spatial_dims]) 144 ) 145 else: 146 self.mean_slopes = nn.Parameter( 147 torch.eye(self.n_spatial_dims).unsqueeze(0).repeat(self.n_views, 1, 1) 148 ) 149 self.mean_intercepts = nn.Parameter( 150 torch.randn([self.n_views, self.n_spatial_dims]) * 0.1 151 ) 152 153 # self.mean_intercepts = torch.zeros([self.n_views, self.n_spatial_dims]) 154 self.diagonal_offset = 1e-5 155 156 def create_view_idx_dict(self, data_dict): 157 """Summary 158 159 Args: 160 data_dict (TYPE): Description 161 162 Returns: 163 TYPE: Description 164 """ 165 view_idx, Ns, Ps = {}, {}, {} 166 n_total = 0 167 for mod in self.modality_names: 168 n_samples_list = data_dict[mod]["n_samples_list"] 169 # self.n_samples_lists[mod] = n_samples_list 170 curr_N = np.sum(n_samples_list) 171 Ns[mod] = curr_N 172 n_total += curr_N 173 Ps[mod] = data_dict[mod]["outputs"].shape[1] 174 175 # Compute the indices of each view for each modality 176 cumulative_sums = np.cumsum(n_samples_list) 177 cumulative_sums = np.insert(cumulative_sums, 0, 0) 178 curr_view_idx = [ 179 np.arange(cumulative_sums[ii], cumulative_sums[ii + 1]) 180 for ii in range(self.n_views) 181 ] 182 view_idx[mod] = curr_view_idx 183 184 return view_idx, Ns, Ps, n_total 185 186 def compute_mean_penalty(self): 187 return self.mean_penalty_param * torch.mean( 188 torch.square( 189 self.mean_slopes 190 - torch.eye(self.n_spatial_dims).unsqueeze(0).repeat(self.n_views, 1, 1) 191 ) 192 ) 193 194 def forward(self, X_spatial): 195 raise (NotImplementedError) 196 197 def loss_fn(self, data_dict, Gs, means_G_list, covs_G_list, means_Y, covs_Y): 198 raise (NotImplementedError)
Args: data_dict (dict): Dictionary of data in the format {"modality": {"spatial_coords": X, "outputs": Y, "n_samples_list": n_samples_list}} data_init (bool, optional): Whether to initialize inducing locations with KMeans on data. n_spatial_dims (int, optional): Number of spatial dimensions (usually 2 or 3). n_noise_variance_params (int, optional): Number of noise variance parameters. kernel_func_warp (function, optional): Covariance function for warp GP. kernel_func_data (function, optional): Covariance function for output GP. mean_function (str, optional): Mean function for warp GP. One of ["identity_fixed", "identity_initialized", or None]. None results in a linear mean function. mean_penalty_param (float, optional): Description fixed_warp_kernel_variances (None, optional): Description fixed_warp_kernel_lengthscales (None, optional): Description fixed_data_kernel_lengthscales (None, optional): Description
26 def __init__( 27 self, 28 data_dict, 29 data_init=True, 30 n_spatial_dims=2, 31 n_noise_variance_params=2, 32 kernel_func_warp=rbf_kernel, 33 kernel_func_data=rbf_kernel, 34 mean_function="identity_fixed", 35 mean_penalty_param=0.0, 36 fixed_warp_kernel_variances=None, 37 fixed_warp_kernel_lengthscales=None, 38 fixed_data_kernel_lengthscales=None, 39 ): 40 # Constructor 41 super(GPSA, self).__init__() 42 self.modality_names = list(data_dict.keys()) 43 self.n_modalities = len(self.modality_names) 44 self.mean_penalty_param = mean_penalty_param 45 46 ## Make sure all modalities have the same number of "views" 47 n_views = np.unique( 48 np.array( 49 [len(data_dict[mod]["n_samples_list"]) for mod in self.modality_names] 50 ) 51 ) 52 if len(n_views) != 1: 53 raise ValueError("Each modality must have the same number of views.") 54 self.n_views = n_views[0] 55 56 ## Make sure all modalities have the same domain for the spatial coordinates 57 n_spatial_dims = np.unique( 58 np.array( 59 [ 60 data_dict[mod]["spatial_coords"].shape[1] 61 for mod in self.modality_names 62 ] 63 ) 64 ) 65 if len(n_spatial_dims) != 1: 66 raise ValueError( 67 "Each modality must have the same number of spatial dimensions." 68 ) 69 self.n_spatial_dims = n_spatial_dims[0] 70 71 view_idx, Ns, Ps, n_total = self.create_view_idx_dict(data_dict) 72 self.view_idx = view_idx 73 self.Ns = Ns 74 self.Ps = Ps 75 self.n_total = n_total 76 # import ipdb; ipdb.set_trace() 77 78 ## Number of kernel parameters: 79 ## - 2 parameters for each view for warp GP (lengthscale and variance) 80 ## - 2 parameters for observation GP (lengthscale and variance) 81 self.n_kernel_params = 2 * self.n_views + 2 82 self.n_noise_variance_params = n_noise_variance_params 83 self.kernel_func_warp = kernel_func_warp 84 self.kernel_func_data = kernel_func_data 85 86 ## Parameters 87 self.noise_variance = nn.Parameter( 88 torch.randn([self.n_noise_variance_params]) - 1 89 ) 90 # self.noise_variance = torch.log(torch.ones(2) * 0.001) 91 92 if fixed_warp_kernel_variances is None: 93 # self.warp_kernel_variances = nn.Parameter( 94 # torch.randn([self.n_kernel_params // 2 - 1]) - 1 95 # ) 96 self.warp_kernel_variances = nn.Parameter( 97 torch.zeros(self.n_kernel_params // 2 - 1) 98 ) 99 else: 100 self.warp_kernel_variances = torch.log( 101 torch.tensor(fixed_warp_kernel_variances) 102 ) 103 104 if fixed_warp_kernel_lengthscales is None: 105 # self.warp_kernel_lengthscales = nn.Parameter( 106 # torch.randn([self.n_kernel_params // 2 - 1]) + 3 107 # ) 108 self.warp_kernel_lengthscales = nn.Parameter( 109 torch.zeros(self.n_kernel_params // 2 - 1) + np.log(10) 110 ) 111 else: 112 self.warp_kernel_lengthscales = torch.log( 113 torch.tensor(fixed_warp_kernel_lengthscales) 114 ) 115 116 if fixed_data_kernel_lengthscales is None: 117 self.data_kernel_lengthscale = nn.Parameter( 118 torch.log(torch.exp(torch.randn(1))) 119 ) 120 else: 121 self.data_kernel_lengthscale = torch.log( 122 torch.tensor(fixed_data_kernel_lengthscales).float() 123 ) 124 125 self.data_kernel_variance = nn.Parameter(torch.randn(1)) 126 # self.data_kernel_variance = nn.Parameter(torch.zeros(1)) 127 # self.data_kernel_variance = torch.tensor(0.).float() 128 129 if mean_function == "identity_fixed": 130 self.mean_slopes = ( 131 torch.eye(self.n_spatial_dims, device=device) 132 .unsqueeze(0) 133 .repeat(self.n_views, 1, 1) 134 ) 135 self.mean_intercepts = torch.zeros( 136 [self.n_views, self.n_spatial_dims], device=device 137 ) 138 elif mean_function == "identity_initialized": 139 self.mean_slopes = nn.Parameter( 140 torch.randn([self.n_views, self.n_spatial_dims, self.n_spatial_dims]) 141 ) 142 self.mean_intercepts = nn.Parameter( 143 torch.zeros([self.n_views, self.n_spatial_dims]) 144 ) 145 else: 146 self.mean_slopes = nn.Parameter( 147 torch.eye(self.n_spatial_dims).unsqueeze(0).repeat(self.n_views, 1, 1) 148 ) 149 self.mean_intercepts = nn.Parameter( 150 torch.randn([self.n_views, self.n_spatial_dims]) * 0.1 151 ) 152 153 # self.mean_intercepts = torch.zeros([self.n_views, self.n_spatial_dims]) 154 self.diagonal_offset = 1e-5
Initializes internal Module state, shared by both nn.Module and ScriptModule.
156 def create_view_idx_dict(self, data_dict): 157 """Summary 158 159 Args: 160 data_dict (TYPE): Description 161 162 Returns: 163 TYPE: Description 164 """ 165 view_idx, Ns, Ps = {}, {}, {} 166 n_total = 0 167 for mod in self.modality_names: 168 n_samples_list = data_dict[mod]["n_samples_list"] 169 # self.n_samples_lists[mod] = n_samples_list 170 curr_N = np.sum(n_samples_list) 171 Ns[mod] = curr_N 172 n_total += curr_N 173 Ps[mod] = data_dict[mod]["outputs"].shape[1] 174 175 # Compute the indices of each view for each modality 176 cumulative_sums = np.cumsum(n_samples_list) 177 cumulative_sums = np.insert(cumulative_sums, 0, 0) 178 curr_view_idx = [ 179 np.arange(cumulative_sums[ii], cumulative_sums[ii + 1]) 180 for ii in range(self.n_views) 181 ] 182 view_idx[mod] = curr_view_idx 183 184 return view_idx, Ns, Ps, n_total
Summary
Args: data_dict (TYPE): Description
Returns: TYPE: Description
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- register_buffer
- register_parameter
- add_module
- apply
- cuda
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- state_dict
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
201def distance_matrix(X, Y): 202 """Compute distances between samples (rows) of two matrices 203 204 Args: 205 X (array): n x D matrix of spatial locations 206 Y (array): m x D matrix of spatial locations 207 208 Returns: 209 array: n x m matrix whose ij'th elementh is the Euclidean distance between i'th row of X and j'th row of Y 210 """ 211 squared_diffs = torch.square(torch.unsqueeze(X, 0) - torch.unsqueeze(Y, 1)) 212 squared_distances = torch.sum(squared_diffs, dim=2) 213 return squared_distances
Compute distances between samples (rows) of two matrices
Args: X (array): n x D matrix of spatial locations Y (array): m x D matrix of spatial locations
Returns: array: n x m matrix whose ij'th elementh is the Euclidean distance between i'th row of X and j'th row of Y