-
Notifications
You must be signed in to change notification settings - Fork 2
Open
Description
Towards BEMB v1.0.
Corresponding branch: api-update
We are planning to refine and expand the current API of BEMBFlex.
The pred_item and multiple class prediction.
- Note: this is a non-trivial extension: for each observation (consumer pick an item), we compute a scalar utility function for each available item. There is a single scalar utility computed from the chosen item (
item_index[i]). This only allows us to do binary classification . We might have to drop this feature.
- Currently the model supports predicting binary
batch.labelor multi-classbatch.item_index. We plan to support arbitrary multi-class classifications.- In particular, you don't need to change anything if
pred_item=True, the model will know the number of classes is exactly thenum_itemsparameter. Also, in this case, yourChoiceDatasetobject does not need to have alabelattribute, since the model will look for theitem_indexas the ground truth for training. - In contrast, if
pred_item=False, now you need to supply anum_classesto theBEMBFlex.__init__()method. Also, you would need alabelattribute in theChoiceDatasetobject. Thelabelattribute should be aLongTensorwith values from{0, 1, ..., num_classes}.
- In particular, you don't need to change anything if
Post-Estimation
- Thanks to feedbacks from our valued users, we are planning to reorganize our post-estimation prediction methods for better user experience.
- We will implement a method called
predict_proba(), the same name as inference methods of scikit-learn models. - This method will have
@torch.no_grad()as a decorator, so you can use it however you want without being worried about gradient tracking. - With
pred_items = True, thebatchneedsitem_indexattribute only if it's involved in the utility computation (e.g., within-category computation). - With
pred_items = False,thebatchdoes not need to have alabelattribute. - The preliminary API of
predict_proba()is used as the following:
- We will implement a method called
batch = ChoiceDataset(...)
bemb = BEMBFlex(..., pred_item=True, ...)
proba = bemb.predict_proba(batch) # shape = (len(batch), num_items)
batch = ChoiceDataset(...)
# not that batch doesn't need to have a label attribute.
bemb = BEMBFlex(..., pred_item=False, num_classes=..., ...)
proba = bemb.predict_proba(batch) # shape = (len(batch), num_classes)
Renaming Variables.
- We received feedbacks that the naming of
price-variation is ambiguous, we propose to change it tosessionitem-variation instead (this is precisely the definition of such variables).
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels