Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions GCEm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def gp_model(training_params, training_data, data_processors=None,
return Emulator(model, training_params, data, name=name, gpu=gpu)


def _get_gpflow_kernel(names, n_params, active_dims=None, operator='add'):
def _get_gpflow_kernel(names, n_params, active_dims=None, operator='add', return_individual=False):
"""
Helper function for creating a single GPFlow kernel from a combination of kernel names.

Expand All @@ -103,6 +103,8 @@ def _get_gpflow_kernel(names, n_params, active_dims=None, operator='add'):
The active dimensions to allow the kernel to fit
operator: {'add', 'mul'}
The operator to use to combine the kernels
return_individual: bool
Whether to return a list of the individually initialized kernels

Returns
-------
Expand Down Expand Up @@ -141,7 +143,7 @@ def init_kernel(k):
try:
K_Class = kernel_dict[k]
except KeyError:
raise ValueError("Invalid Kernel: {}. Please choose from one of: {}".format(k, kernel_dict))
raise ValueError("Invalid Kernel: {}. Please choose from one of: {}".format(k, kernel_dict.keys()))

if issubclass(K_Class, gpflow.kernels.Static): # This covers e.g. White
return K_Class(active_dims=active_dims)
Expand All @@ -154,7 +156,10 @@ def init_kernel(k):
else:
raise ValueError("Unexpected Kernel type: {}".format(K_Class)) # This shouldn't happen...

return reduce(operator_dict[operator], (init_kernel(k) for k in names))
if return_individual:
return reduce(operator_dict[operator], (init_kernel(k) for k in names)), [init_kernel(k) for k in names]
else:
return reduce(operator_dict[operator], (init_kernel(k) for k in names))


def cnn_model(training_params, training_data, data_processors=None,
Expand Down
88 changes: 88 additions & 0 deletions GCEm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,94 @@
import tensorflow as tf
from tqdm import tqdm

def kernel_plot(kernels, kernel_op=None):
""" Function for plotting kernel decomposition """
import gpflow
from operator import add, mul
from functools import reduce
import matplotlib.pyplot as plt

from GCEm.__init__ import _get_gpflow_kernel

assert isinstance(kernels, list), "Input argument `kernels` must be a list of strings."
assert np.all([type(_)==str for _ in kernels]), "Input argument `kernels` must be a list of strings."

# Initialize kernels for plotting
kernel_list = _get_gpflow_kernel(names=kernels, n_params=1,
active_dims=None, operator='add',
return_individual=True)[1]

if kernel_op is not None:
kernel_list.append(_get_gpflow_kernel(names=kernels, n_params=1,
active_dims=None, operator=kernel_op,
return_individual=True)[0])

# Plotting function
def plotkernelsample(k, ax, xmin=-3, xmax=3):
xx = np.linspace(xmin, xmax, 100)[:, None]
K = k(xx)
ax.plot(xx, np.random.multivariate_normal(np.zeros(100), K, 5).T)
ax.set_title(k.__class__.__name__)

# Set up figure
if kernel_op is not None:
if len(kernels)>=2:
ncols = 3
else:
ncols = int(len(kernels)+1)
else:
if len(kernels)>=3:
ncols = 3
else:
ncols = int(len(kernels))

if kernel_op is not None:
nrows = int(np.ceil((len(kernels)+1)/3))
else:
nrows = int(np.ceil(len(kernels)/3))

# Pad end of kernel list with 0 and 1
# So it matches axes shape (nrows*ncols)
kernel_save = kernels*1
if kernel_op is not None:
kernels.extend([1] * (1))
kernels.extend([0] * (nrows*ncols - len(kernels)))

# Plot
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(6*nrows, 3*nrows), dpi=100, sharex=True, sharey=True)

for k_idx, k in enumerate(kernels):
if k==0:
# Redundant axes
axes.flatten()[k_idx].axis('off')
continue

elif k==1:
# Sum/Product axis
plotkernelsample(kernel_list[-1], axes.flatten()[k_idx])
else:
K_class = kernel_list[k_idx]
plotkernelsample(K_class, axes.flatten()[k_idx])

if kernel_op is not None:
extra = 1
else:
extra = 0

for ax in axes.flatten()[0:len(kernel_save)+extra]:
xmin, xmax = ax.get_xlim()
ymin, ymax = ax.get_ylim()
ax.hlines(0, xmin-0.1, xmax+0.1, 'k', '--', zorder=-10, lw=0.5)
ax.vlines(0, ymin-0.1, ymax+0.1, 'k', '--', zorder=-10, lw=0.5)
ax.set_xlim(xmin, xmax)
ax.set_ylim(ymin, ymax)

_ = axes.flatten()[0].set_ylim(-3, 3)
fig.tight_layout()

return fig, axes



def add_121_line(ax):
import numpy as np
Expand Down