Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,36 @@ hlda = HierarchicalLDA(corpus, vocab, alpha=1.0, gamma=1.0, eta=0.1,
hlda.estimate(iterations=50, display_topics=10)
```

### Integration with scikit-learn

The package provides a `HierarchicalLDAEstimator` that follows the scikit-learn API. This allows using the sampler inside a standard `Pipeline`.

```python
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.preprocessing import FunctionTransformer
from sklearn.pipeline import Pipeline
from hlda.sklearn_wrapper import HierarchicalLDAEstimator

vectorizer = CountVectorizer()
prep = FunctionTransformer(
lambda X: (
[[i for i, c in enumerate(row) for _ in range(int(c))] for row in X.toarray()],
list(vectorizer.get_feature_names_out()),
),
validate=False,
)

pipeline = Pipeline([
("vect", vectorizer),
("prep", prep),
("hlda", HierarchicalLDAEstimator(num_levels=3, iterations=10, seed=0)),
])

pipeline.fit(documents)
assignments = pipeline.transform(documents)
```


## Running the tests

The repository includes a small test suite that checks the sampler on both the BBC
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pandas = "^2.2.3"
matplotlib = "^3.10.3"
seaborn = "^0.13.2"
tqdm = "^4.67.1"
scikit-learn = "^1.5.0"

[tool.poetry.group.dev.dependencies]
jupyterlab = "^4.4.3"
Expand Down
3 changes: 2 additions & 1 deletion src/hlda/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
__version__ = "0.4"

from .sampler import HierarchicalLDA
from .sklearn_wrapper import HierarchicalLDAEstimator

__all__ = ["HierarchicalLDA"]
__all__ = ["HierarchicalLDA", "HierarchicalLDAEstimator"]
104 changes: 104 additions & 0 deletions src/hlda/sklearn_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Sklearn wrapper for HierarchicalLDA

from __future__ import annotations

from typing import Any, List, Sequence, Tuple

import numpy as np
from scipy import sparse
from sklearn.base import BaseEstimator, TransformerMixin

from .sampler import HierarchicalLDA


def _dtm_to_corpus(dtm: Any) -> List[List[int]]:
"""Convert a document-term matrix into an integer corpus."""
if sparse.issparse(dtm):
dtm = dtm.toarray()
else:
dtm = np.asarray(dtm)
corpus: List[List[int]] = []
for row in dtm:
doc: List[int] = []
for idx, count in enumerate(row):
if count:
doc.extend([idx] * int(count))
corpus.append(doc)
return corpus


class HierarchicalLDAEstimator(BaseEstimator, TransformerMixin):
"""Scikit-learn compatible estimator for :class:`HierarchicalLDA`."""

def __init__(
self,
*,
alpha: float = 10.0,
gamma: float = 1.0,
eta: float = 0.1,
num_levels: int = 3,
iterations: int = 100,
seed: int = 0,
verbose: bool = False,
vocab: Sequence[str] | None = None,
) -> None:
self.alpha = alpha
self.gamma = gamma
self.eta = eta
self.num_levels = num_levels
self.iterations = iterations
self.seed = seed
self.verbose = verbose
self.vocab = list(vocab) if vocab is not None else None

# ------------------------------------------------------------------
def _prepare_input(self, X: Any) -> Tuple[List[List[int]], Sequence[str]]:
corpus: List[List[int]]
vocab: Sequence[str] | None = None

if isinstance(X, tuple) and len(X) == 2:
corpus, vocab = X
elif sparse.issparse(X) or (isinstance(X, np.ndarray) and X.ndim == 2):
corpus = _dtm_to_corpus(X)
vocab = self.vocab
else:
corpus = X # assume already integer corpus
vocab = self.vocab

if vocab is None:
raise ValueError("Vocabulary is required to fit the model")
return corpus, vocab

# ------------------------------------------------------------------
def fit(self, X: Any, y: Any | None = None): # noqa: D401
corpus, vocab = self._prepare_input(X)
self.vocab_ = list(vocab)
self.model_ = HierarchicalLDA(
corpus,
self.vocab_,
alpha=self.alpha,
gamma=self.gamma,
eta=self.eta,
num_levels=self.num_levels,
seed=self.seed,
verbose=self.verbose,
)
if self.iterations > 0:
self.model_.estimate(
self.iterations,
display_topics=self.iterations + 1,
n_words=0,
with_weights=False,
)
return self

# ------------------------------------------------------------------
def transform(self, X: Any) -> np.ndarray: # noqa: D401
if not hasattr(self, "model_"):
raise RuntimeError("Estimator has not been fitted")
n_docs = len(self.model_.document_leaves)
assignments = np.zeros(n_docs, dtype=int)
for d in range(n_docs):
leaf = self.model_.document_leaves[d]
assignments[d] = leaf.node_id
return assignments
64 changes: 64 additions & 0 deletions tests/test_sklearn_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import numpy as np
from importlib import import_module
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.preprocessing import FunctionTransformer
from sklearn.pipeline import Pipeline


HierarchicalLDAEstimator = import_module(
"hlda.sklearn_wrapper"
).HierarchicalLDAEstimator # noqa: E501


def _prepare_input(vectorizer):
def _transform(X):
if hasattr(X, "toarray"):
arr = X.toarray()
else:
arr = np.asarray(X)
corpus = []
for row in arr:
doc = []
for idx, count in enumerate(row):
doc.extend([idx] * int(count))
corpus.append(doc)
vocab = list(vectorizer.get_feature_names_out())
return corpus, vocab

return _transform


def test_pipeline_fit_transform():
docs = [
"apple orange banana",
"apple orange",
"banana banana orange",
]

vectorizer = CountVectorizer()
hlda = HierarchicalLDAEstimator(
num_levels=2,
iterations=1,
seed=0,
verbose=False,
)

pipeline = Pipeline(
[
("vect", vectorizer),
(
"prep",
FunctionTransformer(
_prepare_input(vectorizer),
validate=False,
),
),
("hlda", hlda),
]
)

pipeline.fit(docs)
result = pipeline.transform(docs)

assert result.shape[0] == len(docs)
assert isinstance(result[0], (int, np.integer))
Loading