%load_ext autoreload

Swiss Roll Experiment

import os
import glob
import pickle
import imageio
import pandas as pd
import pyvista as pv

from itertools import product
from matplotlib import pyplot as plt
store_results = True
plot_order = ['dataset', 'ppca_2_components', 'lddmm', 'umap', 'ebm', 'ae', 'vae', 'naf', 'snf', 'fvae', 'nmf']
def read_plot_kwargs(file_name, file_extension='obj'):
    if file_extension == 'obj': f_read = lambda f: pickle.load(open(f,'rb'))
    if file_extension == 'png': f_read = lambda f: imageio.imread(f)
    if file_extension == 'txt': f_read = lambda f: [open(f, "r").read()]
                                                             
    objs = {os.path.normpath(f).split('/')[-2]:f_read(f)
            for f in glob.glob('../results/swiss_roll/*/{}.{}'.format(
                file_name, file_extension), recursive=True)}
    objs.pop('ppca_3_components')
    
    return objs
pv.set_plot_theme("document")
pv.set_jupyter_backend('None')
plot_shape = (6, 2)
def plotter_for_each_panel(axis, f_plot, plot_kwargs, plot_shape):
    keys = list(plot_kwargs.keys())
    keys.sort(key=lambda e: plot_order.index(e))
    n_keys = len(keys)
    for i in range(plot_shape[0]):
        for j in range(plot_shape[1]):
            index = i * plot_shape[1] + (j) 
            if index == n_keys:
                return
            key = keys[index]

            plotter.subplot(i, j)
            plotter.add_text(key, font_size=24)
            f_plot(**plot_kwargs[key], axis=axis)

def image_plotter_for_each_panel(imgs, plot_shape):
    keys = list(imgs.keys())
    keys.sort(key=lambda e: plot_order.index(e))
    max_index = min(plot_shape[0]*plot_shape[1], len(keys))
    for i in range(0, max_index):
        key = keys[i] 
        g = fig.add_subplot(*plot_shape, i+1)
        g.title.set_text(key)
        plt.axis('off')
        plt.imshow(imgs[key])

Input Representation

imgs = read_plot_kwargs('base_representation', 'png')
fig = plt.figure(figsize=(plot_shape[1]*10,plot_shape[0]*10))
image_plotter_for_each_panel(imgs, plot_shape)
plt.show()
../../_images/1_swiss_roll_experiment_11_0.png

Input Reconstruction

mse_errors = read_plot_kwargs('reconstruction', 'txt')
df_reconstruction = pd.DataFrame.from_dict(mse_errors, orient='index', columns=['MSE'])
df_reconstruction.reindex(plot_order)
MSE
dataset NaN
ppca_2_components 10.422249043973492
lddmm 5.1534423828125
umap 1.7059484720230103
ebm 135.48646545410156
ae 0.19316305220127106
vae 0.8526504635810852
naf 8.366670272152987e-07
snf 2.8050467903995013e-07
fvae 0.9597674012184143
nmf 0.0002189374208683148
from pdmtut.vis import plot_reconstruction

plot_kwargs = read_plot_kwargs('reconstruction')
plotter = pv.Plotter(shape=plot_shape, window_size=(500*plot_shape[1], 500*plot_shape[0]))

plotter_for_each_panel(plotter, plot_reconstruction, plot_kwargs, plot_shape)

plotter.link_views()
plotter.camera_position = [(-65, 0, 65), (0, 0, 0), (0, 1, 0)]
_ = plotter.show()
../../_images/1_swiss_roll_experiment_15_0.png

Density Estimation

data_log_likelihood = read_plot_kwargs('density', 'txt')
df_density = pd.DataFrame.from_dict(data_log_likelihood, orient='index', columns=['Log Likelihood'])
df_density.reindex(plot_order)
Log Likelihood
dataset NaN
ppca_2_components -8.538554969111125
lddmm 1.0
umap 1.0
ebm -7.6159820556640625
ae -7.60769510269165
vae -7.7125563621521
naf 0.27156203985214233
snf -4.133373737335205
fvae -8.800947189331055
nmf -8.880420684814453
from pdmtut.vis import plot_density

plot_kwargs = read_plot_kwargs('discrete_density')
plotter = pv.Plotter(shape=plot_shape, window_size=(500*plot_shape[1], 500*plot_shape[0]))

plotter_for_each_panel(plotter, plot_density, plot_kwargs, plot_shape)

plotter.link_views()
plotter.camera_position = [(-65, 0, 65), (0, 0, 0), (0, 1, 0)]
_ = plotter.show()
../../_images/1_swiss_roll_experiment_19_0.png

Generate Samples

from pdmtut.vis import plot_generated_samples

plot_kwargs = read_plot_kwargs('generated_samples')
plotter = pv.Plotter(shape=(plot_shape[0]-1, plot_shape[1]), window_size=(500*plot_shape[1], 500*plot_shape[0]))

plotter_for_each_panel(plotter, plot_generated_samples, plot_kwargs, (plot_shape[0], plot_shape[1]))

plotter.link_views()
plotter.camera_position = [(-65, 0, 65), (0, 0, 0), (0, 1, 0)]
_ = plotter.show()
../../_images/1_swiss_roll_experiment_22_0.png

Interpolation

from pdmtut.vis import plot_interpolation

plot_kwargs = read_plot_kwargs('interpolation')
plotter = pv.Plotter(shape=plot_shape, window_size=(500*plot_shape[1], 500*plot_shape[0]))

plotter_for_each_panel(plotter, plot_interpolation, plot_kwargs, plot_shape)

plotter.link_views()
plotter.camera_position = [(-65, 0, 65), (0, 0, 0), (0, 1, 0)]
_ = plotter.show()
../../_images/1_swiss_roll_experiment_25_0.png

Extra

table = pd.concat([df_reconstruction, df_density], axis=1)
print(table.reindex(plot_order).astype('float64').round(2).to_latex()) 
\begin{tabular}{lrr}
\toprule
{} &     MSE &  Log Likelihood \\
\midrule
dataset           &     NaN &             NaN \\
ppca\_2\_components &   10.42 &           -8.54 \\
lddmm             &    5.15 &            1.00 \\
umap              &    1.71 &            1.00 \\
ebm               &  135.49 &           -7.62 \\
ae                &    0.19 &           -7.61 \\
vae               &    0.85 &           -7.71 \\
naf               &    0.00 &            0.27 \\
snf               &    0.00 &           -4.13 \\
fvae              &    0.96 &           -8.80 \\
nmf               &    0.00 &           -8.88 \\
\bottomrule
\end{tabular}