diff --git a/src/lenskit/data/builder.py b/src/lenskit/data/builder.py index d891b3fce..6146755d4 100644 --- a/src/lenskit/data/builder.py +++ b/src/lenskit/data/builder.py @@ -265,10 +265,23 @@ def add_entities( duplicates: How to handle duplicate entity IDs. """ - if isinstance(source, pd.DataFrame): # pragma: nocover - raise NotImplementedError() - if isinstance(source, pa.Table): # pragma: nocover - raise NotImplementedError() + if isinstance(source, pd.DataFrame): + source = pa.Table.from_pandas(source) + if isinstance(source, pa.Table): + entity_col = source.column(cls + "_id") + self.add_entities(cls, entity_col) + + for col_name in source.column_names: + if not col_name.endswith("_id"): + col_type = source.column(col_name).type + + if any([pa.types.is_list(col_type), + pa.types.is_large_list(col_type), + pa.types.is_fixed_size_list(col_type)]): + self.add_list_attribute(cls, col_name, entity_col, source.column(col_name)) + else: + self.add_scalar_attribute(cls, col_name, entity_col, source.column(col_name)) + return self._validate_entity_name(cls) diff --git a/tests/data/test_builder_entities.py b/tests/data/test_builder_entities.py index efbc039fa..93964b715 100644 --- a/tests/data/test_builder_entities.py +++ b/tests/data/test_builder_entities.py @@ -7,11 +7,13 @@ # pyright: strict import numpy as np import pyarrow as pa +import pandas as pd from pytest import raises from lenskit.data import DatasetBuilder from lenskit.diagnostics import DataError +from lenskit.testing import ml_test_dir def test_empty_builder(): @@ -143,3 +145,38 @@ def test_add_entities_twice(): assert ds.item_count == 0 assert ds.user_count == 8 assert np.all(ds.users.ids() == ["a", "b", "x", "y", "z", "q", "r", "s"]) + + +def test_add_entities_with_dataframe(): + dsb = DatasetBuilder() + + items = pd.read_csv(ml_test_dir / "movies.csv") + items = items.rename(columns={"movieId": "item_id"}).set_index("item_id") + + genres = items["genres"].str.split("|") + items["genres"] = genres + + dsb.add_entities("item", items) + + ds = dsb.build() + + assert ds.entities("item").attribute("title").is_scalar + assert ds.entities("item").attribute("genres").is_list + + +def test_add_entities_with_arrow_table(): + dsb = DatasetBuilder() + + items = pd.read_csv(ml_test_dir / "movies.csv") + items = items.rename(columns={"movieId": "item_id"}).set_index("item_id") + + genres = items["genres"].str.split("|") + items["genres"] = genres + table = pa.Table.from_pandas(items) + + dsb.add_entities("item", table) + + ds = dsb.build() + + assert ds.entities("item").attribute("title").is_scalar + assert ds.entities("item").attribute("genres").is_list