gpsa.util.util

  1import numpy as np
  2import pandas as pd
  3import numpy.random as npr
  4import torch
  5from scipy.special import xlogy
  6
  7
  8def rbf_kernel(
  9    x1, x2, lengthscale_unconstrained, output_variance_unconstrained, diag=False
 10):
 11
 12    lengthscale = torch.exp(lengthscale_unconstrained)
 13    output_variance = torch.exp(output_variance_unconstrained)
 14
 15    if diag:
 16        diffs = x1 - x2
 17    else:
 18        diffs = x1.unsqueeze(-2) - x2.unsqueeze(-3)
 19
 20    K = output_variance * torch.exp(
 21        -0.5 * torch.sum(torch.square(diffs / lengthscale), dim=-1)
 22    )
 23    return K
 24
 25
 26def rbf_kernel_numpy(x, xp, kernel_params):
 27    output_scale = np.exp(kernel_params[0])
 28    lengthscales = np.exp(kernel_params[1:])
 29    diffs = np.expand_dims(x / lengthscales, 1) - np.expand_dims(xp / lengthscales, 0)
 30    return output_scale * np.exp(-0.5 * np.sum(diffs**2, axis=2))
 31
 32
 33def matern12_kernel(
 34    x1, x2, lengthscale_unconstrained, output_variance_unconstrained, diag=False
 35):
 36
 37    lengthscale = torch.exp(lengthscale_unconstrained)
 38    output_variance = torch.exp(output_variance_unconstrained)
 39
 40    if diag:
 41        diffs = x1 - x2
 42    else:
 43        diffs = x1.unsqueeze(-2) - x2.unsqueeze(-3)
 44    eps = 1e-10
 45    dists = torch.sqrt(torch.sum(torch.square(diffs), dim=-1) + eps)
 46
 47    return output_variance * torch.exp(-0.5 * dists / lengthscale)
 48
 49
 50def matern32_kernel(
 51    x1, x2, lengthscale_unconstrained, output_variance_unconstrained, diag=False
 52):
 53
 54    lengthscale = torch.exp(lengthscale_unconstrained)
 55    output_variance = torch.exp(output_variance_unconstrained)
 56
 57    if diag:
 58        diffs = x1 - x2
 59    else:
 60        diffs = x1.unsqueeze(-2) - x2.unsqueeze(-3)
 61    eps = 1e-10
 62    dists = torch.sqrt(torch.sum(torch.square(diffs), dim=-1) + eps)
 63
 64    inner_term = np.sqrt(3.0) * dists / lengthscale
 65    K = output_variance * (1 + inner_term) * torch.exp(-inner_term)
 66    return K
 67
 68
 69def polar_warp(X, r, theta):
 70    return np.array([X[:, 0] + r * np.cos(theta), X[:, 1] + r * np.sin(theta)]).T
 71
 72
 73def get_st_coordinates(df):
 74    """
 75    Extracts spatial coordinates from ST data with index in 'AxB' type format.
 76
 77    Return: pandas dataframe of coordinates
 78    """
 79    coor = []
 80    for spot in df.index:
 81        coordinates = spot.split("x")
 82        coordinates = [float(i) for i in coordinates]
 83        coor.append(coordinates)
 84    return np.array(coor)
 85
 86
 87def compute_distance(X1, X2):
 88    return np.mean(np.sqrt(np.sum((X1 - X2) ** 2, axis=1)))
 89
 90
 91def make_pinwheel(
 92    radial_std, tangential_std, num_classes, num_per_class, rate, rs=npr.RandomState(0)
 93):
 94    """Based on code by Ryan P. Adams."""
 95    rads = np.linspace(0, 2 * np.pi, num_classes, endpoint=False)
 96
 97    features = rs.randn(num_classes * num_per_class, 2) * np.array(
 98        [radial_std, tangential_std]
 99    )
