gpsa.plotting.callbacks

  1import torch
  2import numpy as np
  3import matplotlib.pyplot as plt
  4from torch.utils.data import Dataset, DataLoader
  5import time
  6import pandas as pd
  7from scipy.stats import pearsonr
  8
  9from matplotlib.lines import Line2D
 10
 11import seaborn as sns
 12
 13SCATTER_POINT_SIZE = 50
 14
 15
 16def callback_oned(
 17    model,
 18    X,
 19    Y,
 20    X_aligned,
 21    data_expression_ax,
 22    latent_expression_ax,
 23    prediction_ax=None,
 24    X_test=None,
 25    Y_pred=None,
 26    Y_test_true=None,
 27    X_test_aligned=None,
 28    F_samples=None,
 29):
 30    model.eval()
 31    markers = list(Line2D.markers.keys())
 32    colors = ["blue", "orange"]
 33
 34    if model.fixed_view_idx is not None:
 35        curr_idx = model.view_idx["expression"][model.fixed_view_idx]
 36        X_aligned["expression"][curr_idx] = torch.tensor(X[curr_idx].astype(np.float32))
 37
 38    data_expression_ax.cla()
 39    latent_expression_ax.cla()
 40
 41    data_expression_ax.set_title("Observed data")
 42    latent_expression_ax.set_title("Aligned data")
 43
 44    data_expression_ax.set_xlabel("Spatial coordinate")
 45    latent_expression_ax.set_xlabel("Spatial coordinate")
 46
 47    data_expression_ax.set_ylabel("Outcome")
 48    latent_expression_ax.set_ylabel("Outcome")
 49
 50    data_expression_ax.set_xlim([X.min(), X.max()])
 51    latent_expression_ax.set_xlim([X.min(), X.max()])
 52
 53    for vv in range(model.n_views):
 54
 55        view_idx = model.view_idx["expression"]
 56
 57        data_expression_ax.scatter(
 58            X[view_idx[vv], 0],
 59            Y[view_idx[vv], 0],
 60            label="View {}".format(vv + 1),
 61            marker=markers[vv],
 62            s=SCATTER_POINT_SIZE,
 63            c="blue",
 64        )
 65        if Y.shape[1] > 1:
 66            data_expression_ax.scatter(
 67                X[view_idx[vv], 0],
 68                Y[view_idx[vv], 1],
 69                label="View {}".format(vv + 1),
 70                marker=markers[vv],
 71                s=SCATTER_POINT_SIZE,
 72                c="orange",
 73            )
 74        latent_expression_ax.scatter(
 75            # model.G_means["expression"].detach().numpy()[view_idx[vv], 0],
 76            X_aligned["expression"].detach().numpy()[view_idx[vv], 0],
 77            Y[view_idx[vv], 0],
 78            c="blue",
 79            label="View {}".format(vv + 1),
 80            marker=markers[vv],
 81            s=SCATTER_POINT_SIZE,
 82        )
 83        if Y.shape[1] > 1:
 84            latent_expression_ax.scatter(
 85                # model.G_means["expression"].detach().numpy()[view_idx[vv], 0],
 86                X_aligned["expression"].detach().numpy()[view_idx[vv], 0],
 87                Y[view_idx[vv], 1],
 88                c="orange",
 89                label="View {}".format(vv + 1),
 90                marker=markers[vv],
 91                s=SCATTER_POINT_SIZE,
 92            )
 93        # latent_expression_ax.scatter(
 94        # 	model.Xtilde.detach().numpy()[vv, :, 0],
 95        # 	model.delta_list.detach().numpy()[vv][:, 0],
 96        # 	c="red",
 97        # 	label="View {}".format(vv + 1),
 98        # 	marker="^",
 99        # 	s=100,
