-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
67 lines (55 loc) · 2.09 KB
/
utils.py
File metadata and controls
67 lines (55 loc) · 2.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from pathlib import Path
from typing import List
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
from torch import Tensor
def generate_random_corr_matrix(d: int) -> Tensor:
A = torch.rand(d, d)
corr_matrix = A @ A.t()
D = torch.diag(1.0 / torch.sqrt(torch.diag(corr_matrix)))
corr_matrix = D @ corr_matrix @ D
return corr_matrix
def plot_marginal_distributions(
distributions: List,
sample: Tensor,
filepath: Path) -> None:
d = len(distributions)
rows = (d + 2) // 3
fig, axs = plt.subplots(rows, 3, figsize=(15, 5 * rows))
axs = axs.flatten()
for i in range(d):
marginal_sample = sample[:, i].numpy()
if is_discrete(distributions[i]):
values, counts = np.unique(marginal_sample, return_counts=True)
counts = counts / counts.sum()
axs[i].bar(values, counts, alpha=0.6, color='b')
pmf_values = distributions[i].pmf(values)
axs[i].plot(values, pmf_values, 'rx', lw=2, label='True PMF')
else:
axs[i].hist(marginal_sample, bins=100, density=True, edgecolor='black', alpha=0.6)
x = np.linspace(marginal_sample.min(), marginal_sample.max(), 1000)
pdf_values = distributions[i].pdf(x)
axs[i].plot(x, pdf_values, 'r', lw=2, label='True PDF')
axs[i].set_xlabel('x')
axs[i].set_ylabel('Frequency')
for j in range(i+1, len(axs)):
fig.delaxes(axs[j])
plt.tight_layout()
fig.savefig(filepath)
plt.close(fig)
def plot_corr_matrices(
corr_matrix_target: Tensor,
corr_matrix_sample: Tensor,
filepath: Path):
fig, axs = plt.subplots(1, 2, figsize=(20, 10))
sns.heatmap(corr_matrix_target.numpy(), annot=True, fmt=".2f", cmap='coolwarm', ax=axs[0])
axs[0].set_title('Target Correlation Matrix')
sns.heatmap(corr_matrix_sample.numpy(), annot=True, fmt=".2f", cmap='coolwarm', ax=axs[1])
axs[1].set_title('Sample Correlation Matrix')
plt.tight_layout()
fig.savefig(filepath)
plt.close(fig)
def is_discrete(dist):
return hasattr(dist, 'pmf')