gpsa.models.gpsa

  1import torch
  2import numpy as np
  3import torch.nn as nn
  4from ..util.util import rbf_kernel
  5
  6device = "cuda" if torch.cuda.is_available() else "cpu"
  7
  8# Define model
  9class GPSA(nn.Module):
 10    """
 11    Args:
 12        data_dict (dict): Dictionary of data in the format {"modality": {"spatial_coords": X, "outputs": Y, "n_samples_list": n_samples_list}}
 13        data_init (bool, optional): Whether to initialize inducing locations with KMeans on data.
 14        n_spatial_dims (int, optional): Number of spatial dimensions (usually 2 or 3).
 15        n_noise_variance_params (int, optional): Number of noise variance parameters.
 16        kernel_func_warp (function, optional): Covariance function for warp GP.
 17        kernel_func_data (function, optional): Covariance function for output GP.
 18        mean_function (str, optional): Mean function for warp GP. One of ["identity_fixed", "identity_initialized", or None]. None results in a linear mean function.
 19        mean_penalty_param (float, optional): Description
 20        fixed_warp_kernel_variances (None, optional): Description
 21        fixed_warp_kernel_lengthscales (None, optional): Description
 22        fixed_data_kernel_lengthscales (None, optional): Description
 23    """
 24
 25    def __init__(
 26        self,
 27        data_dict,
 28        data_init=True,
 29        n_spatial_dims=2,
 30        n_noise_variance_params=2,
 31        kernel_func_warp=rbf_kernel,
 32        kernel_func_data=rbf_kernel,
 33        mean_function="identity_fixed",
 34        mean_penalty_param=0.0,
 35        fixed_warp_kernel_variances=None,
 36        fixed_warp_kernel_lengthscales=None,
 37        fixed_data_kernel_lengthscales=None,
 38    ):
 39        # Constructor
 40        super(GPSA, self).__init__()
 41        self.modality_names = list(data_dict.keys())
 42        self.n_modalities = len(self.modality_names)
 43        self.mean_penalty_param = mean_penalty_param
 44
 45        ## Make sure all modalities have the same number of "views"
 46        n_views = np.unique(
 47            np.array(
 48                [len(data_dict[mod]["n_samples_list"]) for mod in self.modality_names]
 49            )
 50        )
 51        if len(n_views) != 1:
 52            raise ValueError("Each modality must have the same number of views.")
 53        self.n_views = n_views[0]
 54
 55        ## Make sure all modalities have the same domain for the spatial coordinates
 56        n_spatial_dims = np.unique(
 57            np.array(
 58                [
 59                    data_dict[mod]["spatial_coords"].shape[1]
 60                    for mod in self.modality_names
 61                ]
 62            )
 63        )
 64        if len(n_spatial_dims) != 1:
 65            raise ValueError(
 66                "Each modality must have the same number of spatial dimensions."
 67            )
 68        self.n_spatial_dims = n_spatial_dims[0]
 69
 70        view_idx, Ns, Ps, n_total = self.create_view_idx_dict(data_dict)
 71        self.view_idx = view_idx
 72        self.Ns = Ns
 73        self.Ps = Ps
 74        self.n_total = n_total
 75        # import ipdb; ipdb.set_trace()
 76
 77        ## Number of kernel parameters:
 78        ##		- 2 parameters for each view for warp GP (lengthscale and variance)
 79        ##		- 2 parameters for observation GP (lengthscale and variance)
 80        self.n_kernel_params = 2 * self.n_views + 2
 81        self.n_noise_variance_params = n_noise_variance_params
 82        self.kernel_func_warp = kernel_func_warp
 83        self.kernel_func_data = kernel_func_data
 84
 85        ## Parameters
 86        self.noise_variance = nn.Parameter(
 87            torch.randn([self.n_noise_variance_params]) - 1
 88        )
 89        # self.noise_variance = torch.log(torch.ones(2) * 0.001)
 90
 91        if fixed_warp_kernel_variances is None:
 92            # self.warp_kernel_variances = nn.Parameter(
 93            #     torch.randn([self.n_kernel_params // 2 - 1]) - 1
 94            # )
 95            self.warp_kernel_variances = nn.Parameter(
 96                torch.zeros(self.n_kernel_params // 2 - 1)
 97            )
 98        else:
 99            self.warp_kernel_variances = torch.log(
100                torch.tensor(fixed_warp_kernel_variances)
101            )
102
103        if fixed_warp_kernel_lengthscales is None:
104            # self.warp_kernel_lengthscales = nn.Parameter(
105            #     torch.randn([self.n_kernel_params // 2 - 1]) + 3
106            # )
107            self.warp_kernel_lengthscales = nn.Parameter(
108                torch.zeros(self.n_kernel_params // 2 - 1) + np.log(10)
109            )
110        else:
111            self.warp_kernel_lengthscales = torch.log(
112                torch.tensor(fixed_warp_kernel_lengthscales)
113            )
114
115        if fixed_data_kernel_lengthscales is None:
116            self.data_kernel_lengthscale = nn.Parameter(
117                torch.log(torch.exp(torch.randn(1)))
118            )
119        else:
120            self.data_kernel_lengthscale = torch.log(
121                torch.tensor(fixed_data_kernel_lengthscales).float()
122            )
123
124        self.data_kernel_variance = nn.Parameter(torch.randn(1))
125        # self.data_kernel_variance = nn.Parameter(torch.zeros(1))
126        # self.data_kernel_variance = torch.tensor(0.).float()
127
128        if mean_function == "identity_fixed":
129            self.mean_slopes = (
130                torch.eye(self.n_spatial_dims, device=device)
131                .unsqueeze(0)
132                .repeat(self.n_views, 1, 1)
133            )
134            self.mean_intercepts = torch.zeros(
135                [self.n_views, self.n_spatial_dims], device=device
136            )
137        elif mean_function == "identity_initialized":
138            self.mean_slopes = nn.Parameter(
139                torch.randn([self.n_views, self.n_spatial_dims, self.n_spatial_dims])
140            )
141            self.mean_intercepts = nn.Parameter(
142                torch.zeros([self.n_views, self.n_spatial_dims])
143            )
144        else:
145            self.mean_slopes = nn.Parameter(
146                torch.eye(self.n_spatial_dims).unsqueeze(0).repeat(self.n_views, 1, 1)
147            )
148            self.mean_intercepts = nn.Parameter(
149                torch.randn([self.n_views, self.n_spatial_dims]) * 0.1
150            )
151
152        # self.mean_intercepts = torch.zeros([self.n_views, self.n_spatial_dims])
153        self.diagonal_offset = 1e-5
154
155    def create_view_idx_dict(self, data_dict):
156        """Summary
157
158        Args:
159            data_dict (TYPE): Description
160
161        Returns:
162            TYPE: Description
163        """
164        view_idx, Ns, Ps = {}, {}, {}
165        n_total = 0
166        for mod in self.modality_names:
167            n_samples_list = data_dict[mod]["n_samples_list"]
168            # self.n_samples_lists[mod] = n_samples_list
169            curr_N = np.sum(n_samples_list)
170            Ns[mod] = curr_N
171            n_total += curr_N
172            Ps[mod] = data_dict[mod]["outputs"].shape[1]
173
174            # Compute the indices of each view for each modality
175            cumulative_sums = np.cumsum(n_samples_list)
176            cumulative_sums = np.insert(cumulative_sums, 0, 0)
177            curr_view_idx = [
178                np.arange(cumulative_sums[ii], cumulative_sums[ii + 1])
179                for ii in range(self.n_views)
180            ]
181            view_idx[mod] = curr_view_idx
182
183        return view_idx, Ns, Ps, n_total
184
185    def compute_mean_penalty(self):
186        return self.mean_penalty_param * torch.mean(
187            torch.square(
188                self.mean_slopes
189                - torch.eye(self.n_spatial_dims).unsqueeze(0).repeat(self.n_views, 1, 1)
190            )
191        )
192
193    def forward(self, X_spatial):
194        raise (NotImplementedError)
195
196    def loss_fn(self, data_dict, Gs, means_G_list, covs_G_list, means_Y, covs_Y):
197        raise (NotImplementedError)
198
199
200def distance_matrix(X, Y):
201    """Compute distances between samples (rows) of two matrices
202
203    Args:
204        X (array): n x D matrix of spatial locations
205        Y (array): m x D matrix of spatial locations
206
207    Returns:
208        array: n x m matrix whose ij'th elementh is the Euclidean distance between i'th row of X and j'th row of Y
209    """
210    squared_diffs = torch.square(torch.unsqueeze(X, 0) - torch.unsqueeze(Y, 1))
211    squared_distances = torch.sum(squared_diffs, dim=2)
212    return squared_distances
213
214
215if __name__ == "__main__":
216
217    pass
class GPSA(torch.nn.modules.module.Module):
 10class GPSA(nn.Module):
 11    """
 12    Args:
 13        data_dict (dict): Dictionary of data in the format {"modality": {"spatial_coords": X, "outputs": Y, "n_samples_list": n_samples_list}}
 14        data_init (bool, optional): Whether to initialize inducing locations with KMeans on data.
 15        n_spatial_dims (int, optional): Number of spatial dimensions (usually 2 or 3).
 16        n_noise_variance_params (int, optional): Number of noise variance parameters.
 17        kernel_func_warp (function, optional): Covariance function for warp GP.
 18        kernel_func_data (function, optional): Covariance function for output GP.
 19        mean_function (str, optional): Mean function for warp GP. One of ["identity_fixed", "identity_initialized", or None]. None results in a linear mean function.
 20        mean_penalty_param (float, optional): Description
 21        fixed_warp_kernel_variances (None, optional): Description
 22        fixed_warp_kernel_lengthscales (None, optional): Description
 23        fixed_data_kernel_lengthscales (None, optional): Description
 24    """
 25
 26    def __init__(
 27        self,
 28        data_dict,
 29        data_init=True,
 30        n_spatial_dims=2,
 31        n_noise_variance_params=2,
 32        kernel_func_warp=rbf_kernel,
 33        kernel_func_data=rbf_kernel,
 34        mean_function="identity_fixed",
 35        mean_penalty_param=0.0,
 36        fixed_warp_kernel_variances=None,
 37        fixed_warp_kernel_lengthscales=None,
 38        fixed_data_kernel_lengthscales=None,
 39    ):
 40        # Constructor
 41        super(GPSA, self).__init__()
 42        self.modality_names = list(data_dict.keys())
 43        self.n_modalities = len(self.modality_names)
 44        self.mean_penalty_param = mean_penalty_param
 45
 46        ## Make sure all modalities have the same number of "views"
 47        n_views = np.unique(
 48            np.array(
 49                [len(data_dict[mod]["n_samples_list"]) for mod in self.modality_names]
 50            )
 51        )
 52        if len(n_views) != 1:
 53            raise ValueError("Each modality must have the same number of views.")
 54        self.n_views = n_views[0]
 55
 56        ## Make sure all modalities have the same domain for the spatial coordinates
 57        n_spatial_dims = np.unique(
 58            np.array(
 59                [
 60                    data_dict[mod]["spatial_coords"].shape[1]
 61                    for mod in self.modality_names
 62                ]
 63            )
 64        )
 65        if len(n_spatial_dims) != 1:
 66            raise ValueError(
 67                "Each modality must have the same number of spatial dimensions."
 68            )
 69        self.n_spatial_dims = n_spatial_dims[0]
 70
 71        view_idx, Ns, Ps, n_total = self.create_view_idx_dict(data_dict)
 72        self.view_idx = view_idx
 73        self.Ns = Ns
 74        self.Ps = Ps
 75        self.n_total = n_total
 76        # import ipdb; ipdb.set_trace()
 77
 78        ## Number of kernel parameters:
 79        ##		- 2 parameters for each view for warp GP (lengthscale and variance)
 80        ##		- 2 parameters for observation GP (lengthscale and variance)
 81        self.n_kernel_params = 2 * self.n_views + 2
 82        self.n_noise_variance_params = n_noise_variance_params
 83        self.kernel_func_warp = kernel_func_warp
 84        self.kernel_func_data = kernel_func_data
 85
 86        ## Parameters
 87        self.noise_variance = nn.Parameter(
 88            torch.randn([self.n_noise_variance_params]) - 1
 89        )
 90        # self.noise_variance = torch.log(torch.ones(2) * 0.001)
 91
 92        if fixed_warp_kernel_variances is None:
 93            # self.warp_kernel_variances = nn.Parameter(
 94            #     torch.randn([self.n_kernel_params // 2 - 1]) - 1
 95            # )
 96            self.warp_kernel_variances = nn.Parameter(
 97                torch.zeros(self.n_kernel_params // 2 - 1)
 98            )
 99        else:
100            self.warp_kernel_variances = torch.log(
101                torch.tensor(fixed_warp_kernel_variances)
102            )
103
104        if fixed_warp_kernel_lengthscales is None:
105            # self.warp_kernel_lengthscales = nn.Parameter(
106            #     torch.randn([self.n_kernel_params // 2 - 1]) + 3
107            # )
108            self.warp_kernel_lengthscales = nn.Parameter(
109                torch.zeros(self.n_kernel_params // 2 - 1) + np.log(10)
110            )
111        else:
112            self.warp_kernel_lengthscales = torch.log(
113                torch.tensor(fixed_warp_kernel_lengthscales)
114            )
115
116        if fixed_data_kernel_lengthscales is None:
117            self.data_kernel_lengthscale = nn.Parameter(
118                torch.log(torch.exp(torch.randn(1)))
119            )
120        else:
121            self.data_kernel_lengthscale = torch.log(
122                torch.tensor(fixed_data_kernel_lengthscales).float()
123            )
124
125        self.data_kernel_variance = nn.Parameter(torch.randn(1))
126        # self.data_kernel_variance = nn.Parameter(torch.zeros(1))
127        # self.data_kernel_variance = torch.tensor(0.).float()
128
129        if mean_function == "identity_fixed":
130            self.mean_slopes = (
131                torch.eye(self.n_spatial_dims, device=device)
132                .unsqueeze(0)
133                .repeat(self.n_views, 1, 1)
134            )
135            self.mean_intercepts = torch.zeros(
136                [self.n_views, self.n_spatial_dims], device=device
137            )
138        elif mean_function == "identity_initialized":
139            self.mean_slopes = nn.Parameter(
140                torch.randn([self.n_views, self.n_spatial_dims, self.n_spatial_dims])
141            )
142            self.mean_intercepts = nn.Parameter(
143                torch.zeros([self.n_views, self.n_spatial_dims])
144            )
145        else:
146            self.mean_slopes = nn.Parameter(
147                torch.eye(self.n_spatial_dims).unsqueeze(0).repeat(self.n_views, 1, 1)
148            )
149            self.mean_intercepts = nn.Parameter(
150                torch.randn([self.n_views, self.n_spatial_dims]) * 0.1
151            )
152
153        # self.mean_intercepts = torch.zeros([self.n_views, self.n_spatial_dims])
154        self.diagonal_offset = 1e-5
155
156    def create_view_idx_dict(self, data_dict):
157        """Summary
158
159        Args:
160            data_dict (TYPE): Description
161
162        Returns:
163            TYPE: Description
164        """
165        view_idx, Ns, Ps = {}, {}, {}
166        n_total = 0
167        for mod in self.modality_names:
168            n_samples_list = data_dict[mod]["n_samples_list"]
169            # self.n_samples_lists[mod] = n_samples_list
170            curr_N = np.sum(n_samples_list)
171            Ns[mod] = curr_N
172            n_total += curr_N
173            Ps[mod] = data_dict[mod]["outputs"].shape[1]
174
175            # Compute the indices of each view for each modality
176            cumulative_sums = np.cumsum(n_samples_list)
177            cumulative_sums = np.insert(cumulative_sums, 0, 0)
178            curr_view_idx = [
179                np.arange(cumulative_sums[ii], cumulative_sums[ii + 1])
180                for ii in range(self.n_views)
181            ]
182            view_idx[mod] = curr_view_idx
183
184        return view_idx, Ns, Ps, n_total
185
186    def compute_mean_penalty(self):
187        return self.mean_penalty_param * torch.mean(
188            torch.square(
189                self.mean_slopes
190                - torch.eye(self.n_spatial_dims).unsqueeze(0).repeat(self.n_views, 1, 1)
191            )
192        )
193
194    def forward(self, X_spatial):
195        raise (NotImplementedError)
196
197    def loss_fn(self, data_dict, Gs, means_G_list, covs_G_list, means_Y, covs_Y):
198        raise (NotImplementedError)

Args: data_dict (dict): Dictionary of data in the format {"modality": {"spatial_coords": X, "outputs": Y, "n_samples_list": n_samples_list}} data_init (bool, optional): Whether to initialize inducing locations with KMeans on data. n_spatial_dims (int, optional): Number of spatial dimensions (usually 2 or 3). n_noise_variance_params (int, optional): Number of noise variance parameters. kernel_func_warp (function, optional): Covariance function for warp GP. kernel_func_data (function, optional): Covariance function for output GP. mean_function (str, optional): Mean function for warp GP. One of ["identity_fixed", "identity_initialized", or None]. None results in a linear mean function. mean_penalty_param (float, optional): Description fixed_warp_kernel_variances (None, optional): Description fixed_warp_kernel_lengthscales (None, optional): Description fixed_data_kernel_lengthscales (None, optional): Description

GPSA( data_dict, data_init=True, n_spatial_dims=2, n_noise_variance_params=2, kernel_func_warp=<function rbf_kernel>, kernel_func_data=<function rbf_kernel>, mean_function='identity_fixed', mean_penalty_param=0.0, fixed_warp_kernel_variances=None, fixed_warp_kernel_lengthscales=None, fixed_data_kernel_lengthscales=None)
 26    def __init__(
 27        self,
 28        data_dict,
 29        data_init=True,
 30        n_spatial_dims=2,
 31        n_noise_variance_params=2,
 32        kernel_func_warp=rbf_kernel,
 33        kernel_func_data=rbf_kernel,
 34        mean_function="identity_fixed",
 35        mean_penalty_param=0.0,
 36        fixed_warp_kernel_variances=None,
 37        fixed_warp_kernel_lengthscales=None,
 38        fixed_data_kernel_lengthscales=None,
 39    ):
 40        # Constructor
 41        super(GPSA, self).__init__()
 42        self.modality_names = list(data_dict.keys())
 43        self.n_modalities = len(self.modality_names)
 44        self.mean_penalty_param = mean_penalty_param
 45
 46        ## Make sure all modalities have the same number of "views"
 47        n_views = np.unique(
 48            np.array(
 49                [len(data_dict[mod]["n_samples_list"]) for mod in self.modality_names]
 50            )
 51        )
 52        if len(n_views) != 1:
 53            raise ValueError("Each modality must have the same number of views.")
 54        self.n_views = n_views[0]
 55
 56        ## Make sure all modalities have the same domain for the spatial coordinates
 57        n_spatial_dims = np.unique(
 58            np.array(
 59                [
 60                    data_dict[mod]["spatial_coords"].shape[1]
 61                    for mod in self.modality_names
 62                ]
 63            )
 64        )
 65        if len(n_spatial_dims) != 1:
 66            raise ValueError(
 67                "Each modality must have the same number of spatial dimensions."
 68            )
 69        self.n_spatial_dims = n_spatial_dims[0]
 70
 71        view_idx, Ns, Ps, n_total = self.create_view_idx_dict(data_dict)
 72        self.view_idx = view_idx
 73        self.Ns = Ns
 74        self.Ps = Ps
 75        self.n_total = n_total
 76        # import ipdb; ipdb.set_trace()
 77
 78        ## Number of kernel parameters:
 79        ##		- 2 parameters for each view for warp GP (lengthscale and variance)
 80        ##		- 2 parameters for observation GP (lengthscale and variance)
 81        self.n_kernel_params = 2 * self.n_views + 2
 82        self.n_noise_variance_params = n_noise_variance_params
 83        self.kernel_func_warp = kernel_func_warp
 84        self.kernel_func_data = kernel_func_data
 85
 86        ## Parameters
 87        self.noise_variance = nn.Parameter(
 88            torch.randn([self.n_noise_variance_params]) - 1
 89        )
 90        # self.noise_variance = torch.log(torch.ones(2) * 0.001)
 91
 92        if fixed_warp_kernel_variances is None:
 93            # self.warp_kernel_variances = nn.Parameter(
 94            #     torch.randn([self.n_kernel_params // 2 - 1]) - 1
 95            # )
 96            self.warp_kernel_variances = nn.Parameter(
 97                torch.zeros(self.n_kernel_params // 2 - 1)
 98            )
 99        else:
100            self.warp_kernel_variances = torch.log(
101                torch.tensor(fixed_warp_kernel_variances)
102            )
103
104        if fixed_warp_kernel_lengthscales is None:
105            # self.warp_kernel_lengthscales = nn.Parameter(
106            #     torch.randn([self.n_kernel_params // 2 - 1]) + 3
107            # )
108            self.warp_kernel_lengthscales = nn.Parameter(
109                torch.zeros(self.n_kernel_params // 2 - 1) + np.log(10)
110            )
111        else:
112            self.warp_kernel_lengthscales = torch.log(
113                torch.tensor(fixed_warp_kernel_lengthscales)
114            )
115
116        if fixed_data_kernel_lengthscales is None:
117            self.data_kernel_lengthscale = nn.Parameter(
118                torch.log(torch.exp(torch.randn(1)))
119            )
120        else:
121            self.data_kernel_lengthscale = torch.log(
122                torch.tensor(fixed_data_kernel_lengthscales).float()
123            )
124
125        self.data_kernel_variance = nn.Parameter(torch.randn(1))
126        # self.data_kernel_variance = nn.Parameter(torch.zeros(1))
127        # self.data_kernel_variance = torch.tensor(0.).float()
128
129        if mean_function == "identity_fixed":
130            self.mean_slopes = (
131                torch.eye(self.n_spatial_dims, device=device)
132                .unsqueeze(0)
133                .repeat(self.n_views, 1, 1)
134            )
135            self.mean_intercepts = torch.zeros(
136                [self.n_views, self.n_spatial_dims], device=device
137            )
138        elif mean_function == "identity_initialized":
139            self.mean_slopes = nn.Parameter(
140                torch.randn([self.n_views, self.n_spatial_dims, self.n_spatial_dims])
141            )
142            self.mean_intercepts = nn.Parameter(
143                torch.zeros([self.n_views, self.n_spatial_dims])
144            )
145        else:
146            self.mean_slopes = nn.Parameter(
147                torch.eye(self.n_spatial_dims).unsqueeze(0).repeat(self.n_views, 1, 1)
148            )
149            self.mean_intercepts = nn.Parameter(
150                torch.randn([self.n_views, self.n_spatial_dims]) * 0.1
151            )
152
153        # self.mean_intercepts = torch.zeros([self.n_views, self.n_spatial_dims])
154        self.diagonal_offset = 1e-5

Initializes internal Module state, shared by both nn.Module and ScriptModule.

def create_view_idx_dict(self, data_dict):
156    def create_view_idx_dict(self, data_dict):
157        """Summary
158
159        Args:
160            data_dict (TYPE): Description
161
162        Returns:
163            TYPE: Description
164        """
165        view_idx, Ns, Ps = {}, {}, {}
166        n_total = 0
167        for mod in self.modality_names:
168            n_samples_list = data_dict[mod]["n_samples_list"]
169            # self.n_samples_lists[mod] = n_samples_list
170            curr_N = np.sum(n_samples_list)
171            Ns[mod] = curr_N
172            n_total += curr_N
173            Ps[mod] = data_dict[mod]["outputs"].shape[1]
174
175            # Compute the indices of each view for each modality
176            cumulative_sums = np.cumsum(n_samples_list)
177            cumulative_sums = np.insert(cumulative_sums, 0, 0)
178            curr_view_idx = [
179                np.arange(cumulative_sums[ii], cumulative_sums[ii + 1])
180                for ii in range(self.n_views)
181            ]
182            view_idx[mod] = curr_view_idx
183
184        return view_idx, Ns, Ps, n_total

Summary

Args: data_dict (TYPE): Description

Returns: TYPE: Description

def compute_mean_penalty(self):
186    def compute_mean_penalty(self):
187        return self.mean_penalty_param * torch.mean(
188            torch.square(
189                self.mean_slopes
190                - torch.eye(self.n_spatial_dims).unsqueeze(0).repeat(self.n_views, 1, 1)
191            )
192        )
def forward(self, X_spatial):
194    def forward(self, X_spatial):
195        raise (NotImplementedError)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

def loss_fn(self, data_dict, Gs, means_G_list, covs_G_list, means_Y, covs_Y):
197    def loss_fn(self, data_dict, Gs, means_G_list, covs_G_list, means_Y, covs_Y):
198        raise (NotImplementedError)
Inherited Members
torch.nn.modules.module.Module
dump_patches
register_buffer
register_parameter
add_module
apply
cuda
xpu
cpu
type
float
double
half
bfloat16
to
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
state_dict
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
def distance_matrix(X, Y):
201def distance_matrix(X, Y):
202    """Compute distances between samples (rows) of two matrices
203
204    Args:
205        X (array): n x D matrix of spatial locations
206        Y (array): m x D matrix of spatial locations
207
208    Returns:
209        array: n x m matrix whose ij'th elementh is the Euclidean distance between i'th row of X and j'th row of Y
210    """
211    squared_diffs = torch.square(torch.unsqueeze(X, 0) - torch.unsqueeze(Y, 1))
212    squared_distances = torch.sum(squared_diffs, dim=2)
213    return squared_distances

Compute distances between samples (rows) of two matrices

Args: X (array): n x D matrix of spatial locations Y (array): m x D matrix of spatial locations

Returns: array: n x m matrix whose ij'th elementh is the Euclidean distance between i'th row of X and j'th row of Y