100        # )
101
102        if F_samples is not None:
103            latent_expression_ax.scatter(
104                X_aligned["expression"].detach().numpy()[view_idx[vv], 0],
105                F_samples.detach().numpy()[view_idx[vv], 0],
106                c="red",
107                marker=markers[vv],
108                s=SCATTER_POINT_SIZE,
109            )
110            if Y.shape[1] > 1:
111                latent_expression_ax.scatter(
112                    X_aligned["expression"].detach().numpy()[view_idx[vv], 0],
113                    F_samples.detach().numpy()[view_idx[vv], 1],
114                    c="green",
115                    marker=markers[vv],
116                    s=SCATTER_POINT_SIZE,
117                )
118
119    if prediction_ax is not None:
120
121        prediction_ax.cla()
122        prediction_ax.set_title("Predictions")
123        prediction_ax.set_xlabel("True outcome")
124        prediction_ax.set_ylabel("Predicted outcome")
125
126        ### Plots the warping function
127        # prediction_ax.scatter(
128        # 	X[view_idx[vv], 0],
129        # 	X_aligned["expression"].detach().numpy()[view_idx[vv], 0],
130        # 	label="View {}".format(vv + 1),
131        # 	marker=markers[vv],
132        # 	s=100,
133        # 	c="blue",
134        # )
135        # prediction_ax.scatter(
136        # 	model.Xtilde.detach().numpy()[vv, :, 0],
137        # 	model.delta_list.detach().numpy()[vv][:, 0],
138        # 	c="red",
139        # 	label="View {}".format(vv + 1),
140        # 	marker="^",
141        # 	s=100,
142        # )
143        latent_expression_ax.scatter(
144            X_test_aligned["expression"].detach().numpy()[:, 0],
145            Y_pred.detach().numpy()[:, 0],
146            c="blue",
147            label="Prediction",
148            marker="^",
149            s=SCATTER_POINT_SIZE,
150        )
151        latent_expression_ax.scatter(
152            X_test_aligned["expression"].detach().numpy()[:, 0],
153            Y_pred.detach().numpy()[:, 1],
154            c="orange",
155            label="Prediction",
156            marker="^",
157            s=SCATTER_POINT_SIZE,
158        )
159        prediction_ax.scatter(
160            Y_test_true[:, 0],
161            Y_pred.detach().numpy()[:, 0],
162            c="black",
163            s=SCATTER_POINT_SIZE,
164        )
165        prediction_ax.scatter(
166            Y_test_true[:, 1],
167            Y_pred.detach().numpy()[:, 1],
168            c="black",
169            s=SCATTER_POINT_SIZE,
170            marker="^",
171        )
172
173    data_expression_ax.legend()
174    plt.draw()
175    plt.pause(1 / 60.0)
176
177
178def callback_twod(
179    model,
180    X,
181    Y,
182    X_aligned,
183    data_expression_ax,
184    latent_expression_ax,
185    is_mle=False,
186    gene_idx=0,
187    s=200,
188    include_legend=False,
189):
190
191    if model.fixed_view_idx is not None:
192        if is_mle:
193            pass
194        else:
195            curr_idx = model.view_idx["expression"][model.fixed_view_idx]
196            X_aligned["expression"][curr_idx] = torch.tensor(
197                X[curr_idx].astype(np.float32)
198            )
199
200    model.eval()
201    markers = [".", "+", "^"]
202    colors = ["blue", "orange"]
203
204    data_expression_ax.cla()
205    latent_expression_ax.cla()
206    data_expression_ax.set_title("Observed data")
207    latent_expression_ax.set_title("Aligned data")
208
209    curr_view_idx = model.view_idx["expression"]
210
211    latent_Xs = []
212    Xs = []
213    Ys = []
214    markers_list = []
215    viewname_list = []
216
217    for vv in range(model.n_views):
218
219        ## Data
220        Xs.append(X[curr_view_idx[vv]])
221
222        ## Latents
223        curr_latent_Xs = X_aligned["expression"].detach().numpy()[curr_view_idx[vv]]
224        latent_Xs.append(curr_latent_Xs)
225        Ys.append(Y[curr_view_idx[vv], gene_idx])
226        markers_list.append([markers[vv]] * curr_latent_Xs.shape[0])
227        viewname_list.append(
228            ["Observation {}".format(vv + 1)] * curr_latent_Xs.shape[0]
229        )
230
231    Xs = np.concatenate(Xs, axis=0)
232    latent_Xs = np.concatenate(latent_Xs, axis=0)
233    Ys = np.concatenate(Ys)
234    markers_list = np.concatenate(markers_list)
235    viewname_list = np.concatenate(viewname_list)
236
237    data_df = pd.DataFrame(
238        {
239            "X1": Xs[:, 0],
240            "X2": Xs[:, 1],
241            "Y": Ys,
242            "marker": markers_list,
243            "view": viewname_list,
244        }
245    )
246
247    latent_df = pd.DataFrame(
248        {
249            "X1": latent_Xs[:, 0],
250            "X2": latent_Xs[:, 1],
251            "Y": Ys,
252            "marker": markers_list,
253            "view": viewname_list,
254        }
255    )
256
257    plt.sca(data_expression_ax)
258    g = sns.scatterplot(
259        data=data_df,
260        x="X1",
261        y="X2",
262        hue="Y",
263        style="view",
264        ax=data_expression_ax,
265        s=s,
266        linewidth=1.8,
267        edgecolor="black",
268        palette="viridis",
269    )
270    if not include_legend:
271        g.legend_.remove()
272    # plt.colorbar()
273    # plt.axis("off")
274    # plt.scatter(model.Xtilde.detach().numpy()[0, :, 0], model.Xtilde.detach().numpy()[0, :, 1], color="red")
275    # plt.scatter(model.Xtilde.detach().numpy()[1, :, 0], model.Xtilde.detach().numpy()[1, :, 1], color="red")
276    # plt.scatter(model.Gtilde.detach().numpy()[:, 0], model.Gtilde.detach().numpy()[:, 1], color="red")
277    # plt.axis("off")
278
279    plt.sca(latent_expression_ax)
280    g = sns.scatterplot(
281        data=latent_df,
282        x="X1",
283        y="X2",
284        hue="Y",
285        style="view",
286        ax=latent_expression_ax,
287        s=s,
288        linewidth=1.8,
289        edgecolor="black",
290        palette="viridis",
291    )
292    if not include_legend:
293        g.legend_.remove()
294    # plt.colorbar()
295
296    # import ipdb; ipdb.set_trace()
297
298    # for vv in range(model.n_views):
299
300    #     # import ipdb; ipdb.set_trace()
301    #     data_expression_ax.scatter(
302    #         X[curr_view_idx[vv], 0],
303    #         X[curr_view_idx[vv], 1],
304    #         c=Y[curr_view_idx[vv], 0],
305    #         label="View {}".format(vv + 1),
306    #         marker=markers[vv],
307    #         s=400,
308    #     )
309    # latent_expression_ax.scatter(
310    #     X_aligned["expression"].detach().numpy()[curr_view_idx[vv], 0],
311    #     X_aligned["expression"].detach().numpy()[curr_view_idx[vv], 1],
312    #     c=Y[curr_view_idx[vv], 0],
313    #     label="View {}".format(vv + 1),
314    #     marker=markers[vv],
315    #     s=400,
316    # )
317    # plt.axis("off")
318
319
320def callback_twod_aligned_only(
321    model,
322    X,
323    Y,
324    X_aligned,
325    latent_expression_ax1,
326    latent_expression_ax2,
327    is_mle=False,
328    gene_idx=0,
329):
330
331    if model.fixed_view_idx is not None:
332        if is_mle:
333            pass
334        else:
335            curr_idx = model.view_idx["expression"][model.fixed_view_idx]
336            X_aligned["expression"][curr_idx] = torch.tensor(
337                X[curr_idx].astype(np.float32)
338            )
339
340    model.eval()
341    markers = [".", "+", "^"]
342    colors = ["blue", "orange"]
343
344    latent_expression_ax1.cla()
345    latent_expression_ax2.cla()
346    latent_expression_ax1.set_title("Observed data")
347    latent_expression_ax2.set_title("Aligned data")
348
349    curr_view_idx = model.view_idx["expression"]
350
351    latent_Xs = []
352    Xs = []
353    Ys = []
354    markers_list = []
355    viewname_list = []
356
357    aligned_coords = X_aligned["expression"].detach().numpy()
358
359    for vv in range(model.n_views):
360
361        ## Data
362        Xs.append(X[curr_view_idx[vv]])
363
364        ## Latents
365        curr_latent_Xs = aligned_coords[curr_view_idx[vv]]
366        latent_Xs.append(curr_latent_Xs)
367        Ys.append(Y[curr_view_idx[vv], gene_idx])
368        markers_list.append([markers[vv]] * curr_latent_Xs.shape[0])
369        viewname_list.append(["View {}".format(vv + 1)] * curr_latent_Xs.shape[0])
370
371    latent_expression_ax1.scatter(
372        aligned_coords[curr_view_idx[0]][:, 0],
373        aligned_coords[curr_view_idx[0]][:, 1],
374        c=Y[curr_view_idx[0]][:, gene_idx].squeeze(),
375        s=24,
376        marker="h",
377    )
378    latent_expression_ax2.scatter(
379        aligned_coords[curr_view_idx[1]][:, 0],
380        aligned_coords[curr_view_idx[1]][:, 1],
381        c=Y[curr_view_idx[1]][:, gene_idx].squeeze(),
382        s=24,
383        marker="h",
384    )
385    # latent_expression_ax1.scatter(model.Xtilde.detach().numpy()[0, :, 0], model.Xtilde.detach().numpy()[0, :, 1], color="red")
386    # latent_expression_ax2.scatter(model.Xtilde.detach().numpy()[1, :, 0], model.Xtilde.detach().numpy()[1, :, 1], color="red")
387
388    plt.axis("off")
389
390
391def callback_twod_multimodal(
392    model, data_dict, X_aligned, axes, rgb=False, scatterpoint_size=100
393):
394
395    # if model.fixed_view_idx is not None:
396    #     if is_mle:
397    #         pass
398    #     else:
399    #         curr_idx = model.view_idx["expression"][model.fixed_view_idx]
400    #         X_aligned["expression"][curr_idx] = torch.tensor(X[curr_idx].astype(np.float32))
401
402    model.eval()
403    markers = [".", "+", "^"]
404    colors = ["blue", "orange"]
405
406    [ax.cla() for ax in axes]
407
408    axes[0].set_title("Observed expression")
409    axes[1].set_title("Aligned expression")
410    axes[2].set_title("Observed histology")
411    axes[3].set_title("Aligned histology")
412
413    axis_counter = 0
414    n_mods = 2
415    for mod in ["expression", "histology"]:
416        curr_view_idx = model.view_idx[mod]
417        for vv in range(model.n_views):
418
419            # import ipdb; ipdb.set_trace()
420            curr_coords = data_dict[mod]["spatial_coords"]
421
422            if mod == "histology" and rgb:
423                curr_outputs = data_dict[mod]["outputs"][curr_view_idx[vv], :]
424            else:
425                curr_outputs = data_dict[mod]["outputs"][curr_view_idx[vv], 0]
426            axes[axis_counter].scatter(
427                curr_coords[curr_view_idx[vv], 0],
428                curr_coords[curr_view_idx[vv], 1],
429                c=curr_outputs,
430                label="View {}".format(vv + 1),
431                marker=markers[vv],
432                s=scatterpoint_size,
433            )
434            axes[axis_counter + 1].scatter(
435                X_aligned[mod].detach().numpy()[curr_view_idx[vv], 0],
436                X_aligned[mod].detach().numpy()[curr_view_idx[vv], 1],
437                c=curr_outputs,
438                label="View {}".format(vv + 1),
439                marker=markers[vv],
440                s=scatterpoint_size,
441            )
442        axis_counter += n_mods
def callback_oned( model, X, Y, X_aligned, data_expression_ax, latent_expression_ax, prediction_ax=None, X_test=None, Y_pred=None, Y_test_true=None, X_test_aligned=None, F_samples=None):
 17def callback_oned(
 18    model,
 19    X,
 20    Y,
 21    X_aligned,
 22    data_expression_ax,
 23    latent_expression_ax,
 24    prediction_ax=None,
 25    X_test=None,
 26    Y_pred=None,
 27    Y_test_true=None,
 28    X_test_aligned=None,
 29    F_samples=None,
 30):
 31    model.eval()
 32    markers = list(Line2D.markers.keys())
 33    colors = ["blue", "orange"]
 34
 35    if model.fixed_view_idx is not None:
 36        curr_idx = model.view_idx["expression"][model.fixed_view_idx]
 37        X_aligned["expression"][curr_idx] = torch.tensor(X[curr_idx].astype(np.float32))
 38
 39    data_expression_ax.cla()
 40    latent_expression_ax.cla()
 41
 42    data_expression_ax.set_title("Observed data")
 43    latent_expression_ax.set_title("Aligned data")
 44
 45    data_expression_ax.set_xlabel("Spatial coordinate")
 46    latent_expression_ax.set_xlabel("Spatial coordinate")
 47
 48    data_expression_ax.set_ylabel("Outcome")
 49    latent_expression_ax.set_ylabel("Outcome")
 50
 51    data_expression_ax.set_xlim([X.min(), X.max()])
 52    latent_expression_ax.set_xlim([X.min(), X.max()])
 53
 54    for vv in range(model.n_views):
 55
 56        view_idx = model.view_idx["expression"]
 57
 58        data_expression_ax.scatter(
 59            X[view_idx[vv], 0],
 60            Y[view_idx[vv], 0],
 61            label="View {}".format(vv + 1),
 62            marker=markers[vv],
 63            s=SCATTER_POINT_SIZE,
 64            c="blue",
 65        )
 66        if Y.shape[1] > 1:
 67            data_expression_ax.scatter(
 68                X[view_idx[vv], 0],
 69                Y[view_idx[vv], 1],
 70                label="View {}".format(vv + 1),
 71                marker=markers[vv],
 72                s=SCATTER_POINT_SIZE,
 73                c="orange",
 74            )
 75        latent_expression_ax.scatter(
 76            # model.G_means["expression"].detach().numpy()[view_idx[vv], 0],
 77            X_aligned["expression"].detach().numpy()[view_idx[vv], 0],
 78            Y[view_idx[vv], 0],
 79            c="blue",
 80            label="View {}".format(vv + 1),
 81            marker=markers[vv],
 82            s=SCATTER_POINT_SIZE,
 83        )
 84        if Y.shape[1] > 1:
 85            latent_expression_ax.scatter(
 86                # model.G_means["expression"].detach().numpy()[view_idx[vv], 0],
 87                X_aligned["expression"].detach().numpy()[view_idx[vv], 0],
 88                Y[view_idx[vv], 1],
 89                c="orange",
 90                label="View {}".format(vv + 1),
 91                marker=markers[vv],
 92                s=SCATTER_POINT_SIZE,
 93            )
 94        # latent_expression_ax.scatter(
 95        # 	model.Xtilde.detach().numpy()[vv, :, 0],
 96        # 	model.delta_list.detach().numpy()[vv][:, 0],
 97        # 	c="red",
 98        # 	label="View {}".format(vv + 1),
 99        # 	marker="^",
