From 596d1f1ab4d0ca0bd3837d39450bda9df8d152e6 Mon Sep 17 00:00:00 2001 From: Zhixian Li Date: Sun, 26 Oct 2025 01:01:23 -0400 Subject: [PATCH] Add support for dataframe and table in add_entities --- src/lenskit/data/builder.py | 21 ++++++++++++---- tests/data/test_builder_entities.py | 37 +++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 4 deletions(-) diff --git a/src/lenskit/data/builder.py b/src/lenskit/data/builder.py index d891b3fce..6146755d4 100644 --- a/src/lenskit/data/builder.py +++ b/src/lenskit/data/builder.py @@ -265,10 +265,23 @@ def add_entities( duplicates: How to handle duplicate entity IDs. """ - if isinstance(source, pd.DataFrame): # pragma: nocover - raise NotImplementedError() - if isinstance(source, pa.Table): # pragma: nocover - raise NotImplementedError() + if isinstance(source, pd.DataFrame): + source = pa.Table.from_pandas(source) + if isinstance(source, pa.Table): + entity_col = source.column(cls + "_id") + self.add_entities(cls, entity_col) + + for col_name in source.column_names: + if not col_name.endswith("_id"): + col_type = source.column(col_name).type + + if any([pa.types.is_list(col_type), + pa.types.is_large_list(col_type), + pa.types.is_fixed_size_list(col_type)]): + self.add_list_attribute(cls, col_name, entity_col, source.column(col_name)) + else: + self.add_scalar_attribute(cls, col_name, entity_col, source.column(col_name)) + return self._validate_entity_name(cls) diff --git a/tests/data/test_builder_entities.py b/tests/data/test_builder_entities.py index efbc039fa..93964b715 100644 --- a/tests/data/test_builder_entities.py +++ b/tests/data/test_builder_entities.py @@ -7,11 +7,13 @@ # pyright: strict import numpy as np import pyarrow as pa +import pandas as pd from pytest import raises from lenskit.data import DatasetBuilder from lenskit.diagnostics import DataError +from lenskit.testing import ml_test_dir def test_empty_builder(): @@ -143,3 +145,38 @@ def test_add_entities_twice(): assert ds.item_count == 0 assert ds.user_count == 8 assert np.all(ds.users.ids() == ["a", "b", "x", "y", "z", "q", "r", "s"]) + + +def test_add_entities_with_dataframe(): + dsb = DatasetBuilder() + + items = pd.read_csv(ml_test_dir / "movies.csv") + items = items.rename(columns={"movieId": "item_id"}).set_index("item_id") + + genres = items["genres"].str.split("|") + items["genres"] = genres + + dsb.add_entities("item", items) + + ds = dsb.build() + + assert ds.entities("item").attribute("title").is_scalar + assert ds.entities("item").attribute("genres").is_list + + +def test_add_entities_with_arrow_table(): + dsb = DatasetBuilder() + + items = pd.read_csv(ml_test_dir / "movies.csv") + items = items.rename(columns={"movieId": "item_id"}).set_index("item_id") + + genres = items["genres"].str.split("|") + items["genres"] = genres + table = pa.Table.from_pandas(items) + + dsb.add_entities("item", table) + + ds = dsb.build() + + assert ds.entities("item").attribute("title").is_scalar + assert ds.entities("item").attribute("genres").is_list