%reload_ext tensorboard
%reload_ext autoreload

Normalising Soft Ambient Flows

import os
import time
import math
import torch
import numpy as np
import torch.nn as nn
import pytorch_lightning as pl
import torch.nn.functional as F
import torch.distributions as tdist
import torchdyn.nn.node_layers as tdnl

from joblib import dump, load
from sklearn.decomposition import PCA
from pdmtut.core import GenerativeModel
from pytorch_lightning import loggers as pl_loggers
from regilib.core.distributions import MultivariateNormal
from regilib.core.dynamics.dynamics import RegularisedDynamics
from regilib.core.dynamics.dynamical_state import DynamicalState
from regilib.core.invertible_modules import NormalisingFlow
from regilib.core.invertible_modules.bijective import ContinuousAmbientFlow
store_results = True
load_models = True

Introduction

Implementation

class NormalisingSoftAmbientFlows(NormalisingFlow, pl.LightningModule, GenerativeModel):
    class FunctionDynamics(nn.Module):
        def __init__(self):
            super().__init__()
            
            self._in_channels = 3
            self._out_channels = 3
            
            # expected format: N x (C * L)
             # +1 for time, + 1 for noise condition
            self.fc1 = nn.Linear(self._in_channels + 2, 64)
            self.fc2 = nn.Linear(64, 128)
            self.fc3 = nn.Linear(128, 128)
            self.fc4 = nn.Linear(128, 64)
            self.fc5 = nn.Linear(64, self._out_channels)  
            
        @property
        def in_channels(self):
            return self._in_channels

        @property
        def out_channels(self):
            return self._out_channels

        def forward(self, ds):
            x = torch.cat([ds.state, ds.condition, ds.t], -1)
            x = F.tanh(self.fc1(x))
            x = F.tanh(self.fc2(x))
            x = F.tanh(self.fc3(x))
            x = F.tanh(self.fc4(x))
            x = self.fc5(x)
            return x
    
    def __init__(self, input_dimensions = 3):
        super().__init__(
            base_distribution=MultivariateNormal(torch.zeros(3), torch.eye(3)))
        
        # [a,b] interval of standard deviation of noise distribution
        self.a, self.b = 0, 0.1
        
        self.input_dimensions = input_dimensions
        
        # state=[l, e, n | state]
        self.aug1 = tdnl.Augmenter(augment_dims=3)
        self.af1 = ContinuousAmbientFlow(
            dynamics=RegularisedDynamics(fdyn=NormalisingSoftAmbientFlows.FunctionDynamics()),
            sensitivity='autograd', default_n_steps=5
        )
        
    def noise_enhance_data(self, x, c=None, seed=None, perturb_state=True):
        ds = x.clone() if isinstance(x, DynamicalState) else DynamicalState(state=x)

        n_samples, device = ds.state.shape[0], ds.state.device
        
        # sample uniform distribution c ∈ [a, b]
        if c is None:
            if seed is not None: torch.manual_seed(seed)
            c = torch.FloatTensor(n_samples).uniform_(self.a, self.b)[:,None].to(device)
        else:
            c = c*torch.ones(n_samples, 1)

        if perturb_state: 
            # sample gaussian noise ν ∈ 𝓝(0, I*c)
            if seed is not None: torch.manual_seed(seed)
            nu = torch.randn(n_samples, 3, device=device) * c
            
            ds.state = ds.state + nu # perturb datapoint

        # scale c so that it matches [-1,+1] interval of points
        scale_c = 2*((c - self.a) / (1.e-10 + (self.b - self.a))) - 1
        ds.condition = scale_c # store std as condition
        
        return ds
 
    # Region NormalisingFlow
    def forward(self, ds, af_estimate=True):
        assert hasattr(ds, 'condition')
        
        ds = super().forward(ds)
        ds = self.af1.dynamics.update_ds(ds, self.aug1(ds['state']))
        ds = self.af1.forward(ds, estimate_trace=af_estimate)
        return ds

    def inverse(self, ds, af_estimate=True):
        assert hasattr(ds, 'condition')
        
        ds = self.af1.dynamics.update_ds(ds, self.aug1(ds['state']))
        ds = self.af1.inverse(ds, estimate_trace=af_estimate)
        ds = super().inverse(ds)
        return ds
    
    # Region GenerativeModel
    def encode(self, X, c, seed=None):
        if not isinstance(c, torch.Tensor):
            c = torch.tensor([c], dtype=torch.float)
        
        ds = DynamicalState(state=X)
        ds = self.noise_enhance_data(ds, c=c, seed=seed)
        ds = self.inverse(ds)
        return ds['state'].cpu().detach()

    def decode(self, z, c, seed=None):
        if not isinstance(c, torch.Tensor):
            c = torch.tensor([c], dtype=torch.float)
            
        ds = DynamicalState(state=z)
        
        # do not add noise to the latent state, this can lead to problems during reconstruction
        ds = self.noise_enhance_data(ds, c=c, seed=seed, perturb_state=False)
        ds = self.forward(ds)
        return ds['state'].cpu().detach()
    
    def save(self, path):
        torch.save(self, os.path.join(path, 'model.pt'))
    
    def load(path):
        return torch.load(os.path.join(path, 'model.pt'))
            
    def save_exists(path):
        return (
            os.path.isfile(os.path.join(path, 'model.pt')))

    def log_likelihood(self, x, c, seed=None):
        if not isinstance(c, torch.Tensor):
            c = torch.tensor([c], dtype=torch.float)
            
        ds = DynamicalState(state=x)
        ds = self.noise_enhance_data(ds, c=c, seed=seed)
        ds = self.inverse(ds, af_estimate=False)
        return ds.log_prob.cpu().detach()

    def sample_posterior(self, n_samples, c, seed=None):        
        z = self.sample_prior(n_samples)
        return self.decode(z, c, seed)
    
    def fit_model(self, X, X_val=None, path=None):
        start_time = time.time()
        
        if path is None:
            tb_logger = False
            checkpoint_callback=False
        else:
            tb_logger = pl_loggers.TensorBoardLogger(path, version=0)
            checkpoint_callback=True
        
        trainer = pl.Trainer(
            max_epochs=5000, gpus=1, logger=tb_logger,
            checkpoint_callback=checkpoint_callback
        )
        trainer.fit(
            self, train_dataloaders=X, val_dataloaders=X_val)
        elapsed_time = time.time() - start_time
        
        if path is not None: 
            with open(os.path.join(path, 'training_time.txt'), 'w') as f:
                f.write(str(elapsed_time))

    def training_step(self, batch, batch_idx):
        x = batch[0]
        lambda_e, lambda_n = 0.01, 0.01
        
        # state=[x+nu|c]
        ds_x_prime = self.noise_enhance_data(x)
        
        # logp(z_t1) = logp(z_t0) - \int_0^1 - Tr ∂f/∂z(t)
        ds_z = self.inverse(ds_x_prime, af_estimate=True)

        # minimise negative log likelihood and energy
        loss = (-ds_z.log_prob + lambda_e * ds_z.e[:, 0] + lambda_n * ds_z.n[:, 0]
        ).sum() / (x.shape[0]*x.shape[1])

        self.log('train_loss', loss)
        return {'loss': loss}

    def validation_step(self, batch, batch_idx):
        x = batch[0]
        lambda_e, lambda_n = 0.01, 0.01
        
        # state=[x+nu|c]
        ds_x_prime = self.noise_enhance_data(x)
        
        # logp(z_t1) = logp(z_t0) - \int_0^1 - Tr ∂f/∂z(t)
        ds_z = self.inverse(ds_x_prime, af_estimate=True)

        loss = (-ds_z.log_prob + lambda_e * ds_z.e[:, 0] + lambda_n * ds_z.n[:, 0]).sum(
        ) / (x.shape[0]*x.shape[1])

        self.log('validation_loss', loss)
        return {'val_loss': loss}

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return {
            'optimizer': optimizer,
            'lr_scheduler':
            torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, min_lr=1e-8, factor=0.5, verbose=True,
                patience=100
            ), 'monitor': 'train_loss'
        }

    def __str__(self):
        return 'snf'

