diff --git a/GCEm/__init__.py b/GCEm/__init__.py index be5b61e..dd4c152 100644 --- a/GCEm/__init__.py +++ b/GCEm/__init__.py @@ -89,7 +89,7 @@ def gp_model(training_params, training_data, data_processors=None, return Emulator(model, training_params, data, name=name, gpu=gpu) -def _get_gpflow_kernel(names, n_params, active_dims=None, operator='add'): +def _get_gpflow_kernel(names, n_params, active_dims=None, operator='add', return_individual=False): """ Helper function for creating a single GPFlow kernel from a combination of kernel names. @@ -103,6 +103,8 @@ def _get_gpflow_kernel(names, n_params, active_dims=None, operator='add'): The active dimensions to allow the kernel to fit operator: {'add', 'mul'} The operator to use to combine the kernels + return_individual: bool + Whether to return a list of the individually initialized kernels Returns ------- @@ -141,7 +143,7 @@ def init_kernel(k): try: K_Class = kernel_dict[k] except KeyError: - raise ValueError("Invalid Kernel: {}. Please choose from one of: {}".format(k, kernel_dict)) + raise ValueError("Invalid Kernel: {}. Please choose from one of: {}".format(k, kernel_dict.keys())) if issubclass(K_Class, gpflow.kernels.Static): # This covers e.g. White return K_Class(active_dims=active_dims) @@ -154,7 +156,10 @@ def init_kernel(k): else: raise ValueError("Unexpected Kernel type: {}".format(K_Class)) # This shouldn't happen... - return reduce(operator_dict[operator], (init_kernel(k) for k in names)) + if return_individual: + return reduce(operator_dict[operator], (init_kernel(k) for k in names)), [init_kernel(k) for k in names] + else: + return reduce(operator_dict[operator], (init_kernel(k) for k in names)) def cnn_model(training_params, training_data, data_processors=None, diff --git a/GCEm/utils.py b/GCEm/utils.py index 0cd628c..211e8f1 100644 --- a/GCEm/utils.py +++ b/GCEm/utils.py @@ -2,6 +2,94 @@ import tensorflow as tf from tqdm import tqdm +def kernel_plot(kernels, kernel_op=None): + """ Function for plotting kernel decomposition """ + import gpflow + from operator import add, mul + from functools import reduce + import matplotlib.pyplot as plt + + from GCEm.__init__ import _get_gpflow_kernel + + assert isinstance(kernels, list), "Input argument `kernels` must be a list of strings." + assert np.all([type(_)==str for _ in kernels]), "Input argument `kernels` must be a list of strings." + + # Initialize kernels for plotting + kernel_list = _get_gpflow_kernel(names=kernels, n_params=1, + active_dims=None, operator='add', + return_individual=True)[1] + + if kernel_op is not None: + kernel_list.append(_get_gpflow_kernel(names=kernels, n_params=1, + active_dims=None, operator=kernel_op, + return_individual=True)[0]) + + # Plotting function + def plotkernelsample(k, ax, xmin=-3, xmax=3): + xx = np.linspace(xmin, xmax, 100)[:, None] + K = k(xx) + ax.plot(xx, np.random.multivariate_normal(np.zeros(100), K, 5).T) + ax.set_title(k.__class__.__name__) + + # Set up figure + if kernel_op is not None: + if len(kernels)>=2: + ncols = 3 + else: + ncols = int(len(kernels)+1) + else: + if len(kernels)>=3: + ncols = 3 + else: + ncols = int(len(kernels)) + + if kernel_op is not None: + nrows = int(np.ceil((len(kernels)+1)/3)) + else: + nrows = int(np.ceil(len(kernels)/3)) + + # Pad end of kernel list with 0 and 1 + # So it matches axes shape (nrows*ncols) + kernel_save = kernels*1 + if kernel_op is not None: + kernels.extend([1] * (1)) + kernels.extend([0] * (nrows*ncols - len(kernels))) + + # Plot + fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(6*nrows, 3*nrows), dpi=100, sharex=True, sharey=True) + + for k_idx, k in enumerate(kernels): + if k==0: + # Redundant axes + axes.flatten()[k_idx].axis('off') + continue + + elif k==1: + # Sum/Product axis + plotkernelsample(kernel_list[-1], axes.flatten()[k_idx]) + else: + K_class = kernel_list[k_idx] + plotkernelsample(K_class, axes.flatten()[k_idx]) + + if kernel_op is not None: + extra = 1 + else: + extra = 0 + + for ax in axes.flatten()[0:len(kernel_save)+extra]: + xmin, xmax = ax.get_xlim() + ymin, ymax = ax.get_ylim() + ax.hlines(0, xmin-0.1, xmax+0.1, 'k', '--', zorder=-10, lw=0.5) + ax.vlines(0, ymin-0.1, ymax+0.1, 'k', '--', zorder=-10, lw=0.5) + ax.set_xlim(xmin, xmax) + ax.set_ylim(ymin, ymax) + + _ = axes.flatten()[0].set_ylim(-3, 3) + fig.tight_layout() + + return fig, axes + + def add_121_line(ax): import numpy as np