diff --git a/README.md b/README.md index 9f8a067..ae9e214 100644 --- a/README.md +++ b/README.md @@ -83,6 +83,36 @@ hlda = HierarchicalLDA(corpus, vocab, alpha=1.0, gamma=1.0, eta=0.1, hlda.estimate(iterations=50, display_topics=10) ``` +### Integration with scikit-learn + +The package provides a `HierarchicalLDAEstimator` that follows the scikit-learn API. This allows using the sampler inside a standard `Pipeline`. + +```python +from sklearn.feature_extraction.text import CountVectorizer +from sklearn.preprocessing import FunctionTransformer +from sklearn.pipeline import Pipeline +from hlda.sklearn_wrapper import HierarchicalLDAEstimator + +vectorizer = CountVectorizer() +prep = FunctionTransformer( + lambda X: ( + [[i for i, c in enumerate(row) for _ in range(int(c))] for row in X.toarray()], + list(vectorizer.get_feature_names_out()), + ), + validate=False, +) + +pipeline = Pipeline([ + ("vect", vectorizer), + ("prep", prep), + ("hlda", HierarchicalLDAEstimator(num_levels=3, iterations=10, seed=0)), +]) + +pipeline.fit(documents) +assignments = pipeline.transform(documents) +``` + + ## Running the tests The repository includes a small test suite that checks the sampler on both the BBC diff --git a/pyproject.toml b/pyproject.toml index aa4fa1b..d4b4159 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ pandas = "^2.2.3" matplotlib = "^3.10.3" seaborn = "^0.13.2" tqdm = "^4.67.1" +scikit-learn = "^1.5.0" [tool.poetry.group.dev.dependencies] jupyterlab = "^4.4.3" diff --git a/src/hlda/__init__.py b/src/hlda/__init__.py index fae5176..d00cb31 100644 --- a/src/hlda/__init__.py +++ b/src/hlda/__init__.py @@ -1,5 +1,6 @@ __version__ = "0.4" from .sampler import HierarchicalLDA +from .sklearn_wrapper import HierarchicalLDAEstimator -__all__ = ["HierarchicalLDA"] +__all__ = ["HierarchicalLDA", "HierarchicalLDAEstimator"] diff --git a/src/hlda/sklearn_wrapper.py b/src/hlda/sklearn_wrapper.py new file mode 100644 index 0000000..c22e4cc --- /dev/null +++ b/src/hlda/sklearn_wrapper.py @@ -0,0 +1,104 @@ +# Sklearn wrapper for HierarchicalLDA + +from __future__ import annotations + +from typing import Any, List, Sequence, Tuple + +import numpy as np +from scipy import sparse +from sklearn.base import BaseEstimator, TransformerMixin + +from .sampler import HierarchicalLDA + + +def _dtm_to_corpus(dtm: Any) -> List[List[int]]: + """Convert a document-term matrix into an integer corpus.""" + if sparse.issparse(dtm): + dtm = dtm.toarray() + else: + dtm = np.asarray(dtm) + corpus: List[List[int]] = [] + for row in dtm: + doc: List[int] = [] + for idx, count in enumerate(row): + if count: + doc.extend([idx] * int(count)) + corpus.append(doc) + return corpus + + +class HierarchicalLDAEstimator(BaseEstimator, TransformerMixin): + """Scikit-learn compatible estimator for :class:`HierarchicalLDA`.""" + + def __init__( + self, + *, + alpha: float = 10.0, + gamma: float = 1.0, + eta: float = 0.1, + num_levels: int = 3, + iterations: int = 100, + seed: int = 0, + verbose: bool = False, + vocab: Sequence[str] | None = None, + ) -> None: + self.alpha = alpha + self.gamma = gamma + self.eta = eta + self.num_levels = num_levels + self.iterations = iterations + self.seed = seed + self.verbose = verbose + self.vocab = list(vocab) if vocab is not None else None + + # ------------------------------------------------------------------ + def _prepare_input(self, X: Any) -> Tuple[List[List[int]], Sequence[str]]: + corpus: List[List[int]] + vocab: Sequence[str] | None = None + + if isinstance(X, tuple) and len(X) == 2: + corpus, vocab = X + elif sparse.issparse(X) or (isinstance(X, np.ndarray) and X.ndim == 2): + corpus = _dtm_to_corpus(X) + vocab = self.vocab + else: + corpus = X # assume already integer corpus + vocab = self.vocab + + if vocab is None: + raise ValueError("Vocabulary is required to fit the model") + return corpus, vocab + + # ------------------------------------------------------------------ + def fit(self, X: Any, y: Any | None = None): # noqa: D401 + corpus, vocab = self._prepare_input(X) + self.vocab_ = list(vocab) + self.model_ = HierarchicalLDA( + corpus, + self.vocab_, + alpha=self.alpha, + gamma=self.gamma, + eta=self.eta, + num_levels=self.num_levels, + seed=self.seed, + verbose=self.verbose, + ) + if self.iterations > 0: + self.model_.estimate( + self.iterations, + display_topics=self.iterations + 1, + n_words=0, + with_weights=False, + ) + return self + + # ------------------------------------------------------------------ + def transform(self, X: Any) -> np.ndarray: # noqa: D401 + if not hasattr(self, "model_"): + raise RuntimeError("Estimator has not been fitted") + n_docs = len(self.model_.document_leaves) + assignments = np.zeros(n_docs, dtype=int) + for d in range(n_docs): + leaf = self.model_.document_leaves[d] + assignments[d] = leaf.node_id + return assignments diff --git a/tests/test_sklearn_wrapper.py b/tests/test_sklearn_wrapper.py new file mode 100644 index 0000000..b05c60e --- /dev/null +++ b/tests/test_sklearn_wrapper.py @@ -0,0 +1,64 @@ +import numpy as np +from importlib import import_module +from sklearn.feature_extraction.text import CountVectorizer +from sklearn.preprocessing import FunctionTransformer +from sklearn.pipeline import Pipeline + + +HierarchicalLDAEstimator = import_module( + "hlda.sklearn_wrapper" +).HierarchicalLDAEstimator # noqa: E501 + + +def _prepare_input(vectorizer): + def _transform(X): + if hasattr(X, "toarray"): + arr = X.toarray() + else: + arr = np.asarray(X) + corpus = [] + for row in arr: + doc = [] + for idx, count in enumerate(row): + doc.extend([idx] * int(count)) + corpus.append(doc) + vocab = list(vectorizer.get_feature_names_out()) + return corpus, vocab + + return _transform + + +def test_pipeline_fit_transform(): + docs = [ + "apple orange banana", + "apple orange", + "banana banana orange", + ] + + vectorizer = CountVectorizer() + hlda = HierarchicalLDAEstimator( + num_levels=2, + iterations=1, + seed=0, + verbose=False, + ) + + pipeline = Pipeline( + [ + ("vect", vectorizer), + ( + "prep", + FunctionTransformer( + _prepare_input(vectorizer), + validate=False, + ), + ), + ("hlda", hlda), + ] + ) + + pipeline.fit(docs) + result = pipeline.transform(docs) + + assert result.shape[0] == len(docs) + assert isinstance(result[0], (int, np.integer))