From f8caa056b06acd531fb36338b382422f5160399c Mon Sep 17 00:00:00 2001 From: marjanfamili Date: Fri, 12 Apr 2024 12:18:44 +0100 Subject: [PATCH] added saving before conv layers and straight through estimator --- avae/decoders/decoders.py | 2 +- avae/decoders/differentiable.py | 48 ++++++++++++++++++++----- avae/evaluate.py | 2 +- avae/models.py | 4 +-- avae/train.py | 28 +++++++++++++++ avae/utils.py | 2 +- avae/utils_learning.py | 4 +-- avae/vis.py | 63 ++++++++++++++++++++++++++++++--- 8 files changed, 132 insertions(+), 21 deletions(-) diff --git a/avae/decoders/decoders.py b/avae/decoders/decoders.py index d87a2e9f..7b3f3bbc 100644 --- a/avae/decoders/decoders.py +++ b/avae/decoders/decoders.py @@ -253,7 +253,7 @@ def __init__( def forward(self, x, x_pose): if self.pose: - return self.decoder(torch.cat([x_pose, x], dim=-1)) + return self.decoder(torch.cat([x_pose, x], dim=-1)), self.decoder(torch.cat([x_pose, x], dim=-1)) else: return self.decoder(x) diff --git a/avae/decoders/differentiable.py b/avae/decoders/differentiable.py index 4c52f032..29e3b059 100644 --- a/avae/decoders/differentiable.py +++ b/avae/decoders/differentiable.py @@ -1,6 +1,13 @@ from typing import Optional, Tuple +import typing +import logging +import torchvision +import numpy as np +from scipy import stats import torch +from avae.utils import save_imshow_png +from avae import settings, vis from avae.decoders.base import AbstractDecoder from avae.decoders.spatial import ( @@ -10,6 +17,24 @@ quaternion_to_rotation_matrix, ) +class STEFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + return (input > 0).float() + + @staticmethod + def backward(ctx, grad_output): + return torch.nn.functional.hardtanh(grad_output) + + +class StraightThroughEstimator(torch.nn.Module): + def __init__(self): + super(StraightThroughEstimator, self).__init__() + + def forward(self, x): + x = STEFunction.apply(x) + return x + class GaussianSplatRenderer(torch.nn.Module): """Perform gaussian splatting.""" @@ -185,13 +210,15 @@ def __init__( torch.nn.Linear(latent_dims, n_splats * 3), torch.nn.Tanh(), ) + + # weights are effectively whether a splat is used or not # use a soft step function to make this `binary` (but differentiable) # NOTE(arl): not sure if this really makes any difference self.weights = torch.nn.Sequential( torch.nn.Linear(latent_dims, n_splats), - torch.nn.Tanh(), - SoftStep(k=10.0), + StraightThroughEstimator(), + #SoftStep(k=10.0), ) # sigma ends up being scaled by `splat_sigma_range` self.sigmas = torch.nn.Sequential( @@ -226,11 +253,11 @@ def __init__( else torch.nn.Conv3d ) self._decoder = torch.nn.Sequential( - conv(1, 32, 3, padding="same"), - torch.nn.ReLU(), - conv(32, 32, 3, padding="same"), - torch.nn.ReLU(), - conv(32, output_channels, 3, padding="same"), + conv(1, 1, 9, padding="same"), + #torch.nn.ReLU(), + #conv(32, 32, 3, padding="same"), + #torch.nn.ReLU(), + #conv(32, output_channels, 3, padding="same"), ) def configure_renderer( @@ -332,11 +359,14 @@ def forward( x = self._splatter( splats, weights, sigmas, splat_sigma_range=self._splat_sigma_range ) - # if we're doing a final convolution, do it here + + x_before_conv = x + if ( self._output_channels is not None and self._output_channels != 0 and use_final_convolution ): x = self._decoder(x) - return x + + return x, x_before_conv diff --git a/avae/evaluate.py b/avae/evaluate.py index e9991a34..218cd7e6 100644 --- a/avae/evaluate.py +++ b/avae/evaluate.py @@ -121,7 +121,7 @@ def evaluate( vae.eval() for b, batch in enumerate(tests): - x, x_hat, lat_mu, lat_logvar, lat, lat_pose, _ = pass_batch( + x, x_hat, x_before_conv, lat_mu, lat_logvar, lat, lat_pose, _ = pass_batch( device, vae, batch, b, len(tests) ) x_test.extend(lat_mu.cpu().detach().numpy()) diff --git a/avae/models.py b/avae/models.py index c0571353..9b2711b7 100644 --- a/avae/models.py +++ b/avae/models.py @@ -88,9 +88,9 @@ def forward(self, x): # reparametrise latent = self.reparametrise(latent_mu, latent_logvar) # decode - x_recon = self.decoder(latent, latent_pose) # pose set to None if pd=0 + x_recon, x_before_conv = self.decoder(latent, latent_pose) # pose set to None if pd=0 - return x_recon, latent_mu, latent_logvar, latent, latent_pose + return x_recon, x_before_conv, latent_mu, latent_logvar, latent, latent_pose def reparametrise(self, mu, log_var): if self.training: diff --git a/avae/train.py b/avae/train.py index 7f412d00..fb721046 100644 --- a/avae/train.py +++ b/avae/train.py @@ -380,6 +380,7 @@ def train( ( x, x_hat, + x_before_conv, lat_mu, lat_logvar, lat, @@ -435,6 +436,7 @@ def train( ( v, v_hat, + v_before_conv, v_mu, v_logvar, vlat, @@ -582,6 +584,32 @@ def train( epoch=epoch, writer=writer, ) + + xx = x_before_conv.detach().cpu().numpy() + vis.plot_array_distribution_tool((xx - np.min(xx)) / (np.max(xx) - np.min(xx)), "xx_normalised") + vis.plot_array_distribution_tool(x_hat.detach().cpu().numpy(), "x_hat") + vis.plot_array_distribution_tool(xx, "xx") + + vis.recon_plot( + x, + x_before_conv, + y_train, + data_dim, + mode="trn_before_conv", + epoch=epoch, + writer=writer, + ) + + vis.recon_plot( + x, + (x_before_conv - torch.min(x_before_conv)) / (torch.max(x_before_conv) - torch.min(x_before_conv)), + y_train, + data_dim, + mode="trn_before_conv_normalised", + epoch=epoch, + writer=writer, + ) + vis.recon_plot( v, v_hat, diff --git a/avae/utils.py b/avae/utils.py index 85b9e25d..f67b881f 100644 --- a/avae/utils.py +++ b/avae/utils.py @@ -302,7 +302,7 @@ def pose_interpolation( # Decode interpolated vectors with torch.no_grad(): - decoded_img = vae.decoder(lat, pos) + decoded_img, x_before_conv = vae.decoder(lat, pos) decoded_grid.append(decoded_img.cpu().squeeze().numpy()) diff --git a/avae/utils_learning.py b/avae/utils_learning.py index 767c3c8f..d7d4cbe7 100644 --- a/avae/utils_learning.py +++ b/avae/utils_learning.py @@ -140,7 +140,7 @@ def pass_batch( # forward x = x.to(torch.float32) - x_hat, lat_mu, lat_logvar, lat, lat_pose = vae(x) + x_hat, x_before_conv,lat_mu, lat_logvar, lat, lat_pose = vae(x) if loss is not None: history_loss = loss(x, x_hat, lat_mu, lat_logvar, e, batch_aff=aff) @@ -164,7 +164,7 @@ def pass_batch( optimizer.step() optimizer.zero_grad() - return x, x_hat, lat_mu, lat_logvar, lat, lat_pose, history + return x, x_hat, x_before_conv, lat_mu, lat_logvar, lat, lat_pose, history def add_meta( diff --git a/avae/vis.py b/avae/vis.py index 4380268c..bca6353a 100644 --- a/avae/vis.py +++ b/avae/vis.py @@ -1370,7 +1370,7 @@ def latent_4enc_interpolate_plot( # Decode the interpolated encoding to generate an image with torch.no_grad(): - decoded_images = vae.decoder( + decoded_images, x_before_conv = vae.decoder( interpolated_z.view(-1, latent_dim).to(device=device), (torch.zeros(1, poses[0].shape[0]) + pose_mean).to( device=device @@ -1474,11 +1474,11 @@ def latent_disentamglement_plot( current_pos_grid = torch.from_numpy( np.array([pos_means]) ).to(device) - current_recon = vae.decoder( + current_recon, x_before_conv = vae.decoder( current_lat_grid, current_pos_grid ) else: - current_recon = vae.decoder(current_lat_grid, None) + current_recon, x_before_conv = vae.decoder(current_lat_grid, None) recon_images.append(current_recon.cpu().squeeze().numpy()) @@ -1808,7 +1808,7 @@ def interpolations_plot( ) with torch.no_grad(): if poses is not None: - decoded_images = vae.decoder( + decoded_images, x_before_conv = vae.decoder( interpolated_z.view(-1, latent_dim).to(device=device), ( torch.zeros(1, poses[0].shape[0]) @@ -1816,7 +1816,7 @@ def interpolations_plot( ).to(device=device), ) else: - decoded_images = vae.decoder( + decoded_images, x_before_conv = vae.decoder( interpolated_z.view(-1, latent_dim).to(device=device), None, ) @@ -2073,3 +2073,56 @@ def latent_space_similarity_plot( plt.close() else: plt.show() + + + +def plot_array_distribution_tool(data, array_name, display: bool = False, +): + """ + This is a tool for developers + Plot histogram, boxplot, and violin plot of the data on a single figure. + + Parameters: + data (array-like): The input data. + + Returns: + None + """ + if isinstance(data, torch.Tensor): + # If data is a torch tensor, detach and move it to CPU + data = data.detach().cpu().numpy() + + elif not isinstance(data, np.ndarray): + # If data is not a numpy array or a torch tensor, convert it to numpy array + data = np.array(data) + + # Flatten the array if it has more than two dimensions + if data.ndim > 2: + data = data.flatten() + + fig, axes = plt.subplots(3, 1, figsize=(8, 18)) + + axes[0].hist(data, bins=10, density=True, alpha=0.6, color='b') + axes[0].set_title('Histogram of Data') + axes[0].set_xlabel('Value') + axes[0].set_ylabel('Frequency') + + axes[1].boxplot(data) + axes[1].set_title('Boxplot of Data') + axes[1].set_ylabel('Value') + + axes[2].violinplot(data) + axes[2].set_title('Violin Plot of Data') + axes[2].set_ylabel('Value') + + plt.tight_layout() + + if not display: + if not os.path.exists("plots"): + os.mkdir("plots") + plt.savefig( + f"plots/array_{array_name}_stats.{settings.VIS_FORMAT}" + ) + plt.close() + else: + plt.show()