Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion avae/decoders/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def __init__(

def forward(self, x, x_pose):
if self.pose:
return self.decoder(torch.cat([x_pose, x], dim=-1))
return self.decoder(torch.cat([x_pose, x], dim=-1)), self.decoder(torch.cat([x_pose, x], dim=-1))
else:
return self.decoder(x)

Expand Down
48 changes: 39 additions & 9 deletions avae/decoders/differentiable.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from typing import Optional, Tuple
import typing
import logging
import torchvision
import numpy as np
from scipy import stats

import torch
from avae.utils import save_imshow_png
from avae import settings, vis

from avae.decoders.base import AbstractDecoder
from avae.decoders.spatial import (
Expand All @@ -10,6 +17,24 @@
quaternion_to_rotation_matrix,
)

class STEFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
return (input > 0).float()

@staticmethod
def backward(ctx, grad_output):
return torch.nn.functional.hardtanh(grad_output)


class StraightThroughEstimator(torch.nn.Module):
def __init__(self):
super(StraightThroughEstimator, self).__init__()

def forward(self, x):
x = STEFunction.apply(x)
return x


class GaussianSplatRenderer(torch.nn.Module):
"""Perform gaussian splatting."""
Expand Down Expand Up @@ -185,13 +210,15 @@ def __init__(
torch.nn.Linear(latent_dims, n_splats * 3),
torch.nn.Tanh(),
)


# weights are effectively whether a splat is used or not
# use a soft step function to make this `binary` (but differentiable)
# NOTE(arl): not sure if this really makes any difference
self.weights = torch.nn.Sequential(
torch.nn.Linear(latent_dims, n_splats),
torch.nn.Tanh(),
SoftStep(k=10.0),
StraightThroughEstimator(),
#SoftStep(k=10.0),
)
# sigma ends up being scaled by `splat_sigma_range`
self.sigmas = torch.nn.Sequential(
Expand Down Expand Up @@ -226,11 +253,11 @@ def __init__(
else torch.nn.Conv3d
)
self._decoder = torch.nn.Sequential(
conv(1, 32, 3, padding="same"),
torch.nn.ReLU(),
conv(32, 32, 3, padding="same"),
torch.nn.ReLU(),
conv(32, output_channels, 3, padding="same"),
conv(1, 1, 9, padding="same"),
#torch.nn.ReLU(),
#conv(32, 32, 3, padding="same"),
#torch.nn.ReLU(),
#conv(32, output_channels, 3, padding="same"),
)

def configure_renderer(
Expand Down Expand Up @@ -332,11 +359,14 @@ def forward(
x = self._splatter(
splats, weights, sigmas, splat_sigma_range=self._splat_sigma_range
)
# if we're doing a final convolution, do it here

x_before_conv = x

if (
self._output_channels is not None
and self._output_channels != 0
and use_final_convolution
):
x = self._decoder(x)
return x

return x, x_before_conv
2 changes: 1 addition & 1 deletion avae/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def evaluate(

vae.eval()
for b, batch in enumerate(tests):
x, x_hat, lat_mu, lat_logvar, lat, lat_pose, _ = pass_batch(
x, x_hat, x_before_conv, lat_mu, lat_logvar, lat, lat_pose, _ = pass_batch(
device, vae, batch, b, len(tests)
)
x_test.extend(lat_mu.cpu().detach().numpy())
Expand Down
4 changes: 2 additions & 2 deletions avae/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ def forward(self, x):
# reparametrise
latent = self.reparametrise(latent_mu, latent_logvar)
# decode
x_recon = self.decoder(latent, latent_pose) # pose set to None if pd=0
x_recon, x_before_conv = self.decoder(latent, latent_pose) # pose set to None if pd=0

return x_recon, latent_mu, latent_logvar, latent, latent_pose
return x_recon, x_before_conv, latent_mu, latent_logvar, latent, latent_pose

def reparametrise(self, mu, log_var):
if self.training:
Expand Down
28 changes: 28 additions & 0 deletions avae/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ def train(
(
x,
x_hat,
x_before_conv,
lat_mu,
lat_logvar,
lat,
Expand Down Expand Up @@ -435,6 +436,7 @@ def train(
(
v,
v_hat,
v_before_conv,
v_mu,
v_logvar,
vlat,
Expand Down Expand Up @@ -582,6 +584,32 @@ def train(
epoch=epoch,
writer=writer,
)

xx = x_before_conv.detach().cpu().numpy()
vis.plot_array_distribution_tool((xx - np.min(xx)) / (np.max(xx) - np.min(xx)), "xx_normalised")
vis.plot_array_distribution_tool(x_hat.detach().cpu().numpy(), "x_hat")
vis.plot_array_distribution_tool(xx, "xx")

vis.recon_plot(
x,
x_before_conv,
y_train,
data_dim,
mode="trn_before_conv",
epoch=epoch,
writer=writer,
)

vis.recon_plot(
x,
(x_before_conv - torch.min(x_before_conv)) / (torch.max(x_before_conv) - torch.min(x_before_conv)),
y_train,
data_dim,
mode="trn_before_conv_normalised",
epoch=epoch,
writer=writer,
)

vis.recon_plot(
v,
v_hat,
Expand Down
2 changes: 1 addition & 1 deletion avae/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def pose_interpolation(

# Decode interpolated vectors
with torch.no_grad():
decoded_img = vae.decoder(lat, pos)
decoded_img, x_before_conv = vae.decoder(lat, pos)

decoded_grid.append(decoded_img.cpu().squeeze().numpy())

Expand Down
4 changes: 2 additions & 2 deletions avae/utils_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def pass_batch(

# forward
x = x.to(torch.float32)
x_hat, lat_mu, lat_logvar, lat, lat_pose = vae(x)
x_hat, x_before_conv,lat_mu, lat_logvar, lat, lat_pose = vae(x)
if loss is not None:
history_loss = loss(x, x_hat, lat_mu, lat_logvar, e, batch_aff=aff)

Expand All @@ -164,7 +164,7 @@ def pass_batch(
optimizer.step()
optimizer.zero_grad()

return x, x_hat, lat_mu, lat_logvar, lat, lat_pose, history
return x, x_hat, x_before_conv, lat_mu, lat_logvar, lat, lat_pose, history


def add_meta(
Expand Down
63 changes: 58 additions & 5 deletions avae/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -1370,7 +1370,7 @@ def latent_4enc_interpolate_plot(

# Decode the interpolated encoding to generate an image
with torch.no_grad():
decoded_images = vae.decoder(
decoded_images, x_before_conv = vae.decoder(
interpolated_z.view(-1, latent_dim).to(device=device),
(torch.zeros(1, poses[0].shape[0]) + pose_mean).to(
device=device
Expand Down Expand Up @@ -1474,11 +1474,11 @@ def latent_disentamglement_plot(
current_pos_grid = torch.from_numpy(
np.array([pos_means])
).to(device)
current_recon = vae.decoder(
current_recon, x_before_conv = vae.decoder(
current_lat_grid, current_pos_grid
)
else:
current_recon = vae.decoder(current_lat_grid, None)
current_recon, x_before_conv = vae.decoder(current_lat_grid, None)

recon_images.append(current_recon.cpu().squeeze().numpy())

Expand Down Expand Up @@ -1808,15 +1808,15 @@ def interpolations_plot(
)
with torch.no_grad():
if poses is not None:
decoded_images = vae.decoder(
decoded_images, x_before_conv = vae.decoder(
interpolated_z.view(-1, latent_dim).to(device=device),
(
torch.zeros(1, poses[0].shape[0])
+ interpolated_pose
).to(device=device),
)
else:
decoded_images = vae.decoder(
decoded_images, x_before_conv = vae.decoder(
interpolated_z.view(-1, latent_dim).to(device=device),
None,
)
Expand Down Expand Up @@ -2073,3 +2073,56 @@ def latent_space_similarity_plot(
plt.close()
else:
plt.show()



def plot_array_distribution_tool(data, array_name, display: bool = False,
):
"""
This is a tool for developers
Plot histogram, boxplot, and violin plot of the data on a single figure.

Parameters:
data (array-like): The input data.

Returns:
None
"""
if isinstance(data, torch.Tensor):
# If data is a torch tensor, detach and move it to CPU
data = data.detach().cpu().numpy()

elif not isinstance(data, np.ndarray):
# If data is not a numpy array or a torch tensor, convert it to numpy array
data = np.array(data)

# Flatten the array if it has more than two dimensions
if data.ndim > 2:
data = data.flatten()

fig, axes = plt.subplots(3, 1, figsize=(8, 18))

axes[0].hist(data, bins=10, density=True, alpha=0.6, color='b')
axes[0].set_title('Histogram of Data')
axes[0].set_xlabel('Value')
axes[0].set_ylabel('Frequency')

axes[1].boxplot(data)
axes[1].set_title('Boxplot of Data')
axes[1].set_ylabel('Value')

axes[2].violinplot(data)
axes[2].set_title('Violin Plot of Data')
axes[2].set_ylabel('Value')

plt.tight_layout()

if not display:
if not os.path.exists("plots"):
os.mkdir("plots")
plt.savefig(
f"plots/array_{array_name}_stats.{settings.VIS_FORMAT}"
)
plt.close()
else:
plt.show()