Coverage for src / lilbee / retrieval / clustering_embedding / helpers.py: 100%

170 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-15 20:55 +0000

1"""Numeric and tokenization helpers for the embedding clusterer. 

2 

3Mutual-kNN is hub-robust: a pathological hub ends up in many one-way 

4neighborhoods but can reciprocate at most ``k`` of them, so hub-driven 

5bridging across topics is broken at the graph-construction step without 

6any post-hoc similarity rescaling. 

7 

8The similarity kernel is blocked in row chunks of ``_BLOCK_SIZE`` to keep 

9peak memory bounded regardless of corpus size, so it scales comfortably 

10to tens of thousands of chunks on a laptop. 

11""" 

12 

13from __future__ import annotations 

14 

15import math 

16from collections import Counter 

17 

18import numpy as np 

19 

20from lilbee.core.config import CHUNKS_TABLE 

21from lilbee.data.store import Store 

22from lilbee.retrieval.clustering import SourceCluster 

23from lilbee.retrieval.clustering_embedding.types import ChunkRecord 

24 

25# Block size for the similarity kernel. With N=10000 and D=768 this caps 

26# peak float32 memory at block * N * 4 bytes ~= 40 MB. 

27_BLOCK_SIZE = 1024 

28 

29# Label Propagation hard iteration cap. Convergence is typically reached 

30# in well under 10 passes on real corpora. 

31_MAX_LPA_ITERATIONS = 30 

32 

33# Minimum non-zero L2 norm for a row vector to be kept. 

34_MIN_VECTOR_NORM = 1e-12 

35 

36# Source-membership thresholds. A source joins a chunk community when it 

37# contributes at least `min(_MIN_SOURCE_CHUNKS, ceil(total * _MIN_SOURCE_FRACTION))` 

38# of its chunks. The stricter (smaller) side wins, so a single stray chunk 

39# from a long document never pulls the whole source into an unrelated cluster. 

40_MIN_SOURCE_CHUNKS = 3 

41_MIN_SOURCE_FRACTION = 0.2 

42 

43# TF-IDF labeling knobs. 

44_LABEL_TOP_TERMS = 3 

45# Chunks with fewer tokens than this are down-weighted when accumulating 

46# term frequency so short boilerplate (headings, captions) cannot dominate 

47# a cluster label. 20 tokens roughly matches the token count of a section 

48# heading or a two-sentence summary. 

49_SHORT_CHUNK_TOKEN_CAP = 20 

50 

51# kNN auto-scaling bounds. Formula: clamp(round(log2(N)+2), _MIN_K, _MAX_K). 

52_MIN_K = 5 

53_MAX_K = 20 

54 

55# Minimum token length for TF-IDF labeling. Shorter tokens are mostly 

56# articles, prepositions, and single letters: noise that inflates term 

57# counts without adding topic signal. Three characters keeps useful 

58# acronyms (api, xml, sql). 

59_MIN_TF_TOKEN_LEN = 3 

60 

61 

62def _tokenize_for_tf(text: str) -> list[str]: 

63 """Lowercase alphanumeric tokens for TF-IDF scoring. 

64 

65 Deliberately has NO stopword list: common words like "the" or "and" 

66 get an IDF near zero (they appear in almost every chunk) so TF-IDF 

67 filters them automatically. A hand-curated English stoplist would 

68 add maintenance burden and break on non-English corpora for no 

69 additional quality. 

70 """ 

71 result: list[str] = [] 

72 for raw in text.lower().split(): 

73 word = "".join(ch for ch in raw if ch.isalnum()) 

74 if len(word) >= _MIN_TF_TOKEN_LEN: 

75 result.append(word) 

76 return result 

77 

78 

79def auto_k(n: int) -> int: 

80 """Pick a neighborhood size from corpus size via ``clamp(log2(N)+2)``.""" 

81 if n <= 1: 

82 return _MIN_K 

83 raw = round(math.log2(max(n, 4)) + 2) 

84 return max(_MIN_K, min(_MAX_K, raw)) 

85 

86 

87def _parse_chunk_row( 

88 row: dict[str, object], 

89) -> tuple[ChunkRecord, list[float] | tuple[float, ...]] | None: 

90 """Extract a chunk record + vector from a raw Arrow row, or None on invalid.""" 

91 vector = row.get("vector") 

92 if not isinstance(vector, (list, tuple)): 

93 return None 

94 source = row.get("source") 

95 if not isinstance(source, str): 

96 return None 

97 raw_text = row.get("chunk") 

98 chunk_text = raw_text if isinstance(raw_text, str) else "" 

99 raw_index = row.get("chunk_index") 

100 chunk_index = raw_index if isinstance(raw_index, int) else 0 

101 record = ChunkRecord( 

102 source=source, 

103 chunk_index=chunk_index, 

104 text=chunk_text, 

105 tokens=_tokenize_for_tf(chunk_text), 

106 ) 

107 return record, vector 

108 

