Skip to content

Commit 9bf346a

Browse files
author
Paul Miles
authored
Merge pull request #3 from prmiles/update_coverage
Update coverage
2 parents 9dd7b58 + 2eb5f96 commit 9bf346a

6 files changed

Lines changed: 189 additions & 224 deletions

File tree

mcmcplot/mcmatplot.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,11 +196,11 @@ def plot_pairwise_correlation_panel(chains, names = None, settings = None):
196196
'skip': 1,
197197
'fig': dict(figsize = (7,5), dpi = 100),
198198
'plot': dict(color = 'b', marker = '.', linestyle = 'none'),
199-
'xlabel': {'s': 'Iteration'},
199+
'xlabel': {},
200200
'ylabel': {},
201201
'title': {},
202202
'add_5095_contours': False,
203-
'plot_50': dict(color = 'r', marker = None, linewidth = 2, linestyle = '--', label = '50%'),
203+
'plot_50': dict(color = 'g', marker = None, linewidth = 2, linestyle = '--', label = '50%'),
204204
'plot_95': dict(color = 'r', marker = None, linewidth = 2, linestyle = '--', label = '95%'),
205205
'add_legend': False,
206206
'legend': dict(loc = 1),
@@ -300,7 +300,11 @@ def plot_chain_metrics(chain, name = None, settings = None):
300300

301301
class Plot:
302302
'''
303-
Plotting routines for analyzing sampling chains from MCMC process.
303+
Wrapper routines for analyzing/plotting sampling chains from MCMC process.
304+
305+
Uses methods from the `matplotlib` package:
306+
307+
https://matplotlib.org/
304308
305309
Attributes:
306310
- :meth:`~plot_density_panel`

mcmcplot/mcseaborn.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def plot_joint_distributions(chains, names = None, sns_style = 'white', settings
2929
'skip': 1,
3030
'sns_style': sns_style,
3131
'sns': sns.axes_style(style = sns_style),
32-
'jointplot': dict(kind='kde', height=6, space=0)
32+
'jointplot': dict(kind='kde', data = None, height = 6.0, space=0)
3333
}
3434

3535
settings = check_settings(default_settings = default_settings, user_settings = settings)
@@ -46,7 +46,8 @@ def plot_joint_distributions(chains, names = None, sns_style = 'white', settings
4646
chain2 = pd.Series(chains[inds,jj-1], name=names[jj-1])
4747

4848
# Show the joint distribution using kernel density estimation
49-
g.append(sns.jointplot(chain1, chain2, **settings['jointplot']))
49+
a = sns.jointplot(x = chain1, y = chain2, **settings['jointplot'])
50+
g.append(a)
5051

5152
return g, settings
5253

@@ -91,4 +92,20 @@ def plot_paired_density_matrix(chains, names = None, sns_style = 'white', index
9192
g.map_lower(settings['ld_type'], **settings['ld'])
9293
g.map_upper(settings['ud_type'], **settings['ud'])
9394
g.map_diag(settings['md_type'], **settings['md'])
94-
return g, settings
95+
return g, settings
96+
97+
class Plot:
98+
'''
99+
Wrapper routines for analyzing/plotting sampling chains from MCMC process.
100+
101+
Uses methods from the `seaborn` package:
102+
103+
https://seaborn.pydata.org/
104+
105+
Attributes:
106+
- :meth:`~plot_joint_distributions`
107+
- :meth:`~plot_paired_density_matrix`
108+
'''
109+
def __init__(self):
110+
self.plot_joint_distributions = plot_joint_distributions
111+
self.plot_paired_density_matrix = plot_paired_density_matrix

mcmcplot/utilities.py

Lines changed: 33 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -6,43 +6,44 @@
66
@author: prmiles
77
"""
88
import numpy as np
9-
from scipy.interpolate import interp1d
109
from scipy import pi,sin,cos
1110
import sys
1211
import math
1312

1413
def check_settings(default_settings, user_settings = None):
15-
16-
settings = default_settings.copy()
17-
18-
options = list(default_settings.keys())
19-
if user_settings is None:
20-
user_settings = {}
21-
user_options = list(user_settings.keys())
22-
for ii in range(len(user_options)):
23-
if user_options[ii] in options:
24-
# check if checking a dictionary
25-
if isinstance(settings[user_options[ii]], dict):
26-
settings[user_options[ii]] = check_settings(settings[user_options[ii]], user_settings[user_options[ii]])
27-
else:
28-
settings[user_options[ii]] = user_settings[user_options[ii]]
29-
if user_options[ii] not in options:
30-
settings[user_options[ii]] = user_settings[user_options[ii]]
31-
32-
return settings
14+
'''
15+
Check user settings with default.
16+
17+
Recursively checks elements of user settings against the defaults and updates settings
18+
as it goes. If a user setting does not exist in the default, then the user setting
19+
is added to the settings. If the setting is defined in both the user and default
20+
settings, then the user setting overrides the default. Otherwise, the default
21+
settings persist.
3322
34-
def generate_plotly_subplot_coords(nparam, ns1, ns2):
35-
sprow = []
36-
spcol = []
37-
counter = 0
38-
for ii in range(ns1):
39-
for jj in range(ns2):
40-
sprow.append(ii + 1)
41-
spcol.append(jj + 1)
42-
counter += 1
43-
if counter is nparam:
44-
break
45-
return sprow, spcol
23+
Args:
24+
* **default_settings** (:py:class:`dict`): Default settings for particular method.
25+
* **user_settings** (:py:class:`dict`): User defined settings.
26+
27+
Returns:
28+
* (:py:class:`dict`): Updated settings.
29+
'''
30+
settings = default_settings.copy() # initially define settings as default
31+
32+
options = list(default_settings.keys()) # get default settings
33+
if user_settings is None: # convert to empty dict
34+
user_settings = {}
35+
user_options = list(user_settings.keys()) # get user settings
36+
for uo in user_options: # iterate through settings
37+
if uo in options:
38+
# check if checking a dictionary
39+
if isinstance(settings[uo], dict):
40+
settings[uo] = check_settings(settings[uo], user_settings[uo])
41+
else:
42+
settings[uo] = user_settings[uo]
43+
if uo not in options:
44+
settings[uo] = user_settings[uo]
45+
46+
return settings
4647

4748
def generate_subplot_grid(nparam = 2):
4849
'''
@@ -85,30 +86,6 @@ def generate_names(nparam, names):
8586
names = extend_names_to_match_nparam(names, nparam)
8687
return names
8788

