gpsa.plotting.callbacks
1import torch 2import numpy as np 3import matplotlib.pyplot as plt 4from torch.utils.data import Dataset, DataLoader 5import time 6import pandas as pd 7from scipy.stats import pearsonr 8 9from matplotlib.lines import Line2D 10 11import seaborn as sns 12 13SCATTER_POINT_SIZE = 50 14 15 16def callback_oned( 17 model, 18 X, 19 Y, 20 X_aligned, 21 data_expression_ax, 22 latent_expression_ax, 23 prediction_ax=None, 24 X_test=None, 25 Y_pred=None, 26 Y_test_true=None, 27 X_test_aligned=None, 28 F_samples=None, 29): 30 model.eval() 31 markers = list(Line2D.markers.keys()) 32 colors = ["blue", "orange"] 33 34 if model.fixed_view_idx is not None: 35 curr_idx = model.view_idx["expression"][model.fixed_view_idx] 36 X_aligned["expression"][curr_idx] = torch.tensor(X[curr_idx].astype(np.float32)) 37 38 data_expression_ax.cla() 39 latent_expression_ax.cla() 40 41 data_expression_ax.set_title("Observed data") 42 latent_expression_ax.set_title("Aligned data") 43 44 data_expression_ax.set_xlabel("Spatial coordinate") 45 latent_expression_ax.set_xlabel("Spatial coordinate") 46 47 data_expression_ax.set_ylabel("Outcome") 48 latent_expression_ax.set_ylabel("Outcome") 49 50 data_expression_ax.set_xlim([X.min(), X.max()]) 51 latent_expression_ax.set_xlim([X.min(), X.max()]) 52 53 for vv in range(model.n_views): 54 55 view_idx = model.view_idx["expression"] 56 57 data_expression_ax.scatter( 58 X[view_idx[vv], 0], 59 Y[view_idx[vv], 0], 60 label="View {}".format(vv + 1), 61 marker=markers[vv], 62 s=SCATTER_POINT_SIZE, 63 c="blue", 64 ) 65 if Y.shape[1] > 1: 66 data_expression_ax.scatter( 67 X[view_idx[vv], 0], 68 Y[view_idx[vv], 1], 69 label="View {}".format(vv + 1), 70 marker=markers[vv], 71 s=SCATTER_POINT_SIZE, 72 c="orange", 73 ) 74 latent_expression_ax.scatter( 75 # model.G_means["expression"].detach().numpy()[view_idx[vv], 0], 76 X_aligned["expression"].detach().numpy()[view_idx[vv], 0], 77 Y[view_idx[vv], 0], 78 c="blue", 79 label="View {}".format(vv + 1), 80 marker=markers[vv], 81 s=SCATTER_POINT_SIZE, 82 ) 83 if Y.shape[1] > 1: 84 latent_expression_ax.scatter( 85 # model.G_means["expression"].detach().numpy()[view_idx[vv], 0], 86 X_aligned["expression"].detach().numpy()[view_idx[vv], 0], 87 Y[view_idx[vv], 1], 88 c="orange", 89 label="View {}".format(vv + 1), 90 marker=markers[vv], 91 s=SCATTER_POINT_SIZE, 92 ) 93 # latent_expression_ax.scatter( 94 # model.Xtilde.detach().numpy()[vv, :, 0], 95 # model.delta_list.detach().numpy()[vv][:, 0], 96 # c="red", 97 # label="View {}".format(vv + 1), 98 # marker="^", 99 # s=100, 100 # ) 101 102 if F_samples is not None: 103 latent_expression_ax.scatter( 104 X_aligned["expression"].detach().numpy()[view_idx[vv], 0], 105 F_samples.detach().numpy()[view_idx[vv], 0], 106 c="red", 107 marker=markers[vv], 108 s=SCATTER_POINT_SIZE, 109 ) 110 if Y.shape[1] > 1: 111 latent_expression_ax.scatter( 112 X_aligned["expression"].detach().numpy()[view_idx[vv], 0], 113 F_samples.detach().numpy()[view_idx[vv], 1], 114 c="green", 115 marker=markers[vv], 116 s=SCATTER_POINT_SIZE, 117 ) 118 119 if prediction_ax is not None: 120 121 prediction_ax.cla() 122 prediction_ax.set_title("Predictions") 123 prediction_ax.set_xlabel("True outcome") 124 prediction_ax.set_ylabel("Predicted outcome") 125 126 ### Plots the warping function 127 # prediction_ax.scatter( 128 # X[view_idx[vv], 0], 129 # X_aligned["expression"].detach().numpy()[view_idx[vv], 0], 130 # label="View {}".format(vv + 1), 131 # marker=markers[vv], 132 # s=100, 133 # c="blue", 134 # ) 135 # prediction_ax.scatter( 136 # model.Xtilde.detach().numpy()[vv, :, 0], 137 # model.delta_list.detach().numpy()[vv][:, 0], 138 # c="red", 139 # label="View {}".format(vv + 1), 140 # marker="^", 141 # s=100, 142 # ) 143 latent_expression_ax.scatter( 144 X_test_aligned["expression"].detach().numpy()[:, 0], 145 Y_pred.detach().numpy()[:, 0], 146 c="blue", 147 label="Prediction", 148 marker="^", 149 s=SCATTER_POINT_SIZE, 150 ) 151 latent_expression_ax.scatter( 152 X_test_aligned["expression"].detach().numpy()[:, 0], 153 Y_pred.detach().numpy()[:, 1], 154 c="orange", 155 label="Prediction", 156 marker="^", 157 s=SCATTER_POINT_SIZE, 158 ) 159 prediction_ax.scatter( 160 Y_test_true[:, 0], 161 Y_pred.detach().numpy()[:, 0], 162 c="black", 163 s=SCATTER_POINT_SIZE, 164 ) 165 prediction_ax.scatter( 166 Y_test_true[:, 1], 167 Y_pred.detach().numpy()[:, 1], 168 c="black", 169 s=SCATTER_POINT_SIZE, 170 marker="^", 171 ) 172 173 data_expression_ax.legend() 174 plt.draw() 175 plt.pause(1 / 60.0) 176 177 178def callback_twod( 179 model, 180 X, 181 Y, 182 X_aligned, 183 data_expression_ax, 184 latent_expression_ax, 185 is_mle=False, 186 gene_idx=0, 187 s=200, 188 include_legend=False, 189): 190 191 if model.fixed_view_idx is not None: 192 if is_mle: 193 pass 194 else: 195 curr_idx = model.view_idx["expression"][model.fixed_view_idx] 196 X_aligned["expression"][curr_idx] = torch.tensor( 197 X[curr_idx].astype(np.float32) 198 ) 199 200 model.eval() 201 markers = [".", "+", "^"] 202 colors = ["blue", "orange"] 203 204 data_expression_ax.cla() 205 latent_expression_ax.cla() 206 data_expression_ax.set_title("Observed data") 207 latent_expression_ax.set_title("Aligned data") 208 209 curr_view_idx = model.view_idx["expression"] 210 211 latent_Xs = [] 212 Xs = [] 213 Ys = [] 214 markers_list = [] 215 viewname_list = [] 216 217 for vv in range(model.n_views): 218 219 ## Data 220 Xs.append(X[curr_view_idx[vv]]) 221 222 ## Latents 223 curr_latent_Xs = X_aligned["expression"].detach().numpy()[curr_view_idx[vv]] 224 latent_Xs.append(curr_latent_Xs) 225 Ys.append(Y[curr_view_idx[vv], gene_idx]) 226 markers_list.append([markers[vv]] * curr_latent_Xs.shape[0]) 227 viewname_list.append( 228 ["Observation {}".format(vv + 1)] * curr_latent_Xs.shape[0] 229 ) 230 231 Xs = np.concatenate(Xs, axis=0) 232 latent_Xs = np.concatenate(latent_Xs, axis=0) 233 Ys = np.concatenate(Ys) 234 markers_list = np.concatenate(markers_list) 235 viewname_list = np.concatenate(viewname_list) 236 237 data_df = pd.DataFrame( 238 { 239 "X1": Xs[:, 0], 240 "X2": Xs[:, 1], 241 "Y": Ys, 242 "marker": markers_list, 243 "view": viewname_list, 244 } 245 ) 246 247 latent_df = pd.DataFrame( 248 { 249 "X1": latent_Xs[:, 0], 250 "X2": latent_Xs[:, 1], 251 "Y": Ys, 252 "marker": markers_list, 253 "view": viewname_list, 254 } 255 ) 256 257 plt.sca(data_expression_ax) 258 g = sns.scatterplot( 259 data=data_df, 260 x="X1", 261 y="X2", 262 hue="Y", 263 style="view", 264 ax=data_expression_ax, 265 s=s, 266 linewidth=1.8, 267 edgecolor="black", 268 palette="viridis", 269 ) 270 if not include_legend: 271 g.legend_.remove() 272 # plt.colorbar() 273 # plt.axis("off") 274 # plt.scatter(model.Xtilde.detach().numpy()[0, :, 0], model.Xtilde.detach().numpy()[0, :, 1], color="red") 275 # plt.scatter(model.Xtilde.detach().numpy()[1, :, 0], model.Xtilde.detach().numpy()[1, :, 1], color="red") 276 # plt.scatter(model.Gtilde.detach().numpy()[:, 0], model.Gtilde.detach().numpy()[:, 1], color="red") 277 # plt.axis("off") 278 279 plt.sca(latent_expression_ax) 280 g = sns.scatterplot( 281 data=latent_df, 282 x="X1", 283 y="X2", 284 hue="Y", 285 style="view", 286 ax=latent_expression_ax, 287 s=s, 288 linewidth=1.8, 289 edgecolor="black", 290 palette="viridis", 291 ) 292 if not include_legend: 293 g.legend_.remove() 294 # plt.colorbar() 295 296 # import ipdb; ipdb.set_trace() 297 298 # for vv in range(model.n_views): 299 300 # # import ipdb; ipdb.set_trace() 301 # data_expression_ax.scatter( 302 # X[curr_view_idx[vv], 0], 303 # X[curr_view_idx[vv], 1], 304 # c=Y[curr_view_idx[vv], 0], 305 # label="View {}".format(vv + 1), 306 # marker=markers[vv], 307 # s=400, 308 # ) 309 # latent_expression_ax.scatter( 310 # X_aligned["expression"].detach().numpy()[curr_view_idx[vv], 0], 311 # X_aligned["expression"].detach().numpy()[curr_view_idx[vv], 1], 312 # c=Y[curr_view_idx[vv], 0], 313 # label="View {}".format(vv + 1), 314 # marker=markers[vv], 315 # s=400, 316 # ) 317 # plt.axis("off") 318 319 320def callback_twod_aligned_only( 321 model, 322 X, 323 Y, 324 X_aligned, 325 latent_expression_ax1, 326 latent_expression_ax2, 327 is_mle=False, 328 gene_idx=0, 329): 330 331 if model.fixed_view_idx is not None: 332 if is_mle: 333 pass 334 else: 335 curr_idx = model.view_idx["expression"][model.fixed_view_idx] 336 X_aligned["expression"][curr_idx] = torch.tensor( 337 X[curr_idx].astype(np.float32) 338 ) 339 340 model.eval() 341 markers = [".", "+", "^"] 342 colors = ["blue", "orange"] 343 344 latent_expression_ax1.cla() 345 latent_expression_ax2.cla() 346 latent_expression_ax1.set_title("Observed data") 347 latent_expression_ax2.set_title("Aligned data") 348 349 curr_view_idx = model.view_idx["expression"] 350 351 latent_Xs = [] 352 Xs = [] 353 Ys = [] 354 markers_list = [] 355 viewname_list = [] 356 357 aligned_coords = X_aligned["expression"].detach().numpy() 358 359 for vv in range(model.n_views): 360 361 ## Data 362 Xs.append(X[curr_view_idx[vv]]) 363 364 ## Latents 365 curr_latent_Xs = aligned_coords[curr_view_idx[vv]] 366 latent_Xs.append(curr_latent_Xs) 367 Ys.append(Y[curr_view_idx[vv], gene_idx]) 368 markers_list.append([markers[vv]] * curr_latent_Xs.shape[0]) 369 viewname_list.append(["View {}".format(vv + 1)] * curr_latent_Xs.shape[0]) 370 371 latent_expression_ax1.scatter( 372 aligned_coords[curr_view_idx[0]][:, 0], 373 aligned_coords[curr_view_idx[0]][:, 1], 374 c=Y[curr_view_idx[0]][:, gene_idx].squeeze(), 375 s=24, 376 marker="h", 377 ) 378 latent_expression_ax2.scatter( 379 aligned_coords[curr_view_idx[1]][:, 0], 380 aligned_coords[curr_view_idx[1]][:, 1], 381 c=Y[curr_view_idx[1]][:, gene_idx].squeeze(), 382 s=24, 383 marker="h", 384 ) 385 # latent_expression_ax1.scatter(model.Xtilde.detach().numpy()[0, :, 0], model.Xtilde.detach().numpy()[0, :, 1], color="red") 386 # latent_expression_ax2.scatter(model.Xtilde.detach().numpy()[1, :, 0], model.Xtilde.detach().numpy()[1, :, 1], color="red") 387 388 plt.axis("off") 389 390 391def callback_twod_multimodal( 392 model, data_dict, X_aligned, axes, rgb=False, scatterpoint_size=100 393): 394 395 # if model.fixed_view_idx is not None: 396 # if is_mle: 397 # pass 398 # else: 399 # curr_idx = model.view_idx["expression"][model.fixed_view_idx] 400 # X_aligned["expression"][curr_idx] = torch.tensor(X[curr_idx].astype(np.float32)) 401 402 model.eval() 403 markers = [".", "+", "^"] 404 colors = ["blue", "orange"] 405 406 [ax.cla() for ax in axes] 407 408 axes[0].set_title("Observed expression") 409 axes[1].set_title("Aligned expression") 410 axes[2].set_title("Observed histology") 411 axes[3].set_title("Aligned histology") 412 413 axis_counter = 0 414 n_mods = 2 415 for mod in ["expression", "histology"]: 416 curr_view_idx = model.view_idx[mod] 417 for vv in range(model.n_views): 418 419 # import ipdb; ipdb.set_trace() 420 curr_coords = data_dict[mod]["spatial_coords"] 421 422 if mod == "histology" and rgb: 423 curr_outputs = data_dict[mod]["outputs"][curr_view_idx[vv], :] 424 else: 425 curr_outputs = data_dict[mod]["outputs"][curr_view_idx[vv], 0] 426 axes[axis_counter].scatter( 427 curr_coords[curr_view_idx[vv], 0], 428 curr_coords[curr_view_idx[vv], 1], 429 c=curr_outputs, 430 label="View {}".format(vv + 1), 431 marker=markers[vv], 432 s=scatterpoint_size, 433 ) 434 axes[axis_counter + 1].scatter( 435 X_aligned[mod].detach().numpy()[curr_view_idx[vv], 0], 436 X_aligned[mod].detach().numpy()[curr_view_idx[vv], 1], 437 c=curr_outputs, 438 label="View {}".format(vv + 1), 439 marker=markers[vv], 440 s=scatterpoint_size, 441 ) 442 axis_counter += n_mods
def
callback_oned( model, X, Y, X_aligned, data_expression_ax, latent_expression_ax, prediction_ax=None, X_test=None, Y_pred=None, Y_test_true=None, X_test_aligned=None, F_samples=None):
17def callback_oned( 18 model, 19 X, 20 Y, 21 X_aligned, 22 data_expression_ax, 23 latent_expression_ax, 24 prediction_ax=None, 25 X_test=None, 26 Y_pred=None, 27 Y_test_true=None, 28 X_test_aligned=None, 29 F_samples=None, 30): 31 model.eval() 32 markers = list(Line2D.markers.keys()) 33 colors = ["blue", "orange"] 34 35 if model.fixed_view_idx is not None: 36 curr_idx = model.view_idx["expression"][model.fixed_view_idx] 37 X_aligned["expression"][curr_idx] = torch.tensor(X[curr_idx].astype(np.float32)) 38 39 data_expression_ax.cla() 40 latent_expression_ax.cla() 41 42 data_expression_ax.set_title("Observed data") 43 latent_expression_ax.set_title("Aligned data") 44 45 data_expression_ax.set_xlabel("Spatial coordinate") 46 latent_expression_ax.set_xlabel("Spatial coordinate") 47 48 data_expression_ax.set_ylabel("Outcome") 49 latent_expression_ax.set_ylabel("Outcome") 50 51 data_expression_ax.set_xlim([X.min(), X.max()]) 52 latent_expression_ax.set_xlim([X.min(), X.max()]) 53 54 for vv in range(model.n_views): 55 56 view_idx = model.view_idx["expression"] 57 58 data_expression_ax.scatter( 59 X[view_idx[vv], 0], 60 Y[view_idx[vv], 0], 61 label="View {}".format(vv + 1), 62 marker=markers[vv], 63 s=SCATTER_POINT_SIZE, 64 c="blue", 65 ) 66 if Y.shape[1] > 1: 67 data_expression_ax.scatter( 68 X[view_idx[vv], 0], 69 Y[view_idx[vv], 1], 70 label="View {}".format(vv + 1), 71 marker=markers[vv], 72 s=SCATTER_POINT_SIZE, 73 c="orange", 74 ) 75 latent_expression_ax.scatter( 76 # model.G_means["expression"].detach().numpy()[view_idx[vv], 0], 77 X_aligned["expression"].detach().numpy()[view_idx[vv], 0], 78 Y[view_idx[vv], 0], 79 c="blue", 80 label="View {}".format(vv + 1), 81 marker=markers[vv], 82 s=SCATTER_POINT_SIZE, 83 ) 84 if Y.shape[1] > 1: 85 latent_expression_ax.scatter( 86 # model.G_means["expression"].detach().numpy()[view_idx[vv], 0], 87 X_aligned["expression"].detach().numpy()[view_idx[vv], 0], 88 Y[view_idx[vv], 1], 89 c="orange", 90 label="View {}".format(vv + 1), 91 marker=markers[vv], 92 s=SCATTER_POINT_SIZE, 93 ) 94 # latent_expression_ax.scatter( 95 # model.Xtilde.detach().numpy()[vv, :, 0], 96 # model.delta_list.detach().numpy()[vv][:, 0], 97 # c="red", 98 # label="View {}".format(vv + 1), 99 # marker="^", 100 # s=100, 101 # ) 102 103 if F_samples is not None: 104 latent_expression_ax.scatter( 105 X_aligned["expression"].detach().numpy()[view_idx[vv], 0], 106 F_samples.detach().numpy()[view_idx[vv], 0], 107 c="red", 108 marker=markers[vv], 109 s=SCATTER_POINT_SIZE, 110 ) 111 if Y.shape[1] > 1: 112 latent_expression_ax.scatter( 113 X_aligned["expression"].detach().numpy()[view_idx[vv], 0], 114 F_samples.detach().numpy()[view_idx[vv], 1], 115 c="green", 116 marker=markers[vv], 117 s=SCATTER_POINT_SIZE, 118 ) 119 120 if prediction_ax is not None: 121 122 prediction_ax.cla() 123 prediction_ax.set_title("Predictions") 124 prediction_ax.set_xlabel("True outcome") 125 prediction_ax.set_ylabel("Predicted outcome") 126 127 ### Plots the warping function 128 # prediction_ax.scatter( 129 # X[view_idx[vv], 0], 130 # X_aligned["expression"].detach().numpy()[view_idx[vv], 0], 131 # label="View {}".format(vv + 1), 132 # marker=markers[vv], 133 # s=100, 134 # c="blue", 135 # ) 136 # prediction_ax.scatter( 137 # model.Xtilde.detach().numpy()[vv, :, 0], 138 # model.delta_list.detach().numpy()[vv][:, 0], 139 # c="red", 140 # label="View {}".format(vv + 1), 141 # marker="^", 142 # s=100, 143 # ) 144 latent_expression_ax.scatter( 145 X_test_aligned["expression"].detach().numpy()[:, 0], 146 Y_pred.detach().numpy()[:, 0], 147 c="blue", 148 label="Prediction", 149 marker="^", 150 s=SCATTER_POINT_SIZE, 151 ) 152 latent_expression_ax.scatter( 153 X_test_aligned["expression"].detach().numpy()[:, 0], 154 Y_pred.detach().numpy()[:, 1], 155 c="orange", 156 label="Prediction", 157 marker="^", 158 s=SCATTER_POINT_SIZE, 159 ) 160 prediction_ax.scatter( 161 Y_test_true[:, 0], 162 Y_pred.detach().numpy()[:, 0], 163 c="black", 164 s=SCATTER_POINT_SIZE, 165 ) 166 prediction_ax.scatter( 167 Y_test_true[:, 1], 168 Y_pred.detach().numpy()[:, 1], 169 c="black", 170 s=SCATTER_POINT_SIZE, 171 marker="^", 172 ) 173 174 data_expression_ax.legend() 175 plt.draw() 176 plt.pause(1 / 60.0)
def
callback_twod( model, X, Y, X_aligned, data_expression_ax, latent_expression_ax, is_mle=False, gene_idx=0, s=200, include_legend=False):
179def callback_twod( 180 model, 181 X, 182 Y, 183 X_aligned, 184 data_expression_ax, 185 latent_expression_ax, 186 is_mle=False, 187 gene_idx=0, 188 s=200, 189 include_legend=False, 190): 191 192 if model.fixed_view_idx is not None: 193 if is_mle: 194 pass 195 else: 196 curr_idx = model.view_idx["expression"][model.fixed_view_idx] 197 X_aligned["expression"][curr_idx] = torch.tensor( 198 X[curr_idx].astype(np.float32) 199 ) 200 201 model.eval() 202 markers = [".", "+", "^"] 203 colors = ["blue", "orange"] 204 205 data_expression_ax.cla() 206 latent_expression_ax.cla() 207 data_expression_ax.set_title("Observed data") 208 latent_expression_ax.set_title("Aligned data") 209 210 curr_view_idx = model.view_idx["expression"] 211 212 latent_Xs = [] 213 Xs = [] 214 Ys = [] 215 markers_list = [] 216 viewname_list = [] 217 218 for vv in range(model.n_views): 219 220 ## Data 221 Xs.append(X[curr_view_idx[vv]]) 222 223 ## Latents 224 curr_latent_Xs = X_aligned["expression"].detach().numpy()[curr_view_idx[vv]] 225 latent_Xs.append(curr_latent_Xs) 226 Ys.append(Y[curr_view_idx[vv], gene_idx]) 227 markers_list.append([markers[vv]] * curr_latent_Xs.shape[0]) 228 viewname_list.append( 229 ["Observation {}".format(vv + 1)] * curr_latent_Xs.shape[0] 230 ) 231 232 Xs = np.concatenate(Xs, axis=0) 233 latent_Xs = np.concatenate(latent_Xs, axis=0) 234 Ys = np.concatenate(Ys) 235 markers_list = np.concatenate(markers_list) 236 viewname_list = np.concatenate(viewname_list) 237 238 data_df = pd.DataFrame( 239 { 240 "X1": Xs[:, 0], 241 "X2": Xs[:, 1], 242 "Y": Ys, 243 "marker": markers_list, 244 "view": viewname_list, 245 } 246 ) 247 248 latent_df = pd.DataFrame( 249 { 250 "X1": latent_Xs[:, 0], 251 "X2": latent_Xs[:, 1], 252 "Y": Ys, 253 "marker": markers_list, 254 "view": viewname_list, 255 } 256 ) 257 258 plt.sca(data_expression_ax) 259 g = sns.scatterplot( 260 data=data_df, 261 x="X1", 262 y="X2", 263 hue="Y", 264 style="view", 265 ax=data_expression_ax, 266 s=s, 267 linewidth=1.8, 268 edgecolor="black", 269 palette="viridis", 270 ) 271 if not include_legend: 272 g.legend_.remove() 273 # plt.colorbar() 274 # plt.axis("off") 275 # plt.scatter(model.Xtilde.detach().numpy()[0, :, 0], model.Xtilde.detach().numpy()[0, :, 1], color="red") 276 # plt.scatter(model.Xtilde.detach().numpy()[1, :, 0], model.Xtilde.detach().numpy()[1, :, 1], color="red") 277 # plt.scatter(model.Gtilde.detach().numpy()[:, 0], model.Gtilde.detach().numpy()[:, 1], color="red") 278 # plt.axis("off") 279 280 plt.sca(latent_expression_ax) 281 g = sns.scatterplot( 282 data=latent_df, 283 x="X1", 284 y="X2", 285 hue="Y", 286 style="view", 287 ax=latent_expression_ax, 288 s=s, 289 linewidth=1.8, 290 edgecolor="black", 291 palette="viridis", 292 ) 293 if not include_legend: 294 g.legend_.remove()
def
callback_twod_aligned_only( model, X, Y, X_aligned, latent_expression_ax1, latent_expression_ax2, is_mle=False, gene_idx=0):
321def callback_twod_aligned_only( 322 model, 323 X, 324 Y, 325 X_aligned, 326 latent_expression_ax1, 327 latent_expression_ax2, 328 is_mle=False, 329 gene_idx=0, 330): 331 332 if model.fixed_view_idx is not None: 333 if is_mle: 334 pass 335 else: 336 curr_idx = model.view_idx["expression"][model.fixed_view_idx] 337 X_aligned["expression"][curr_idx] = torch.tensor( 338 X[curr_idx].astype(np.float32) 339 ) 340 341 model.eval() 342 markers = [".", "+", "^"] 343 colors = ["blue", "orange"] 344 345 latent_expression_ax1.cla() 346 latent_expression_ax2.cla() 347 latent_expression_ax1.set_title("Observed data") 348 latent_expression_ax2.set_title("Aligned data") 349 350 curr_view_idx = model.view_idx["expression"] 351 352 latent_Xs = [] 353 Xs = [] 354 Ys = [] 355 markers_list = [] 356 viewname_list = [] 357 358 aligned_coords = X_aligned["expression"].detach().numpy() 359 360 for vv in range(model.n_views): 361 362 ## Data 363 Xs.append(X[curr_view_idx[vv]]) 364 365 ## Latents 366 curr_latent_Xs = aligned_coords[curr_view_idx[vv]] 367 latent_Xs.append(curr_latent_Xs) 368 Ys.append(Y[curr_view_idx[vv], gene_idx]) 369 markers_list.append([markers[vv]] * curr_latent_Xs.shape[0]) 370 viewname_list.append(["View {}".format(vv + 1)] * curr_latent_Xs.shape[0]) 371 372 latent_expression_ax1.scatter( 373 aligned_coords[curr_view_idx[0]][:, 0], 374 aligned_coords[curr_view_idx[0]][:, 1], 375 c=Y[curr_view_idx[0]][:, gene_idx].squeeze(), 376 s=24, 377 marker="h", 378 ) 379 latent_expression_ax2.scatter( 380 aligned_coords[curr_view_idx[1]][:, 0], 381 aligned_coords[curr_view_idx[1]][:, 1], 382 c=Y[curr_view_idx[1]][:, gene_idx].squeeze(), 383 s=24, 384 marker="h", 385 ) 386 # latent_expression_ax1.scatter(model.Xtilde.detach().numpy()[0, :, 0], model.Xtilde.detach().numpy()[0, :, 1], color="red") 387 # latent_expression_ax2.scatter(model.Xtilde.detach().numpy()[1, :, 0], model.Xtilde.detach().numpy()[1, :, 1], color="red") 388 389 plt.axis("off")
def
callback_twod_multimodal(model, data_dict, X_aligned, axes, rgb=False, scatterpoint_size=100):
392def callback_twod_multimodal( 393 model, data_dict, X_aligned, axes, rgb=False, scatterpoint_size=100 394): 395 396 # if model.fixed_view_idx is not None: 397 # if is_mle: 398 # pass 399 # else: 400 # curr_idx = model.view_idx["expression"][model.fixed_view_idx] 401 # X_aligned["expression"][curr_idx] = torch.tensor(X[curr_idx].astype(np.float32)) 402 403 model.eval() 404 markers = [".", "+", "^"] 405 colors = ["blue", "orange"] 406 407 [ax.cla() for ax in axes] 408 409 axes[0].set_title("Observed expression") 410 axes[1].set_title("Aligned expression") 411 axes[2].set_title("Observed histology") 412 axes[3].set_title("Aligned histology") 413 414 axis_counter = 0 415 n_mods = 2 416 for mod in ["expression", "histology"]: 417 curr_view_idx = model.view_idx[mod] 418 for vv in range(model.n_views): 419 420 # import ipdb; ipdb.set_trace() 421 curr_coords = data_dict[mod]["spatial_coords"] 422 423 if mod == "histology" and rgb: 424 curr_outputs = data_dict[mod]["outputs"][curr_view_idx[vv], :] 425 else: 426 curr_outputs = data_dict[mod]["outputs"][curr_view_idx[vv], 0] 427 axes[axis_counter].scatter( 428 curr_coords[curr_view_idx[vv], 0], 429 curr_coords[curr_view_idx[vv], 1], 430 c=curr_outputs, 431 label="View {}".format(vv + 1), 432 marker=markers[vv], 433 s=scatterpoint_size, 434 ) 435 axes[axis_counter + 1].scatter( 436 X_aligned[mod].detach().numpy()[curr_view_idx[vv], 0], 437 X_aligned[mod].detach().numpy()[curr_view_idx[vv], 1], 438 c=curr_outputs, 439 label="View {}".format(vv + 1), 440 marker=markers[vv], 441 s=scatterpoint_size, 442 ) 443 axis_counter += n_mods