Coverage for src / lilbee / retrieval / clustering_embedding / helpers.py: 100%
170 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
1"""Numeric and tokenization helpers for the embedding clusterer.
3Mutual-kNN is hub-robust: a pathological hub ends up in many one-way
4neighborhoods but can reciprocate at most ``k`` of them, so hub-driven
5bridging across topics is broken at the graph-construction step without
6any post-hoc similarity rescaling.
8The similarity kernel is blocked in row chunks of ``_BLOCK_SIZE`` to keep
9peak memory bounded regardless of corpus size, so it scales comfortably
10to tens of thousands of chunks on a laptop.
11"""
13from __future__ import annotations
15import math
16from collections import Counter
18import numpy as np
20from lilbee.core.config import CHUNKS_TABLE
21from lilbee.data.store import Store
22from lilbee.retrieval.clustering import SourceCluster
23from lilbee.retrieval.clustering_embedding.types import ChunkRecord
25# Block size for the similarity kernel. With N=10000 and D=768 this caps
26# peak float32 memory at block * N * 4 bytes ~= 40 MB.
27_BLOCK_SIZE = 1024
29# Label Propagation hard iteration cap. Convergence is typically reached
30# in well under 10 passes on real corpora.
31_MAX_LPA_ITERATIONS = 30
33# Minimum non-zero L2 norm for a row vector to be kept.
34_MIN_VECTOR_NORM = 1e-12
36# Source-membership thresholds. A source joins a chunk community when it
37# contributes at least `min(_MIN_SOURCE_CHUNKS, ceil(total * _MIN_SOURCE_FRACTION))`
38# of its chunks. The stricter (smaller) side wins, so a single stray chunk
39# from a long document never pulls the whole source into an unrelated cluster.
40_MIN_SOURCE_CHUNKS = 3
41_MIN_SOURCE_FRACTION = 0.2
43# TF-IDF labeling knobs.
44_LABEL_TOP_TERMS = 3
45# Chunks with fewer tokens than this are down-weighted when accumulating
46# term frequency so short boilerplate (headings, captions) cannot dominate
47# a cluster label. 20 tokens roughly matches the token count of a section
48# heading or a two-sentence summary.
49_SHORT_CHUNK_TOKEN_CAP = 20
51# kNN auto-scaling bounds. Formula: clamp(round(log2(N)+2), _MIN_K, _MAX_K).
52_MIN_K = 5
53_MAX_K = 20
55# Minimum token length for TF-IDF labeling. Shorter tokens are mostly
56# articles, prepositions, and single letters: noise that inflates term
57# counts without adding topic signal. Three characters keeps useful
58# acronyms (api, xml, sql).
59_MIN_TF_TOKEN_LEN = 3
62def _tokenize_for_tf(text: str) -> list[str]:
63 """Lowercase alphanumeric tokens for TF-IDF scoring.
65 Deliberately has NO stopword list: common words like "the" or "and"
66 get an IDF near zero (they appear in almost every chunk) so TF-IDF
67 filters them automatically. A hand-curated English stoplist would
68 add maintenance burden and break on non-English corpora for no
69 additional quality.
70 """
71 result: list[str] = []
72 for raw in text.lower().split():
73 word = "".join(ch for ch in raw if ch.isalnum())
74 if len(word) >= _MIN_TF_TOKEN_LEN:
75 result.append(word)
76 return result
79def auto_k(n: int) -> int:
80 """Pick a neighborhood size from corpus size via ``clamp(log2(N)+2)``."""
81 if n <= 1:
82 return _MIN_K
83 raw = round(math.log2(max(n, 4)) + 2)
84 return max(_MIN_K, min(_MAX_K, raw))
87def _parse_chunk_row(
88 row: dict[str, object],
89) -> tuple[ChunkRecord, list[float] | tuple[float, ...]] | None:
90 """Extract a chunk record + vector from a raw Arrow row, or None on invalid."""
91 vector = row.get("vector")
92 if not isinstance(vector, (list, tuple)):
93 return None
94 source = row.get("source")
95 if not isinstance(source, str):
96 return None
97 raw_text = row.get("chunk")
98 chunk_text = raw_text if isinstance(raw_text, str) else ""
99 raw_index = row.get("chunk_index")
100 chunk_index = raw_index if isinstance(raw_index, int) else 0
101 record = ChunkRecord(
102 source=source,
103 chunk_index=chunk_index,
104 text=chunk_text,
105 tokens=_tokenize_for_tf(chunk_text),
106 )
107 return record, vector
110def _load_chunk_records(
111 store: Store,
112) -> tuple[list[ChunkRecord], np.ndarray]:
113 """Scan the chunks table once and return records plus a float32 matrix.
115 Rows with an unparseable vector are skipped. Records are sorted by
116 ``(source, chunk_index)`` so downstream cluster IDs are stable
117 regardless of LanceDB's row return order. Records are tokenized once
118 here so TF-IDF labeling does not re-tokenize. The vector matrix is
119 preallocated and populated via numpy row-assignment, which pushes
120 the Python-level float cast into numpy's C loop and avoids building
121 a transient ``list[list[float]]``.
122 """
123 table = store.open_table(CHUNKS_TABLE)
124 if table is None:
125 return [], np.zeros((0, 0), dtype=np.float32)
127 parsed = [pair for pair in map(_parse_chunk_row, table.to_arrow().to_pylist()) if pair]
128 if not parsed:
129 return [], np.zeros((0, 0), dtype=np.float32)
131 parsed.sort(key=lambda pair: (pair[0].source, pair[0].chunk_index))
132 dim = len(parsed[0][1])
133 matrix = np.empty((len(parsed), dim), dtype=np.float32)
134 records: list[ChunkRecord] = []
135 for row_idx, (record, vector) in enumerate(parsed):
136 records.append(record)
137 matrix[row_idx] = vector
138 return records, matrix
141def normalize_rows(matrix: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
142 """Return (normalized_matrix, keep_mask). Zero-norm rows are dropped."""
143 if matrix.size == 0:
144 return matrix, np.zeros(0, dtype=bool)
145 norms = np.linalg.norm(matrix, axis=1)
146 keep = norms > _MIN_VECTOR_NORM
147 if not keep.all():
148 matrix = matrix[keep]
149 norms = norms[keep]
150 return matrix / norms[:, None], keep
153def mutual_knn(matrix: np.ndarray, k: int) -> dict[int, set[int]]:
154 """Build a mutual k-nearest-neighbors graph over L2-normalized rows.
156 Computes similarity in row blocks so peak memory stays bounded.
157 Self-similarity is masked so each row's neighbors exclude itself.
158 Mutuality is enforced by keeping only edges ``(i, j)`` where
159 ``j`` is in row ``i``'s top-k AND ``i`` is in row ``j``'s top-k;
160 this single rule breaks hub-driven bridging without any extra
161 similarity rescaling.
162 """
163 n = matrix.shape[0]
164 if n == 0 or k <= 0:
165 return {}
167 effective_k = min(k, n - 1)
168 if effective_k <= 0:
169 return {i: set() for i in range(n)}
171 top_neighbors: list[set[int]] = [set() for _ in range(n)]
173 for start in range(0, n, _BLOCK_SIZE):
174 stop = min(start + _BLOCK_SIZE, n)
175 sim_block = matrix[start:stop] @ matrix.T # (block, n)
176 # Mask self-similarity so each row's own index is never returned.
177 # For the tail block where stop-start < _BLOCK_SIZE the fancy-index
178 # pairing still lines up because both sides are length stop-start.
179 block_rows = np.arange(stop - start)
180 sim_block[block_rows, np.arange(start, stop)] = -math.inf
181 # Partition by largest similarities without allocating a negated
182 # copy of the block: pass a negative kth to select the tail.
183 neighbor_idx = np.argpartition(sim_block, -effective_k, axis=1)[:, -effective_k:]
184 for local_row, global_row in enumerate(range(start, stop)):
185 top_neighbors[global_row] = set(neighbor_idx[local_row].tolist())
187 mutual: dict[int, set[int]] = {i: set() for i in range(n)}
188 for i in range(n):
189 for j in top_neighbors[i]:
190 if i in top_neighbors[j]:
191 mutual[i].add(j)
192 return mutual
195def label_propagation(
196 adjacency: dict[int, set[int]],
197 order: list[int],
198) -> list[int]:
199 """Async Label Propagation with deterministic min-label tie-breaking.
201 Each node adopts the most common label among its neighbors. Ties are
202 broken by smallest label id so the outcome is reproducible across
203 runs on the same corpus. The caller is responsible for passing a
204 complete ``adjacency`` (one entry per node 0..n-1) and an ``order``
205 that covers every node: ``get_clusters`` guarantees both.
206 """
207 n = len(adjacency)
208 labels = list(range(n))
209 for _ in range(_MAX_LPA_ITERATIONS):
210 changed = False
211 for node in order:
212 neighbors = adjacency.get(node)
213 if not neighbors:
214 continue
215 counts: Counter[int] = Counter(labels[j] for j in neighbors)
216 top_count = max(counts.values())
217 best = min(label for label, count in counts.items() if count == top_count)
218 if labels[node] != best:
219 labels[node] = best
220 changed = True
221 if not changed:
222 break
223 return labels
226def communities_by_label(labels: list[int]) -> dict[int, list[int]]:
227 """Group node indices by their final community label."""
228 communities: dict[int, list[int]] = {}
229 for node, label in enumerate(labels):
230 communities.setdefault(label, []).append(node)
231 return communities
234def _source_totals(records: list[ChunkRecord]) -> dict[str, int]:
235 """Return the total chunk count per source across the whole corpus."""
236 totals: dict[str, int] = {}
237 for record in records:
238 totals[record.source] = totals.get(record.source, 0) + 1
239 return totals
242def _filter_sources(
243 member_indices: list[int],
244 records: list[ChunkRecord],
245 source_totals: dict[str, int],
246) -> frozenset[str]:
247 """Apply the source-membership threshold to a community's members."""
248 per_source: dict[str, int] = {}
249 for idx in member_indices:
250 source = records[idx].source
251 per_source[source] = per_source.get(source, 0) + 1
252 kept: set[str] = set()
253 for source, count in per_source.items():
254 total = source_totals.get(source, count)
255 fractional_cutoff = math.ceil(total * _MIN_SOURCE_FRACTION)
256 cutoff = min(_MIN_SOURCE_CHUNKS, fractional_cutoff)
257 if count >= cutoff:
258 kept.add(source)
259 return frozenset(kept)
262def _corpus_document_frequency(records: list[ChunkRecord]) -> dict[str, int]:
263 """Compute document frequency (chunk count containing term) for every term."""
264 df: dict[str, int] = {}
265 for record in records:
266 for term in set(record.tokens):
267 df[term] = df.get(term, 0) + 1
268 return df
271def _label_community(
272 member_indices: list[int],
273 records: list[ChunkRecord],
274 df: dict[str, int],
275 total_chunks: int,
276 fallback: str,
277) -> str:
278 """Pick a topic label for a community using sublinear TF-IDF scoring."""
279 tf: dict[str, float] = {}
280 for idx in member_indices:
281 tokens = records[idx].tokens
282 if not tokens:
283 continue
284 weight = min(1.0, len(tokens) / _SHORT_CHUNK_TOKEN_CAP)
285 counts: Counter[str] = Counter(tokens)
286 for term, count in counts.items():
287 tf[term] = tf.get(term, 0.0) + weight * (1.0 + math.log(count))
289 scored: list[tuple[float, str]] = []
290 for term, term_tf in tf.items():
291 # Standard ``log(N / (1 + df))`` smoothing: the +1 keeps the
292 # denominator non-zero for new terms and damps the score of
293 # terms that appear in every chunk (where idf goes negative
294 # and the term is filtered out entirely).
295 idf = math.log(total_chunks / (1 + df.get(term, 0)))
296 if idf <= 0:
297 continue
298 scored.append((term_tf * idf, term))
299 if not scored:
300 return fallback
302 scored.sort(key=lambda pair: (-pair[0], pair[1]))
303 return " ".join(term for _, term in scored[:_LABEL_TOP_TERMS])
306def _build_clusters(
307 communities: dict[int, list[int]],
308 records: list[ChunkRecord],
309 source_totals: dict[str, int],
310 df: dict[str, int],
311 min_sources: int,
312) -> tuple[list[SourceCluster], int]:
313 """Turn raw chunk communities into published source clusters.
315 Returns ``(clusters, noise_chunk_count)`` where ``noise_chunk_count``
316 is the number of chunks whose community failed the source filter.
317 """
318 ordered = sorted(communities.items(), key=lambda pair: (-len(pair[1]), pair[0]))
319 total_chunks = len(records)
320 clusters: list[SourceCluster] = []
321 noise = 0
322 for idx, (_, members) in enumerate(ordered):
323 kept_sources = _filter_sources(members, records, source_totals)
324 if len(kept_sources) < min_sources:
325 noise += len(members)
326 continue
327 cluster_id = f"embedding-{idx}"
328 label = _label_community(members, records, df, total_chunks, fallback=cluster_id)
329 clusters.append(
330 SourceCluster(
331 cluster_id=cluster_id,
332 label=label,
333 sources=kept_sources,
334 )
335 )
336 return clusters, noise