Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,34 @@
# Neural Importance Sampling

## Configuration example
config/config.conf:
```
train
{
function = Gaussian
epochs = 200
batch_size = 100
learning_rate = 0.0001
num_hidden_dims = 128
num_coupling_layers = 4
num_hidden_layers = 5
num_blob_bins = 0
num_piecewise_bins = 5
loss = MSE
coupling_name="piecewiseLinear",
}
logging
{
plot_dir_name = "./Plots/SaveOutput"
save_plots = True
save_plt_interval = 10
plot_dimension = 2
tensorboard
{
use_tensorboard = False
# wandb_project = "NIS"
}
}
```
## How to run code

Try :
Expand Down
38 changes: 16 additions & 22 deletions integration_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from network import MLP
from transform import CompositeTransform
from utils import pyhocon_wrapper
from visualize import visualize
from visualize import visualize, FunctionVisualizer


@dataclass
Expand All @@ -38,6 +38,8 @@ class ExperimentConfig:
save_plt_interval: Frequency for plot saving (default : 10)
wandb_project: Name of wandb project in neural_importance_sampling team
use_tensorboard: Use tensorboard logging
save_plots: save plots if ndims >= 2
plot_dimension: add 2d or 3d plot
"""

experiment_dir_name: str = f"test_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
Expand All @@ -56,6 +58,8 @@ class ExperimentConfig:
save_plt_interval: int = 10
wandb_project: Union[str, None] = None
use_tensorboard: bool = False
save_plots: blob = True
plot_dimension: int = 2

@classmethod
def init_from_pyhocon(cls, pyhocon_config: pyhocon_wrapper.ConfigTree):
Expand All @@ -73,7 +77,9 @@ def init_from_pyhocon(cls, pyhocon_config: pyhocon_wrapper.ConfigTree):
funcname=pyhocon_config.get_string('train.function'),
coupling_name=pyhocon_config.get_string('train.coupling_name'),
wandb_project=pyhocon_config.get_string('logging.tensorboard.wandb_project', None),
use_tensorboard=pyhocon_config.get_bool('logging.tensorboard.use_tensorboard', False)
use_tensorboard=pyhocon_config.get_bool('logging.tensorboard.use_tensorboard', False),
save_plots=pyhocon_config.get_bool('logging.save_plots', False),
plot_dimension=pyhocon_config.get_int('logging.plot_dimension', 2),
)


Expand Down Expand Up @@ -154,12 +160,10 @@ def run_experiment(config: ExperimentConfig):
scheduler=scheduler,
loss_func=config.loss_func)

if config.ndims == 2: # if 2D -> prepare x1,x2 gird for visualize
grid_x1, grid_x2 = torch.meshgrid(torch.linspace(0, 1, 100), torch.linspace(0, 1, 100))
grid = torch.cat([grid_x1.reshape(-1, 1), grid_x2.reshape(-1, 1)], axis=1)
func_out = function(grid).reshape(100, 100)
if config.save_plots and config.ndims >= 2:
function_visualizer = FunctionVisualizer(vis_object=visObject, function=function, input_dimension=config.ndims,
max_plot_dimension=config.plot_dimension)

bins = None
means = []
errors = []
for epoch in range(1, config.epochs + 1):
Expand Down Expand Up @@ -195,31 +199,21 @@ def run_experiment(config: ExperimentConfig):
visObject.AddCurves(x=epoch, x_err=0, title="Integral value", dict_val=dict_val)
visObject.AddCurves(x=epoch, x_err=0, title="Integral uncertainty", dict_val=dict_error)

if config.ndims == 2: # if 2D -> visualize distribution
if bins is None:
bins, x_edges, y_edges = np.histogram2d(x[:, 0], x[:, 1], bins=20, range=[[0, 1], [0, 1]])
else:
newbins, x_edges, y_edges = np.histogram2d(x[:, 0], x[:, 1], bins=20, range=[[0, 1], [0, 1]])
bins += newbins.T
x_centers = (x_edges[:-1] + x_edges[1:]) / 2
y_centers = (y_edges[:-1] + y_edges[1:]) / 2
x_centers, y_centers = np.meshgrid(x_centers, y_centers)
visObject.AddPointSet(x, title="Observed $x$ %s" % config.coupling_name, color='b')
visObject.AddContour(x_centers, y_centers, bins, "Cumulative %s" % config.coupling_name)
if config.save_plots and config.ndims >= 2: # if 2D -> visualize distribution
visualize_x = function_visualizer.add_trained_function_plot(x=x, plot_name="Cumulative %s" % config.coupling_name)
visObject.AddPointSet(visualize_x, title="Observed $x$ %s" % config.coupling_name, color='b')
visObject.AddPointSet(z, title="Latent space $z$", color='b')


if config.use_tensorboard:
tb_writer.add_scalar('Train/Loss', loss, epoch)
tb_writer.add_scalar('Train/Integral', mean_wgt, epoch)
tb_writer.add_scalar('Train/LR', lr, epoch)

# Plot function output #
if epoch % config.save_plt_interval == 0:
if config.ndims == 2:
if config.save_plots and config.ndims >= 2:
visObject.AddPointSet(z, title="Latent space $z$", color='b')
visObject.AddContour(grid_x1, grid_x2, func_out,
"Target function : " + function.name)
function_visualizer.add_target_function_plot()
visObject.MakePlot(epoch)


Expand Down
3 changes: 2 additions & 1 deletion integrator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import sys
import numpy as no
import torch

import warnings
warnings.filterwarnings("ignore", category=UserWarning)
from divergences import Divergence

class Integrator():
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ torch==1.12.1
torchvision==0.13.1
tensorboard==2.10.1
numpy==1.23.3
scikit-learn==1.1.3
matplotlib==3.6.1
wandb==0.13.4
pyhocon==0.3.59
Expand Down
125 changes: 122 additions & 3 deletions visualize.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import os
import sys
import numpy as np
Expand All @@ -7,19 +8,24 @@
import torch
import shutil

from sklearn.manifold import LocallyLinearEmbedding
from sklearn.preprocessing import MinMaxScaler


class visualize:
def __init__(self,path):
self.idx = 0
self.data_dict = {}
self.cont_dict = {}
self.plot_3d_dict = {}
self.hist_dict = {}
self.curv_dict = {}
self.path = os.path.abspath(path)
if os.path.exists(self.path):
print ("Already exists directory %s, will recreate it"%self.path)
shutil.rmtree(self.path)
os.makedirs(self.path)
print ("Created directory %s"%self.path)
print("Created directory %s"%self.path)


def AddPointSet(self,data,title,color):
Expand All @@ -31,6 +37,13 @@ def AddContour(self,X,Y,Z,title):
assert Z.shape[0] == Z.shape[1]
self.cont_dict[title] = [X,Y,Z]

def Add3dPlot(self, X, Y, Z, func, title):
assert X.shape[0] == X.shape[1]
assert Y.shape[0] == Y.shape[1]
assert Z.shape[0] == Z.shape[1]
assert func.shape[0] == func.shape[1]
self.plot_3d_dict[title] = [X,Y,Z,func]

def AddCurves(self,x,x_err,dict_val,title):
# x : float
# x_err : float or tuple (down, up)
Expand Down Expand Up @@ -65,10 +78,11 @@ def AddHistogram(self,vec,bins,title):
def MakePlot(self,epoch):
N_data = len(self.data_dict.keys())
N_cont = len(self.cont_dict.keys())
N_plot_3d = len(self.plot_3d_dict.keys())
N_hist = len(self.hist_dict.keys())
N_curv = len(self.curv_dict.keys())
Nh = max(N_data,N_cont,N_curv,N_hist)
Nv = int(N_data!=0)+int(N_cont!=0)+int(N_hist!=0)+int(N_curv!=0)
Nh = max(N_data,N_cont,N_plot_3d,N_curv,N_hist)
Nv = int(N_data!=0)+int(N_cont!=0)+int(N_plot_3d!=0)+int(N_hist!=0)+int(N_curv!=0)
fig, axs = plt.subplots(Nv,Nh,figsize=(Nh*6,Nv*6))
plt.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.9, wspace=0.2, hspace=0.2)
fig.suptitle("Epoch %d"%epoch,fontsize=22)
Expand All @@ -77,6 +91,7 @@ def MakePlot(self,epoch):
axs = axs.reshape(1,-1)
idx_data = 0
idx_cont = 0
idx_3d_plot = 0
idx_hist = 0
idx_curv = 0
idx_vert = 0
Expand Down Expand Up @@ -110,6 +125,16 @@ def MakePlot(self,epoch):
idx_cont += 1
idx_vert += 1

##### 3d plots #####
if len(self.plot_3d_dict.keys()) != 0:
for title,data in self.plot_3d_dict.items():
axs[idx_vert, idx_3d_plot].remove()
ax = fig.add_subplot(Nv, Nh, idx_vert * Nv + idx_3d_plot, projection='3d')
ax.set_title(title,fontsize=20)
cs = ax.scatter(data[0],data[1],data[2], c=data[3], s=20)
idx_3d_plot += 1
idx_vert += 1

##### Hist plots ####
if len(self.hist_dict.keys()) != 0:
for title,(centers,vals,widths) in self.hist_dict.items():
Expand Down Expand Up @@ -150,6 +175,100 @@ def MakePlot(self,epoch):
fig.savefig(path_fig)
plt.close(fig)
self.idx += 1


class FunctionVisualizer:
def __init__(self, vis_object: visualize, function, input_dimension, max_plot_dimension):

self.vis_object = vis_object
self.function = function
self.input_dimension = input_dimension
self.n_components = min(max_plot_dimension, input_dimension)

assert self.n_components in [2, 3], "plot_dimension can be 2 or 3"
if self.n_components < self.input_dimension:
self.use_dimension_reduction = True
self._init_dimension_reduction()
else:
self.use_dimension_reduction = False

self.grids, self.func_out = self.compute_target_function_grid()
self.bins = None

def _init_dimension_reduction(self):
self.dimension_transform = LocallyLinearEmbedding(n_components=self.n_components)
self.scaler = MinMaxScaler()

def compute_target_function_grid(self):
"""
generate gird for target function plot
"""
if self.n_components == 2:
num_grid_samples = 100**self.n_components
target_shape = [100] * self.n_components
elif self.n_components == 3:
num_grid_samples = 10 ** self.n_components
target_shape = [10] * self.n_components
num_samples_per_dimension = math.ceil(num_grid_samples ** (1 / self.input_dimension))
grid = torch.meshgrid(*[torch.linspace(0, 1, num_samples_per_dimension) for dim in range(self.input_dimension)])
grid = torch.cat([dim_grid.reshape(-1, 1) for dim_grid in grid], axis=1)[:num_grid_samples]
func_out = self.function(grid).reshape(target_shape)

if self.use_dimension_reduction:
grid = self.dimension_transform.fit_transform(grid)
grid = self.scaler.fit_transform(grid)
grids = [grid[:num_grid_samples, dim].reshape(target_shape) for dim in range(self.n_components)]
return grids, func_out

def add_target_function_plot(self):
"""
Add target function plot to visualize object
"""
if self.n_components == 2:
self.vis_object.AddContour(*self.grids, self.func_out,
"Target function : " + self.function.name)
elif self.n_components == 3:
self.vis_object.Add3dPlot(*self.grids, self.func_out,
"Target function : " + self.function.name)

def add_trained_function_plot(self, x, plot_name) -> np.ndarray:
"""
Add trained function plot to visualize object
return input x or transformed input x
"""
if self.use_dimension_reduction:
visualize_x = self.dimension_transform.transform(x)
visualize_x = self.scaler.transform(visualize_x)
else:
visualize_x = x

if self.n_components == 2:
if self.bins is None:
bins, x_edges, y_edges = np.histogram2d(visualize_x[:, 0], visualize_x[:, 1], bins=20,
range=[[0, 1], [0, 1]])
else:
newbins, x_edges, y_edges = np.histogram2d(visualize_x[:, 0], visualize_x[:, 1], bins=20,
range=[[0, 1], [0, 1]])
self.bins += newbins.T
x_centers = (x_edges[:-1] + x_edges[1:]) / 2
y_centers = (y_edges[:-1] + y_edges[1:]) / 2
x_centers, y_centers = np.meshgrid(x_centers, y_centers)
self.vis_object.AddContour(x_centers, y_centers, bins, plot_name)
return visualize_x
elif self.n_components == 3:
if self.bins is None:
bins, (x_edges, y_edges, z_edges) = np.histogramdd(visualize_x, bins=10,
range=[[0, 1], [0, 1], [0, 1]])
else:
newbins, (x_edges, y_edges, z_edges) = np.histogramdd(visualize_x, bins=10,
range=[[0, 1],[0, 1], [0, 1]])
self.bins += newbins.T
x_centers = (x_edges[:-1] + x_edges[1:]) / 2
y_centers = (y_edges[:-1] + y_edges[1:]) / 2
z_centers = (z_edges[:-1] + z_edges[1:]) / 2
x_centers, y_centers, z_centers = np.meshgrid(x_centers, y_centers, z_centers)
self.vis_object.Add3dPlot(x_centers, y_centers, z_centers, bins, plot_name)
return visualize_x


#import numpy as np
Expand Down