-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathGaussianKDE.py
More file actions
78 lines (61 loc) · 2.31 KB
/
GaussianKDE.py
File metadata and controls
78 lines (61 loc) · 2.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from numpy.lib.shape_base import apply_along_axis
import torch
import numpy as np
from torch.distributions import MultivariateNormal, Normal
from torch.distributions.distribution import Distribution
torch.manual_seed(0)
class GaussianKDE(Distribution):
def __init__(self, X, bw):
"""
X : tensor (n, d)
`n` points with `d` dimensions to which KDE will be fit
bw : numeric
bandwidth for Gaussian kernel
"""
self.X = X
self.bw = bw
self.dims = X.shape[-1]
self.n = X.shape[0]
self.normal = Normal(loc = self.X, scale = self.bw)
def sample(self, num_samples):
idxs = (np.random.uniform(0, 1, num_samples) * self.n).astype(int)
norm = Normal(loc=self.X[idxs], scale=self.bw[idxs])
return norm.sample()
def score_samples(self, Y):
"""Returns the kernel density estimates of each point in `Y`.
Parameters
----------
Y : tensor (m, d)
`m` points with `d` dimensions for which the probability density will
be calculated
X : tensor (n, d), optional
`n` points with `d` dimensions to which KDE will be fit. Provided to
allow batch calculations in `log_prob`. By default, `X` is None and
all points used to initialize KernelDensityEstimator are included.
Returns
-------
log_probs : tensor (m)
probability densities for each of the queried points in `Y`
"""
probs = [torch.exp(self.normal.log_prob(y).sum(-1)).sum() / self.n for y in Y]
return torch.tensor(probs)
def log_prob(self, Y):
"""Returns the total log probability of one or more points, `Y`, using
a Multivariate Normal kernel fit to `X` and scaled using `bw`.
Parameters
----------
Y : tensor (m, d)
`m` points with `d` dimensions for which the probability density will
be calculated
Returns
-------
log_prob : numeric
total log probability density for the queried points, `Y`
"""
X_chunks = self.X.split(1000)
Y_chunks = Y.split(1000)
log_prob = 0
for x in X_chunks:
for y in Y_chunks:
log_prob += self.score_samples(y, x).sum(dim=0)
return log_prob