forked from probml/pyprobml
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdirichlet_3d_spiky_plot.py
More file actions
57 lines (47 loc) · 1.89 KB
/
dirichlet_3d_spiky_plot.py
File metadata and controls
57 lines (47 loc) · 1.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import numpy as np
import matplotlib.pyplot as plt
from pyprobml_utils import save_fig
from mpl_toolkits.mplot3d import proj3d
from scipy.stats import dirichlet
grain = 20 #how many points along each axis to plot
edgedist = 0.008 #How close to an extreme value of say [1,0,0] are we willing to plot.
weight = np.linspace(0, 1, grain)
#Most extreme corners of the sample space
Corner1 = np.array([1.0 - edgedist*2, edgedist, edgedist])
Corner2 = np.array([edgedist, 1.0 - edgedist*2, edgedist])
Corner3 = np.array([edgedist, edgedist, 1.0 - edgedist*2])
#Probability density function that accepts 2D coordiantes
def dpdf(v1,v2, alphavec):
if (v1 + v2)>1:
out = np.nan
else:
vec = v1 * Corner1 + v2 * Corner2 + (1.0 - v1 - v2)*Corner3
out = dirichlet.pdf(vec, alphavec)
return(out)
#Dirichlet parameter
alphas = [ [20,20,20], [3,3,20], [0.1,0.1,0.1] ]
for i in range(len(alphas)):
alphavec = np.array(alphas[i])
azim = 20
probs = np.array([dpdf(v1, v2, alphavec) for v1 in weight for v2 in weight]).reshape(-1,grain)
fig = plt.figure(figsize=(20,15))
ax = fig.add_subplot(111, projection='3d')
X,Y = np.meshgrid(weight, weight)
ax.plot_surface(Y, X, probs, cmap = 'jet', vmin=0, vmax=3,rstride=1,cstride=1, linewidth=0)
ax.view_init(elev=25, azim=azim)
ax.set_zlabel('p')
ttl = ','.join(['{:0.2f}'.format(d) for d in alphavec])
ax.set_title(ttl)
alpha = int(np.round(alphavec[0]*10))
save_fig('dirSimplexAlpha{}.pdf'.format(alpha))
plt.show()
if 0:
fig = plt.figure(figsize=(20,15))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(Y, X, probs, cmap = 'jet', vmin=0, vmax=3,rstride=1,cstride=1, linewidth=0)
ax.view_init(elev=25, azim=200)
ax.set_zlabel('p')
ttl = ','.join(['{:0.2f}'.format(d) for d in alphavec])
ax.set_title(ttl)
alpha = np.round(alphavec[0]*10)
plt.show()