109 

110def _load_chunk_records( 

111 store: Store, 

112) -> tuple[list[ChunkRecord], np.ndarray]: 

113 """Scan the chunks table once and return records plus a float32 matrix. 

114 

115 Rows with an unparseable vector are skipped. Records are sorted by 

116 ``(source, chunk_index)`` so downstream cluster IDs are stable 

117 regardless of LanceDB's row return order. Records are tokenized once 

118 here so TF-IDF labeling does not re-tokenize. The vector matrix is 

119 preallocated and populated via numpy row-assignment, which pushes 

120 the Python-level float cast into numpy's C loop and avoids building 

121 a transient ``list[list[float]]``. 

122 """ 

123 table = store.open_table(CHUNKS_TABLE) 

124 if table is None: 

125 return [], np.zeros((0, 0), dtype=np.float32) 

126 

127 parsed = [pair for pair in map(_parse_chunk_row, table.to_arrow().to_pylist()) if pair] 

128 if not parsed: 

129 return [], np.zeros((0, 0), dtype=np.float32) 

130 

131 parsed.sort(key=lambda pair: (pair[0].source, pair[0].chunk_index)) 

132 dim = len(parsed[0][1]) 

133 matrix = np.empty((len(parsed), dim), dtype=np.float32) 

134 records: list[ChunkRecord] = [] 

135 for row_idx, (record, vector) in enumerate(parsed): 

136 records.append(record) 

137 matrix[row_idx] = vector 

138 return records, matrix 

139 

140 

141def normalize_rows(matrix: np.ndarray) -> tuple[np.ndarray, np.ndarray]: 

142 """Return (normalized_matrix, keep_mask). Zero-norm rows are dropped.""" 

143 if matrix.size == 0: 

144 return matrix, np.zeros(0, dtype=bool) 

145 norms = np.linalg.norm(matrix, axis=1) 

146 keep = norms > _MIN_VECTOR_NORM 

147 if not keep.all(): 

148 matrix = matrix[keep] 

149 norms = norms[keep] 

150 return matrix / norms[:, None], keep 

151 

152 

153def mutual_knn(matrix: np.ndarray, k: int) -> dict[int, set[int]]: 

154 """Build a mutual k-nearest-neighbors graph over L2-normalized rows. 

155 

156 Computes similarity in row blocks so peak memory stays bounded. 

157 Self-similarity is masked so each row's neighbors exclude itself. 

158 Mutuality is enforced by keeping only edges ``(i, j)`` where 

159 ``j`` is in row ``i``'s top-k AND ``i`` is in row ``j``'s top-k; 

160 this single rule breaks hub-driven bridging without any extra 

161 similarity rescaling. 

162 """ 

163 n = matrix.shape[0] 

164 if n == 0 or k <= 0: 

165 return {} 

166 

167 effective_k = min(k, n - 1) 

168 if effective_k <= 0: 

169 return {i: set() for i in range(n)} 

170 

171 top_neighbors: list[set[int]] = [set() for _ in range(n)] 

172 

173 for start in range(0, n, _BLOCK_SIZE): 

174 stop = min(start + _BLOCK_SIZE, n) 

175 sim_block = matrix[start:stop] @ matrix.T # (block, n) 

176 # Mask self-similarity so each row's own index is never returned. 

177 # For the tail block where stop-start < _BLOCK_SIZE the fancy-index 

178 # pairing still lines up because both sides are length stop-start. 

179 block_rows = np.arange(stop - start) 

180 sim_block[block_rows, np.arange(start, stop)] = -math.inf 

181 # Partition by largest similarities without allocating a negated 

182 # copy of the block: pass a negative kth to select the tail. 

183 neighbor_idx = np.argpartition(sim_block, -effective_k, axis=1)[:, -effective_k:] 

184 for local_row, global_row in enumerate(range(start, stop)): 

185 top_neighbors[global_row] = set(neighbor_idx[local_row].tolist()) 

186 

187 mutual: dict[int, set[int]] = {i: set() for i in range(n)} 

188 for i in range(n): 

189 for j in top_neighbors[i]: 

190 if i in top_neighbors[j]: 

191 mutual[i].add(j) 

192 return mutual 

193 

194 

195def label_propagation( 

196 adjacency: dict[int, set[int]], 

197 order: list[int], 

198) -> list[int]: 

199 """Async Label Propagation with deterministic min-label tie-breaking. 

200 

201 Each node adopts the most common label among its neighbors. Ties are 

202 broken by smallest label id so the outcome is reproducible across 

203 runs on the same corpus. The caller is responsible for passing a 

204 complete ``adjacency`` (one entry per node 0..n-1) and an ``order`` 

205 that covers every node: ``get_clusters`` guarantees both. 

206 """ 

207 n = len(adjacency) 

208 labels = list(range(n)) 

209 for _ in range(_MAX_LPA_ITERATIONS): 

210 changed = False 

211 for node in order: 

