%reload_ext tensorboard
%load_ext autoreload
Large Deformation Diffeomorphic Metric Mapping and US¶
import os
import time
import torch
import numpy as np
import pyvista as pv
from tqdm import tqdm
from geomloss import SamplesLoss
from joblib import dump, load
import pytorch_lightning as pl
from sklearn.decomposition import PCA
from pdmtut.datasets import SwissRoll
from pdmtut.core import GenerativeModel
from pytorch_lightning import loggers as pl_loggers
from regilib.core.invertible_modules.charts import PadProj
from regilib.core.dynamics.dynamical_state import DynamicalState
from torchcubicspline import natural_cubic_spline_coeffs, NaturalCubicSpline
from regilib.core.invertible_modules.bijective import ShootingLayer, PointCloudDeformationLayer
store_results = True
load_models = True
Introduction¶
Implementation¶
class LDDMMPCA(pl.LightningModule, GenerativeModel):
def __init__(self, n_components):
super().__init__()
self.n_components = n_components
self.base_distribution = torch.nn.parameter.Parameter(
torch.stack(torch.meshgrid(
torch.linspace(0, 1, steps=100),
torch.zeros(1),
torch.linspace(-1, +1, steps=100),
indexing='xy'), -1).view(100*100, 3),
requires_grad=False)
n_cps_x, n_cps_z = 20, 20
self.shooting_layer = ShootingLayer(
control_points = torch.stack(
torch.meshgrid(
torch.linspace(0, 1, steps=n_cps_x),
torch.zeros(1),
torch.linspace(-1, +1, steps=n_cps_z),
indexing='xy'),
-1).view(n_cps_x*n_cps_z, 3), sigma=4/20, default_n_steps=20, solver='rk4')
self.warp = PointCloudDeformationLayer(default_n_steps=20, solver='rk4')
self.loss = SamplesLoss(loss="sinkhorn", p=2, blur=1e-12, reach=0.1)
self.momenta = torch.nn.parameter.Parameter(
torch.empty(self.shooting_layer.control_points.shape)
)
#torch.nn.init.normal_(self.momenta)
torch.nn.init.zeros_(self.momenta)
self.chart = PadProj()
def fit_model(self, X, X_val=None, path=None):
start_time = time.time()
if path is None:
tb_logger = False
checkpoint_callback=False
else:
tb_logger = pl_loggers.TensorBoardLogger(path, version=0)
checkpoint_callback=True
trainer = pl.Trainer(
max_epochs=1000, gpus=1, logger=tb_logger,
log_every_n_steps=1, checkpoint_callback=checkpoint_callback
)
trainer.fit(
self, train_dataloaders=X, val_dataloaders=X_val)
elapsed_time = time.time() - start_time
if path is not None:
with open(os.path.join(path, 'training_time.txt'), 'w') as f:
f.write(str(elapsed_time))
def v_t(self, t, x, interpolatable_state):
# interpolate state of control_points at time t
cs, ms = interpolatable_state(t).view(2, -1, 3)
# compute change in velocity for input x
return self.shooting_layer.velocity(x, cs, ms)[0]
def forward(self, ds_z, **kwargs):
s_time_steps, st = self.shooting_layer.forward(self.momenta)
interp = NaturalCubicSpline(natural_cubic_spline_coeffs(
s_time_steps, st.view(s_time_steps.shape[0], -1)))
ds_u = self.chart.forward(ds_z)
ds_x = self.warp.forward(
ds_u, lambda t, x: self.v_t(t, x, interp.evaluate), **kwargs)
return ds_x
def inverse(self, ds_x, **kwargs):
s_time_steps, st = self.shooting_layer.forward(self.momenta)
interp = NaturalCubicSpline(natural_cubic_spline_coeffs(
s_time_steps, st.view(s_time_steps.shape[0], -1)))
ds_u = self.warp.inverse(
ds_x, lambda t, x: self.v_t(t, x, interp.evaluate), **kwargs)
ds_z = self.chart.inverse(ds_u)
return ds_z
def training_step(self, batch, batch_idx):
x = batch[0]
s_time_steps, st = self.shooting_layer.forward(self.momenta)
interp = NaturalCubicSpline(natural_cubic_spline_coeffs(s_time_steps, st.view(s_time_steps.shape[0], -1)))
ds_x = self.warp.forward(self.base_distribution, lambda t, x: self.v_t(t, x, interp.evaluate))
dist = self.loss(x, ds_x.state)
reg = self.shooting_layer.inner_product(self.momenta, self.momenta)
loss = dist + 0.0001*reg
self.log('train_loss', loss)
return {'loss': loss}
def encode(self, X, **kwargs):
ds_x = X.clone() if isinstance(X, DynamicalState) else DynamicalState(state=X)
ds_z = self.inverse(ds_x, **kwargs)
return ds_z.state.detach()
def decode(self, z, **kwargs):
ds_z = z.clone() if isinstance(z, DynamicalState) else DynamicalState(state=z)
ds_x = model.forward(ds_z, **kwargs)
return ds_x.state.detach()
def save(self, path):
torch.save(self, os.path.join(path, 'model.pt'))
def load(path):
return torch.load(os.path.join(path, 'model.pt'))
def save_exists(path):
return (
os.path.isfile(os.path.join(path, 'model.pt')))
def log_likelihood(self, X):
return torch.ones(X.shape[0])
def sample_posterior(self, n_samples):
z_encodings = torch.stack([
torch.FloatTensor(n_samples).uniform_(0, 1),
torch.FloatTensor(n_samples).uniform_(-1, 1)
], -1)
return self.decode(z_encodings)
def configure_optimizers(self):
#optimizer = torch.optim.LBFGS([self.momenta, model.shooting_layer.control_points], lr=0.01)
optimizer = torch.optim.Adam([self.momenta, model.shooting_layer.control_points], lr=0.01)
return {
'optimizer': optimizer,
'lr_scheduler':
torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, min_lr=1e-8, factor=0.5, verbose=True,
patience=100
),
'monitor': 'train_loss'
}
def __str__(self):
return 'lddmm'
Experiment 1: swiss roll¶
import pyvista as pv
from pdmtut.datasets import SwissRoll
pv.set_plot_theme("document")
model_save_path = '../results/swiss_roll/lddmm'
if store_results:
result_save_path = '../results/swiss_roll/lddmm'
pv.set_jupyter_backend('None')
else:
pv.set_jupyter_backend('ipygany')
result_save_path = None
dataset = SwissRoll(n_samples=100**2, seed=11)
if load_models and LDDMMPCA.save_exists(model_save_path):
model = LDDMMPCA.load(model_save_path)
else:
model = LDDMMPCA(n_components=2)
model.fit_model(
X=dataset.X_loader(batch_size=100**2, shuffle=False),
path=result_save_path)
if store_results:
model.save(model_save_path)
model = model.eval()
%tensorboard --logdir ../results/swiss_roll/lddmm
Input Representation¶
from pdmtut.vis import plot_representation
z = model.encode(dataset.X)
z_extremes = model.encode(dataset.y_extremes)
z_extremes = torch.cat([z_extremes, z_extremes[[1,2]]])
plot_representation(z.numpy(), index_colors=dataset.index_colors, z_extremes=z_extremes, interpolate_background=True, root=result_save_path)
Input Reconstruction¶
from pdmtut.vis import plot_reconstruction
z = model.encode(dataset.X)
x = model.decode(z)
mse = (dataset.unnormalise_scale(dataset.X) - dataset.unnormalise_scale(x)).pow(2).sum(-1).mean()
if result_save_path is not None:
with open(os.path.join(result_save_path, 'reconstruction.txt'), 'w') as f:
f.write(str(mse.item()))
mse
tensor(5.1534)
plot_reconstruction(dataset.unnormalise_scale(x).numpy(), dataset.index_colors, root=result_save_path)
Density Estimation¶
from pdmtut.vis import plot_density
log_likelihood = model.log_likelihood(dataset.X)
data_log_likelihood = log_likelihood.mean()
if result_save_path is not None:
with open(os.path.join(result_save_path, 'density.txt'), 'w') as f:
f.write(str(data_log_likelihood.item()))
data_log_likelihood
tensor(1.)
plot_density(dataset.unnormalise_scale(dataset.X).numpy(), log_likelihood.numpy(), root=result_save_path)
Generate Samples¶
from pdmtut.vis import plot_generated_samples
generated_samples = model.sample_posterior(100**2)
generated_samples_log_likelihood = model.log_likelihood(generated_samples)
plot_generated_samples(dataset.unnormalise_scale(generated_samples).numpy(), generated_samples_log_likelihood.numpy(), root=result_save_path)
Interpolation¶
from pdmtut.vis import plot_interpolation
from scipy.interpolate import interp1d
z_extremes = model.encode(dataset.y_extremes)
uniform_state, uniform_log_prob, _ = dataset.sample_points_uniformly(n_samples=100**2, seed=11)
linfit1 = interp1d([1,20], z_extremes[:2].numpy(), axis=0)
linfit2 = interp1d([1,20], z_extremes[2:].numpy(), axis=0)
linfit3 = interp1d([1,20], z_extremes[[1,2]].numpy(), axis=0)
interpolated_points_1 = model.decode(torch.Tensor(linfit1(np.arange(1,21))))
interpolated_points_2 = model.decode(torch.Tensor(linfit2(np.arange(1,21))))
interpolated_points_3 = model.decode(torch.Tensor(linfit3(np.arange(1,21))))
/home/bawaw/.conda/envs/pdm_tutorial/lib/python3.8/site-packages/torch/functional.py:445: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /opt/conda/conda-bld/pytorch_1639180588308/work/aten/src/ATen/native/TensorShape.cpp:2157.)
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
plot_interpolation(
dataset.unnormalise_scale(interpolated_points_1).numpy(),
dataset.unnormalise_scale(interpolated_points_2).numpy(),
dataset.unnormalise_scale(interpolated_points_3).numpy(),
uniform_state.detach().view(100, 100, 3).permute(2, 0, 1).numpy(),
uniform_log_prob.numpy(), root=result_save_path
)
Extra¶
Reconstruction trajectory¶
n_steps = 100
z_traj = model.inverse(DynamicalState(state=dataset.X), steps=n_steps, include_trajectory=True)
x_traj = model.forward(DynamicalState(state=z_traj.state), steps=n_steps, include_trajectory=True)
def plot_traj_state(traj, t, file_name):
traj_state = traj[t]
pv.set_plot_theme("document")
pv.set_jupyter_backend('None')
plotter = pv.Plotter()
plotter.add_mesh(
pv.PolyData(dataset.unnormalise_scale(traj_state.detach()).numpy()), render_points_as_spheres=True,
scalars=dataset.index_colors, rgb=True, point_size=5)
plotter.camera_position = [(-65, 0, 65), (0, 0, 0), (0, 1, 0)]
_ = plotter.show(window_size=[800, 800])
if result_save_path is not None:
plotter.screenshot(os.path.join(
result_save_path, file_name + '{}.png'.format(str(t).replace('.', '_'))))
for t in torch.linspace(0, n_steps-1, 5):
plot_traj_state(z_traj.trajectory, int(t), file_name='base_representation_t_')
for t in torch.linspace(0, n_steps-1, 5):
plot_traj_state(x_traj.trajectory, int(t), file_name='reconstruction_t_')
pv.set_jupyter_backend('None')
from regilib.vis.video_plotter import plot_video
if result_save_path is not None:
_ = plot_video(
z_traj.trajectory.detach(), os.path.join(result_save_path, 'lddmm_inverse.gif'), render_points_as_spheres=True, reverse=False,
camera_pos = [(0, 0, 5), (0, 0, 0), (0, 1, 0)]
)
_ = plot_video(
x_traj.trajectory.detach(), os.path.join(result_save_path, 'lddmm_forward.gif'), render_points_as_spheres=True, reverse=False,
camera_pos = [(0, 0, 5), (0, 0, 0), (0, 1, 0)]
)

