Skip to content

Commit 415bc24

Browse files
authored
feat: Changed default behavior for SSD window (#21)
1 parent 4a2562e commit 415bc24

File tree

2 files changed

+49
-7
lines changed

2 files changed

+49
-7
lines changed

src/pyversity/strategies/ssd.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def ssd( # noqa: C901
1010
k: int,
1111
diversity: float = 0.5,
1212
recent_embeddings: np.ndarray | None = None,
13-
window: int = 10,
13+
window: int | None = None,
1414
gamma: float = 1.0,
1515
normalize: bool = True,
1616
append_bias: bool = True,
@@ -33,7 +33,7 @@ def ssd( # noqa: C901
3333
1.0 = pure diversity, 0.0 = pure relevance.
3434
:param recent_embeddings: Optional 2D array (m, n_dims), oldest → newest; seeds the sliding window so
3535
selection is aware of what was recently shown.
36-
:param window: Sliding window size (≥ 1) for Gram-Schmidt bases.
36+
:param window: Window size (≥ 1) for Gram-Schmidt bases. If None, defaults to len(recent_embeddings) + k.
3737
:param gamma: Diversity scale (> 0).
3838
:param normalize: Whether to normalize embeddings before computing similarity.
3939
:param append_bias: Append a constant-one bias dimension after normalization.
@@ -44,7 +44,7 @@ def ssd( # noqa: C901
4444
# Validate parameters
4545
if not (0.0 <= float(diversity) <= 1.0):
4646
raise ValueError("diversity must be in [0, 1]")
47-
if window < 1:
47+
if window is not None and window < 1:
4848
raise ValueError("window must be >= 1")
4949
if gamma <= 0.0:
5050
raise ValueError("gamma must be > 0")
@@ -65,6 +65,7 @@ def ssd( # noqa: C901
6565
)
6666

6767
# Validate recent_embeddings
68+
n_recent = 0
6869
if recent_embeddings is not None and np.size(recent_embeddings) > 0:
6970
if recent_embeddings.ndim != 2:
7071
raise ValueError("recent_embeddings must be a 2D array of shape (n_items, n_dims).")
@@ -73,6 +74,10 @@ def ssd( # noqa: C901
7374
f"recent_embeddings has {recent_embeddings.shape[1]} dims; "
7475
f"expected {feature_matrix.shape[1]} to match `embeddings` columns."
7576
)
77+
n_recent = int(recent_embeddings.shape[0])
78+
79+
# Determine effective window size
80+
window_size = (n_recent + top_k) if window is None else int(window)
7681

7782
# Pure relevance: select top-k by raw scores
7883
if float(theta) == 1.0:
@@ -83,7 +88,7 @@ def ssd( # noqa: C901
8388
selection_scores=selection_scores,
8489
strategy=Strategy.SSD,
8590
diversity=diversity,
86-
parameters={"gamma": gamma, "window": window},
91+
parameters={"gamma": gamma, "window": window_size},
8792
)
8893

8994
def _prepare_vectors(matrix: np.ndarray) -> np.ndarray:
@@ -124,7 +129,7 @@ def _prepare_vectors(matrix: np.ndarray) -> np.ndarray:
124129

125130
def _push_basis_vector(basis_vector: np.ndarray) -> None:
126131
"""Add a new basis vector to the sliding window and update residuals/projections."""
127-
if len(basis_vectors) == window:
132+
if len(basis_vectors) == window_size:
128133
# Remove oldest basis and restore its contribution to residuals
129134
oldest_basis = basis_vectors.pop(0)
130135
oldest_coefficients = projection_coefficients_per_basis.pop(0)
@@ -148,7 +153,7 @@ def _push_basis_vector(basis_vector: np.ndarray) -> None:
148153
seeded_bases = 0
149154
if recent_embeddings is not None and np.size(recent_embeddings) > 0:
150155
context = _prepare_vectors(recent_embeddings.astype(feature_matrix.dtype, copy=False))
151-
context = context[-window:] # keep only the latest `window` items
156+
context = context[-window_size:] # keep only the latest `window_size` items
152157
for context_vector in context:
153158
residual_context = context_vector.copy()
154159
for basis in basis_vectors:
@@ -201,5 +206,5 @@ def _push_basis_vector(basis_vector: np.ndarray) -> None:
201206
selection_scores=selection_scores.astype(np.float32, copy=False),
202207
strategy=Strategy.SSD,
203208
diversity=diversity,
204-
parameters={"gamma": gamma, "window": window},
209+
parameters={"gamma": gamma, "window": window_size},
205210
)

tests/test_strategies.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,43 @@ def test_ssd_recent_embeddings_window_blocks_multiple_recent() -> None:
247247
assert res.indices[0] in (2, 3)
248248

249249

250+
def test_ssd_window_none_matches_large_window_when_recent_smaller() -> None:
251+
"""Test that window=None behaves like a large window when recent_embeddings is smaller than k."""
252+
emb = np.eye(4, dtype=np.float32)
253+
scores = np.ones(4, dtype=np.float32)
254+
recent = emb[[0, 1]]
255+
k = 3
256+
257+
res_none = ssd(
258+
emb,
259+
scores,
260+
k=k,
261+
window=None,
262+
recent_embeddings=recent,
263+
)
264+
res_big = ssd(
265+
emb,
266+
scores,
267+
k=k,
268+
window=10,
269+
recent_embeddings=recent,
270+
)
271+
272+
assert np.array_equal(res_none.indices, res_big.indices)
273+
274+
275+
def test_ssd_window_none_equals_k_when_no_recent() -> None:
276+
"""Test that window=None behaves like window=k when recent_embeddings is not provided."""
277+
emb = np.eye(5, dtype=np.float32)
278+
scores = np.array([0.4, 0.9, 0.1, 0.7, 0.2], dtype=np.float32)
279+
k = 3
280+
281+
res_none = ssd(emb, scores, k=k, window=None, recent_embeddings=None)
282+
res_k = ssd(emb, scores, k=k, window=k, recent_embeddings=None)
283+
284+
assert np.array_equal(res_none.indices, res_k.indices)
285+
286+
250287
@pytest.mark.parametrize(
251288
"strategy, fn, kwargs",
252289
[

0 commit comments

Comments
 (0)