212 neighbors = adjacency.get(node) 

213 if not neighbors: 

214 continue 

215 counts: Counter[int] = Counter(labels[j] for j in neighbors) 

216 top_count = max(counts.values()) 

217 best = min(label for label, count in counts.items() if count == top_count) 

218 if labels[node] != best: 

219 labels[node] = best 

220 changed = True 

221 if not changed: 

222 break 

223 return labels 

224 

225 

226def communities_by_label(labels: list[int]) -> dict[int, list[int]]: 

227 """Group node indices by their final community label.""" 

228 communities: dict[int, list[int]] = {} 

229 for node, label in enumerate(labels): 

230 communities.setdefault(label, []).append(node) 

231 return communities 

232 

233 

234def _source_totals(records: list[ChunkRecord]) -> dict[str, int]: 

235 """Return the total chunk count per source across the whole corpus.""" 

236 totals: dict[str, int] = {} 

237 for record in records: 

238 totals[record.source] = totals.get(record.source, 0) + 1 

239 return totals 

240 

241 

242def _filter_sources( 

243 member_indices: list[int], 

244 records: list[ChunkRecord], 

245 source_totals: dict[str, int], 

246) -> frozenset[str]: 

247 """Apply the source-membership threshold to a community's members.""" 

248 per_source: dict[str, int] = {} 

249 for idx in member_indices: 

250 source = records[idx].source 

251 per_source[source] = per_source.get(source, 0) + 1 

252 kept: set[str] = set() 

253 for source, count in per_source.items(): 

254 total = source_totals.get(source, count) 

255 fractional_cutoff = math.ceil(total * _MIN_SOURCE_FRACTION) 

256 cutoff = min(_MIN_SOURCE_CHUNKS, fractional_cutoff) 

257 if count >= cutoff: 

258 kept.add(source) 

259 return frozenset(kept) 

260 

261 

262def _corpus_document_frequency(records: list[ChunkRecord]) -> dict[str, int]: 

263 """Compute document frequency (chunk count containing term) for every term.""" 

264 df: dict[str, int] = {} 

265 for record in records: 

266 for term in set(record.tokens): 

267 df[term] = df.get(term, 0) + 1 

268 return df 

269 

270 

271def _label_community( 

272 member_indices: list[int], 

273 records: list[ChunkRecord], 

274 df: dict[str, int], 

275 total_chunks: int, 

276 fallback: str, 

277) -> str: 

278 """Pick a topic label for a community using sublinear TF-IDF scoring.""" 

279 tf: dict[str, float] = {} 

280 for idx in member_indices: 

281 tokens = records[idx].tokens 

282 if not tokens: 

283 continue 

284 weight = min(1.0, len(tokens) / _SHORT_CHUNK_TOKEN_CAP) 

285 counts: Counter[str] = Counter(tokens) 

286 for term, count in counts.items(): 

287 tf[term] = tf.get(term, 0.0) + weight * (1.0 + math.log(count)) 

288 

289 scored: list[tuple[float, str]] = [] 

290 for term, term_tf in tf.items(): 

291 # Standard ``log(N / (1 + df))`` smoothing: the +1 keeps the 

292 # denominator non-zero for new terms and damps the score of 

293 # terms that appear in every chunk (where idf goes negative 

294 # and the term is filtered out entirely). 

295 idf = math.log(total_chunks / (1 + df.get(term, 0))) 

296 if idf <= 0: 

297 continue 

298 scored.append((term_tf * idf, term)) 

299 if not scored: 

300 return fallback 

301 

302 scored.sort(key=lambda pair: (-pair[0], pair[1])) 

303 return " ".join(term for _, term in scored[:_LABEL_TOP_TERMS]) 

304 

305 

306def _build_clusters( 

307 communities: dict[int, list[int]], 

308 records: list[ChunkRecord], 

309 source_totals: dict[str, int], 

310 df: dict[str, int], 

311 min_sources: int, 

312) -> tuple[list[SourceCluster], int]: 

313 """Turn raw chunk communities into published source clusters. 

314 

315 Returns ``(clusters, noise_chunk_count)`` where ``noise_chunk_count`` 

316 is the number of chunks whose community failed the source filter. 

317 """ 

318 ordered = sorted(communities.items(), key=lambda pair: (-len(pair[1]), pair[0])) 

319 total_chunks = len(records) 

320 clusters: list[SourceCluster] = [] 

321 noise = 0 

322 for idx, (_, members) in enumerate(ordered): 

323 kept_sources = _filter_sources(members, records, source_totals) 

324 if len(kept_sources) < min_sources: 

325 noise += len(members) 

326 continue 

327 cluster_id = f"embedding-{idx}" 

328 label = _label_community(members, records, df, total_chunks, fallback=cluster_id) 

329 clusters.append( 

330 SourceCluster( 

331 cluster_id=cluster_id, 

332 label=label, 

333 sources=kept_sources, 

334 ) 

335 ) 

336 return clusters, noise