Experiment 1: swiss roll

import pyvista as pv
from pdmtut.datasets import SwissRoll
pv.set_plot_theme("document")

model_save_path = '../results/swiss_roll/snf'

if store_results:
    result_save_path = '../results/swiss_roll/snf'
    pv.set_jupyter_backend('None')
else:
    pv.set_jupyter_backend('ipygany')
    result_save_path = None
dataset = SwissRoll(n_samples=100**2, seed=11)

if load_models and NormalisingSoftAmbientFlows.save_exists(model_save_path):
    model = NormalisingSoftAmbientFlows.load(model_save_path)
else:
    model = NormalisingSoftAmbientFlows()
    model.fit_model(
        X=dataset.train_loader(batch_size=512),
        X_val=dataset.validation_loader(batch_size=512),
        path=result_save_path)
    
    if store_results:
        model.save(model_save_path)
        
model = model.eval()
%tensorboard --logdir ../results/swiss_roll/snf

Input Representation

from pdmtut.vis import plot_representation
z = model.encode(dataset.X, c=0., seed=3)
z_extremes = model.encode(dataset.y_extremes, c=0., seed=3)
z_extremes = torch.cat([z_extremes, z_extremes[[1,2]]])
/home/bawaw/.conda/envs/pdm_tutorial/lib/python3.8/site-packages/torch/nn/functional.py:1795: UserWarning: nn.functional.tanh is deprecated. Use torch.tanh instead.
  warnings.warn("nn.functional.tanh is deprecated. Use torch.tanh instead.")
