diff --git a/bemb/model/__init__.py b/bemb/model/__init__.py index 60c8b2e..d5f6d8a 100644 --- a/bemb/model/__init__.py +++ b/bemb/model/__init__.py @@ -1,2 +1,3 @@ from .bemb import * +from .bemb_chunked import * from .bemb_flex_lightning import * diff --git a/bemb/model/bayesian_coefficient.py b/bemb/model/bayesian_coefficient.py index 790c4fe..0a6460f 100644 --- a/bemb/model/bayesian_coefficient.py +++ b/bemb/model/bayesian_coefficient.py @@ -6,9 +6,12 @@ """ from typing import Optional, Tuple, Union +import numpy as np import torch import torch.nn as nn from torch.distributions.lowrank_multivariate_normal import LowRankMultivariateNormal +from torch.distributions.gamma import Gamma +from torch.distributions.log_normal import LogNormal class BayesianCoefficient(nn.Module): @@ -21,7 +24,8 @@ def __init__(self, num_obs: Optional[int] = None, dim: int = 1, prior_mean: float = 0.0, - prior_variance: Union[float, torch.Tensor] = 1.0 + prior_variance: Union[float, torch.Tensor] = 1.0, + distribution: str = 'gaussian' ) -> None: """The Bayesian coefficient object represents a learnable tensor mu_i in R^k, where i is from a family (e.g., user, item) so there are num_classes * num_obs learnable weights in total. @@ -63,12 +67,56 @@ def __init__(self, If a tensor with shape (num_classes, dim) is supplied, supplying a (num_classes, dim) tensor is amount to specifying a different prior variance for each entry in the coefficient. Defaults to 1.0. + distribution (str, optional): the distribution of the coefficient. Currently we support 'gaussian', 'gamma' and 'lognormal'. + Defaults to 'gaussian'. """ super(BayesianCoefficient, self).__init__() # do we use this at all? TODO: drop self.variation. assert variation in ['item', 'user', 'constant', 'category'] self.variation = variation + + assert distribution in ['gaussian', 'gamma', 'lognormal'], f'Unsupported distribution {distribution}' + if distribution == 'gamma': + ''' + assert not obs2prior, 'Gamma distribution is not supported for obs2prior at present.' + mean = 1.0 + variance = 10.0 + assert mean > 0, 'Gamma distribution requires mean > 0' + assert variance > 0, 'Gamma distribution requires variance > 0' + # shape (concentration) is mean^2/variance, rate is variance/mean for Gamma distribution. + ''' + self.mean_clamp = (np.log(0.1), np.log(100.0)) + self.logstd_clamp = (np.log(0.1), np.log(10000.0)) + shape = prior_mean ** 2 / prior_variance + rate = prior_mean / prior_variance + if not obs2prior: + assert shape > np.exp(self.mean_clamp[0])**2 and shape < np.exp(self.mean_clamp[1])**2, f'Gamma shape {shape} is out of range, should be in ({np.exp(self.mean_clamp[0])**2}, {np.exp(self.mean_clamp[1])**2})' + assert rate > np.exp(self.logstd_clamp[0]) and rate < np.exp(self.logstd_clamp[1]), f'Gamma rate {rate} is out of range, should be in ({np.exp(self.logstd_clamp[0])}, {np.exp(self.logstd_clamp[1])})' + # prior_mean stores ln(shape) for gamma + prior_mean = np.log(shape) + # prior_variance stores rate for gamma + prior_variance = rate + # prior_mean = np.log(prior_mean) + # prior_variance = prior_variance + + elif distribution == 'lognormal': + # mean is exp(mu + sigma^2/2), variance is (exp(sigma^2) - 1) * exp(2*mu + sigma^2) + # prior_mean in -2, exp(3) + self.mean_clamp = (-10.0, np.exp(1.5)) + # sigma sq clamp exp(-20), exp(1.5) + # therefore sigma in (exp(-10), exp(0.75)) + # therefore log sigma in (-10, 0.75) + self.logstd_clamp = (-10.0, 0.75) + if not obs2prior: + assert prior_mean > self.mean_clamp[0] and prior_mean < self.mean_clamp[1], f'Lognormal distribution requires prior_mean in {self.mean_clamp}, given {prior_mean}' + assert np.sqrt(prior_variance) > np.exp(self.logstd_clamp[0]) and np.sqrt(prior_variance) < np.exp(self.logstd_clamp[1]), f'Lognormal distribution requires prior_variance in {self.logstd_clamp}, given {prior_variance}' + # assert prior_mean > np.exp(-100.0) and prior_mean < np.exp(10.0), f'Lognormal distribution requires shape in (exp(-100), exp(10)), given {prior_mean}' + # assert prior_variance > np.exp(-100.0) and prior_variance < np.exp(2.0), f'Lognormal distribution requires rate in (exp(-100), exp(2)), given {prior_variance}' + + + self.distribution = distribution + self.obs2prior = obs2prior if variation == 'constant' or variation == 'category': if obs2prior: @@ -89,13 +137,15 @@ def __init__(self, if self.obs2prior: # the mean of prior distribution depends on observables. # initiate a Bayesian Coefficient with shape (dim, num_obs) standard Gaussian. + prior_H_dist = 'gaussian' self.prior_H = BayesianCoefficient(variation='constant', num_classes=dim, obs2prior=False, dim=num_obs, prior_variance=1.0, H_zero_mask=self.H_zero_mask, - is_H=True) # this is a distribution responsible for the obs2prior H term. + is_H=True, + distribution=prior_H_dist) # this is a distribution responsible for the obs2prior H term. else: self.register_buffer( @@ -114,15 +164,49 @@ def __init__(self, num_classes, dim) * self.prior_variance) # create variational distribution. - self.variational_mean_flexible = nn.Parameter( - torch.randn(num_classes, dim), requires_grad=True) + if self.distribution == 'gaussian': + if self.is_H: + self.variational_mean_flexible = nn.Parameter( + torch.randn(num_classes, dim), requires_grad=True) + # multiply by 0.0001 to avoid numerical issues. + self.variational_mean_flexible.data *= 0.0001 + + else: + self.variational_mean_flexible = nn.Parameter( + torch.randn(num_classes, dim), requires_grad=True) + elif self.distribution == 'lognormal': + self.variational_mean_flexible = nn.Parameter( + torch.randn(num_classes, dim), requires_grad=True) + self.variational_mean_flexible.data = torch.clamp( + self.variational_mean_flexible.data, min=self.mean_clamp[0], max=self.mean_clamp[1]) + # TOOD(kanodiaayush): initialize the gamma distribution variational mean in a more principled way. + elif self.distribution == 'gamma': + # initialize using uniform distribution between 0.5 and 1.5 + # for a gamma distribution, we store the concentration (shape) as log(concentration) = variational_mean_flexible + self.variational_mean_flexible = nn.Parameter( + torch.rand(num_classes, dim) + 0.5, requires_grad=True) + self.variational_mean_flexible.data = torch.clamp( + self.variational_mean_flexible.data, min=self.mean_clamp[0], max=self.mean_clamp[1]) if self.is_H and self.H_zero_mask is not None: assert self.H_zero_mask.shape == self.variational_mean_flexible.shape, \ f"The H_zero_mask should have exactly the shape as the H variable, `H_zero_mask`.shape is {self.H_zero_mask.shape}, `H`.shape is {self.variational_mean_flexible.shape} " - self.variational_logstd = nn.Parameter( - torch.randn(num_classes, dim), requires_grad=True) + # for gamma distribution, we store the rate as log(rate) = variational_logstd + if self.distribution == 'gaussian': + self.variational_logstd = nn.Parameter( + torch.randn(num_classes, dim), requires_grad=True) + elif self.distribution == 'lognormal': + # uniform -1 to 1 + self.variational_logstd = nn.Parameter( + torch.rand(num_classes, dim) * 2 - 1, requires_grad=True) + self.variational_logstd.data = torch.clamp( + self.variational_logstd.data, min=self.logstd_clamp[0], max=self.logstd_clamp[1]) + elif self.distribution == 'gamma': + self.variational_logstd = nn.Parameter( + torch.randn(num_classes, dim), requires_grad=True) + self.variational_logstd.data = torch.clamp( + self.variational_logstd.data, min=self.logstd_clamp[0], max=self.logstd_clamp[1]) self.register_buffer('variational_cov_factor', torch.zeros(num_classes, dim, 1)) @@ -158,10 +242,21 @@ def variational_mean(self) -> torch.Tensor: Returns: torch.Tensor: the current mean of the variational distribution with shape (num_classes, dim). """ - if self.variational_mean_fixed is None: - M = self.variational_mean_flexible - else: - M = self.variational_mean_fixed + self.variational_mean_flexible + assert self.variational_mean_fixed is None, "not supported" + + # if self.variational_mean_fixed is None: + # else: + # M = self.variational_mean_fixed + self.variational_mean_flexible + + M = self.variational_mean_flexible + + if self.distribution == 'gamma': + # M = torch.pow(M, 2) + 0.000001 + M = (torch.minimum((M.exp() + 0.000001), torch.tensor(1e3))) / self.variational_logstd.exp() + + elif self.distribution == 'lognormal': + M = (torch.minimum((M.exp() + 0.000001), torch.tensor(1e3))) + # M = torch.minimum(M + 0.000001, torch.tensor(1e3)) if self.is_H and (self.H_zero_mask is not None): # a H-variable with zero-entry restriction. @@ -196,7 +291,11 @@ def log_prior(self, Returns: torch.Tensor: the log prior of the variable with shape (num_seeds, num_classes). """ - # p(sample) + # DEBUG_MARKER + ''' + print(sample) + print('log_prior') + ''' num_seeds, num_classes, dim = sample.shape # shape (num_seeds, num_classes) if self.obs2prior: @@ -211,9 +310,32 @@ def log_prior(self, else: mu = self.prior_zero_mean - out = LowRankMultivariateNormal(loc=mu, - cov_factor=self.prior_cov_factor, - cov_diag=self.prior_cov_diag).log_prob(sample) + + if self.distribution == 'gaussian':# or self.distribution == 'lognormal': + out = LowRankMultivariateNormal(loc=mu, + cov_factor=self.prior_cov_factor, + cov_diag=self.prior_cov_diag).log_prob(sample) + elif self.distribution == 'lognormal': + mu = torch.clamp(mu, min=-100.0, max=10.0) + out = LowRankMultivariateNormal(loc=mu, + cov_factor=self.prior_cov_factor, + cov_diag=self.prior_cov_diag).log_prob(sample) + # out = LogNormal(loc=mu, scale=np.sqrt(self.prior_variance)).log_prob(sample) + # out = torch.sum(out, dim=-1) + # out = torch.zeros((num_seeds, num_classes), device=sample.device) + + elif self.distribution == 'gamma': + mu = torch.clamp(mu, min=-100.0, max=4.0) + out = LowRankMultivariateNormal(loc=mu, + cov_factor=self.prior_cov_factor, + cov_diag=self.prior_cov_diag).log_prob(sample) + # concentration = torch.exp(mu) + # rate = self.prior_variance + # out = Gamma(concentration=concentration, + # rate=rate).log_prob(sample) + # out = torch.sum(out, dim=-1) + + # out = torch.zeros((num_seeds, num_classes), device=sample.device) assert out.shape == (num_seeds, num_classes) return out @@ -250,6 +372,11 @@ def rsample(self, num_seeds: int = 1) -> Union[torch.Tensor, Tuple[torch.Tensor] """ value_sample = self.variational_distribution.rsample( torch.Size([num_seeds])) + # if self.distribution == 'lognormal': + # print(torch.min(value_sample)) + # print(torch.max(value_sample)) + # breakpoint() + # DEBUG_MARKER if self.obs2prior: # sample obs2prior H as well. H_sample = self.prior_H.rsample(num_seeds=num_seeds) @@ -258,14 +385,71 @@ def rsample(self, num_seeds: int = 1) -> Union[torch.Tensor, Tuple[torch.Tensor] return value_sample @property - def variational_distribution(self) -> LowRankMultivariateNormal: + def variational_distribution(self) -> Union[LowRankMultivariateNormal, Gamma]: """Constructs the current variational distribution of the coefficient from current variational mean and covariance. """ - return LowRankMultivariateNormal(loc=self.variational_mean, - cov_factor=self.variational_cov_factor, - cov_diag=torch.exp(self.variational_logstd)) + if self.distribution == 'gaussian':# or self.distribution == 'lognormal': + # print(torch.max(self.variational_mean), torch.min(self.variational_mean)) + # print(torch.max(self.variational_logstd), torch.min(self.variational_logstd)) + return LowRankMultivariateNormal(loc=self.variational_mean_flexible, + cov_factor=self.variational_cov_factor, + cov_diag=torch.exp(self.variational_logstd)) + elif self.distribution == 'lognormal': + # print(self.variational_mean_flexible) + # print(self.variational_logstd) + # print(torch.max(self.variational_logstd), torch.min(self.variational_logstd)) + # print(torch.max(self.variational_mean_flexible), torch.min(self.variational_mean_flexible)) + # print(self.variational_mean_flexible.shape, self.variational_logstd.shape) + # return LowRankMultivariateNormal(loc=self.variational_mean_flexible, + # cov_factor=self.variational_cov_factor, + # cov_diag=torch.exp(self.variational_logstd)) + # variational_mean_flexible = torch.clamp(self.variational_mean_flexible, min=-10, max=10) + # variational_logstd = torch.clamp(self.variational_logstd, min=-4, max=3) + loc = self.variational_mean_flexible + scale = torch.exp(self.variational_logstd) + return LogNormal(loc=loc, scale=scale) + elif self.distribution == 'gamma': + # for a gamma distribution, we store the concentration as log(concentration) = variational_mean_flexible + # assert self.variational_mean_fixed == None, 'Gamma distribution does not support fixed mean' + concentration = torch.exp(self.variational_mean_flexible) + # assert that all concentration should be between exp -4 and exp 4 + # assert torch.all(concentration > 0.1353 - 0.0001), 'concentration should be greater than exp -2' + # assert torch.all(concentration < 54.5981 + 0.0001), 'concentration should be less than exp 4' + # concentration = self.variational_mean_flexible.exp() + 0.000001 + # concentration = torch.clamp(concentration, min=1e-2, max=1e2) + # concentration = torch.minimum(concentration, torch.tensor(1e3)) + # for gamma distribution, we store the rate as log(rate) = variational_logstd + rate = torch.exp(self.variational_logstd) + # print(concentration, rate) + # rate = torch.clamp(rate, min=1e-2, max=1e2) + return Gamma(concentration=concentration, rate=rate) + else: + raise NotImplementedError("Unknown variational distribution type.") @property def device(self) -> torch.device: """Returns the device of tensors contained in this module.""" return self.variational_mean.device + + def clamp_params(self) -> None: + """Clamps the parameters of the variational distribution to be within a reasonable range. + """ + if self.distribution == 'gaussian': + # do nothing + pass + # self.variational_mean_flexible.data = torch.clamp( + # self.variational_mean_flexible.data, min=-10, max=10) + # self.variational_logstd.data = torch.clamp( + # self.variational_logstd.data, min=-4, max=3) + elif self.distribution in ['lognormal', 'gamma']: + self.variational_mean_flexible.data = torch.clamp( + self.variational_mean_flexible.data, min=self.mean_clamp[0], max=self.mean_clamp[1]) + self.variational_logstd.data = torch.clamp( + self.variational_logstd.data, min=self.logstd_clamp[0], max=self.logstd_clamp[1]) + # elif self.distribution == 'gamma': + # self.variational_mean_flexible.data = torch.clamp( + # self.variational_mean_flexible.data, min=-2.0, max=4) + # self.variational_logstd.data = torch.clamp( + # self.variational_logstd.data, min=-100.0, max=2.0) + else: + raise NotImplementedError("Unknown variational distribution type.") diff --git a/bemb/model/bemb.py b/bemb/model/bemb.py index be704db..f6df9d9 100644 --- a/bemb/model/bemb.py +++ b/bemb/model/bemb.py @@ -39,21 +39,26 @@ def parse_utility(utility_string: str) -> List[Dict[str, Union[List[str], None]] A helper function parse utility string into a list of additive terms. Example: - utility_string = 'lambda_item + theta_user * alpha_item + gamma_user * beta_item * price_obs' + utility_string = 'lambda_item + theta_user * alpha_item - gamma_user * beta_item * price_obs' output = [ { 'coefficient': ['lambda_item'], - 'observable': None + 'observable': None, + 'sign': 1.0, + }, { 'coefficient': ['theta_user', 'alpha_item'], 'observable': None + 'sign': 1.0, }, { 'coefficient': ['gamma_user', 'beta_item'], 'observable': 'price_obs' + 'sign': -1.0, } ] + Note that 'minus' is allowed in the utility string. If the first term is negative, the minus should be without a space. """ # split additive terms coefficient_suffix = ('_item', '_user', '_constant', '_category') @@ -76,10 +81,16 @@ def is_coefficient(name: str) -> bool: def is_observable(name: str) -> bool: return any(name.startswith(prefix) for prefix in observable_prefix) + utility_string = utility_string.replace(' - ', ' + -') additive_terms = utility_string.split(' + ') additive_decomposition = list() for term in additive_terms: - atom = {'coefficient': [], 'observable': None} + if term.startswith('-'): + sign = -1.0 + term = term[1:] + else: + sign = 1.0 + atom = {'coefficient': [], 'observable': None, 'sign': sign} # split multiplicative terms. for x in term.split(' * '): assert not (is_observable(x) and is_coefficient(x)), f"The element {x} is ambiguous, it follows naming convention of both an observable and a coefficient." @@ -113,6 +124,7 @@ def __init__(self, num_items: int, pred_item: bool, num_classes: int = 2, + coef_dist_dict: Dict[str, str] = {'default' : 'gaussian'}, H_zero_mask_dict: Optional[Dict[str, torch.BoolTensor]] = None, prior_mean: Union[float, Dict[str, float]] = 0.0, prior_variance: Union[float, Dict[str, float]] = 1.0, @@ -140,6 +152,14 @@ def __init__(self, lambda_item + theta_user * alpha_item + gamma_user * beta_item * price_obs See the doc-string of parse_utility for an example. + coef_dist_dict (Dict[str, str]): a dictionary mapping coefficient name to coefficient distribution name. + The coefficient distribution name can be one of the following: + 1. 'gaussian' + 2. 'gamma' - obs2prior is not supported for gamma coefficients + If a coefficient does not appear in the dictionary, it will be assigned the distribution specified + by the 'default' key. By default, the default distribution is 'gaussian'. + For coefficients which have gamma distributions, prior mean and variance MUST be specified in the prior_mean and prior_variance arguments if obs2prior is False for this coefficient. If obs2prior is True, prior_variance is still required + obs2prior_dict (Dict[str, bool]): a dictionary maps coefficient name (e.g., 'lambda_item') to a boolean indicating if observable (e.g., item_obs) enters the prior of the coefficient. @@ -184,6 +204,8 @@ def __init__(self, If no `prior_mean['default']` is provided, the default prior mean will be 0.0 for those coefficients not in the prior_mean.keys(). + For coefficients with gamma distributions, prior_mean specifies the shape parameter of the gamma prior. + Defaults to 0.0. prior_variance (Union[float, Dict[str, float]], Dict[str, torch. Tensor]): the variance of prior distribution @@ -203,6 +225,8 @@ def __init__(self, If no `prior_variance['default']` is provided, the default prior variance will be 1.0 for those coefficients not in the prior_variance.keys(). + For coefficients with gamma distributions, prior_variance specifies the concentration parameter of the gamma prior. + Defaults to 1.0, which means all priors have identity matrix as the covariance matrix. num_users (int, optional): number of users, required only if coefficient or observable @@ -233,6 +257,7 @@ def __init__(self, self.utility_formula = utility_formula self.obs2prior_dict = obs2prior_dict self.coef_dim_dict = coef_dim_dict + self.coef_dist_dict = coef_dist_dict if H_zero_mask_dict is not None: self.H_zero_mask_dict = H_zero_mask_dict else: @@ -325,6 +350,21 @@ def __init__(self, for additive_term in self.formula: for coef_name in additive_term['coefficient']: variation = coef_name.split('_')[-1] + + if coef_name not in self.coef_dist_dict.keys(): + if 'default' in self.coef_dist_dict.keys(): + self.coef_dist_dict[coef_name] = self.coef_dist_dict['default'] + else: + warnings.warn(f"You provided a dictionary of coef_dist_dict, but coefficient {coef_name} is not a key in it. Supply a value for 'default' in the coef_dist_dict dictionary to use that as default value (e.g., coef_dist_dict['default'] = 'gaussian'); now using distribution='gaussian' since this is not supplied.") + self.coef_dist_dict[coef_name] = 'gaussian' + + elif self.coef_dist_dict[coef_name] == 'gamma': + if not self.obs2prior_dict[coef_name]: + assert isinstance(self.prior_mean, dict) and coef_name in self.prior_mean.keys(), \ + f"Prior mean for {coef_name} needs to be provided because it's posterior is estimated as a gamma distribution." + assert isinstance(self.prior_variance, dict) and coef_name in self.prior_variance.keys(), \ + f"Prior variance for {coef_name} needs to be provided because it's posterior is estimated as a gamma distribution." + if isinstance(self.prior_mean, dict): # the user didn't specify prior mean for this coefficient. if coef_name not in self.prior_mean.keys(): @@ -345,7 +385,7 @@ def __init__(self, if 'default' in self.prior_variance.keys(): self.prior_variance[coef_name] = self.prior_variance['default'] else: - warnings.warn(f"You provided a dictionary of prior variance, but coefficient {coef_name} is not a key in it. Supply a value for 'default' in the prior_variance dictionary to use that as default value (e.g., prior_variance['default'] = 0.3); now using variance=1.0 since this is not supplied.") + # warnings.warn(f"You provided a dictionary of prior variance, but coefficient {coef_name} is not a key in it. Supply a value for 'default' in the prior_variance dictionary to use that as default value (e.g., prior_variance['default'] = 0.3); now using variance=1.0 since this is not supplied.") self.prior_variance[coef_name] = 1.0 s2 = self.prior_variance[coef_name] if isinstance( @@ -359,6 +399,7 @@ def __init__(self, if (not self.obs2prior_dict[coef_name]) and (H_zero_mask is not None): raise ValueError(f'You specified H_zero_mask for {coef_name}, but obs2prior is False for this coefficient.') + print(coef_name) coef_dict[coef_name] = BayesianCoefficient(variation=variation, num_classes=variation_to_num_classes[variation], obs2prior=self.obs2prior_dict[coef_name], @@ -367,7 +408,8 @@ def __init__(self, prior_mean=mean, prior_variance=s2, H_zero_mask=H_zero_mask, - is_H=False) + is_H=False, + distribution=self.coef_dist_dict[coef_name]) self.coef_dict = nn.ModuleDict(coef_dict) # ============================================================================================================== @@ -380,6 +422,10 @@ def __init__(self, 'Additional modules are temporarily disabled for further development.') self.additional_modules = nn.ModuleList(additional_modules) + def clamp_coefs(self): + for coef_name in self.coef_dict.keys(): + self.coef_dict[coef_name].clamp_params() + def __str__(self): return f'Bayesian EMBedding Model with U[user, item, session] = {self.raw_formula}\n' \ + f'Total number of parameters: {self.num_params}.\n' \ @@ -654,7 +700,12 @@ def sample_coefficient_dictionary(self, num_seeds: int, deterministic: bool = Fa sample_dict = dict() for coef_name, coef in self.coef_dict.items(): if deterministic: - sample_dict[coef_name] = coef.variational_distribution.mean.unsqueeze(dim=0) # (1, num_*, dim) + s = coef.variational_distribution.mean.unsqueeze(dim=0) # (1, num_*, dim) + # print(torch.min(s), torch.max(s)) + # breakpoint() + # if coef.distribution == 'lognormal': + # s = torch.exp(s) + sample_dict[coef_name] = s if coef.obs2prior: sample_dict[coef_name + '.H'] = coef.prior_H.variational_distribution.mean.unsqueeze(dim=0) # (1, num_*, dim) else: @@ -662,11 +713,18 @@ def sample_coefficient_dictionary(self, num_seeds: int, deterministic: bool = Fa if coef.obs2prior: # sample both obs2prior weight and realization of variable. assert isinstance(s, tuple) and len(s) == 2 - sample_dict[coef_name] = s[0] + # if coef.distribution == 'lognormal': + if False: + ss = torch.exp(s[0]) + else: + ss = s[0] + sample_dict[coef_name] = ss sample_dict[coef_name + '.H'] = s[1] else: # only sample the realization of variable. assert torch.is_tensor(s) + # if coef.distribution == 'lognormal': + # s = torch.exp(s) sample_dict[coef_name] = s return sample_dict @@ -907,6 +965,7 @@ def reshape_observable(obs, name): sample_dict[coef_name], coef_name) assert coef_sample.shape == (R, P, I, 1) additive_term = coef_sample.view(R, P, I) + additive_term *= term['sign'] # Type II: factorized coefficient, e.g., . elif len(term['coefficient']) == 2 and term['observable'] is None: @@ -922,6 +981,7 @@ def reshape_observable(obs, name): R, P, I, positive_integer) additive_term = (coef_sample_0 * coef_sample_1).sum(dim=-1) + additive_term *= term['sign'] # Type III: single coefficient multiplied by observable, e.g., theta_user * x_obs_item. elif len(term['coefficient']) == 1 and term['observable'] is not None: @@ -935,6 +995,7 @@ def reshape_observable(obs, name): assert obs.shape == (R, P, I, positive_integer) additive_term = (coef_sample * obs).sum(dim=-1) + additive_term *= term['sign'] # Type IV: factorized coefficient multiplied by observable. # e.g., gamma_user * beta_item * price_obs. @@ -961,10 +1022,13 @@ def reshape_observable(obs, name): R, P, I, num_obs, latent_dim) coef_sample_1 = coef_sample_1.view( R, P, I, num_obs, latent_dim) + # coef_sample_0 = torch.exp(coef_sample_0) + # coef_sample_1 = torch.exp(coef_sample_1) # compute the factorized coefficient with shape (R, P, I, O). coef = (coef_sample_0 * coef_sample_1).sum(dim=-1) additive_term = (coef * obs).sum(dim=-1) + additive_term *= term['sign'] else: raise ValueError(f'Undefined term type: {term}') @@ -1138,6 +1202,7 @@ def reshape_observable(obs, name): sample_dict[coef_name], coef_name) assert coef_sample.shape == (R, total_computation, 1) additive_term = coef_sample.view(R, total_computation) + additive_term *= term['sign'] # Type II: factorized coefficient, e.g., . elif len(term['coefficient']) == 2 and term['observable'] is None: @@ -1153,6 +1218,7 @@ def reshape_observable(obs, name): R, total_computation, positive_integer) additive_term = (coef_sample_0 * coef_sample_1).sum(dim=-1) + additive_term *= term['sign'] # Type III: single coefficient multiplied by observable, e.g., theta_user * x_obs_item. elif len(term['coefficient']) == 1 and term['observable'] is not None: @@ -1167,6 +1233,7 @@ def reshape_observable(obs, name): assert obs.shape == (R, total_computation, positive_integer) additive_term = (coef_sample * obs).sum(dim=-1) + additive_term *= term['sign'] # Type IV: factorized coefficient multiplied by observable. # e.g., gamma_user * beta_item * price_obs. @@ -1196,6 +1263,7 @@ def reshape_observable(obs, name): coef = (coef_sample_0 * coef_sample_1).sum(dim=-1) additive_term = (coef * obs).sum(dim=-1) + additive_term *= term['sign'] else: raise ValueError(f'Undefined term type: {term}') diff --git a/bemb/model/bemb_chunked.py b/bemb/model/bemb_chunked.py new file mode 100644 index 0000000..b1572c8 --- /dev/null +++ b/bemb/model/bemb_chunked.py @@ -0,0 +1,1404 @@ +""" +A chunked version of BEMB. + +We divide users, items(categories) and sessions into u, i and s chunks. +Then for each user, there are i*s parameters, for each item there are u*s parameters and for each session there are u*i parameters. + +Author: Ayush Kanodia +Update: Dec 04, 2022 +""" +import warnings +from pprint import pprint +import warnings +from typing import Dict, List, Optional, Tuple, Union +from pprint import pprint +from typing import Dict, List, Optional, Union, Tuple + +import numpy as np +import torch +import torch.nn as nn +from torch_choice.data import ChoiceDataset +from torch_scatter import scatter_logsumexp, scatter_max +from torch_scatter.composite import scatter_log_softmax + +from bemb.model.bayesian_coefficient import BayesianCoefficient + +# ====================================================================================================================== +# helper functions. +# ====================================================================================================================== + +from bemb.model.bemb import PositiveInteger, parse_utility + +positive_integer = PositiveInteger() + +# ====================================================================================================================== +# core class of the BEMB model. +# ====================================================================================================================== + + +class BEMBFlexChunked(nn.Module): + # ================================================================================================================== + # core function as a PyTorch module. + # ================================================================================================================== + def __init__(self, + utility_formula: str, + obs2prior_dict: Dict[str, bool], + coef_dim_dict: Dict[str, int], + num_items: int, + pred_item: bool, + num_classes: int = 2, + coef_dist_dict: Dict[str, str] = {'default' : 'gaussian'}, + H_zero_mask_dict: Optional[Dict[str, torch.BoolTensor]] = None, + prior_mean: Union[float, Dict[str, float]] = 0.0, + prior_variance: Union[float, Dict[str, float]] = 1.0, + num_users: Optional[int] = None, + num_sessions: Optional[int] = None, + trace_log_q: bool = False, + category_to_item: Dict[int, List[int]] = None, + # number of observables. + num_user_obs: Optional[int] = None, + num_item_obs: Optional[int] = None, + num_session_obs: Optional[int] = None, + num_price_obs: Optional[int] = None, + num_taste_obs: Optional[int] = None, + # additional modules. + additional_modules: Optional[List[nn.Module]] = None, + deterministic_variational: bool = False, + chunk_info: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] = None, + ) -> None: + """ + Args: + utility_formula (str): a string representing the utility function U[user, item, session]. + See documentation for more details in the documentation for the format of formula. + Examples: + lambda_item + lambda_item + theta_user * alpha_item + zeta_user * item_obs + lambda_item + theta_user * alpha_item + gamma_user * beta_item * price_obs + See the doc-string of parse_utility for an example. + + coef_dist_dict (Dict[str, str]): a dictionary mapping coefficient name to coefficient distribution name. + The coefficient distribution name can be one of the following: + 1. 'gaussian' + 2. 'gamma' - obs2prior is not supported for gamma coefficients + If a coefficient does not appear in the dictionary, it will be assigned the distribution specified + by the 'default' key. By default, the default distribution is 'gaussian'. + For coefficients which have gamma distributions, prior mean and variance MUST be specified in the prior_mean and prior_variance arguments if obs2prior is False for this coefficient. If obs2prior is True, prior_variance is still required + + obs2prior_dict (Dict[str, bool]): a dictionary maps coefficient name (e.g., 'lambda_item') + to a boolean indicating if observable (e.g., item_obs) enters the prior of the coefficient. + + coef_dim_dict (Dict[str, int]): a dictionary maps coefficient name (e.g., 'lambda_item') + to an integer indicating the dimension of coefficient. + For standalone coefficients like U = lambda_item, the dim should be 1. + For factorized coefficients like U = theta_user * alpha_item, the dim should be the + latent dimension of theta and alpha. + For coefficients multiplied with observables like U = zeta_user * item_obs, the dim + should be the number of observables in item_obs. + For factorized coefficient multiplied with observables like U = gamma_user * beta_item * price_obs, + the dim should be the latent dim multiplied by number of observables in price_obs. + + H_zero_mask_dict (Dict[str, torch.BoolTensor]): A dictionary maps coefficient names to a boolean tensor, + you should only specify the H_zero_mask for coefficients with obs2prior turned on. + Recall that with obs2prior on, the prior of coefficient looks like N(H*X_obs, sigma * I), the H_zero_mask + the mask for this coefficient should have the same shape as H, and H[H_zero_mask] will be set to zeros + and non-learnable during the training. + Defaults to None. + + num_items (int): number of items. + + pred_item (bool): there are two use cases of this model, suppose we have `user_index[i]` and `item_index[i]` + for the i-th observation in the dataset. + Case 1: which item among all items user `user_index[i]` is going to purchase, the prediction label + is therefore `item_index[i]`. Equivalently, we can ask what's the likelihood for user `user_index[i]` + to purchase `item_index[i]`. + Case 2: what rating would user `user_index[i]` assign to item `item_index[i]`? In this case, the dataset + object needs to contain a separate label. + NOTE: for now, we only support binary labels. + + prior_mean (Union[float, Dict[str, float]]): the mean of prior + distribution for coefficients. If a float is provided, all prior + mean will be diagonal matrix with the provided value. If a + dictionary is provided, keys of prior_mean should be coefficient + names, and the mean of prior of coef_name would the provided + value Defaults to 0.0, which means all prior means are + initialized to 0.0 + + If a dictionary prior_mean is supplied, for coefficient names not in the prior_mean.keys(), the + user can add a `prior_mean['default']` value to specify the mean for those coefficients. + If no `prior_mean['default']` is provided, the default prior mean will be 0.0 for those coefficients + not in the prior_mean.keys(). + + For coefficients with gamma distributions, prior_mean specifies the shape parameter of the gamma prior. + + Defaults to 0.0. + + prior_variance (Union[float, Dict[str, float]], Dict[str, torch. Tensor]): the variance of prior distribution + for coefficients. + If a float is provided, all priors will be diagonal matrix with prior_variance along the diagonal. + If a float-valued dictionary is provided, keys of prior_variance should be coefficient names, and the + variance of prior of coef_name would be a diagonal matrix with prior_variance[coef_name] along the diagonal. + If a tensor-valued dictionary is provided, keys of prior_variance should be coefficient names, and the + values need to be tensor with shape (num_classes, coef_dim_dict[coef_name]). For example, for `beta_user` in + `U = beta_user * item_obs`, the prior_variance should be a tensor with shape (num_classes, dimension_of_item_obs). + In this case, every single entry in the coefficient has its own prior variance. + Following the `beta_user` example, for every `i` and `j`, `beta_user[i, j]` is a scalar with prior variance + `prior_variance['beta_user'][i, j]`. Moreover, `beta_user[i, j]`'s are independent for different `i, j`. + + If a dictionary prior_variance is supplied, for coefficient names not in the prior_variance.keys(), the + user can add a `prior_variance['default']` value to specify the variance for those coefficients. + If no `prior_variance['default']` is provided, the default prior variance will be 1.0 for those coefficients + not in the prior_variance.keys(). + + For coefficients with gamma distributions, prior_variance specifies the concentration parameter of the gamma prior. + + Defaults to 1.0, which means all priors have identity matrix as the covariance matrix. + + num_users (int, optional): number of users, required only if coefficient or observable + depending on user is in utility. Defaults to None. + num_sessions (int, optional): number of sessions, required only if coefficient or + observable depending on session is in utility. Defaults to None. + + trace_log_q (bool, optional): whether to trace the derivative of variational likelihood logQ + with respect to variational parameters in the ELBO while conducting gradient update. + Defaults to False. + + category_to_item (Dict[str, List[int]], optional): a dictionary with category id or name + as keys, and category_to_item[C] contains the list of item ids belonging to category C. + If None is provided, all items are assumed to be in the same category. + Defaults to None. + + num_{user, item, session, price, taste}_obs (int, optional): number of observables of + each type of features, only required if observable enters prior. + NOTE: currently we only allow coefficient to depend on either user or item, thus only + user and item observables can enter the prior of coefficient. Hence session, price, + and taste observables are never required, we include it here for completeness. + + deterministic_variational (bool, optional): if True, the variational posterior is equivalent to frequentist MLE estimates of parameters + + chunk_info (Tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor, torch.LongTensor], optional): a tuple of four tensors + The first tensor specifies a chunk id for each user + The second tensor specifies a chunk id for each item + The third tensor specifies a chunk id for each category + The fourth tensor specifies a chunk id for each session + """ + super(BEMBFlexChunked, self).__init__() + self.utility_formula = utility_formula + self.obs2prior_dict = obs2prior_dict + self.coef_dim_dict = coef_dim_dict + self.coef_dist_dict = coef_dist_dict + if H_zero_mask_dict is not None: + self.H_zero_mask_dict = H_zero_mask_dict + else: + self.H_zero_mask_dict = dict() + self.prior_variance = prior_variance + self.prior_mean = prior_mean + self.pred_item = pred_item + if not self.pred_item: + assert isinstance(num_classes, int) and num_classes > 0, \ + f"With pred_item being False, the num_classes should be a positive integer, received {num_classes} instead." + self.num_classes = num_classes + if self.num_classes != 2: + raise NotImplementedError('Multi-class classification is not supported yet.') + # we don't set the num_classes attribute when pred_item == False to avoid calling it accidentally. + + self.num_items = num_items + self.num_users = num_users + self.num_sessions = num_sessions + self.deterministic_variational = deterministic_variational + + self.trace_log_q = trace_log_q + self.category_to_item = category_to_item + + # ============================================================================================================== + # Category ID to Item ID mapping. + # Category ID to Category Size mapping. + # Item ID to Category ID mapping. + # ============================================================================================================== + if self.category_to_item is None: + if self.pred_item: + # assign all items to the same category if predicting items. + self.category_to_item = {0: list(np.arange(self.num_items))} + else: + # otherwise, for the j-th observation in the dataset, the label[j] + # only depends on user_index[j] and item_index[j], so we put each + # item to its own category. + self.category_to_item = {i: [i] for i in range(self.num_items)} + + self.num_categories = len(self.category_to_item) + + max_category_size = max(len(x) for x in self.category_to_item.values()) + category_to_item_tensor = torch.full( + (self.num_categories, max_category_size), -1) + category_to_size_tensor = torch.empty(self.num_categories) + + for c, item_in_c in self.category_to_item.items(): + category_to_item_tensor[c, :len( + item_in_c)] = torch.LongTensor(item_in_c) + category_to_size_tensor[c] = torch.scalar_tensor(len(item_in_c)) + + self.register_buffer('category_to_item_tensor', + category_to_item_tensor.long()) + self.register_buffer('category_to_size_tensor', + category_to_size_tensor.long()) + + item_to_category_tensor = torch.zeros(self.num_items) + for c, items_in_c in self.category_to_item.items(): + item_to_category_tensor[items_in_c] = c + self.register_buffer('item_to_category_tensor', + item_to_category_tensor.long()) + + # ============================================================================================================== + # Chunk Information + self.num_user_chunks = chunk_info[0].max().item() + 1 + self.num_item_chunks = chunk_info[1].max().item() + 1 + self.num_category_chunks = chunk_info[2].max().item() + 1 + self.num_session_chunks = chunk_info[3].max().item() + 1 + self.register_buffer('user_chunk_ids', chunk_info[0]) + self.register_buffer('item_chunk_ids', chunk_info[1]) + self.register_buffer('category_chunk_ids', chunk_info[2]) + self.register_buffer('session_chunk_ids', chunk_info[3]) + # ============================================================================================================== + # Create Bayesian Coefficient Objects + # ============================================================================================================== + # model configuration. + self.formula = parse_utility(utility_formula) + print('BEMB: utility formula parsed:') + pprint(self.formula) + self.raw_formula = utility_formula + self.obs2prior_dict = obs2prior_dict + + # dimension of each observable, this one is used only for obs2prior. + self.num_obs_dict = { + 'user': num_user_obs, + 'item': num_item_obs, + 'category' : 0, + 'session': num_session_obs, + 'price': num_price_obs, + 'taste': num_taste_obs, + 'constant': 1 # not really used, for dummy variables. + } + + # how many classes for the variational distribution. + # for example, beta_item would be `num_items` 10-dimensional gaussian if latent dim = 10. + variation_to_num_classes = { + 'user': self.num_users, + 'item': self.num_items, + 'constant': 1, + 'category' : self.num_categories, + } + + variation_to_num_chunks = { + 'user':(self.num_category_chunks, self.num_session_chunks), + 'item':(self.num_session_chunks, self.num_user_chunks), + 'category':(self.num_session_chunks, self.num_user_chunks), + 'session':(self.num_user_chunks, self.num_category_chunks), + 'constant': (1, 1), + } + + coef_dict = dict() + for additive_term in self.formula: + for coef_name in additive_term['coefficient']: + variation = coef_name.split('_')[-1] + + if coef_name not in self.coef_dist_dict.keys(): + if 'default' in self.coef_dist_dict.keys(): + self.coef_dist_dict[coef_name] = self.coef_dist_dict['default'] + else: + warnings.warn(f"You provided a dictionary of coef_dist_dict, but coefficient {coef_name} is not a key in it. Supply a value for 'default' in the coef_dist_dict dictionary to use that as default value (e.g., coef_dist_dict['default'] = 'gaussian'); now using distribution='gaussian' since this is not supplied.") + self.coef_dist_dict[coef_name] = 'gaussian' + + ''' + elif self.coef_dist_dict[coef_name] == 'gamma': + if not self.obs2prior_dict[coef_name]: + assert isinstance(self.prior_mean, dict) and coef_name in self.prior_mean.keys(), \ + f"Prior mean for {coef_name} needs to be provided because it's posterior is estimated as a gamma distribution." + assert isinstance(self.prior_variance, dict) and coef_name in self.prior_variance.keys(), \ + f"Prior variance for {coef_name} needs to be provided because it's posterior is estimated as a gamma distribution." + ''' + + if isinstance(self.prior_mean, dict): + # the user didn't specify prior mean for this coefficient. + if coef_name not in self.prior_mean.keys(): + # the user may specify 'default' prior variance through the prior_variance dictionary. + if 'default' in self.prior_mean.keys(): + # warnings.warn(f"You provided a dictionary of prior mean, but coefficient {coef_name} is not a key in it. We found a key 'default' in the dictionary, so we use the value of 'default' as the prior mean for coefficient {coef_name}.") + self.prior_mean[coef_name] = self.prior_mean['default'] + else: + # warnings.warn(f"You provided a dictionary of prior mean, but coefficient {coef_name} is not a key in it. Supply a value for 'default' in the prior_mean dictionary to use that as default value (e.g., prior_mean['default'] = 0.1); now using mean=0.0 since this is not supplied.") + self.prior_mean[coef_name] = 0.0 + + mean = self.prior_mean[coef_name] if isinstance( + self.prior_mean, dict) else self.prior_mean + + if isinstance(self.prior_variance, dict): + # the user didn't specify prior variance for this coefficient. + if coef_name not in self.prior_variance.keys(): + # the user may specify 'default' prior variance through the prior_variance dictionary. + if 'default' in self.prior_variance.keys(): + # warnings.warn(f"You provided a dictionary of prior variance, but coefficient {coef_name} is not a key in it. We found a key 'default' in the dictionary, so we use the value of 'default' as the prior variance for coefficient {coef_name}.") + self.prior_variance[coef_name] = self.prior_variance['default'] + else: + # warnings.warn(f"You provided a dictionary of prior variance, but coefficient {coef_name} is not a key in it. Supply a value for 'default' in the prior_variance dictionary to use that as default value (e.g., prior_variance['default'] = 0.3); now using variance=1.0 since this is not supplied.") + self.prior_variance[coef_name] = 1.0 + + s2 = self.prior_variance[coef_name] if isinstance( + self.prior_variance, dict) else self.prior_variance + + if coef_name in self.H_zero_mask_dict.keys(): + H_zero_mask = self.H_zero_mask_dict[coef_name] + else: + H_zero_mask = None + + if (not self.obs2prior_dict[coef_name]) and (H_zero_mask is not None): + raise ValueError(f'You specified H_zero_mask for {coef_name}, but obs2prior is False for this coefficient.') + + chunk_sizes = variation_to_num_chunks[variation] + bayesian_coefs = [] * chunk_sizes[0] + for ii in range(chunk_sizes[0]): + bayesian_coefs_inner = [] + for jj in range(chunk_sizes[1]): + if self.coef_dist_dict[coef_name] == 'gamma' and not self.obs2prior_dict[coef_name]: + assert mean > 0, 'shape of gamma distribution specified as prior_mean needs to be > 0' + bayesian_coefs_inner.append(BayesianCoefficient(variation=variation, + num_classes=variation_to_num_classes[variation], + obs2prior=self.obs2prior_dict[coef_name], + num_obs=self.num_obs_dict[variation], + dim=self.coef_dim_dict[coef_name], + prior_mean=mean, + prior_variance=s2, + H_zero_mask=H_zero_mask, + is_H=False, + distribution=self.coef_dist_dict[coef_name]), + ) + bayesian_coefs_inner = nn.ModuleList(bayesian_coefs_inner) + bayesian_coefs.append(bayesian_coefs_inner) + coef_dict[coef_name] = nn.ModuleList(bayesian_coefs) + + self.coef_dict = nn.ModuleDict(coef_dict) + + # ============================================================================================================== + # Optional: register additional modules. + # ============================================================================================================== + if additional_modules is None: + self.additional_modules = [] + else: + raise NotImplementedError( + 'Additional modules are temporarily disabled for further development.') + self.additional_modules = nn.ModuleList(additional_modules) + + + def clamp_coefs(self): + for coef_name in self.coef_dict.keys(): + for ii in range(len(self.coef_dict[coef_name])): + for jj in range(len(self.coef_dict[coef_name][ii])): + self.coef_dict[coef_name][ii][jj].clamp_params() + + def __str__(self): + return f'Bayesian EMBedding Model with U[user, item, session] = {self.raw_formula}\n' \ + + f'Total number of parameters: {self.num_params}.\n' \ + + 'With the following coefficients:\n' \ + + str(self.coef_dict) + '\n' \ + + str(self.additional_modules) + + def posterior_mean(self, coef_name: str) -> torch.Tensor: + """Returns the mean of estimated posterior distribution of coefficient `coef_name`. + + Args: + coef_name (str): name of the coefficient to query. + + Returns: + torch.Tensor: mean of the estimated posterior distribution of `coef_name`. + """ + if coef_name in self.coef_dict.keys(): + return self.coef_dict[coef_name].variational_mean + else: + raise KeyError(f'{coef_name} is not a valid coefficient name in {self.utility_formula}.') + + def posterior_distribution(self, coef_name: str) -> torch.distributions.lowrank_multivariate_normal.LowRankMultivariateNormal: + """Returns the posterior distribution of coefficient `coef_name`. + + Args: + coef_name (str): name of the coefficient to query. + + Returns: + torch.Tensor: variance of the estimated posterior distribution of `coef_name`. + """ + if coef_name in self.coef_dict.keys(): + return self.coef_dict[coef_name].variational_distribution + else: + raise KeyError(f'{coef_name} is not a valid coefficient name in {self.utility_formula}.') + + def ivs(self, batch) -> torch.Tensor: + """The combined method of computing utilities and log probability. + + Args: + batch (dict): a batch of data. + + Returns: + torch.Tensor: the combined utility and log probability. + """ + # Use the means of variational distributions as the sole MC sample. + sample_dict = self.sample_coefficient_dictionary(1, deterministic=True) + # there is 1 random seed in this case. + # (num_seeds=1, len(batch), num_items) + out = self.log_likelihood_all_items(batch, return_logit=True, sample_dict=sample_dict) + out = out.squeeze(0) + # import pdb; pdb.set_trace() + out = out.view(-1, self.num_items) + ivs = scatter_logsumexp(out, self.item_to_category_tensor, dim=-1) + return ivs # (len(batch), num_categories) + + def sample_choices(self, batch:ChoiceDataset, debug: bool = False, num_seeds: int = 1, **kwargs) -> Tuple[torch.Tensor]: + """Samples choices given model paramaters and trips + + Args: + batch(ChoiceDataset): batch data containing trip information; item choice information is discarded + debug(bool): whether to print debug information + + Returns: + Tuple[torch.Tensor]: sampled choices; shape: (batch_size, num_categories) + """ + # Use the means of variational distributions as the sole MC sample. + sample_dict = dict() + for coef_name, coef in self.coef_dict.items(): + sample_dict[coef_name] = coef.variational_distribution.mean.unsqueeze(dim=0) # (1, num_*, dim) + # sample_dict = self.sample_coefficient_dictionary(num_seeds) + maxes, out = self.sample_log_likelihoods(batch, sample_dict) + return maxes.squeeze(), out.squeeze() + + def sample_log_likelihoods(self, batch:ChoiceDataset, sample_dict: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + """Samples log likelihoods given model parameters and trips + + Args: + batch(ChoiceDataset): batch data containing trip information; item choice information is discarded + sample_dict(Dict[str, torch.Tensor]): sampled coefficient values + + Returns: + Tuple[torch.Tensor]: sampled log likelihoods; shape: (batch_size, num_categories) + """ + # TODO(akanodia): disallow this for now + raise NotImplementedError() + # get the log likelihoods for all items for all categories + utility = self.log_likelihood_all_items(batch, return_logit=True, sample_dict=sample_dict) + mu_gumbel = 0.0 + beta_gumbel = 1.0 + EUL_MAS_CONST = 0.5772156649 + mean_gumbel = torch.tensor([mu_gumbel + beta_gumbel * EUL_MAS_CONST], device=self.device) + m = torch.distributions.gumbel.Gumbel(torch.tensor([0.0], device=self.device), torch.tensor([1.0], device=self.device)) + # m = torch.distributions.gumbel.Gumbel(0.0, 1.0) + gumbel_samples = m.sample(utility.shape).squeeze(-1) + gumbel_samples -= mean_gumbel + utility += gumbel_samples + max_by_category, argmax_by_category = scatter_max(utility, self.item_to_category_tensor, dim=-1) + return max_by_category, argmax_by_category + log_likelihoods = self.sample_log_likelihoods_per_category(batch, sample_dict) + + # sum over all categories. + log_likelihoods = log_likelihoods.sum(dim=1) + + return log_likelihoods, log_likelihoods + + @torch.no_grad() + def predict_proba(self, batch: ChoiceDataset) -> torch.Tensor: + """ + Draw prediction on a given batch of dataset. + + Args: + batch (ChoiceDataset): the dataset to draw inference on. + + Returns: + torch.Tensor: the predicted probabilities for each class, the behavior varies by self.pred_item. + (1: pred_item == True) While predicting items, the return tensor has shape (len(batch), num_items), out[i, j] is the predicted probability for choosing item j AMONG ALL ITEMS IN ITS CATEGORY in observation i. Please note that since probabilities are computed from within-category normalization, hence out.sum(dim=0) can be greater than 1 if there are multiple categories. + (2: pred_item == False) While predicting external labels for each observations, out[i, 0] is the predicted probability for label == 0 on the i-th observation, out[i, 1] is the predicted probability for label == 1 on the i-th observation. Generally, out[i, 0] + out[i, 1] = 1.0. However, this could be false if under-flowing/over-flowing issue is encountered. + + We highly recommend users to get log-probs as those are less prone to overflow/underflow; those can be accessed using the forward() function. + """ + if self.pred_item: + # (len(batch), num_items) + log_p = self.forward(batch, return_type='log_prob', return_scope='all_items', deterministic=True) + p = log_p.exp() + else: + # (len(batch), num_items) + # probability of getting label = 1. + p_1 = torch.nn.functional.sigmoid(self.forward(batch, return_type='utility', return_scope='all_items', deterministic=True)) + # (len(batch), 1) + p_1 = p_1[torch.arange(len(batch)), batch.item_index].view(len(batch), 1) + p_0 = 1 - p_1 + # (len(batch), 2) + p = torch.cat([p_0, p_1], dim=1) + + if self.pred_item: + assert p.shape == (len(batch), self.num_items) + else: + assert p.shape == (len(batch), self.num_classes) + + return p + + def forward(self, batch: ChoiceDataset, + return_type: str, + return_scope: str, + deterministic: bool = True, + sample_dict: Optional[Dict[str, torch.Tensor]] = None, + num_seeds: Optional[int] = None, + debug=False, + ) -> torch.Tensor: + """A combined method for inference with the model. + + Args: + batch (ChoiceDataset): batch data containing choice information. + return_type (str): either 'log_prob' or 'utility'. + 'log_prob': return the log-probability (by within-category log-softmax) for items + 'utility': return the utility value of items. + return_scope (str): either 'item_index' or 'all_items'. + 'item_index': for each observation i, return log-prob/utility for the chosen item batch.item_index[i] only. + 'all_items': for each observation i, return log-prob/utility for all items. + deterministic (bool, optional): + True: expectations of parameter variational distributions are used for inference. + False: the user needs to supply a dictionary of sampled parameters for inference. + Defaults to True. + sample_dict (Optional[Dict[str, torch.Tensor]], optional): sampled parameters for inference task. + This is not needed when `deterministic` is True. + When `deterministic` is False, the user can supply a `sample_dict`. If `sample_dict` is not provided, + this method will create `num_seeds` samples. + Defaults to None. + num_seeds (Optional[int]): the number of random samples of parameters to construct. This is only required + if `deterministic` is False (i.e., stochastic mode) and `sample_dict` is not provided. + Defaults to None. + Returns: + torch.Tensor: a tensor of log-probabilities or utilities, depending on `return_type`. + The shape of the returned tensor depends on `return_scope` and `deterministic`. + ------------------------------------------------------------------------- + | `return_scope` | `deterministic` | Output shape | + ------------------------------------------------------------------------- + | 'item_index` | True | (len(batch),) | + ------------------------------------------------------------------------- + | 'all_items' | True | (len(batch), num_items) | + ------------------------------------------------------------------------- + | 'item_index' | False | (num_seeds, len(batch)) | + ------------------------------------------------------------------------- + | 'all_items' | False | (num_seeds, len(batch), num_items) | + ------------------------------------------------------------------------- + """ + # ============================================================================================================== + # check arguments. + # ============================================================================================================== + assert return_type in [ + 'log_prob', 'utility'], "return_type must be either 'log_prob' or 'utility'." + assert return_scope in [ + 'item_index', 'all_items'], "return_scope must be either 'item_index' or 'all_items'." + assert deterministic in [True, False] + if (not deterministic) and (sample_dict is None): + assert num_seeds >= 1, "A positive interger `num_seeds` is required if `deterministic` is False and no `sample_dict` is provided." + + # when pred_item is true, the model is predicting which item is bought (specified by item_index). + if self.pred_item: + batch.label = batch.item_index + + # ============================================================================================================== + # get sample_dict ready. + # ============================================================================================================== + if deterministic: + sample_dict = self.sample_coefficient_dictionary(num_seeds, deterministic=True) + ''' + num_seeds = 1 + # Use the means of variational distributions as the sole deterministic MC sample. + # NOTE: here we don't need to sample the obs2prior weight H since we only compute the log-likelihood. + # TODO: is this correct? + sample_dict = dict() + for coef_name, bayesian_coeffs in self.coef_dict.items(): + num_classes = bayesian_coeffs[0][0].num_classes + dim = bayesian_coeffs[0][0].dim + this_sample = torch.FloatTensor(num_seeds, num_classes, dim, len(bayesian_coeffs), len(bayesian_coeffs[0])).to(self.device) + # outer_list = [] + for ii, bayesian_coeffs_inner in enumerate(bayesian_coeffs): + # inner_list = [] + for jj, coef in enumerate(bayesian_coeffs_inner): + this_sample[:, :, :, ii, jj] = coef.variational_distribution.mean.unsqueeze(dim=0) # (1, num_*, dim) + # inner_list.append(coef.variational_distribution.mean.unsqueeze(dim=0)) # (1, num_*, dim) + # inner_list.append(coef.variational_distribution.mean.unsqueeze(dim=0)) # (1, num_*, dim) + # outer_list.append(inner_list) + sample_dict[coef_name] = this_sample + ''' + else: + if sample_dict is None: + # sample stochastic parameters. + sample_dict = self.sample_coefficient_dictionary(num_seeds) + else: + # use the provided sample_dict. + num_seeds = list(sample_dict.values())[0].shape[0] + + # ============================================================================================================== + # call the sampling method of additional modules. + # ============================================================================================================== + for module in self.additional_modules: + # deterministic sample. + if deterministic: + module.dsample() + else: + module.rsample(num_seeds=num_seeds) + + # if utility is requested, don't run log-softmax, simply return logit. + return_logit = (return_type == 'utility') + if return_scope == 'all_items': + # (num_seeds, len(batch), num_items) + # TODO: (akanodia) disallow this for now. + raise NotImplementedError() + out = self.log_likelihood_all_items( + batch=batch, sample_dict=sample_dict, return_logit=return_logit) + elif return_scope == 'item_index': + # (num_seeds, len(batch)) + out = self.log_likelihood_item_index( + batch=batch, sample_dict=sample_dict, return_logit=return_logit, debug=debug) + + if deterministic: + # drop the first dimension, which has size of `num_seeds` (equals 1 in the deterministic case). + # (len(batch), num_items) or (len(batch),) + return out.squeeze(dim=0) + + return out + + @property + def num_params(self) -> int: + return sum([p.numel() for p in self.parameters()]) + + @property + def device(self) -> torch.device: + for coef in self.coef_dict.values(): + return coef[0][0].device + + # ================================================================================================================== + # helper functions. + # ================================================================================================================== + def sample_coefficient_dictionary(self, num_seeds: int, deterministic: bool=False) -> Dict[str, torch.Tensor]: + """A helper function to sample parameters from coefficients. + + Args: + num_seeds (int): number of random samples. + + Returns: + Dict[str, torch.Tensor]: a dictionary maps coefficient names to tensor of sampled coefficient parameters, + where the first dimension of the sampled tensor has size `num_seeds`. + Each sample tensor has shape (num_seeds, num_classes, dim). + """ + sample_dict = dict() + if deterministic: + num_seeds = 1 + # Use the means of variational distributions as the sole deterministic MC sample. + # NOTE: here we don't need to sample the obs2prior weight H since we only compute the log-likelihood. + # TODO: is this correct? + sample_dict = dict() + for coef_name, bayesian_coeffs in self.coef_dict.items(): + num_classes = bayesian_coeffs[0][0].num_classes + dim = bayesian_coeffs[0][0].dim + this_sample = torch.FloatTensor(num_seeds, num_classes, dim, len(bayesian_coeffs), len(bayesian_coeffs[0])).to(self.device) + # outer_list = [] + for ii, bayesian_coeffs_inner in enumerate(bayesian_coeffs): + # inner_list = [] + for jj, coef in enumerate(bayesian_coeffs_inner): + s = coef.variational_distribution.mean.unsqueeze(dim=0) # (1, num_*, dim) + # if coef.distribution == 'lognormal': + # s = s.exp() + this_sample[:, :, :, ii, jj] = s + # inner_list.append(coef.variational_distribution.mean.unsqueeze(dim=0)) # (1, num_*, dim) + # inner_list.append(coef.variational_distribution.mean.unsqueeze(dim=0)) # (1, num_*, dim) + # outer_list.append(inner_list) + sample_dict[coef_name] = this_sample + else: + for coef_name, bayesian_coeffs in self.coef_dict.items(): + # outer_list = [] + num_classes = bayesian_coeffs[0][0].num_classes + dim = bayesian_coeffs[0][0].dim + this_sample = torch.FloatTensor(num_seeds, num_classes, dim, len(bayesian_coeffs), len(bayesian_coeffs[0])).to(self.device) + obs2prior = self.obs2prior_dict[coef_name] + if obs2prior: + num_obs = bayesian_coeffs[0][0].num_obs + this_sample_H = torch.FloatTensor(num_seeds, dim, num_obs, len(bayesian_coeffs), len(bayesian_coeffs[0])).to(self.device) + # outer_list_H = [] + for ii, bayesian_coeffs_inner in enumerate(bayesian_coeffs): + # inner_list = [] + # if obs2prior: + # inner_list_H = [] + for jj, coef in enumerate(bayesian_coeffs_inner): + s = coef.rsample(num_seeds) + if coef.obs2prior: + # sample both obs2prior weight and realization of variable. + assert isinstance(s, tuple) and len(s) == 2 + # if coef.distribution == 'lognormal': + if False: + ss = torch.exp(s[0]) + else: + ss = s[0] + this_sample[:, :, :, ii, jj] = s[0] + this_sample_H[:, :, :, ii, jj] = s[1] + # inner_list.append(s[0]) + # inner_list_H.append(s[1]) + else: + # only sample the realization of variable. + assert torch.is_tensor(s) + # if coef.distribution == 'lognormal': + # s = torch.exp(s) + this_sample[:, :, :, ii, jj] = s + # inner_list.append(s) + # outer_list.append(inner_list) + # if obs2prior: + # outer_list_H.append(inner_list_H) + sample_dict[coef_name] = this_sample + # sample_dict[coef_name] = outer_list + if obs2prior: + sample_dict[coef_name + '.H'] = this_sample_H + # sample_dict[coef_name + '.H'] = outer_list_H + + return sample_dict + + @torch.no_grad() + def get_within_category_accuracy(self, log_p_all_items: torch.Tensor, label: torch.LongTensor) -> Dict[str, float]: + """A helper function for computing prediction accuracy (i.e., all non-differential metrics) + within category. + In particular, this method calculates the accuracy, precision, recall and F1 score. + + + This method has the same functionality as the following peusodcode: + for C in categories: + # get sessions in which item in category C was purchased. + T <- (t for t in {0,1,..., len(label)-1} if label[t] is in C) + Y <- label[T] + + predictions = list() + for t in T: + # get the prediction within category for this session. + y_pred = argmax_{items in C} log prob computed before. + predictions.append(y_pred) + + accuracy = mean(Y == predictions) + + Similarly, this function computes precision, recall and f1score as well. + + Args: + log_p_all_items (torch.Tensor): shape (num_sessions, num_items) the log probability of + choosing each item in each session. + label (torch.LongTensor): shape (num_sessions,), the IDs of items purchased in each session. + + Returns: + [Dict[str, float]]: A dictionary containing performance metrics. + """ + # argmax: (num_sessions, num_categories), within category argmax. + # item IDs are consecutive, thus argmax is the same as IDs of the item with highest P. + _, argmax_by_category = scatter_max( + log_p_all_items, self.item_to_category_tensor, dim=-1) + + # category_purchased[t] = the category of item label[t]. + # (num_sessions,) + category_purchased = self.item_to_category_tensor[label] + + # pred[t] = the item with highest utility from the category item label[t] belongs to. + # (num_sessions,) + pred_from_category = argmax_by_category[torch.arange( + len(label)), category_purchased] + + within_category_accuracy = ( + pred_from_category == label).float().mean().item() + + # precision + precision = list() + + recall = list() + for i in range(self.num_items): + correct_i = torch.sum( + (torch.logical_and(pred_from_category == i, label == i)).float()) + precision_i = correct_i / \ + torch.sum((pred_from_category == i).float()) + recall_i = correct_i / torch.sum((label == i).float()) + + # do not add if divided by zero. + if torch.any(pred_from_category == i): + precision.append(precision_i.cpu().item()) + if torch.any(label == i): + recall.append(recall_i.cpu().item()) + + precision = float(np.mean(precision)) + recall = float(np.mean(recall)) + + if precision == recall == 0: + f1 = 0 + else: + f1 = 2 * precision * recall / (precision + recall) + + return {'accuracy': within_category_accuracy, + 'precision': precision, + 'recall': recall, + 'f1score': f1} + + # ================================================================================================================== + # Methods for terms in the ELBO: prior, likelihood, and variational. + # ================================================================================================================== + def log_likelihood_all_items(self, batch: ChoiceDataset, return_logit: bool, sample_dict: Dict[str, torch.Tensor]) -> torch.Tensor: + """ + NOTE to developers: + NOTE (akanodia to tianyudu): Is this really slow; even with log_likelihood you need log_prob which depends on logits of all items? + This method computes utilities for all items available, which is a relatively slow operation. For + training the model, you only need the utility/log-prob for the chosen/relevant item (i.e., item_index[i] for each i-th observation). + Use this method for inference only. + Use self.log_likelihood_item_index() for training instead. + + Computes the log probability of choosing `each` item in each session based on current model parameters. + NOTE (akanodiadu to tianyudu): What does the next line mean? I think it just says its allowing for samples instead of posterior mean. + This method allows for specifying {user, item}_latent_value for Monte Carlo estimation in ELBO. + For actual prediction tasks, use the forward() function, which will use means of variational + distributions for user and item latents. + + Args: + batch (ChoiceDataset): a ChoiceDataset object containing relevant information. + return_logit(bool): if set to True, return the log-probability, otherwise return the logit/utility. + sample_dict(Dict[str, torch.Tensor]): Monte Carlo samples for model coefficients + (i.e., those Greek letters). + sample_dict.keys() should be the same as keys of self.obs2prior_dict, i.e., those + greek letters actually enter the functional form of utility. + The value of sample_dict should be tensors of shape (num_seeds, num_classes, dim) + where num_classes in {num_users, num_items, 1} + and dim in {latent_dim(K), num_item_obs, num_user_obs, 1}. + + Returns: + torch.Tensor: a tensor of shape (num_seeds, len(batch), self.num_items), where + out[x, y, z] is the probability of choosing item z in session y conditioned on + latents to be the x-th Monte Carlo sample. + """ + batch.item_index = torch.arange(self.num_items, device=batch.device) + batch.item_index = batch.item_index.repeat(batch.user_index.shape[0]) + batch.user_index = batch.user_index.repeat_interleave(self.num_items) + batch.session_index = batch.session_index.repeat_interleave(self.num_items) + return self.log_likelihood_item_index(batch, return_logit, sample_dict, all_items=True) + + def log_likelihood_item_index(self, batch: ChoiceDataset, return_logit: bool, sample_dict: Dict[str, torch.Tensor], all_items: bool=False, debug=False, return_logit_components=False, return_price_coeff=False) -> torch.Tensor: + """ + NOTE for developers: + This method is more efficient and only computes log-likelihood/logit(utility) for item in item_index[i] for each + i-th observation. + Developers should use use `log_likelihood_all_items` for inference purpose and to computes log-likelihoods/utilities + for ALL items for the i-th observation. + + Computes the log probability of choosing item_index[i] in each session based on current model parameters. + This method allows for specifying {user, item}_latent_value for Monte Carlo estimation in ELBO. + For actual prediction tasks, use the forward() function, which will use means of variational + distributions for user and item latents. + + Args: + batch (ChoiceDataset): a ChoiceDataset object containing relevant information. + return_logit(bool): if set to True, return the logit/utility, otherwise return the log-probability. + sample_dict(Dict[str, torch.Tensor]): Monte Carlo samples for model coefficients + (i.e., those Greek letters). + sample_dict.keys() should be the same as keys of self.obs2prior_dict, i.e., those + greek letters actually enter the functional form of utility. + The value of sample_dict should be tensors of shape (num_seeds, num_classes, dim) + where num_classes in {num_users, num_items, 1} + and dim in {latent_dim(K), num_item_obs, num_user_obs, 1}. + all_items: return for all items + debug: debug param, keeps evolving, see code + return_logit_components: return a tensor of size (len(batch), num_terms) where num_terms is the number of additive terms in self.formula. If this is set to True, return_logit must be set to True too + + + Returns: + torch.Tensor: a tensor of shape (num_seeds, len(batch)), where + out[x, y] is the probabilities of choosing item batch.item[y] in session y + conditioned on latents to be the x-th Monte Carlo sample. + """ + num_seeds = list(sample_dict.values())[0].shape[0] + + # get category id of the item bought in each row of batch. + cate_index = self.item_to_category_tensor[batch.item_index] + + # get item ids of all items from the same category of each item bought. + relevant_item_index = self.category_to_item_tensor[cate_index, :] + relevant_item_index = relevant_item_index.view(-1,) + # index were padded with -1's, drop those dummy entries. + relevant_item_index = relevant_item_index[relevant_item_index != -1] + + # the first repeats[0] entries in relevant_item_index are for the category of item_index[0] + repeats = self.category_to_size_tensor[cate_index] + # argwhere(reverse_indices == k) are positions in relevant_item_index for the category of item_index[k]. + reverse_indices = torch.repeat_interleave( + torch.arange(len(batch), device=self.device), repeats) + # expand the user_index and session_index. + # if all_items: + # breakpoint() + user_index = torch.repeat_interleave(batch.user_index, repeats) + repeat_category_index = torch.repeat_interleave(cate_index, repeats) + session_index = torch.repeat_interleave(batch.session_index, repeats) + # duplicate the item focused to match. + item_index_expanded = torch.repeat_interleave( + batch.item_index, repeats) + + # short-hands for easier shape check. + R = num_seeds + # total number of relevant items. + total_computation = len(session_index) + S = self.num_sessions + U = self.num_users + I = self.num_items + NC = self.num_categories + + user_chunk_ids = torch.repeat_interleave(self.user_chunk_ids[batch.user_index], repeats) + item_chunk_ids = torch.repeat_interleave(self.item_chunk_ids[batch.item_index], repeats) + session_chunk_ids = torch.repeat_interleave(self.session_chunk_ids[batch.session_index], repeats) + category_chunk_ids = torch.repeat_interleave(self.category_chunk_ids[cate_index], repeats) + + # ========================================================================================== + # Helper Functions for Reshaping. + # ========================================================================================== + + def reshape_coef_sample(sample, name): + # reshape the monte carlo sample of coefficients to (R, P, I, *). + if name.endswith('_user'): + # (R, total_computation, dim, chunk_size_1, chunk_size_2) + all_chunks_sample = sample[:, user_index, :, :, :] + # (total_computation) --> (1, total_computation, 1, 1, 1) + second_chunk_index = session_chunk_ids.reshape(1, -1, 1, 1, 1) + # (1, total_computation, 1, 1, 1) --> (R, total_computation, dim, chunk_size_1, 1) + second_chunk_index = second_chunk_index.repeat(R, 1, all_chunks_sample.shape[2], all_chunks_sample.shape[3], 1) + # (total_computation) --> (1, total_computation, 1, 1) + first_chunk_index = category_chunk_ids.reshape(1, -1, 1, 1) + # (1, total_computation, 1, 1) --> (R, total_computation, dim, 1) + first_chunk_index = first_chunk_index.repeat(R, 1, all_chunks_sample.shape[2], 1) + # select the first chunk. + second_chunk_selected = torch.gather(all_chunks_sample, -1, second_chunk_index).squeeze(-1) + # select the second chunk. + first_chunk_selected = torch.gather(second_chunk_selected, -1, first_chunk_index).squeeze(-1) + return first_chunk_selected + elif name.endswith('_item'): + # (R, total_computation, dim, chunk_size_1, chunk_size_2) + all_chunks_sample = sample[:, relevant_item_index, :, :, :] + # (total_computation) --> (1, total_computation, 1, 1, 1) + second_chunk_index = user_chunk_ids.reshape(1, -1, 1, 1, 1) + # (1, total_computation, 1, 1, 1) --> (R, total_computation, dim, chunk_size_1, 1) + second_chunk_index = second_chunk_index.repeat(R, 1, all_chunks_sample.shape[2], all_chunks_sample.shape[3], 1) + # (total_computation) --> (1, total_computation, 1, 1) + first_chunk_index = session_chunk_ids.reshape(1, -1, 1, 1) + # (1, total_computation, 1, 1) --> (R, total_computation, dim, 1) + first_chunk_index = first_chunk_index.repeat(R, 1, all_chunks_sample.shape[2], 1) + # select the first chunk. + second_chunk_selected = torch.gather(all_chunks_sample, -1, second_chunk_index).squeeze(-1) + # select the second chunk. + first_chunk_selected = torch.gather(second_chunk_selected, -1, first_chunk_index).squeeze(-1) + return first_chunk_selected + elif name.endswith('_category'): + # (R, total_computation, dim, chunk_size_1, chunk_size_2) + all_chunks_sample = sample[:, repeat_category_index, :, :, :] + # (total_computation) --> (1, total_computation, 1, 1, 1) + second_chunk_index = user_chunk_ids.reshape(1, -1, 1, 1, 1) + # (1, total_computation, 1, 1, 1) --> (R, total_computation, dim, chunk_size_1, 1) + second_chunk_index = second_chunk_index.repeat(R, 1, all_chunks_sample.shape[2], all_chunks_sample.shape[3], 1) + # (total_computation) --> (1, total_computation, 1, 1) + first_chunk_index = session_chunk_ids.reshape(1, -1, 1, 1) + # (1, total_computation, 1, 1) --> (R, total_computation, dim, 1) + first_chunk_index = first_chunk_index.repeat(R, 1, all_chunks_sample.shape[2], 1) + # select the first chunk. + second_chunk_selected = torch.gather(all_chunks_sample, -1, second_chunk_index).squeeze(-1) + # select the second chunk. + first_chunk_selected = torch.gather(second_chunk_selected, -1, first_chunk_index).squeeze(-1) + return first_chunk_selected + elif name.endswith('_constant'): + # (R, *) --> (R, total_computation, *) + return sample[:, 0, 0].view(R, 1, -1).expand(-1, total_computation, -1) + else: + raise ValueError + + def reshape_observable(obs, name): + # reshape observable to (R, P, I, *) so that it can be multiplied with monte carlo + # samples of coefficients. + O = obs.shape[-1] # number of observables. + assert O == positive_integer + if name.startswith('item_'): + assert obs.shape == (I, O) + obs = obs[relevant_item_index, :] + elif name.startswith('user_'): + assert obs.shape == (U, O) + obs = obs[user_index, :] + elif name.startswith('session_'): + assert obs.shape == (S, O) + obs = obs[session_index, :] + elif name.startswith('price_'): + assert obs.shape == (S, I, O) + obs = obs[session_index, relevant_item_index, :] + elif name.startswith('taste_'): + assert obs.shape == (U, I, O) + obs = obs[user_index, relevant_item_index, :] + else: + raise ValueError + assert obs.shape == (total_computation, O) + return obs.unsqueeze(dim=0).expand(R, -1, -1) + + # ========================================================================================== + # Compute Components related to users and items only. + # ========================================================================================== + utility = torch.zeros(R, total_computation, device=self.device) + if return_logit_components: + assert R == 1, 'return_logit_components is not supported for R > 1' + assert return_logit, "return_logit_components requires return_logit" + utility_components = torch.zeros(len(self.formula), R, total_computation, device=self.device) + + if return_price_coeff: + # 'price_obs' needs to be seen in self.formula exactly once in self.utility_formula + assert R == 1, 'return_price_coeff is not supported for R > 1' + assert self.utility_formula.count('price_obs') == 1, "price_obs needs to be seen in self.formula exactly once for return_price_coeff" + price_coeffs = torch.zeros(R, total_computation, device=self.device) + + # loop over additive term to utility + for ii, term in enumerate(self.formula): + obs_coeff = None + if debug: + breakpoint() + # Type I: single coefficient, e.g., lambda_item or lambda_user. + if len(term['coefficient']) == 1 and term['observable'] is None: + # E.g., lambda_item or lambda_user + coef_name = term['coefficient'][0] + coef_sample = reshape_coef_sample( + sample_dict[coef_name], coef_name) + assert coef_sample.shape == (R, total_computation, 1) + additive_term = coef_sample.view(R, total_computation) + additive_term *= term['sign'] + + # Type II: factorized coefficient, e.g., . + elif len(term['coefficient']) == 2 and term['observable'] is None: + coef_name_0 = term['coefficient'][0] + coef_name_1 = term['coefficient'][1] + + coef_sample_0 = reshape_coef_sample( + sample_dict[coef_name_0], coef_name_0) + coef_sample_1 = reshape_coef_sample( + sample_dict[coef_name_1], coef_name_1) + + assert coef_sample_0.shape == coef_sample_1.shape == ( + R, total_computation, positive_integer) + + additive_term = (coef_sample_0 * coef_sample_1).sum(dim=-1) + additive_term *= term['sign'] + + # Type III: single coefficient multiplied by observable, e.g., theta_user * x_obs_item. + elif len(term['coefficient']) == 1 and term['observable'] is not None: + coef_name = term['coefficient'][0] + coef_sample = reshape_coef_sample( + sample_dict[coef_name], coef_name) + # breakpoint() + assert coef_sample.shape == ( + R, total_computation, positive_integer) + + obs_name = term['observable'] + obs = reshape_observable(getattr(batch, obs_name), obs_name) + assert obs.shape == (R, total_computation, positive_integer) + + if return_price_coeff and term['observable'] is not None and term['observable'].startswith('price_'): + obs_coeff = coef_sample.sum(dim=-1) + price_coeffs = obs_coeff + price_coeffs *= term['sign'] + + additive_term = (coef_sample * obs).sum(dim=-1) + additive_term *= term['sign'] + + # Type IV: factorized coefficient multiplied by observable. + # e.g., gamma_user * beta_item * price_obs. + elif len(term['coefficient']) == 2 and term['observable'] is not None: + coef_name_0, coef_name_1 = term['coefficient'][0], term['coefficient'][1] + coef_sample_0 = reshape_coef_sample( + sample_dict[coef_name_0], coef_name_0) + coef_sample_1 = reshape_coef_sample( + sample_dict[coef_name_1], coef_name_1) + assert coef_sample_0.shape == coef_sample_1.shape == ( + R, total_computation, positive_integer) + num_obs_times_latent_dim = coef_sample_0.shape[-1] + + obs_name = term['observable'] + obs = reshape_observable(getattr(batch, obs_name), obs_name) + assert obs.shape == (R, total_computation, positive_integer) + num_obs = obs.shape[-1] # number of observables. + + assert (num_obs_times_latent_dim % num_obs) == 0 + latent_dim = num_obs_times_latent_dim // num_obs + + coef_sample_0 = coef_sample_0.view( + R, total_computation, num_obs, latent_dim) + coef_sample_1 = coef_sample_1.view( + R, total_computation, num_obs, latent_dim) + # compute the factorized coefficient with shape (R, P, I, O). + coef = (coef_sample_0 * coef_sample_1).sum(dim=-1) + + if return_price_coeff and term['observable'] is not None and term['observable'].startswith('price_'): + obs_coeff = coef.sum(dim=-1) + price_coeffs = obs_coeff + price_coeffs *= term['sign'] + + additive_term = (coef * obs).sum(dim=-1) + additive_term *= term['sign'] + + else: + raise ValueError(f'Undefined term type: {term}') + + assert additive_term.shape == (R, total_computation) + utility += additive_term + if return_logit_components: + utility_components[ii] = additive_term + + # ========================================================================================== + # Mask Out Unavailable Items in Each Session. + # ========================================================================================== + + if batch.item_availability is not None: + # expand to the Monte Carlo sample dimension. + A = batch.item_availability[session_index, relevant_item_index].unsqueeze( + dim=0).expand(R, -1) + utility[~A] = - (torch.finfo(utility.dtype).max / 2) + if return_logit_components: + utility_components[:, ~A] = - (torch.finfo(utility.dtype).max / 2) + if return_price_coeff: + price_coeffs[~A] = 0 + + for module in self.additional_modules: + # current utility shape: (R, total_computation) + assert False, "additional modules not supported for bemb_chunked" + additive_term = module(batch) + assert additive_term.shape == ( + R, len(batch)) or additive_term.shape == (R, len(batch), 1) + if additive_term.shape == (R, len(batch), 1): + # TODO: need to make this consistent with log_likelihood_all. + # be tolerant for some customized module with BayesianLinear that returns (R, len(batch), 1). + additive_term = additive_term.view(R, len(batch)) + # expand to total number of computation, query by reverse_indices. + # reverse_indices has length total_computation, and reverse_indices[i] correspond to the row-id that this + # computation is responsible for. + additive_term = additive_term[:, reverse_indices] + assert additive_term.shape == (R, total_computation) + + if return_logit: + # (num_seeds, len(batch)) + if return_logit_components: + u = utility_components[:, :, item_index_expanded == relevant_item_index] + assert u.shape == (len(self.formula), R, len(batch)) + return u + else: + u = utility[:, item_index_expanded == relevant_item_index] + assert u.shape == (R, len(batch)) + if return_price_coeff: + price_coeffs = price_coeffs[ :, item_index_expanded == relevant_item_index].squeeze(dim=0) + assert price_coeffs.shape[0] == len(batch) + return u, price_coeffs + else: + return u + + if self.pred_item: + # compute log likelihood log p(choosing item i | user, item latents) + # compute the log probability from logits/utilities. + # output shape: (num_seeds, len(batch), num_items) + # breakpoint() + log_p = scatter_log_softmax(utility, reverse_indices, dim=-1) + # select the log-P of the item actually bought. + log_p = log_p[:, item_index_expanded == relevant_item_index] + assert log_p.shape == (R, len(batch)) + return log_p + else: + # This is the binomial choice situation in which case we just report sigmoid log likelihood + utility = utility[:, item_index_expanded == relevant_item_index] + assert utility.shape == (R, len(batch)) + bce = nn.BCELoss(reduction='none') + # make num_seeds copies of the label, expand to (R, len(batch)) + label_expanded = batch.label.to(torch.float32).view(1, len(batch)).expand(R, -1) + assert label_expanded.shape == (R, len(batch)) + log_p = - bce(torch.sigmoid(utility), label_expanded) + assert log_p.shape == (R, len(batch)) + return log_p + + def log_prior(self, batch: ChoiceDataset, sample_dict: Dict[str, torch.Tensor]) -> torch.Tensor: + """Calculates the log-likelihood of Monte Carlo samples of Bayesian coefficients under their + prior distribution. This method assume coefficients are statistically independent. + + Args: + batch (ChoiceDataset): a dataset object contains observables for computing the prior distribution + if obs2prior is True. + sample_dict (Dict[str, torch.Tensor]): a dictionary coefficient names to Monte Carlo samples. + + Raises: + ValueError: [description] + + Returns: + torch.scalar_tensor: a tensor with shape (num_seeds,) of [ log P_{prior_distribution}(param[i]) ], + where param[i] is the i-th Monte Carlo sample. + """ + # assert sample_dict.keys() == self.coef_dict.keys() + num_seeds = list(sample_dict.values())[0].shape[0] + cate_index = self.item_to_category_tensor[batch.item_index] + user_chunk_ids = self.user_chunk_ids[batch.user_index] + item_chunk_ids = self.item_chunk_ids[batch.item_index] + session_chunk_ids = self.session_chunk_ids[batch.session_index] + category_chunk_ids = self.category_chunk_ids[cate_index] + + total = torch.zeros(num_seeds, device=self.device) + + def reshape_coef_sample(sample, name): + # reshape the monte carlo sample of coefficients to (R, P, I, *). + if name.endswith('_user'): + # (R, U, *) --> (R, total_computation, *) + temp = sample[:, :, :, :, :] + stemp = session_chunk_ids.reshape(1, -1, 1, 1, 1) + stemp = stemp.repeat(1, 1, temp.shape[2], temp.shape[3], 1) + ctemp = category_chunk_ids.reshape(1, -1, 1, 1) + ctemp = ctemp.repeat(1, 1, temp.shape[2], 1) + gathered1 = torch.gather(temp, 4, stemp).squeeze(4) + gathered2 = torch.gather(gathered1, 3, ctemp).squeeze(3) + return gathered2 + # return sample[:, user_index, :, category_chunk_ids, session_chunk_ids] + elif name.endswith('_item'): + # (R, I, *) --> (R, total_computation, *) + temp = sample[:, :, :, :, :] + utemp = user_chunk_ids.reshape(1, -1, 1, 1, 1) + utemp = utemp.repeat(1, 1, temp.shape[2], temp.shape[3], 1) + stemp = session_chunk_ids.reshape(1, -1, 1, 1) + stemp = stemp.repeat(1, 1, temp.shape[2], 1) + gathered1 = torch.gather(temp, 4, utemp).squeeze(4) + gathered2 = torch.gather(gathered1, 3, stemp).squeeze(3) + return gathered2 + # return sample[:, relevant_item_index, :, session_chunk_ids, user_chunk_ids] + elif name.endswith('_category'): + # (R, NC, *) --> (R, total_computation, *) + return sample[:, repeat_category_index, :, session_chunk_ids, user_chunk_ids] + elif name.endswith('_constant'): + # (R, *) --> (R, total_computation, *) + return sample[:, 0, 0].view(R, 1, -1).expand(-1, total_computation, -1) + else: + raise ValueError + + # for coef_name, coef in self.coef_dict.items(): + for coef_name, bayesian_coeffs in self.coef_dict.items(): + for ii, bayesian_coeffs_inner in enumerate(bayesian_coeffs): + for jj, coef in enumerate(bayesian_coeffs_inner): + if self.obs2prior_dict[coef_name]: + if coef_name.endswith('_item'): + x_obs = batch.item_obs + elif coef_name.endswith('_user'): + x_obs = batch.user_obs + else: + raise ValueError( + f'No observable found to support obs2prior for {coef_name}.') + + total += coef.log_prior(sample=sample_dict[coef_name][:, :, :, ii, jj], + H_sample=sample_dict[coef_name + '.H'][:, :, :, ii, jj], + x_obs=x_obs).sum(dim=-1) + else: + # log_prob outputs (num_seeds, num_{items, users}), sum to (num_seeds). + total += coef.log_prior( + sample=sample_dict[coef_name][:, :, :, ii, jj], H_sample=None, x_obs=None).sum(dim=-1) + # break + # break + + for module in self.additional_modules: + raise NotImplementedError() + total += module.log_prior() + + return total + + def log_variational(self, sample_dict: Dict[str, torch.Tensor]) -> torch.Tensor: + """Calculate the log-likelihood of samples in sample_dict under the current variational + distribution. + + Args: + sample_dict (Dict[str, torch.Tensor]): a dictionary coefficient names to Monte Carlo + samples. + + Returns: + torch.Tensor: a tensor of shape (num_seeds) of [ log P_{variational_distribution}(param[i]) ], + where param[i] is the i-th Monte Carlo sample. + """ + num_seeds = list(sample_dict.values())[0].shape[0] + total = torch.zeros(num_seeds, device=self.device) + + for coef_name, coef in self.coef_dict.items(): + # log_prob outputs (num_seeds, num_{items, users}), sum to (num_seeds). + total += coef.log_variational(sample_dict[coef_name]).sum(dim=-1) + + for module in self.additional_modules: + raise NotImplementedError() + # with shape (num_seeds,) + total += module.log_variational().sum() + + return total + + def elbo(self, batch: ChoiceDataset, num_seeds: int = 1) -> torch.Tensor: + """A combined method to computes the current ELBO given a batch, this method is used for training the model. + + Args: + batch (ChoiceDataset): a ChoiceDataset containing necessary information. + num_seeds (int, optional): the number of Monte Carlo samples from variational distributions + to evaluate the expectation in ELBO. + Defaults to 1. + + Returns: + torch.Tensor: a scalar tensor of the ELBO estimated from num_seeds Monte Carlo samples. + """ + # ============================================================================================================== + # 1. sample latent variables from their variational distributions. + # ============================================================================================================== + if self.deterministic_variational: + sample_dict = self.sample_coefficient_dictionary(num_seeds) + ''' + num_seeds = 1 + # Use the means of variational distributions as the sole deterministic MC sample. + # NOTE: here we don't need to sample the obs2prior weight H since we only compute the log-likelihood. + # TODO: is this correct? + sample_dict = dict() + for coef_name, coef in self.coef_dict.items(): + sample_dict[coef_name] = coef.variational_distribution.mean.unsqueeze( + dim=0) # (1, num_*, dim) + ''' + else: + sample_dict = self.sample_coefficient_dictionary(num_seeds) + + # ============================================================================================================== + # 2. compute log p(latent) prior. + # (num_seeds,) --mean--> scalar. + # with torch.no_grad(): + # while True: + # elbo = self.log_prior(batch, sample_dict).mean(dim=0) + # elbo = torch.tensor(0.0, device=self.device) + elbo = self.log_prior(batch, sample_dict).mean(dim=0) + # ============================================================================================================== + + # ============================================================================================================== + # 3. compute the log likelihood log p(obs|latent). + # sum over independent purchase decision for individual observations, mean over MC seeds. + # the forward() function calls module.rsample(num_seeds) for module in self.additional_modules. + # ============================================================================================================== + if self.pred_item: + # the prediction target is item_index. + elbo_expanded = self.forward(batch, + return_type='log_prob', + return_scope='item_index', + deterministic=self.deterministic_variational, + sample_dict=sample_dict) + if self.deterministic_variational: + elbo_expanded = elbo_expanded.unsqueeze(dim=0) + elbo += elbo_expanded.sum(dim=1).mean(dim=0) # (num_seeds, len(batch)) --> scalar. + else: + # the prediction target is binary. + # TODO: update the prediction function. + utility = self.forward(batch, + return_type='utility', + return_scope='item_index', + deterministic=self.deterministic_variational, + sample_dict=sample_dict) # (num_seeds, len(batch)) + + # compute the log-likelihood for binary label. + # (num_seeds, len(batch)) + y_stacked = torch.stack([batch.label] * num_seeds).float() + assert y_stacked.shape == utility.shape + bce = nn.BCELoss(reduction='none') + # scalar. + ll = - bce(torch.sigmoid(utility), + y_stacked).sum(dim=1).mean(dim=0) + elbo += ll + + # ============================================================================================================== + # 4. optionally add log likelihood under variational distributions q(latent). + # ============================================================================================================== + if self.trace_log_q: + #TODO(akanodia): do not allow at this time + raise NotImplementedError() + assert not self.deterministic_variational, "deterministic_variational is not compatible with trace_log_q." + elbo -= self.log_variational(sample_dict).mean(dim=0) + + return elbo diff --git a/bemb/model/bemb_flex_lightning.py b/bemb/model/bemb_flex_lightning.py index 7873aa5..0c7dbee 100644 --- a/bemb/model/bemb_flex_lightning.py +++ b/bemb/model/bemb_flex_lightning.py @@ -79,6 +79,14 @@ def training_step(self, batch, batch_idx): loss = - elbo return loss + # DEBUG_MARKER + ''' + def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + print(f"Epoch {self.current_epoch} has ended") + breakpoint() + ''' + # DEBUG_MARKER + def _get_performance_dict(self, batch): if self.model.pred_item: log_p = self.model(batch, return_type='log_prob', diff --git a/bemb/model/bemb_supermarket_lightning.py b/bemb/model/bemb_supermarket_lightning.py new file mode 100644 index 0000000..a26dc4e --- /dev/null +++ b/bemb/model/bemb_supermarket_lightning.py @@ -0,0 +1,201 @@ +""" +BEMB Flex model adopted to PyTorch Lightning. +""" +import os + +import pandas as pd +import torch +from torch import nn +from torch.nn import functional as F +from torch.utils.data import DataLoader, random_split +from torchvision.datasets import MNIST +from torchvision import transforms +import pytorch_lightning as pl +# from pytorch_lightning.metrics.functional import accuracy +from torchmetrics.functional import accuracy +from bemb.model import BEMBFlex +from bemb.model.bemb import parse_utility +from torch_choice.data.utils import create_data_loader + + +class WeekTrendPreprocessor(nn.Module): + def __init__(self, num_weeks: int, latent_dim: int): + super().__init__() + self.emb = nn.Embedding(num_weeks, latent_dim) + + def forward(self, batch): + # batch.session_obs_w expected, (num_session, 1) + # convert to batch.session_delta, (num_session, num_latent) + # session_delta will be considered as a session-specific observable by BEMB. + batch.session_week = self.emb(batch.session_week_id.long()) + return batch + +class LitBEMBFlex(pl.LightningModule): + + def __init__(self, configs=None, user_encoder=None, item_encoder=None, session_encoder=None, category_encoder=None, obs_dict=None, print_info=None, train_data=None, validation_data=None, batch_size: int = -1, num_workers: int = 8, learning_rate: float = 0.3, num_seeds: int=1, num_weeks=0, num_week_trend_latents=0, test_data=None, preprocess=True, lr_decay_type='multi_step_lr', lr_milestones=[], lr_decay=1.0, check_val_every_n_epoch=5, **kwargs): + # use kwargs to pass parameter to BEMB Torch. + super().__init__() + # import pdb; pdb.pdb.set_trace() + self.save_hyperparameters(ignore=['train_data', 'validation_data', 'test_data']) + self.train_data = train_data + self.validation_data = validation_data + self.test_data = test_data + bemb_params = {key : self.hparams[key] for key in kwargs.keys()} + self.model = BEMBFlex(**bemb_params) + self.configs = self.hparams.configs + # self.configs = configs + self.num_seeds = self.hparams.num_seeds + self.batch_size = self.hparams.batch_size + self.num_workers = self.hparams.num_workers + self.learning_rate = self.hparams.learning_rate + self.preprocess = self.hparams.preprocess + self.lr_decay_type = self.hparams.lr_decay_type + self.lr_milestones = self.hparams.lr_milestones + self.lr_decay = self.hparams.lr_decay + self.check_val_every_n_epoch = self.hparams.check_val_every_n_epoch + if self.preprocess: + self.batch_preprocess = WeekTrendPreprocessor(num_weeks=self.hparams.num_weeks, latent_dim=self.hparams.num_week_trend_latents) + + def __str__(self) -> str: + return str(self.model) + + def training_step(self, batch, batch_idx): + if self.preprocess: + batch = self.batch_preprocess(batch) + elbo = self.model.elbo(batch, num_seeds=self.num_seeds) + self.log('train_elbo', elbo) + loss = - elbo + return loss + + def validation_step(self, batch, batch_idx): + if self.preprocess: + batch = self.batch_preprocess(batch) + LL = self.model.forward(batch, return_type='log_prob', return_scope='item_index').mean() + # self.log('val_log_likelihood', LL, prog_bar=True) + return {'LL':LL} + + # def training_epoch_end(self, outputs): + # sch = self.lr_schedulers() + # print("Learning Rate %f" % sch.get_lr) + + def validation_epoch_end(self, outputs): + sch = self.lr_schedulers() + # print("Learning Rate %r" % sch.get_lr()) + avg_LL = torch.stack([x["LL"] for x in outputs]).mean() + # self.log("lr", sch.get_last_lr()[0], prog_bar=True) + self.log("lr", self.hparams.learning_rate, prog_bar=True) + self.log("val_log_likelihood", avg_LL, prog_bar=True) + + # If the selected scheduler is a ReduceLROnPlateau scheduler. + if isinstance(sch, torch.optim.lr_scheduler.ReduceLROnPlateau): + sch.step(self.trainer.callback_metrics["val_log_likelihood"]) + # avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean() + # self.log("ptl/val_loss", avg_loss) + # self.log("ptl/val_accuracy", avg_acc) + + # def test_step(self, batch, batch_idx): + # LL = self.model.forward(batch, return_logit=False, all_items=False).mean() + # self.log('test_log_likelihood', LL) + + # pred = self.model(batch) + # performance = self.model.get_within_category_accuracy(pred, batch.label) + # for key, val in performance.items(): + # self.log('test_' + key, val, prog_bar=True) + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) + # return optimizer + # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[400], gamma=0.99) + if self.lr_decay_type == 'reduce_on_plateau': + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2, factor=self.lr_decay) + lr_scheduler = { + 'scheduler': scheduler, + 'name': 'lr_log', + 'monitor' : 'val_log_likelihood', + 'frequency' : self.check_val_every_n_epoch, + } + elif self.lr_decay_type == 'multi_step_lr': + scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.lr_milestones, gamma=self.lr_decay) + lr_scheduler = scheduler + else: + raise ValueError + + # scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1.00, total_steps=200)# steps_per_epoch=len(data_loader), epochs=10) + # lr_scheduler = scheduler + return [optimizer], [lr_scheduler] + # return optimizer + + def train_dataloader(self): + train_dataloader = create_data_loader(self.train_data, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers) + return train_dataloader + + def val_dataloader(self): + validation_dataloader = create_data_loader(self.validation_data, + # batch_size=-1, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers) + return validation_dataloader + + def test_dataloader(self): + if self.test_data is None: + return None + test_dataloader = create_data_loader(self.test_data, + # batch_size=-1, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers) + return test_dataloader + + # def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure, ): + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs): + optimizer.step(closure=optimizer_closure) + with torch.no_grad(): + self.model.clamp_coefs() + + + def write_bemb_cpp_format(self): + model = self.model + configs = self.hparams.configs + print_info = self.hparams.print_info + users = self.hparams.user_encoder.classes_ + items = self.hparams.item_encoder.classes_ + categories = self.hparams.category_encoder.classes_ + sessions = self.hparams.session_encoder.classes_ + user_df = pd.DataFrame(users, columns=['user_id']) + item_df = pd.DataFrame(items, columns=['item_id']) + category_df = pd.DataFrame(categories, columns=['category_id']) + formula = parse_utility(model.utility_formula) + variations = {'user', 'item', 'category', 'constant', 'session'} + variations_dfs = { + 'user' : user_df, + 'item' : item_df, + 'category' : category_df, + } + variations_ids = { + 'user' : 'user_id', + 'item' : 'item_id', + 'category' : 'category_id', + } + for additive_term in formula: + for coef_name in additive_term['coefficient']: + variation = coef_name.split('_')[-1] + params_mean = model.coef_dict[coef_name].variational_mean + params_std = torch.exp(model.coef_dict[coef_name].variational_logstd) + for params, moment in zip((params_mean, params_std), ('mean', 'std')): + params_df = pd.DataFrame(params.detach().cpu().numpy()) + if variation != 'constant': + params_df = pd.concat((variations_dfs[variation], params_df), axis=1) + params_df.to_csv('%s/param_%s_%s.tsv' % (configs.out_dir, coef_name, moment), sep='\t') + + if 'lambda_item' in model.coef_dict: + print_info['store_dummies'].to_csv('%s/store_dummies.tsv' % (configs.out_dir), sep='\t') + if 'delta_item' in model.coef_dict: + print_info['store_weekday_dummies'].to_csv('%s/store_weekday_dummies.tsv' % (configs.out_dir), sep='\t') + if 'mu_item' in model.coef_dict: + week_trend_latents = self.batch_preprocess.emb.weight + week_trend_latents_df = pd.DataFrame(week_trend_latents.detach().cpu().numpy()) + week_trend_latents_df.to_csv('%s/param_weektrends_mean.tsv' % (configs.out_dir), sep='\t') diff --git a/bemb/utils/run_helper.py b/bemb/utils/run_helper.py index 410745d..33fd544 100644 --- a/bemb/utils/run_helper.py +++ b/bemb/utils/run_helper.py @@ -6,6 +6,7 @@ import time import pytorch_lightning as pl +from pytorch_lightning.callbacks.early_stopping import EarlyStopping from torch_choice.data.utils import create_data_loader from typing import List from torch_choice.data import ChoiceDataset @@ -17,7 +18,7 @@ def section_print(input_text): print('=' * 20, input_text, '=' * 20) -def run(model: LitBEMBFlex, dataset_list: List[ChoiceDataset], batch_size: int=-1, num_epochs: int=10, num_workers: int=8, **kwargs) -> LitBEMBFlex: +def run(model: LitBEMBFlex, dataset_list: List[ChoiceDataset], batch_size: int=-1, num_epochs: int=10, num_workers: int=8, run_test=True, patience=100, **kwargs) -> LitBEMBFlex: """A standard pipeline of model training and evaluation. Args: @@ -25,6 +26,7 @@ def run(model: LitBEMBFlex, dataset_list: List[ChoiceDataset], batch_size: int=- dataset_list (List[ChoiceDataset]): train_dataset, validation_test, and test_dataset in a list of length 3. batch_size (int, optional): batch_size for training and evaluation. Defaults to -1, which indicates full-batch training. num_epochs (int, optional): number of epochs for training. Defaults to 10. + run_test (bool, optional): whether to run evaluation on test set. Defaults to True. **kwargs: additional keyword argument for the pytorch-lightning Trainer. Returns: @@ -52,11 +54,13 @@ def run(model: LitBEMBFlex, dataset_list: List[ChoiceDataset], batch_size: int=- max_epochs=num_epochs, check_val_every_n_epoch=1, log_every_n_steps=1, + callbacks=[EarlyStopping(monitor='val_log_likelihood', patience=patience, mode='max')], **kwargs) start_time = time.time() trainer.fit(model, train_dataloaders=train, val_dataloaders=validation) print(f'time taken: {time.time() - start_time}') - section_print('test performance') - trainer.test(model, dataloaders=test) + if run_test: + section_print('test performance') + trainer.test(model, dataloaders=test) return model diff --git a/tutorials/supermarket/configs.yaml b/tutorials/supermarket/configs.yaml index ce6b672..7ecc4f7 100644 --- a/tutorials/supermarket/configs.yaml +++ b/tutorials/supermarket/configs.yaml @@ -1,30 +1,36 @@ device: cuda # data_dir: /home/tianyudu/Data/MoreSupermarket/tsv/ -data_dir: /home/tianyudu/Data/MoreSupermarket/20180101-20191231_13/tsv/ +# data_dir: /home/tianyudu/Data/MoreSupermarket/20180101-20191231_13/tsv/ +data_dir: /oak/stanford/groups/athey/MoreSupermarkets/csv/new_data/nf_runs/1904/20180101-20191231_44/tsv # utility: lambda_item # utility: lambda_item + theta_user * alpha_item # utility: lambda_item + theta_user * alpha_item + zeta_user * item_obs -utility: lambda_item + theta_user * alpha_item + gamma_user * beta_item * price_obs +# utility: lambda_item + theta_user * alpha_item + gamma_user * beta_item * price_obs +utility: lambda_item + theta_user * alpha_item + gamma_user * price_obs +# utility: alpha_item * gamma_user * price_obs out_dir: ./output/ # model configuration. +coef_dist_dict: + default: 'gaussian' + gamma_user: 'gamma' obs2prior_dict: - lambda_item: True - theta_user: True - alpha_item: True + lambda_item: False + theta_user: False + alpha_item: False zeta_user: True lota_item: True - gamma_user: True + gamma_user: False beta_item: True coef_dim_dict: lambda_item: 1 theta_user: 10 alpha_item: 10 - gamma_user: 10 - beta_item: 10 + gamma_user: 1 + beta_item: 1 #### optimization. trace_log_q: False shuffle: False batch_size: 100000 -num_epochs: 3 +num_epochs: 100 learning_rate: 0.03 -num_mc_seeds: 1 +num_mc_seeds: 2 diff --git a/tutorials/supermarket/configs3_1.yaml b/tutorials/supermarket/configs3_1.yaml new file mode 100644 index 0000000..c2bbb7e --- /dev/null +++ b/tutorials/supermarket/configs3_1.yaml @@ -0,0 +1,70 @@ +device: cuda +# data_dir: /home/tianyudu/Data/MoreSupermarket/tsv/ +# data_dir: /home/tianyudu/Data/MoreSupermarket/20180101-20191231_13/tsv/ +# data_dir: /oak/stanford/groups/athey/MoreSupermarkets/csv/new_data/nf_runs/m_9999/20180101-20191231_42/tsv +# data_dir: /oak/stanford/groups/athey/MoreSupermarkets/csv/new_data/nf_runs/m_9999/2 +# data_dir: /oak/stanford/groups/athey/MoreSupermarkets/csv/new_data/nf_runs/m_1904-1887-1974-2012-1992/20180101-20191231_44/tsv/ +# data_dir: /oak/stanford/groups/athey/MoreSupermarkets/csv/new_data/nf_runs/m_1887/chunked_1 +data_dir: /oak/stanford/groups/athey/MoreSupermarkets/csv/new_data/nf_runs/sims34/1 +# utility: lambda_item +# utility: lambda_item + theta_user * alpha_item +# utility: lambda_item + theta_user * alpha_item + zeta_user * item_obs +# utility: lambda_item + theta_user * alpha_item + gamma_user * beta_item * price_obs +# utility: lambda_item + theta_user * alpha_item + gamma_user * price_obs +# utility: lambda_item + theta_user * alpha_item - nfact_category * gamma_user * price_obs +# utility: theta_user * alpha_item + nfact_category * gamma_user * price_obs +# utility: -nfact_category * gamma_user * price_obs +# utility: -nfact_category * gamma_user * price_obs +utility: -gamma_user * price_obs +# utility: lambda_item +out_dir: ./output/ +# model configuration. +obs2prior_dict: + lambda_item: False + theta_user: False + alpha_item: False + zeta_user: True + lota_item: True + gamma_user: False + nfact_category: False + beta_item: True +coef_dim_dict: + lambda_item: 1 + theta_user: 50 + alpha_item: 50 + gamma_user: 1 + nfact_category: 1 + beta_item: 10 +coef_dist_dict: + default: 'gaussian' + # gamma_user: 'gamma' + # nfact_category: 'gamma' + # gamma_user: 'gamma' + gamma_user: 'lognormal' + # nfact_category: 'gamma' +prior_mean: + default: 0.0 + # mean is shape for gamma variable + # shape is mean^sq / var + # we want mean 1, var 100 for these two vars + # gamma_user: 0.01 + gamma_user: 1.0 + nfact_category: 1.0 +prior_variance: + default: 100000.0 + # variance is rate for gamma variable + # shape is mean / var + gamma_user: 2.0 + nfact_category: 2.0 +#### optimization. +trace_log_q: False +shuffle: False +batch_size: 50000 +num_epochs: 200 +learning_rate: 0.03 +num_mc_seeds: 1 +num_price_obs: 1 +obs_user: False +obs_item: False +patience: 10 +complete_availability: True diff --git a/tutorials/supermarket/main.py b/tutorials/supermarket/main.py index 7bbb1e4..1fe1f85 100644 --- a/tutorials/supermarket/main.py +++ b/tutorials/supermarket/main.py @@ -10,7 +10,8 @@ from termcolor import cprint from example_customized_module import ExampleCustomizedModule from torch_choice.data import ChoiceDataset -from bemb.model import LitBEMBFlex +# from bemb.model import LitBEMBFlex +from bemb.model.bemb_supermarket_lightning import LitBEMBFlex from bemb.utils.run_helper import run @@ -70,54 +71,36 @@ def load_tsv(file_name, data_dir): # ============================================================================================== # user observables # ============================================================================================== - user_obs = pd.read_csv(os.path.join(configs.data_dir, 'obsUser.tsv'), - sep='\t', - index_col=0, - header=None) - # TODO(Tianyu): there could be duplicate information for each user. - # do we need to catch it in some check process? - user_obs = user_obs.groupby(user_obs.index).first().sort_index() - user_obs = torch.Tensor(user_obs.values) - configs.num_user_obs = user_obs.shape[1] - configs.coef_dim_dict['obsuser_item'] = configs.num_user_obs + if configs.obs_user: + user_obs = pd.read_csv(os.path.join(configs.data_dir, 'obsUser.tsv'), + sep='\t', + index_col=0, + header=None) + # TODO(Tianyu): there could be duplicate information for each user. + # do we need to catch it in some check process? + user_obs = user_obs.groupby(user_obs.index).first().sort_index() + user_obs = torch.Tensor(user_obs.values) + configs.num_user_obs = user_obs.shape[1] + configs.coef_dim_dict['obsuser_item'] = configs.num_user_obs # ============================================================================================== # item observables # ============================================================================================== - item_obs = pd.read_csv(os.path.join(configs.data_dir, 'obsItem.tsv'), - sep='\t', - index_col=0, - header=None) - item_obs = item_obs.groupby(item_obs.index).first().sort_index() - item_obs = torch.Tensor(item_obs.values) - configs.num_item_obs = item_obs.shape[1] - configs.coef_dim_dict['obsitem_user'] = configs.num_item_obs + if configs.obs_item: + item_obs = pd.read_csv(os.path.join(configs.data_dir, 'obsItem.tsv'), + sep='\t', + index_col=0, + header=None) + item_obs = item_obs.groupby(item_obs.index).first().sort_index() + item_obs = torch.Tensor(item_obs.values) + configs.num_item_obs = item_obs.shape[1] + configs.coef_dim_dict['obsitem_user'] = configs.num_item_obs # ============================================================================================== # item availability # ============================================================================================== # parse item availability. # Try and catch? Optionally specify full availability? - a_tsv = pd.read_csv(os.path.join(configs.data_dir, 'availabilityList.tsv'), - sep='\t', - index_col=None, - header=None, - names=['session_id', 'item_id']) - - # availability ties session as well. - session_encoder = LabelEncoder().fit(a_tsv['session_id'].values) - configs.num_sessions = len(session_encoder.classes_) - assert is_sorted(session_encoder.classes_) - # this loop could be slow, depends on # sessions. - item_availability = torch.zeros(configs.num_sessions, configs.num_items).bool() - - a_tsv['item_id'] = item_encoder.transform(a_tsv['item_id'].values) - a_tsv['session_id'] = session_encoder.transform(a_tsv['session_id'].values) - - for session_id, df_group in a_tsv.groupby('session_id'): - # get IDs of items available at this date. - a_item_ids = df_group['item_id'].unique() # this unique is not necessary if the dataset is well-prepared. - item_availability[session_id, a_item_ids] = True # ============================================================================================== # price observables @@ -125,6 +108,10 @@ def load_tsv(file_name, data_dir): df_price = pd.read_csv(os.path.join(configs.data_dir, 'item_sess_price.tsv'), sep='\t', names=['item_id', 'session_id', 'price']) + # availability ties session as well. + session_encoder = LabelEncoder().fit(df_price['session_id'].values) + configs.num_sessions = len(session_encoder.classes_) + assert is_sorted(session_encoder.classes_) # only keep prices of relevant items. mask = df_price['item_id'].isin(item_encoder.classes_) @@ -138,6 +125,26 @@ def load_tsv(file_name, data_dir): price_obs = torch.Tensor(df_price.values).view(configs.num_sessions, configs.num_items, 1) configs.num_price_obs = 1 + if not(hasattr(configs, 'complete_availability') and configs.complete_availability): + a_tsv = pd.read_csv(os.path.join(configs.data_dir, 'availabilityList.tsv'), + sep='\t', + index_col=None, + header=None, + names=['session_id', 'item_id']) + + item_availability = torch.zeros(configs.num_sessions, configs.num_items).bool() + + a_tsv['item_id'] = item_encoder.transform(a_tsv['item_id'].values) + a_tsv['session_id'] = session_encoder.transform(a_tsv['session_id'].values) + + # this loop could be slow, depends on # sessions. + for session_id, df_group in a_tsv.groupby('session_id'): + # get IDs of items available at this date. + a_item_ids = df_group['item_id'].unique() # this unique is not necessary if the dataset is well-prepared. + item_availability[session_id, a_item_ids] = True + else: + item_availability = torch.ones(configs.num_sessions, configs.num_items).bool() + # ============================================================================================== # create datasets # ============================================================================================== @@ -152,14 +159,22 @@ def load_tsv(file_name, data_dir): # example day of week, random example. session_day_of_week = torch.LongTensor(np.random.randint(0, 7, configs.num_sessions)) - choice_dataset = ChoiceDataset(item_index=label, - user_index=user_index, - session_index=session_index, - item_availability=item_availability, - user_obs=user_obs, - item_obs=item_obs, - price_obs=price_obs, - session_day_of_week=session_day_of_week) + choice_dataset_args = { + "item_index": label, + "user_index": user_index, + "session_index": session_index, + "item_availability": item_availability, + "price_obs": price_obs, + "session_day_of_week": session_day_of_week + } + + if configs.obs_user: + choice_dataset_args["user_obs"] = user_obs + + if configs.obs_item: + choice_dataset_args["item_obs"] = item_obs + + choice_dataset = ChoiceDataset(**choice_dataset_args) dataset_list.append(choice_dataset) @@ -187,16 +202,26 @@ def load_tsv(file_name, data_dir): item_groups['category_id'] = category_encoder.transform( item_groups['category_id'].values) - print('Category sizes:') - print(item_groups.groupby('category_id').size().describe()) + # print('Category sizes:') + # print(item_groups.groupby('category_id').size().describe()) item_groups = item_groups.groupby('category_id')['item_id'].apply(list) category_to_item = dict(zip(item_groups.index, item_groups.values)) # ============================================================================================== # pytorch-lightning training # ============================================================================================== + # if prior_mean is in the configs namespace set it to the prior mean + if hasattr(configs, 'prior_mean'): + prior_mean = configs.prior_mean + else: + prior_mean = 0.0 + if hasattr(configs, 'prior_variance'): + prior_variance = configs.prior_variance + else: + prior_variance = 1.0 bemb = LitBEMBFlex( # trainings args. pred_item = configs.pred_item, + coef_dist_dict=configs.coef_dist_dict, learning_rate=configs.learning_rate, num_seeds=configs.num_mc_seeds, # model args, will be passed to BEMB constructor. @@ -208,14 +233,47 @@ def load_tsv(file_name, data_dir): coef_dim_dict=configs.coef_dim_dict, trace_log_q=configs.trace_log_q, category_to_item=category_to_item, - num_user_obs=configs.num_user_obs, - num_item_obs=configs.num_item_obs, - # num_price_obs=configs.num_price_obs, + num_user_obs=configs.num_user_obs if configs.obs_user else None, + num_item_obs=configs.num_item_obs if configs.obs_item else None, + prior_mean = prior_mean, + prior_variance= prior_variance, + num_price_obs=configs.num_price_obs, + preprocess=False, # additional_modules=[ExampleCustomizedModule()] ) + if hasattr(configs, 'patience'): + patience = configs.patience + else: + patience = 100 bemb = bemb.to(configs.device) - bemb = run(bemb, dataset_list, batch_size=configs.batch_size, num_epochs=configs.num_epochs) + bemb = run(bemb, dataset_list, batch_size=configs.batch_size, num_epochs=configs.num_epochs, run_test=False, patience=patience) + + # ''' + # coeffs = coeffs**2 + # give distribution statistics + out_dir = './output' + if not os.path.exists(out_dir): + os.makedirs(out_dir) + if 'gamma_user' in configs.utility: + coeffs_gamma = bemb.model.coef_dict['gamma_user'].variational_mean.detach().cpu().numpy() + # if configs.coef_dist_dict['gamma_user'] == 'lognormal': + # coeffs_gamma = np.exp(coeffs_gamma) + print('Coefficients statistics Gamma:') + print(pd.DataFrame(coeffs_gamma).describe()) + # write to file + print("writing") + np.savetxt(f'{out_dir}/gammas_{sys.argv[1]}.txt', coeffs_gamma, delimiter=',') + if 'nfact_category' in configs.utility: + coeffs_nfact = bemb.model.coef_dict['nfact_category'].variational_mean.detach().cpu().numpy() + # if configs.coef_dist_dict['nfact_category'] == 'lognormal': + # coeffs_nfact = np.exp(coeffs_nfact) + print('Coefficients statistics nfact_category:') + print(pd.DataFrame(coeffs_nfact).describe()) + # write to file + print("writing") + np.savetxt(f'{out_dir}/nfacts_{sys.argv[1]}.txt', coeffs_nfact, delimiter=',') + # ''' # ============================================================================================== # inference example