100        # 	s=100,
101        # )
102
103        if F_samples is not None:
104            latent_expression_ax.scatter(
105                X_aligned["expression"].detach().numpy()[view_idx[vv], 0],
106                F_samples.detach().numpy()[view_idx[vv], 0],
107                c="red",
108                marker=markers[vv],
109                s=SCATTER_POINT_SIZE,
110            )
111            if Y.shape[1] > 1:
112                latent_expression_ax.scatter(
113                    X_aligned["expression"].detach().numpy()[view_idx[vv], 0],
114                    F_samples.detach().numpy()[view_idx[vv], 1],
115                    c="green",
116                    marker=markers[vv],
117                    s=SCATTER_POINT_SIZE,
118                )
119
120    if prediction_ax is not None:
121
122        prediction_ax.cla()
123        prediction_ax.set_title("Predictions")
124        prediction_ax.set_xlabel("True outcome")
125        prediction_ax.set_ylabel("Predicted outcome")
126
127        ### Plots the warping function
128        # prediction_ax.scatter(
129        # 	X[view_idx[vv], 0],
130        # 	X_aligned["expression"].detach().numpy()[view_idx[vv], 0],
131        # 	label="View {}".format(vv + 1),
132        # 	marker=markers[vv],
133        # 	s=100,
134        # 	c="blue",
135        # )
136        # prediction_ax.scatter(
137        # 	model.Xtilde.detach().numpy()[vv, :, 0],
138        # 	model.delta_list.detach().numpy()[vv][:, 0],
139        # 	c="red",
140        # 	label="View {}".format(vv + 1),
141        # 	marker="^",
142        # 	s=100,
143        # )
144        latent_expression_ax.scatter(
145            X_test_aligned["expression"].detach().numpy()[:, 0],
146            Y_pred.detach().numpy()[:, 0],
147            c="blue",
148            label="Prediction",
149            marker="^",
150            s=SCATTER_POINT_SIZE,
151        )
152        latent_expression_ax.scatter(
153            X_test_aligned["expression"].detach().numpy()[:, 0],
154            Y_pred.detach().numpy()[:, 1],
155            c="orange",
156            label="Prediction",
157            marker="^",
158            s=SCATTER_POINT_SIZE,
159        )
160        prediction_ax.scatter(
161            Y_test_true[:, 0],
162            Y_pred.detach().numpy()[:, 0],
163            c="black",
164            s=SCATTER_POINT_SIZE,
165        )
166        prediction_ax.scatter(
167            Y_test_true[:, 1],
168            Y_pred.detach().numpy()[:, 1],
169            c="black",
170            s=SCATTER_POINT_SIZE,
171            marker="^",
172        )
173
174    data_expression_ax.legend()
175    plt.draw()
176    plt.pause(1 / 60.0)
def callback_twod( model, X, Y, X_aligned, data_expression_ax, latent_expression_ax, is_mle=False, gene_idx=0, s=200, include_legend=False):
179def callback_twod(
180    model,
181    X,
182    Y,
183    X_aligned,
184    data_expression_ax,
185    latent_expression_ax,
186    is_mle=False,
187    gene_idx=0,
188    s=200,
189    include_legend=False,
190):
191
192    if model.fixed_view_idx is not None:
193        if is_mle:
194            pass
195        else:
196            curr_idx = model.view_idx["expression"][model.fixed_view_idx]
197            X_aligned["expression"][curr_idx] = torch.tensor(
198                X[curr_idx].astype(np.float32)
199            )
200
201    model.eval()
202    markers = [".", "+", "^"]
203    colors = ["blue", "orange"]
204
205    data_expression_ax.cla()
206    latent_expression_ax.cla()
207    data_expression_ax.set_title("Observed data")
208    latent_expression_ax.set_title("Aligned data")
209
210    curr_view_idx = model.view_idx["expression"]
211
212    latent_Xs = []
213    Xs = []
214    Ys = []
215    markers_list = []
216    viewname_list = []
217
218    for vv in range(model.n_views):
219
220        ## Data
221        Xs.append(X[curr_view_idx[vv]])
222
223        ## Latents
224        curr_latent_Xs = X_aligned["expression"].detach().numpy()[curr_view_idx[vv]]
225        latent_Xs.append(curr_latent_Xs)
226        Ys.append(Y[curr_view_idx[vv], gene_idx])
227        markers_list.append([markers[vv]] * curr_latent_Xs.shape[0])
228        viewname_list.append(
229            ["Observation {}".format(vv + 1)] * curr_latent_Xs.shape[0]
230        )
231
232    Xs = np.concatenate(Xs, axis=0)
233    latent_Xs = np.concatenate(latent_Xs, axis=0)
234    Ys = np.concatenate(Ys)
235    markers_list = np.concatenate(markers_list)
236    viewname_list = np.concatenate(viewname_list)
237
238    data_df = pd.DataFrame(
239        {
240            "X1": Xs[:, 0],
241            "X2": Xs[:, 1],
242            "Y": Ys,
243            "marker": markers_list,
244            "view": viewname_list,
245        }
246    )
247
248    latent_df = pd.DataFrame(
249        {
250            "X1": latent_Xs[:, 0],
251            "X2": latent_Xs[:, 1],
252            "Y": Ys,
253            "marker": markers_list,
254            "view": viewname_list,
255        }
256    )
257
258    plt.sca(data_expression_ax)
259    g = sns.scatterplot(
260        data=data_df,
261        x="X1",
262        y="X2",
263        hue="Y",
264        style="view",
265        ax=data_expression_ax,
266        s=s,
267        linewidth=1.8,
268        edgecolor="black",
269        palette="viridis",
270    )
271    if not include_legend:
272        g.legend_.remove()
273    # plt.colorbar()
274    # plt.axis("off")
275    # plt.scatter(model.Xtilde.detach().numpy()[0, :, 0], model.Xtilde.detach().numpy()[0, :, 1], color="red")
276    # plt.scatter(model.Xtilde.detach().numpy()[1, :, 0], model.Xtilde.detach().numpy()[1, :, 1], color="red")
277    # plt.scatter(model.Gtilde.detach().numpy()[:, 0], model.Gtilde.detach().numpy()[:, 1], color="red")
278    # plt.axis("off")
279
280    plt.sca(latent_expression_ax)
281    g = sns.scatterplot(
282        data=latent_df,
283        x="X1",
284        y="X2",
285        hue="Y",
286        style="view",
287        ax=latent_expression_ax,
288        s=s,
289        linewidth=1.8,
290        edgecolor="black",
291        palette="viridis",
292    )
293    if not include_legend:
294        g.legend_.remove()
def callback_twod_aligned_only( model, X, Y, X_aligned, latent_expression_ax1, latent_expression_ax2, is_mle=False, gene_idx=0):
321def callback_twod_aligned_only(
322    model,
323    X,
324    Y,
325    X_aligned,
326    latent_expression_ax1,
327    latent_expression_ax2,
328    is_mle=False,
329    gene_idx=0,
330):
331
332    if model.fixed_view_idx is not None:
333        if is_mle:
334            pass
335        else:
336            curr_idx = model.view_idx["expression"][model.fixed_view_idx]
337            X_aligned["expression"][curr_idx] = torch.tensor(
338                X[curr_idx].astype(np.float32)
339            )
340
341    model.eval()
342    markers = [".", "+", "^"]
343    colors = ["blue", "orange"]
344
345    latent_expression_ax1.cla()
346    latent_expression_ax2.cla()
347    latent_expression_ax1.set_title("Observed data")
348    latent_expression_ax2.set_title("Aligned data")
349
350    curr_view_idx = model.view_idx["expression"]
351
352    latent_Xs = []
353    Xs = []
354    Ys = []
355    markers_list = []
356    viewname_list = []
357
358    aligned_coords = X_aligned["expression"].detach().numpy()
359
360    for vv in range(model.n_views):
361
362        ## Data
363        Xs.append(X[curr_view_idx[vv]])
364
365        ## Latents
366        curr_latent_Xs = aligned_coords[curr_view_idx[vv]]
367        latent_Xs.append(curr_latent_Xs)
368        Ys.append(Y[curr_view_idx[vv], gene_idx])
369        markers_list.append([markers[vv]] * curr_latent_Xs.shape[0])
370        viewname_list.append(["View {}".format(vv + 1)] * curr_latent_Xs.shape[0])
371
372    latent_expression_ax1.scatter(
373        aligned_coords[curr_view_idx[0]][:, 0],
374        aligned_coords[curr_view_idx[0]][:, 1],
375        c=Y[curr_view_idx[0]][:, gene_idx].squeeze(),
376        s=24,
377        marker="h",
378    )
379    latent_expression_ax2.scatter(
380        aligned_coords[curr_view_idx[1]][:, 0],
381        aligned_coords[curr_view_idx[1]][:, 1],
382        c=Y[curr_view_idx[1]][:, gene_idx].squeeze(),
383        s=24,
384        marker="h",
385    )
386    # latent_expression_ax1.scatter(model.Xtilde.detach().numpy()[0, :, 0], model.Xtilde.detach().numpy()[0, :, 1], color="red")
387    # latent_expression_ax2.scatter(model.Xtilde.detach().numpy()[1, :, 0], model.Xtilde.detach().numpy()[1, :, 1], color="red")
388
389    plt.axis("off")
def callback_twod_multimodal(model, data_dict, X_aligned, axes, rgb=False, scatterpoint_size=100):
392def callback_twod_multimodal(
393    model, data_dict, X_aligned, axes, rgb=False, scatterpoint_size=100
394):
395
396    # if model.fixed_view_idx is not None:
397    #     if is_mle:
398    #         pass
399    #     else:
400    #         curr_idx = model.view_idx["expression"][model.fixed_view_idx]
401    #         X_aligned["expression"][curr_idx] = torch.tensor(X[curr_idx].astype(np.float32))
402
403    model.eval()
404    markers = [".", "+", "^"]
405    colors = ["blue", "orange"]
406
407    [ax.cla() for ax in axes]
408
409    axes[0].set_title("Observed expression")
410    axes[1].set_title("Aligned expression")
411    axes[2].set_title("Observed histology")
412    axes[3].set_title("Aligned histology")
413
414    axis_counter = 0
415    n_mods = 2
416    for mod in ["expression", "histology"]:
417        curr_view_idx = model.view_idx[mod]
418        for vv in range(model.n_views):
419
420            # import ipdb; ipdb.set_trace()
421            curr_coords = data_dict[mod]["spatial_coords"]
422
423            if mod == "histology" and rgb:
424                curr_outputs = data_dict[mod]["outputs"][curr_view_idx[vv], :]
425            else:
426                curr_outputs = data_dict[mod]["outputs"][curr_view_idx[vv], 0]
427            axes[axis_counter].scatter(
428                curr_coords[curr_view_idx[vv], 0],
429                curr_coords[curr_view_idx[vv], 1],
430                c=curr_outputs,
431                label="View {}".format(vv + 1),
432                marker=markers[vv],
433                s=scatterpoint_size,
434            )
435            axes[axis_counter + 1].scatter(
436                X_aligned[mod].detach().numpy()[curr_view_idx[vv], 0],
437                X_aligned[mod].detach().numpy()[curr_view_idx[vv], 1],
438                c=curr_outputs,
439                label="View {}".format(vv + 1),
440                marker=markers[vv],
441                s=scatterpoint_size,
442            )
443        axis_counter += n_mods