100    features[:, 0] += 1
101    labels = np.repeat(np.arange(num_classes), num_per_class)
102
103    angles = rads[labels] + rate * np.exp(features[:, 0])
104    rotations = np.stack(
105        [np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)]
106    )
107    rotations = np.reshape(rotations.T, (-1, 2, 2))
108
109    return np.einsum("ti,tij->tj", features, rotations)
110
111
112class ConvergenceChecker(object):
113    def __init__(self, span, dtp="float64"):
114        self.span = span
115        x = np.arange(span, dtype=dtp)
116        x -= x.mean()
117        X = np.column_stack((np.ones(shape=x.shape), x, x**2, x**3))
118        self.U = np.linalg.svd(X, full_matrices=False)[0]
119
120    def smooth(self, y):
121        return self.U @ (self.U.T @ y)
122
123    def subset(self, y, idx=-1):
124        span = self.U.shape[0]
125        lo = idx - span + 1
126        if idx == -1:
127            return y[lo:]
128        else:
129            return y[lo : (idx + 1)]
130
131    def relative_change(self, y, idx=-1, smooth=True):
132        y = self.subset(y, idx=idx)
133        if smooth:
134            y = self.smooth(y)
135        prev = y[-2]
136        return (y[-1] - prev) / (0.1 + abs(prev))
137
138    def converged(self, y, tol=1e-4, **kwargs):
139        return abs(self.relative_change(y, **kwargs)) < tol
140
141    def relative_change_all(self, y, smooth=True):
142        n = len(y)
143        span = self.U.shape[0]
144        cc = np.tile([np.nan], n)
145        for i in range(span, n):
146            cc[i] = self.relative_change(y, idx=i, smooth=smooth)
147        return cc
148
149    def converged_all(self, y, tol=1e-4, smooth=True):
150        cc = self.relative_change_all(y, smooth=smooth)
151        return np.abs(cc) < tol
152
153
154# Function for computing size factors
155def compute_size_factors(m):
156    # given matrix m with samples in the columns
157    # compute size factors
158
159    sz = np.sum(m.values, axis=0)  # column sums (sum of counts in each cell)
160    lsz = np.log(sz)
161
162    # make geometric mean of sz be 1 for poisson
163    sz_poisson = np.exp(lsz - np.mean(lsz))
164    return sz_poisson
165
166
167def poisson_deviance(X, sz):
168
169    LP = X.values / sz  # recycling
170    # import ipdb; ipdb.set_trace()
171    LP[LP > 0] = np.log(LP[LP > 0])  # log transform nonzero elements only
172
173    # Transpose to make features in cols, observations in rows
174    X = X.T
175    ll_sat = np.sum(np.multiply(X, LP.T), axis=0)
176    feature_sums = np.sum(X, axis=0)
177    ll_null = feature_sums * np.log(feature_sums / np.sum(sz))
178    return 2 * (ll_sat - ll_null)
179
180
181def deviance_feature_selection(X):
182
183    # Remove cells without any counts
184    X = X[np.sum(X, axis=1) > 0]
185
186    # Compute size factors
187    sz = compute_size_factors(X)
188
189    # Compute deviances
190    devs = poisson_deviance(X, sz)
191
192    # Get associated gene names
193    gene_names = X.index.values
194
195    assert gene_names.shape[0] == devs.values.shape[0]
196
197    return devs.values, gene_names
198
199
200def deviance_residuals(x, theta, mu=None):
201    """Computes deviance residuals for NB model with a fixed theta"""
202
203    if mu is None:
204        counts_sum0 = np.sum(x, axis=0, keepdims=True)
205        counts_sum1 = np.sum(x, axis=1, keepdims=True)
206        counts_sum = np.sum(x)
207        # get residuals
208        mu = counts_sum1 @ counts_sum0 / counts_sum
209
210    def remove_negatives(sqrt_term):
211        negatives_idx = sqrt_term < 0
212        if np.any(negatives_idx):
213            n_negatives = np.sum(negatives_idx)
214            print(
215                "Setting %u negative sqrt term values to 0 (%f%%)"
216                % (n_negatives, n_negatives / np.product(sqrt_term.shape))
217            )
218            sqrt_term[negatives_idx] = 0
219
220    if np.isinf(theta):  ### POISSON
221        x_minus_mu = x - mu
222        sqrt_term = 2 * (
223            xlogy(x, x / mu) - x_minus_mu
224        )  # xlogy(x,x/mu) computes xlog(x/mu) and returns 0 if x=0
225        remove_negatives(sqrt_term)
226        dev = np.sign(x_minus_mu) * np.sqrt(sqrt_term)
227    else:  ### NEG BIN
228        x_plus_theta = x + theta
229        sqrt_term = 2 * (
230            xlogy(x, x / mu) - (x_plus_theta) * np.log(x_plus_theta / (mu + theta))
231        )  # xlogy(x,x/mu) computes xlog(x/mu) and returns 0 if x=0
232        remove_negatives(sqrt_term)
233        dev = np.sign(x - mu) * np.sqrt(sqrt_term)
234
235    return dev
236
237
238def pearson_residuals(counts, theta, clipping=True):
239    """Computes analytical residuals for NB model with a fixed theta, clipping outlier residuals to sqrt(N)"""
240    counts_sum0 = np.sum(counts, axis=0, keepdims=True)
241    counts_sum1 = np.sum(counts, axis=1, keepdims=True)
242    counts_sum = np.sum(counts)
243
244    # get residuals
245    mu = counts_sum1 @ counts_sum0 / counts_sum
246    z = (counts - mu) / np.sqrt(mu + mu**2 / theta)
247
248    # clip to sqrt(n)
249    if clipping:
250        n = counts.shape[0]
251        z[z > np.sqrt(n)] = np.sqrt(n)
252        z[z < -np.sqrt(n)] = -np.sqrt(n)
253
254    return z
255
256
257class LossNotDecreasingChecker:
258    def __init__(self, max_epochs, atol=1e-2, window_size=10):
259        self.max_epochs = max_epochs
260        self.atol = atol
261        self.window_size = window_size
262        self.decrease_in_loss = np.zeros(max_epochs)
263        self.average_decrease_in_loss = np.zeros(max_epochs)
264
265    def check_loss(self, iternum, loss_trace):
266
267        if iternum >= 1:
268            self.decrease_in_loss[iternum] = (
269                loss_trace[iternum - 1] - loss_trace[iternum]
270            )
271            if iternum >= self.window_size:
272                self.average_decrease_in_loss[iternum] = np.mean(
273                    self.decrease_in_loss[iternum - self.window_size + 1 : iternum]
274                )
275                has_converged = self.average_decrease_in_loss[iternum] < self.atol
276                return has_converged
277
278        return False
def rbf_kernel( x1, x2, lengthscale_unconstrained, output_variance_unconstrained, diag=False):
 9def rbf_kernel(
10    x1, x2, lengthscale_unconstrained, output_variance_unconstrained, diag=False
11):
12
13    lengthscale = torch.exp(lengthscale_unconstrained)
14    output_variance = torch.exp(output_variance_unconstrained)
15
16    if diag:
17        diffs = x1 - x2
18    else:
19        diffs = x1.unsqueeze(-2) - x2.unsqueeze(-3)
20
21    K = output_variance * torch.exp(
22        -0.5 * torch.sum(torch.square(diffs / lengthscale), dim=-1)
23    )
24    return K
def rbf_kernel_numpy(x, xp, kernel_params):
27def rbf_kernel_numpy(x, xp, kernel_params):
28    output_scale = np.exp(kernel_params[0])
29    lengthscales = np.exp(kernel_params[1:])
30    diffs = np.expand_dims(x / lengthscales, 1) - np.expand_dims(xp / lengthscales, 0)
31    return output_scale * np.exp(-0.5 * np.sum(diffs**2, axis=2))
def matern12_kernel( x1, x2, lengthscale_unconstrained, output_variance_unconstrained, diag=False):
34def matern12_kernel(
35    x1, x2, lengthscale_unconstrained, output_variance_unconstrained, diag=False
36):
37
38    lengthscale = torch.exp(lengthscale_unconstrained)
39    output_variance = torch.exp(output_variance_unconstrained)
40
41    if diag:
42        diffs = x1 - x2
43    else:
44        diffs = x1.unsqueeze(-2) - x2.unsqueeze(-3)
45    eps = 1e-10
46    dists = torch.sqrt(torch.sum(torch.square(diffs), dim=-1) + eps)
47
48    return output_variance * torch.exp(-0.5 * dists / lengthscale)
def matern32_kernel( x1, x2, lengthscale_unconstrained, output_variance_unconstrained, diag=False):
51def matern32_kernel(
52    x1, x2, lengthscale_unconstrained, output_variance_unconstrained, diag=False
53):
54
55    lengthscale = torch.exp(lengthscale_unconstrained)
56    output_variance = torch.exp(output_variance_unconstrained)
57
58    if diag:
59        diffs = x1 - x2
60    else:
61        diffs = x1.unsqueeze(-2) - x2.unsqueeze(-3)
62    eps = 1e-10
63    dists = torch.sqrt(torch.sum(torch.square(diffs), dim=-1) + eps)
64
65    inner_term = np.sqrt(3.0) * dists / lengthscale
66    K = output_variance * (1 + inner_term) * torch.exp(-inner_term)
67    return K
def polar_warp(X, r, theta):
70def polar_warp(X, r, theta):
71    return np.array([X[:, 0] + r * np.cos(theta), X[:, 1] + r * np.sin(theta)]).T
def get_st_coordinates(df):
74def get_st_coordinates(df):
75    """
76    Extracts spatial coordinates from ST data with index in 'AxB' type format.
77
78    Return: pandas dataframe of coordinates
79    """
80    coor = []
81    for spot in df.index:
82        coordinates = spot.split("x")
83        coordinates = [float(i) for i in coordinates]
84        coor.append(coordinates)
85    return np.array(coor)

