-
Notifications
You must be signed in to change notification settings - Fork 72
Add support for dataframe and arrow table in add_entities #907
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,3 @@ | ||
| # This file is part of LensKit. | ||
| # Copyright (C) 2018-2023 Boise State University. | ||
| # Copyright (C) 2023-2025 Drexel University. | ||
|
|
@@ -265,10 +265,23 @@ | |
| duplicates: | ||
| How to handle duplicate entity IDs. | ||
| """ | ||
| if isinstance(source, pd.DataFrame): # pragma: nocover | ||
| raise NotImplementedError() | ||
| if isinstance(source, pa.Table): # pragma: nocover | ||
| raise NotImplementedError() | ||
| if isinstance(source, pd.DataFrame): | ||
| source = pa.Table.from_pandas(source) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is an interesting and challenging edge case here, that we need to clearly document and/or design for. Right now, this works because your test case names the index However, if the client provides code that has no I think we probably want to use the Pandas index, with the following logic:
Implementing this logic will require this line to be a little more aware of the Pandas data frames, and also require tests for each of the different conditions. Importantly, for case (1), this line here will create a new attribute called The input cases we will need to test for correct behavior with:
This isn't a problem for PyArrow input, because Arrow tables do not have indices. |
||
| if isinstance(source, pa.Table): | ||
| entity_col = source.column(cls + "_id") | ||
| self.add_entities(cls, entity_col) | ||
|
|
||
| for col_name in source.column_names: | ||
| if not col_name.endswith("_id"): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should only exclude the |
||
| col_type = source.column(col_name).type | ||
|
|
||
| if any([pa.types.is_list(col_type), | ||
| pa.types.is_large_list(col_type), | ||
| pa.types.is_fixed_size_list(col_type)]): | ||
| self.add_list_attribute(cls, col_name, entity_col, source.column(col_name)) | ||
| else: | ||
| self.add_scalar_attribute(cls, col_name, entity_col, source.column(col_name)) | ||
| return | ||
|
|
||
| self._validate_entity_name(cls) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,3 @@ | ||
| # This file is part of LensKit. | ||
| # Copyright (C) 2018-2023 Boise State University. | ||
| # Copyright (C) 2023-2025 Drexel University. | ||
|
|
@@ -5,13 +5,15 @@ | |
| # SPDX-License-Identifier: MIT | ||
|
|
||
| # pyright: strict | ||
| import numpy as np | ||
| import pyarrow as pa | ||
| import pandas as pd | ||
|
|
||
| from pytest import raises | ||
|
|
||
| from lenskit.data import DatasetBuilder | ||
| from lenskit.diagnostics import DataError | ||
| from lenskit.testing import ml_test_dir | ||
|
|
||
|
|
||
| def test_empty_builder(): | ||
|
|
@@ -143,3 +145,38 @@ | |
| assert ds.item_count == 0 | ||
| assert ds.user_count == 8 | ||
| assert np.all(ds.users.ids() == ["a", "b", "x", "y", "z", "q", "r", "s"]) | ||
|
|
||
|
|
||
| def test_add_entities_with_dataframe(): | ||
| dsb = DatasetBuilder() | ||
|
|
||
| items = pd.read_csv(ml_test_dir / "movies.csv") | ||
| items = items.rename(columns={"movieId": "item_id"}).set_index("item_id") | ||
|
|
||
| genres = items["genres"].str.split("|") | ||
| items["genres"] = genres | ||
|
|
||
| dsb.add_entities("item", items) | ||
|
|
||
| ds = dsb.build() | ||
|
|
||
| assert ds.entities("item").attribute("title").is_scalar | ||
| assert ds.entities("item").attribute("genres").is_list | ||
|
Comment on lines
+163
to
+164
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should test that a few item IDs have the correct titles, too. It's possible for the code to set up the structures in the right format, but not align them correctly, and the tests should check for that. |
||
|
|
||
|
|
||
| def test_add_entities_with_arrow_table(): | ||
| dsb = DatasetBuilder() | ||
|
|
||
| items = pd.read_csv(ml_test_dir / "movies.csv") | ||
| items = items.rename(columns={"movieId": "item_id"}).set_index("item_id") | ||
|
|
||
| genres = items["genres"].str.split("|") | ||
| items["genres"] = genres | ||
| table = pa.Table.from_pandas(items) | ||
|
|
||
| dsb.add_entities("item", table) | ||
|
|
||
| ds = dsb.build() | ||
|
|
||
| assert ds.entities("item").attribute("title").is_scalar | ||
| assert ds.entities("item").attribute("genres").is_list | ||
|
Comment on lines
+181
to
+182
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should update the docstring to document the kinds of attributes supported, limitations, etc., along with the index logic.
This should be in the main body of the docstring (before
Args:), not in the argument documentation, for readability.