88-
def setup_plot_features(nparam, names, figsizeinches):
89-
'''
90-
Setup plot features.
91-
92-
Args:
93-
* **nparam** (:py:class:`int`): Number of parameters
94-
* **names** (:py:class:`list`): Names of parameters provided by user
95-
* **figsizeinches** (:py:class:`list`): [Width, Height]
96-
97-
Returns:
98-
* **ns1** (:py:class:`int`): Number of rows in subplot
99-
* **ns2** (:py:class:`int`): Number of columns in subplot
100-
* **names** (:py:class:`list`): List of strings - parameter names
101-
* **figsizeiches** (:py:class:`list`): [Width, Height]
102-
'''
103-
ns1, ns2 = generate_subplot_grid(nparam = nparam)
104-
105-
names = generate_names(nparam = nparam, names = names)
106-
107-
if figsizeinches is None:
108-
figsizeinches = [5,4]
109-
110-
return ns1, ns2, names, figsizeinches
111-
11289
def generate_default_names(nparam):
11390
'''
11491
Generate generic parameter name set.
@@ -348,86 +325,4 @@ def append_to_nrow_ncol_based_on_shape(sh, nrow, ncol):
348325
else:
349326
nrow.append(sh[0])
350327
ncol.append(sh[1])
351-
return nrow, ncol
352-
353-
# --------------------------------------------
354-
def convert_flag_to_boolean(flag):
355-
'''
356-
Convert flag to boolean for backwards compatibility.
357-
358-
Args:
359-
* **flag** (:py:class:`bool` or :py:class:`int`): Flag to specify something.
360-
361-
Returns:
362-
* **flag** (:py:class:`bool`): Flag to converted to boolean.
363-
'''
364-
if flag is 'on':
365-
flag = True
366-
elif flag is 'off':
367-
flag = False
368-
369-
return flag
370-
371-
# --------------------------------------------
372-
def set_local_parameters(ii, local):
373-
'''
374-
Set local parameters based on tests.
375-
376-
:Test 1:
377-
* `local == 0`
378-
:Test 2:
379-
* `local == ii`
380-
381-
Args:
382-
* **ii** (:py:class:`int`): Index.
383-
* **local** (:class:`~numpy.ndarray`): Local flags.
384-
385-
Returns:
386-
* **test** (:class:`~numpy.ndarray`): Array of Booleans indicated test results.
387-
'''
388-
# some parameters may only apply to certain batch sets
389-
test1 = local == 0
390-
test2 = local == ii
391-
test = test1 + test2
392-
return test.reshape(test.size,)
393-
394-
# --------------------------------------------
395-
def empirical_quantiles(x, p = np.array([0.25, 0.5, 0.75])):
396-
'''
397-
Calculate empirical quantiles.
398-
399-
Args:
400-
* **x** (:class:`~numpy.ndarray`): Observations from which to generate quantile.
401-
* **p** (:class:`~numpy.ndarray`): Quantile limits.
402-
403-
Returns:
404-
* (:class:`~numpy.ndarray`): Interpolated quantiles.
405-
'''
406-
407-
# extract number of rows/cols from np.array
408-
n = x.shape[0]
409-
# define vector valued interpolation function
410-
xpoints = range(n)
411-
interpfun = interp1d(xpoints, np.sort(x, 0), axis = 0)
412-
413-
# evaluation points
414-
itpoints = (n-1)*p
415-
416-
return interpfun(itpoints)
417-
418-
# --------------------------------------------
419-
def check_defaults(kwargs, defaults):
420-
'''
421-
Check if defaults are defined in kwargs
422-
423-
Args:
424-
* **kwargs** (:py:class:`dict`): Keyword arguments.
425-
* **defaults** (:py:class:`dict`): Default settings.
426-
427-
Returns:
428-
* **kwargs** (:py:class:`dict`): Updated keyword arguments with at least defaults set.
429-
'''
430-
for ii in defaults:
431-
if ii not in kwargs:
432-
kwargs[ii] = defaults[ii]
433-
return kwargs
328+
return nrow, ncol

test/test_mcmatplot.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import math
1313

1414
import unittest
15-
15+
1616
# --------------------------
1717
class PlotDensityPanel(unittest.TestCase):
1818
def test_basic_plot_features(self):
@@ -66,6 +66,21 @@ def test_basic_plot_features_nsimu_gt_maxpoints(self):
6666
self.assertEqual(f.axes[1].get_xlabel(), 'Iteration', msg = 'Should be Iteration')
6767
plt.close()
6868

69+
def test_basic_plot_features_nsimu_gt_maxpoints_with_pm2std(self):
70+
nsimu = 1000
71+
chains = np.random.random_sample(size = (nsimu,2))
72+
f, _ = MP.plot_chain_panel(chains = chains, settings = dict(add_pm2std = True))
73+
x1, y1 = f.axes[0].lines[0].get_xydata().T
74+
x2, y2 = f.axes[1].lines[0].get_xydata().T
75+
skip = int(math.floor(nsimu/500))
76+
self.assertTrue(np.array_equal(y1, chains[range(0,nsimu,skip),0]), msg = 'Expect y1 to match column 1')
77+
self.assertTrue(np.array_equal(y2, chains[range(0,nsimu,skip),1]), msg = 'Expect y2 to match column 2')
78+
self.assertEqual(f.axes[0].get_xlabel(), '', msg = 'Should be blank')
79+
self.assertEqual(f.axes[1].get_xlabel(), 'Iteration', msg = 'Should be Iteration')
80+
self.assertEqual(len(f.axes[0].lines), 4, msg = 'Expect 4 lines')
81+
self.assertEqual(len(f.axes[1].lines), 4, msg = 'Expect 4 lines')
82+
plt.close()
83+
6984
# --------------------------
7085
class PlotHistogramPanel(unittest.TestCase):
7186
def test_basic_plot_features_nsimu_lt_maxpoints(self):
@@ -106,6 +121,43 @@ def test_basic_plot_features_nsimu_lt_maxpoints(self):
106121

107122
plt.close()
108123

124+
def test_basic_plot_features_nsimu_lt_maxpoints_and_2_chains(self):
125+
chains = np.random.random_sample(size = (100,2))
126+
f, _ = MP.plot_pairwise_correlation_panel(chains = chains)
127+
x1, y1 = f.axes[0].lines[0].get_xydata().T
128+
self.assertTrue(np.array_equal(x1, chains[:,0]), msg = 'Expect x1 to match column 0')
129+
self.assertTrue(np.array_equal(y1, chains[:,1]), msg = 'Expect y1 to match column 1')
130+
for ai in f.axes:
131+
self.assertEqual(ai.get_title(), '', msg = 'Should be blank')
132+
self.assertEqual(f.axes[0].get_xlabel(),'$p_{0}$', msg = 'Expect $p_{0}$')
133+
self.assertEqual(f.axes[0].get_ylabel(),'$p_{1}$', msg = 'Expect $p_{1}$')
134+
plt.close()
135+
136+
def test_basic_plot_features_2c_w_contours(self):
137+
chains = np.random.random_sample(size = (100,2))
138+
f, _ = MP.plot_pairwise_correlation_panel(chains = chains, settings = dict(add_5095_contours = True))
139+
x1, y1 = f.axes[0].lines[0].get_xydata().T
140+
141+
self.assertTrue(np.array_equal(x1, chains[:,0]), msg = 'Expect x1 to match column 0')
142+
self.assertTrue(np.array_equal(y1, chains[:,1]), msg = 'Expect y1 to match column 1')
143+
self.assertEqual(len(f.axes[0].lines), 3, msg = 'Expect 3 lines')
144+
for ai in f.axes:
145+
self.assertEqual(ai.get_title(), '', msg = 'Should be blank')
146+
self.assertEqual(f.axes[0].get_xlabel(),'$p_{0}$', msg = 'Expect $p_{0}$')
147+
self.assertEqual(f.axes[0].get_ylabel(),'$p_{1}$', msg = 'Expect $p_{1}$')
148+
plt.close()
149+
150+
def test_basic_plot_features_2c_w_contours_and_legend(self):
151+
chains = np.random.random_sample(size = (100,2))
152+
f, _ = MP.plot_pairwise_correlation_panel(chains = chains, settings = dict(add_5095_contours = True, add_legend = True))
153+
x1, y1 = f.axes[0].lines[0].get_xydata().T
154+
155+
self.assertTrue(np.array_equal(x1, chains[:,0]), msg = 'Expect x1 to match column 0')
156+
self.assertTrue(np.array_equal(y1, chains[:,1]), msg = 'Expect y1 to match column 1')
157+
self.assertEqual(len(f.axes[0].lines), 3, msg = 'Expect 3 lines')
158+
self.assertEqual(len(f.legends), 1, msg = 'Expect legend')
159+
plt.close()
160+
109161
# --------------------------
110162
class PlotChainMetrics(unittest.TestCase):
111163
def test_basic_plot_features(self):

test/test_mcseaborn.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
"""
4+
Created on Thu Jun 21 12:21:24 2018
5+
6+
@author: prmiles
7+
"""
8+
9+
from mcmcplot import mcseaborn as MP
10+
from mcmcplot import utilities
11+
import matplotlib.pyplot as plt
12+
import seaborn as sns
13+
import numpy as np
14+
15+
import unittest
16+
17+
# --------------------------
18+
class Plot(unittest.TestCase):
19+
def test_plot_attributes(self):
20+
P = MP.Plot()
21+
check_these = ['plot_joint_distributions', 'plot_paired_density_matrix']
22+
for ct in check_these:
23+
self.assertTrue(hasattr(P, ct), msg = str('P contains method {}'.format(ct)))
24+
25+
# --------------------------
26+
class PlotJointDistributions(unittest.TestCase):
27+
def test_basic_joint_distributions(self):
28+
npar = 3
29+
chains = np.random.random_sample(size = (100,npar))
30+
f, _ = MP.plot_joint_distributions(chains = chains)
31+
names = utilities.generate_names(nparam = npar, names = None)
32+
count = 0
33+
for jj in range(2,npar+1):
34+
for ii in range(1,jj):
35+
name1 = names[ii-1]
36+
name2 = names[jj-1]
37+
self.assertEqual(f[count].ax_joint.get_xlabel(), name1, msg = str('Should be {}'.format(name1)))
38+
self.assertEqual(f[count].ax_joint.get_ylabel(), name2, msg = str('Should be {}'.format(name2)))
39+
count += 1
40+
plt.close()
41+
42+
# --------------------------
43+
class PlotPairedGrid(unittest.TestCase):
44+
def test_basic_paired_grid(self):
45+
npar = 3
46+
chains = np.random.random_sample(size = (100,npar))
47+
f, _ = MP.plot_paired_density_matrix(chains = chains)
48+
self.assertTrue(isinstance(f, sns.axisgrid.PairGrid), msg = 'Expect seaborn axisgrid PairGrid object')
49+
plt.close()

0 commit comments

Comments
 (0)