Extracts spatial coordinates from ST data with index in 'AxB' type format.

Return: pandas dataframe of coordinates

def compute_distance(X1, X2):
88def compute_distance(X1, X2):
89    return np.mean(np.sqrt(np.sum((X1 - X2) ** 2, axis=1)))
def make_pinwheel( radial_std, tangential_std, num_classes, num_per_class, rate, rs=RandomState(MT19937) at 0x7F1BC7633940):
 92def make_pinwheel(
 93    radial_std, tangential_std, num_classes, num_per_class, rate, rs=npr.RandomState(0)
 94):
 95    """Based on code by Ryan P. Adams."""
 96    rads = np.linspace(0, 2 * np.pi, num_classes, endpoint=False)
 97
 98    features = rs.randn(num_classes * num_per_class, 2) * np.array(
 99        [radial_std, tangential_std]
100    )
101    features[:, 0] += 1
102    labels = np.repeat(np.arange(num_classes), num_per_class)
103
104    angles = rads[labels] + rate * np.exp(features[:, 0])
105    rotations = np.stack(
106        [np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)]
107    )
108    rotations = np.reshape(rotations.T, (-1, 2, 2))
109
110    return np.einsum("ti,tij->tj", features, rotations)

Based on code by Ryan P. Adams.

class ConvergenceChecker:
113class ConvergenceChecker(object):
114    def __init__(self, span, dtp="float64"):
115        self.span = span
116        x = np.arange(span, dtype=dtp)
117        x -= x.mean()
118        X = np.column_stack((np.ones(shape=x.shape), x, x**2, x**3))
119        self.U = np.linalg.svd(X, full_matrices=False)[0]
120
121    def smooth(self, y):
122        return self.U @ (self.U.T @ y)
123
124    def subset(self, y, idx=-1):
125        span = self.U.shape[0]
126        lo = idx - span + 1
127        if idx == -1:
128            return y[lo:]
129        else:
130            return y[lo : (idx + 1)]
131
132    def relative_change(self, y, idx=-1, smooth=True):
133        y = self.subset(y, idx=idx)
134        if smooth:
135            y = self.smooth(y)
136        prev = y[-2]
137        return (y[-1] - prev) / (0.1 + abs(prev))
138
139    def converged(self, y, tol=1e-4, **kwargs):
140        return abs(self.relative_change(y, **kwargs)) < tol
141
142    def relative_change_all(self, y, smooth=True):
143        n = len(y)
144        span = self.U.shape[0]
145        cc = np.tile([np.nan], n)
146        for i in range(span, n):
147            cc[i] = self.relative_change(y, idx=i, smooth=smooth)
148        return cc
149
150    def converged_all(self, y, tol=1e-4, smooth=True):
151        cc = self.relative_change_all(y, smooth=smooth)
152        return np.abs(cc) < tol
ConvergenceChecker(span, dtp='float64')
114    def __init__(self, span, dtp="float64"):
115        self.span = span
116        x = np.arange(span, dtype=dtp)
117        x -= x.mean()
118        X = np.column_stack((np.ones(shape=x.shape), x, x**2, x**3))
119        self.U = np.linalg.svd(X, full_matrices=False)[0]
def smooth(self, y):
121    def smooth(self, y):
122        return self.U @ (self.U.T @ y)
def subset(self, y, idx=-1):
124    def subset(self, y, idx=-1):
125        span = self.U.shape[0]
126        lo = idx - span + 1
127        if idx == -1:
128            return y[lo:]
129        else:
130            return y[lo : (idx + 1)]
def relative_change(self, y, idx=-1, smooth=True):
132    def relative_change(self, y, idx=-1, smooth=True):
133        y = self.subset(y, idx=idx)
134        if smooth:
135            y = self.smooth(y)
136        prev = y[-2]
137        return (y[-1] - prev) / (0.1 + abs(prev))
def converged(self, y, tol=0.0001, **kwargs):
139    def converged(self, y, tol=1e-4, **kwargs):
140        return abs(self.relative_change(y, **kwargs)) < tol
def relative_change_all(self, y, smooth=True):
142    def relative_change_all(self, y, smooth=True):
143        n = len(y)
144        span = self.U.shape[0]
145        cc = np.tile([np.nan], n)
146        for i in range(span, n):
147            cc[i] = self.relative_change(y, idx=i, smooth=smooth)
148        return cc
def converged_all(self, y, tol=0.0001, smooth=True):
150    def converged_all(self, y, tol=1e-4, smooth=True):
151        cc = self.relative_change_all(y, smooth=smooth)
152        return np.abs(cc) < tol
def compute_size_factors(m):
156def compute_size_factors(m):
157    # given matrix m with samples in the columns
158    # compute size factors
159
160    sz = np.sum(m.values, axis=0)  # column sums (sum of counts in each cell)
161    lsz = np.log(sz)
162
163    # make geometric mean of sz be 1 for poisson
164    sz_poisson = np.exp(lsz - np.mean(lsz))
165    return sz_poisson
def poisson_deviance(X, sz):
168def poisson_deviance(X, sz):
169
170    LP = X.values / sz  # recycling
171    # import ipdb; ipdb.set_trace()
172    LP[LP > 0] = np.log(LP[LP > 0])  # log transform nonzero elements only
173
174    # Transpose to make features in cols, observations in rows
175    X = X.T
176    ll_sat = np.sum(np.multiply(X, LP.T), axis=0)
177    feature_sums = np.sum(X, axis=0)
178    ll_null = feature_sums * np.log(feature_sums / np.sum(sz))
179    return 2 * (ll_sat - ll_null)
def deviance_feature_selection(X):
182def deviance_feature_selection(X):
183
184    # Remove cells without any counts
185    X = X[np.sum(X, axis=1) > 0]
186
187    # Compute size factors
188    sz = compute_size_factors(X)
189
190    # Compute deviances
191    devs = poisson_deviance(X, sz)
192
193    # Get associated gene names
194    gene_names = X.index.values
195
196    assert gene_names.shape[0] == devs.values.shape[0]
197
198    return devs.values, gene_names
def deviance_residuals(x, theta, mu=None):
201def deviance_residuals(x, theta, mu=None):
202    """Computes deviance residuals for NB model with a fixed theta"""
203
204    if mu is None:
205        counts_sum0 = np.sum(x, axis=0, keepdims=True)
206        counts_sum1 = np.sum(x, axis=1, keepdims=True)
207        counts_sum = np.sum(x)
208        # get residuals
209        mu = counts_sum1 @ counts_sum0 / counts_sum
210
211    def remove_negatives(sqrt_term):
212        negatives_idx = sqrt_term < 0
213        if np.any(negatives_idx):
214            n_negatives = np.sum(negatives_idx)
215            print(
216                "Setting %u negative sqrt term values to 0 (%f%%)"
217                % (n_negatives, n_negatives / np.product(sqrt_term.shape))
218            )
219            sqrt_term[negatives_idx] = 0
220
221    if np.isinf(theta):  ### POISSON
222        x_minus_mu = x - mu
223        sqrt_term = 2 * (
224            xlogy(x, x / mu) - x_minus_mu
225        )  # xlogy(x,x/mu) computes xlog(x/mu) and returns 0 if x=0
226        remove_negatives(sqrt_term)
227        dev = np.sign(x_minus_mu) * np.sqrt(sqrt_term)
228    else:  ### NEG BIN
229        x_plus_theta = x + theta
230        sqrt_term = 2 * (
231            xlogy(x, x / mu) - (x_plus_theta) * np.log(x_plus_theta / (mu + theta))
232        )  # xlogy(x,x/mu) computes xlog(x/mu) and returns 0 if x=0
233        remove_negatives(sqrt_term)
234        dev = np.sign(x - mu) * np.sqrt(sqrt_term)
235
236    return dev