plot_representation(z.numpy(), index_colors=dataset.index_colors, z_extremes=z_extremes, interpolate_background=True, root=result_save_path)
../../_images/8_SNF_15_0.png

Input Reconstruction

from pdmtut.vis import plot_reconstruction
z = model.encode(dataset.X, c=0., seed=3)
x = model.decode(z, c=0., seed=3)
/home/bawaw/.conda/envs/pdm_tutorial/lib/python3.8/site-packages/torch/nn/functional.py:1795: UserWarning: nn.functional.tanh is deprecated. Use torch.tanh instead.
  warnings.warn("nn.functional.tanh is deprecated. Use torch.tanh instead.")
mse = (dataset.unnormalise_scale(dataset.X) - dataset.unnormalise_scale(x)).pow(2).sum(-1).mean()

if result_save_path is not None: 
    with open(os.path.join(result_save_path, 'reconstruction.txt'), 'w') as f:
        f.write(str(mse.item()))
        
mse
tensor(2.8050e-07)
plot_reconstruction(dataset.unnormalise_scale(x).numpy(), dataset.index_colors, root=result_save_path)
../../_images/8_SNF_20_0.png

Density Estimation

from pdmtut.vis import plot_density
from regilib.core.invertible_modules.bijective import AffineTransform
log_likelihood = model.log_likelihood(dataset.X, c=0., seed=3)

# unnormalise the data and compute the change in density
un_normalise = AffineTransform(dataset._mean, 1/dataset._std)
data = un_normalise.forward(DynamicalState(state=dataset.X.clone().requires_grad_(True), log_prob=log_likelihood.clone()))
/home/bawaw/.conda/envs/pdm_tutorial/lib/python3.8/site-packages/torch/nn/functional.py:1795: UserWarning: nn.functional.tanh is deprecated. Use torch.tanh instead.
  warnings.warn("nn.functional.tanh is deprecated. Use torch.tanh instead.")
data_log_likelihood = data.log_prob.mean()

if result_save_path is not None: 
    with open(os.path.join(result_save_path, 'density.txt'), 'w') as f:
        f.write(str(data_log_likelihood.item()))
        
data_log_likelihood
tensor(-4.1334)
plot_density(data.state.detach().numpy(), data.log_prob.detach().numpy(), root=result_save_path)
../../_images/8_SNF_25_0.png

Generate Samples

from pdmtut.vis import plot_generated_samples
from regilib.core.invertible_modules.bijective import AffineTransform
generated_samples = model.sample_posterior(100**2, c=0., seed=3)
generated_samples_log_likelihood = model.log_likelihood(generated_samples, c=0., seed=3)

# unnormalise the data and compute the change in density
un_normalise = AffineTransform(dataset._mean, 1/dataset._std)
data = un_normalise.forward(DynamicalState(state=generated_samples.clone().requires_grad_(True), log_prob=generated_samples_log_likelihood.clone()))
plot_generated_samples(data.state.detach().numpy(), data.log_prob.detach().numpy(), root=result_save_path)
../../_images/8_SNF_29_0.png

Interpolation

from pdmtut.vis import plot_interpolation
from scipy.interpolate import interp1d
z_extremes = model.encode(dataset.y_extremes, c=0., seed=3)
uniform_state, uniform_log_prob, _ = dataset.sample_points_uniformly(n_samples=100**2, seed=11)

