Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions src/pyversity/strategies/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def ssd( # noqa: C901
k: int,
diversity: float = 0.5,
recent_embeddings: np.ndarray | None = None,
window: int = 10,
window: int | None = None,
gamma: float = 1.0,
normalize: bool = True,
append_bias: bool = True,
Expand All @@ -33,7 +33,7 @@ def ssd( # noqa: C901
1.0 = pure diversity, 0.0 = pure relevance.
:param recent_embeddings: Optional 2D array (m, n_dims), oldest → newest; seeds the sliding window so
selection is aware of what was recently shown.
:param window: Sliding window size (≥ 1) for Gram-Schmidt bases.
:param window: Window size (≥ 1) for Gram-Schmidt bases. If None, defaults to len(recent_embeddings) + k.
:param gamma: Diversity scale (> 0).
:param normalize: Whether to normalize embeddings before computing similarity.
:param append_bias: Append a constant-one bias dimension after normalization.
Expand All @@ -44,7 +44,7 @@ def ssd( # noqa: C901
# Validate parameters
if not (0.0 <= float(diversity) <= 1.0):
raise ValueError("diversity must be in [0, 1]")
if window < 1:
if window is not None and window < 1:
raise ValueError("window must be >= 1")
if gamma <= 0.0:
raise ValueError("gamma must be > 0")
Expand All @@ -65,6 +65,7 @@ def ssd( # noqa: C901
)

# Validate recent_embeddings
n_recent = 0
if recent_embeddings is not None and np.size(recent_embeddings) > 0:
if recent_embeddings.ndim != 2:
raise ValueError("recent_embeddings must be a 2D array of shape (n_items, n_dims).")
Expand All @@ -73,6 +74,10 @@ def ssd( # noqa: C901
f"recent_embeddings has {recent_embeddings.shape[1]} dims; "
f"expected {feature_matrix.shape[1]} to match `embeddings` columns."
)
n_recent = int(recent_embeddings.shape[0])

# Determine effective window size
window_size = (n_recent + top_k) if window is None else int(window)

# Pure relevance: select top-k by raw scores
if float(theta) == 1.0:
Expand All @@ -83,7 +88,7 @@ def ssd( # noqa: C901
selection_scores=selection_scores,
strategy=Strategy.SSD,
diversity=diversity,
parameters={"gamma": gamma, "window": window},
parameters={"gamma": gamma, "window": window_size},
)

def _prepare_vectors(matrix: np.ndarray) -> np.ndarray:
Expand Down Expand Up @@ -124,7 +129,7 @@ def _prepare_vectors(matrix: np.ndarray) -> np.ndarray:

def _push_basis_vector(basis_vector: np.ndarray) -> None:
"""Add a new basis vector to the sliding window and update residuals/projections."""
if len(basis_vectors) == window:
if len(basis_vectors) == window_size:
# Remove oldest basis and restore its contribution to residuals
oldest_basis = basis_vectors.pop(0)
oldest_coefficients = projection_coefficients_per_basis.pop(0)
Expand All @@ -148,7 +153,7 @@ def _push_basis_vector(basis_vector: np.ndarray) -> None:
seeded_bases = 0
if recent_embeddings is not None and np.size(recent_embeddings) > 0:
context = _prepare_vectors(recent_embeddings.astype(feature_matrix.dtype, copy=False))
context = context[-window:] # keep only the latest `window` items
context = context[-window_size:] # keep only the latest `window_size` items
for context_vector in context:
residual_context = context_vector.copy()
for basis in basis_vectors:
Expand Down Expand Up @@ -201,5 +206,5 @@ def _push_basis_vector(basis_vector: np.ndarray) -> None:
selection_scores=selection_scores.astype(np.float32, copy=False),
strategy=Strategy.SSD,
diversity=diversity,
parameters={"gamma": gamma, "window": window},
parameters={"gamma": gamma, "window": window_size},
)
37 changes: 37 additions & 0 deletions tests/test_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,43 @@ def test_ssd_recent_embeddings_window_blocks_multiple_recent() -> None:
assert res.indices[0] in (2, 3)


def test_ssd_window_none_matches_large_window_when_recent_smaller() -> None:
"""Test that window=None behaves like a large window when recent_embeddings is smaller than k."""
emb = np.eye(4, dtype=np.float32)
scores = np.ones(4, dtype=np.float32)
recent = emb[[0, 1]]
k = 3

res_none = ssd(
emb,
scores,
k=k,
window=None,
recent_embeddings=recent,
)
res_big = ssd(
emb,
scores,
k=k,
window=10,
recent_embeddings=recent,
)

assert np.array_equal(res_none.indices, res_big.indices)


def test_ssd_window_none_equals_k_when_no_recent() -> None:
"""Test that window=None behaves like window=k when recent_embeddings is not provided."""
emb = np.eye(5, dtype=np.float32)
scores = np.array([0.4, 0.9, 0.1, 0.7, 0.2], dtype=np.float32)
k = 3

res_none = ssd(emb, scores, k=k, window=None, recent_embeddings=None)
res_k = ssd(emb, scores, k=k, window=k, recent_embeddings=None)

assert np.array_equal(res_none.indices, res_k.indices)


@pytest.mark.parametrize(
"strategy, fn, kwargs",
[
Expand Down