Computes deviance residuals for NB model with a fixed theta

def pearson_residuals(counts, theta, clipping=True):
239def pearson_residuals(counts, theta, clipping=True):
240    """Computes analytical residuals for NB model with a fixed theta, clipping outlier residuals to sqrt(N)"""
241    counts_sum0 = np.sum(counts, axis=0, keepdims=True)
242    counts_sum1 = np.sum(counts, axis=1, keepdims=True)
243    counts_sum = np.sum(counts)
244
245    # get residuals
246    mu = counts_sum1 @ counts_sum0 / counts_sum
247    z = (counts - mu) / np.sqrt(mu + mu**2 / theta)
248
249    # clip to sqrt(n)
250    if clipping:
251        n = counts.shape[0]
252        z[z > np.sqrt(n)] = np.sqrt(n)
253        z[z < -np.sqrt(n)] = -np.sqrt(n)
254
255    return z

Computes analytical residuals for NB model with a fixed theta, clipping outlier residuals to sqrt(N)

class LossNotDecreasingChecker:
258class LossNotDecreasingChecker:
259    def __init__(self, max_epochs, atol=1e-2, window_size=10):
260        self.max_epochs = max_epochs
261        self.atol = atol
262        self.window_size = window_size
263        self.decrease_in_loss = np.zeros(max_epochs)
264        self.average_decrease_in_loss = np.zeros(max_epochs)
265
266    def check_loss(self, iternum, loss_trace):
267
268        if iternum >= 1:
269            self.decrease_in_loss[iternum] = (
270                loss_trace[iternum - 1] - loss_trace[iternum]
271            )
272            if iternum >= self.window_size:
273                self.average_decrease_in_loss[iternum] = np.mean(
274                    self.decrease_in_loss[iternum - self.window_size + 1 : iternum]
275                )
276                has_converged = self.average_decrease_in_loss[iternum] < self.atol
277                return has_converged
278
279        return False
LossNotDecreasingChecker(max_epochs, atol=0.01, window_size=10)
259    def __init__(self, max_epochs, atol=1e-2, window_size=10):
260        self.max_epochs = max_epochs
261        self.atol = atol
262        self.window_size = window_size
263        self.decrease_in_loss = np.zeros(max_epochs)
264        self.average_decrease_in_loss = np.zeros(max_epochs)
def check_loss(self, iternum, loss_trace):
266    def check_loss(self, iternum, loss_trace):
267
268        if iternum >= 1:
269            self.decrease_in_loss[iternum] = (
270                loss_trace[iternum - 1] - loss_trace[iternum]
271            )
272            if iternum >= self.window_size:
273                self.average_decrease_in_loss[iternum] = np.mean(
274                    self.decrease_in_loss[iternum - self.window_size + 1 : iternum]
275                )
276                has_converged = self.average_decrease_in_loss[iternum] < self.atol
277                return has_converged
278
279        return False