linfit1 = interp1d([1,20], z_extremes[:2].numpy(), axis=0)
linfit2 = interp1d([1,20], z_extremes[2:].numpy(), axis=0)
linfit3 = interp1d([1,20], z_extremes[[1,2]].numpy(), axis=0)

interpolated_points_1 = model.decode(torch.Tensor(linfit1(np.arange(1,21))), c=0., seed=3)
interpolated_points_2 = model.decode(torch.Tensor(linfit2(np.arange(1,21))), c=0., seed=3)
interpolated_points_3 = model.decode(torch.Tensor(linfit3(np.arange(1,21))), c=0., seed=3)
/home/bawaw/.conda/envs/pdm_tutorial/lib/python3.8/site-packages/torch/functional.py:445: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at  /opt/conda/conda-bld/pytorch_1639180588308/work/aten/src/ATen/native/TensorShape.cpp:2157.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
plot_interpolation(
    dataset.unnormalise_scale(interpolated_points_1).numpy(), 
    dataset.unnormalise_scale(interpolated_points_2).numpy(), 
    dataset.unnormalise_scale(interpolated_points_3).numpy(), 
    uniform_state.detach().view(100, 100, 3).permute(2, 0, 1).numpy(),
    uniform_log_prob.numpy(), root=result_save_path
)
../../_images/8_SNF_33_0.png

Extra

Noisy generative samples

import pickle
import pyvista as pv
def plot_noisy_samples(c):
    generated_samples = dataset.unnormalise_scale(
        model.sample_posterior(100**2, c=c, seed=3)).numpy()
    
    
    plotter = pv.Plotter() 
    plotter.add_mesh(
        pv.PolyData(generated_samples),
        render_points_as_spheres=True, point_size=10,
        diffuse=0.99, specular=0.8, ambient=0.3, smooth_shading=True,
        style='points'
    )

    plotter.camera_position = [(-65, 0, 65), (0, 0, 0), (0, 1, 0)]

    _ = plotter.show(window_size=[800, 800])

    if result_save_path is not None:
        plotter.screenshot(os.path.join(
            result_save_path, 'generated_samples_c_{}.png'.format(str(c).replace('.', '_'))))


for c in [0, 0.01, 0.05, 0.1, 0.2, 0.5]:
    plot_noisy_samples(c)
../../_images/8_SNF_36_0.png ../../_images/8_SNF_36_1.png ../../_images/8_SNF_36_2.png ../../_images/8_SNF_36_3.png ../../_images/8_SNF_36_4.png ../../_images/8_SNF_36_5.png

Input Representation (3D)

import pickle
import pyvista as pv
z = model.encode(dataset.X, c=0., seed=3)
plotter = pv.Plotter() 

plotter.add_mesh(
    pv.PolyData(z.detach().numpy()),
    render_points_as_spheres=True, point_size=10,
    diffuse=0.99, specular=0.8, ambient=0.3, smooth_shading=True,
    scalars=dataset.index_colors,
    style='points', rgb=True
)

plotter.camera_position = [(-10, 0, 10), (0, 0, 0), (0, 1, 0)]
_ = plotter.show(window_size=[800, 800])

if result_save_path is not None:
    plotter.screenshot(os.path.join(result_save_path, '3d_base_representation.png'))
    pickle.dump({
        'reconstructed_state': z,
        'index_colors': dataset.index_colors
    }, open(os.path.join(result_save_path, '3d_base_representation.obj'), 'wb'))
../../_images/8_SNF_40_0.png

Input Representation (3D - c=0.1)

import pickle
import pyvista as pv
z = model.encode(dataset.X, c=0.1, seed=3)
plotter = pv.Plotter() 

plotter.add_mesh(
    pv.PolyData(z.detach().numpy()),
    render_points_as_spheres=True, point_size=10,
    diffuse=0.99, specular=0.8, ambient=0.3, smooth_shading=True,
    scalars=dataset.index_colors,
    style='points', rgb=True
)

plotter.camera_position = [(-10, 0, 10), (0, 0, 0), (0, 1, 0)]
_ = plotter.show(window_size=[800, 800])

if result_save_path is not None:
    plotter.screenshot(os.path.join(result_save_path, '3d_base_representation_c_01.png'))
    pickle.dump({
        'reconstructed_state': z,
        'index_colors': dataset.index_colors
    }, open(os.path.join(result_save_path, '3d_base_representation_c_01.obj'), 'wb'))
../../_images/8_SNF_44_0.png