@@ -10,7 +10,7 @@ def ssd( # noqa: C901
1010 k : int ,
1111 diversity : float = 0.5 ,
1212 recent_embeddings : np .ndarray | None = None ,
13- window : int = 10 ,
13+ window : int | None = None ,
1414 gamma : float = 1.0 ,
1515 normalize : bool = True ,
1616 append_bias : bool = True ,
@@ -33,7 +33,7 @@ def ssd( # noqa: C901
3333 1.0 = pure diversity, 0.0 = pure relevance.
3434 :param recent_embeddings: Optional 2D array (m, n_dims), oldest → newest; seeds the sliding window so
3535 selection is aware of what was recently shown.
36- :param window: Sliding window size (≥ 1) for Gram-Schmidt bases.
36+ :param window: Window size (≥ 1) for Gram-Schmidt bases. If None, defaults to len(recent_embeddings) + k .
3737 :param gamma: Diversity scale (> 0).
3838 :param normalize: Whether to normalize embeddings before computing similarity.
3939 :param append_bias: Append a constant-one bias dimension after normalization.
@@ -44,7 +44,7 @@ def ssd( # noqa: C901
4444 # Validate parameters
4545 if not (0.0 <= float (diversity ) <= 1.0 ):
4646 raise ValueError ("diversity must be in [0, 1]" )
47- if window < 1 :
47+ if window is not None and window < 1 :
4848 raise ValueError ("window must be >= 1" )
4949 if gamma <= 0.0 :
5050 raise ValueError ("gamma must be > 0" )
@@ -65,6 +65,7 @@ def ssd( # noqa: C901
6565 )
6666
6767 # Validate recent_embeddings
68+ n_recent = 0
6869 if recent_embeddings is not None and np .size (recent_embeddings ) > 0 :
6970 if recent_embeddings .ndim != 2 :
7071 raise ValueError ("recent_embeddings must be a 2D array of shape (n_items, n_dims)." )
@@ -73,6 +74,10 @@ def ssd( # noqa: C901
7374 f"recent_embeddings has { recent_embeddings .shape [1 ]} dims; "
7475 f"expected { feature_matrix .shape [1 ]} to match `embeddings` columns."
7576 )
77+ n_recent = int (recent_embeddings .shape [0 ])
78+
79+ # Determine effective window size
80+ window_size = (n_recent + top_k ) if window is None else int (window )
7681
7782 # Pure relevance: select top-k by raw scores
7883 if float (theta ) == 1.0 :
@@ -83,7 +88,7 @@ def ssd( # noqa: C901
8388 selection_scores = selection_scores ,
8489 strategy = Strategy .SSD ,
8590 diversity = diversity ,
86- parameters = {"gamma" : gamma , "window" : window },
91+ parameters = {"gamma" : gamma , "window" : window_size },
8792 )
8893
8994 def _prepare_vectors (matrix : np .ndarray ) -> np .ndarray :
@@ -124,7 +129,7 @@ def _prepare_vectors(matrix: np.ndarray) -> np.ndarray:
124129
125130 def _push_basis_vector (basis_vector : np .ndarray ) -> None :
126131 """Add a new basis vector to the sliding window and update residuals/projections."""
127- if len (basis_vectors ) == window :
132+ if len (basis_vectors ) == window_size :
128133 # Remove oldest basis and restore its contribution to residuals
129134 oldest_basis = basis_vectors .pop (0 )
130135 oldest_coefficients = projection_coefficients_per_basis .pop (0 )
@@ -148,7 +153,7 @@ def _push_basis_vector(basis_vector: np.ndarray) -> None:
148153 seeded_bases = 0
149154 if recent_embeddings is not None and np .size (recent_embeddings ) > 0 :
150155 context = _prepare_vectors (recent_embeddings .astype (feature_matrix .dtype , copy = False ))
151- context = context [- window :] # keep only the latest `window ` items
156+ context = context [- window_size :] # keep only the latest `window_size ` items
152157 for context_vector in context :
153158 residual_context = context_vector .copy ()
154159 for basis in basis_vectors :
@@ -201,5 +206,5 @@ def _push_basis_vector(basis_vector: np.ndarray) -> None:
201206 selection_scores = selection_scores .astype (np .float32 , copy = False ),
202207 strategy = Strategy .SSD ,
203208 diversity = diversity ,
204- parameters = {"gamma" : gamma , "window" : window },
209+ parameters = {"gamma" : gamma , "window" : window_size },
205210 )
0 commit comments