Coverage for src / lilbee / retrieval / concepts / graph.py: 100%

208 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-15 20:55 +0000

1"""ConceptGraph: extracts, stores, and queries concept relationships.""" 

2 

3from __future__ import annotations 

4 

5import logging 

6from collections import Counter 

7from typing import Any 

8 

9import pyarrow as pa 

10import pyarrow.compute as pc 

11 

12from lilbee.core.config import ( 

13 CHUNK_CONCEPTS_TABLE, 

14 CONCEPT_EDGES_TABLE, 

15 CONCEPT_NODES_TABLE, 

16 Config, 

17) 

18from lilbee.data.store import Store, escape_sql_string 

19from lilbee.retrieval.concepts.community import Community, _compute_pmi, _leiden_partition 

20from lilbee.retrieval.concepts.nlp import _ensure_spacy_model, _filter_noun_chunks 

21from lilbee.retrieval.concepts.schema import ( 

22 _chunk_concepts_schema, 

23 _concept_edges_schema, 

24 _concept_nodes_schema, 

25) 

26 

27log = logging.getLogger(__name__) 

28 

29 

30class ConceptGraph: 

31 """Concept graph -- extracts, stores, and queries concept relationships.""" 

32 

33 def __init__(self, config: Config, store: Store) -> None: 

34 self._config = config 

35 self._store = store 

36 self._nlp: Any = None 

37 self._nlp_unavailable: bool = False 

38 

39 def _ensure_nlp(self) -> Any | None: 

40 """Lazy-load and cache the spaCy model. Returns None if unavailable.""" 

41 if self._nlp_unavailable: 

42 return None 

43 if self._nlp is None: 

44 try: 

45 self._nlp = _ensure_spacy_model() 

46 except ImportError: 

47 log.warning("Concept graph disabled: spaCy model unavailable") 

48 self._nlp_unavailable = True 

49 return None 

50 return self._nlp 

51 

52 def extract_concepts(self, text: str, max_concepts: int | None = None) -> list[str]: 

53 """Extract noun-phrase concepts from text via spaCy.""" 

54 if max_concepts is None: 

55 max_concepts = self._config.concept_max_per_chunk 

56 if not text.strip(): 

57 return [] 

58 nlp = self._ensure_nlp() 

59 if nlp is None: 

60 return [] 

61 doc = nlp(text) 

62 return _filter_noun_chunks(doc, max_concepts) 

63 

64 def extract_concepts_batch(self, texts: list[str]) -> list[list[str]]: 

65 """Batch-extract concepts from multiple texts.""" 

66 if not texts: 

67 return [] 

68 nlp = self._ensure_nlp() 

69 if nlp is None: 

70 return [[] for _ in texts] 

71 max_concepts = self._config.concept_max_per_chunk 

72 return [_filter_noun_chunks(doc, max_concepts) for doc in nlp.pipe(texts)] 

73 

74 def build_from_chunks( 

75 self, chunk_ids: list[tuple[str, int]], concept_lists: list[list[str]] 

76 ) -> None: 

77 """Build co-occurrence graph from chunk concepts, compute PMI, store tables.""" 

78 from lilbee.data.store import ensure_table 

79 from lilbee.runtime.lock import write_lock 

80 

81 if not chunk_ids: 

82 return 

83 

84 cooccurrences: Counter[tuple[str, str]] = Counter() 

85 concept_counts: Counter[str] = Counter() 

86 chunk_concept_records: list[dict[str, Any]] = [] 

87 

88 for (source, idx), concepts in zip(chunk_ids, concept_lists, strict=True): 

89 for c in concepts: 

90 concept_counts[c] += 1 

91 chunk_concept_records.append( 

92 {"chunk_source": source, "chunk_index": idx, "concept": c} 

93 ) 

94 for i, a in enumerate(concepts): 

95 for b in concepts[i + 1 :]: 

96 pair = (min(a, b), max(a, b)) 

97 cooccurrences[pair] += 1 

98 

99 pmi_weights = _compute_pmi(cooccurrences, concept_counts, len(chunk_ids)) 

100 

101 edge_records = [ 

102 {"source": a, "target": b, "weight": w} for (a, b), w in pmi_weights.items() 

103 ] 

104 

105 node_records = [ 

106 {"concept": c, "cluster_id": 0, "degree": count} for c, count in concept_counts.items() 

107 ] 

108 

109 with write_lock(): 

110 db = self._store.get_db() 

111 # Always create tables so get_graph() returns True even when 

112 # concept extraction yields no results for the current corpus. 

113 nodes_tbl = ensure_table(db, CONCEPT_NODES_TABLE, _concept_nodes_schema()) 

114 edges_tbl = ensure_table(db, CONCEPT_EDGES_TABLE, _concept_edges_schema()) 

115 cc_tbl = ensure_table(db, CHUNK_CONCEPTS_TABLE, _chunk_concepts_schema()) 

116 if node_records: 

117 nodes_tbl.add(node_records) 

118 if edge_records: 

119 edges_tbl.add(edge_records) 

120 if chunk_concept_records: 

121 cc_tbl.add(chunk_concept_records) 

122 

123 def boost_results(self, results: list[Any], query_concepts: list[str]) -> list[Any]: 

124 """Boost search results whose chunks overlap with query concepts.""" 

125 if not query_concepts or not results: 

126 return results 

127 query_set = set(query_concepts) 

128 boosted: list[Any] = [] 

129 for r in results: 

130 chunk_concepts = set(self.get_chunk_concepts(r.source, r.chunk_index)) 

131 overlap = len(query_set & chunk_concepts) 

132 if overlap > 0: 

133 boost = (overlap / len(query_set)) * self._config.concept_boost_weight 

134 r = r.model_copy() 

135 if r.relevance_score is not None: 

136 r.relevance_score = r.relevance_score + boost 

137 elif r.distance is not None: 

138 r.distance = max(self._config.concept_boost_floor, r.distance - boost) 

139 boosted.append(r) 

140 return boosted 

141 

142 def get_chunk_concepts(self, source: str, chunk_index: int) -> list[str]: 

143 """Get concepts associated with a specific chunk.""" 

144 table = self._store.open_table(CHUNK_CONCEPTS_TABLE) 

145 if table is None: 

146 return [] 

147 escaped = escape_sql_string(source) 

148 try: 

149 rows = ( 

150 table.search() 

151 .where(f"chunk_source = '{escaped}' AND chunk_index = {chunk_index}") 

152 .to_list() 

153 ) 

154 return [r["concept"] for r in rows] 

155 except Exception: 

156 return [] 

157 

158 def expand_query(self, query: str) -> list[str]: 

159 """Expand a query with related concepts from the graph.""" 

160 concepts = self.extract_concepts(query) 

161 if not concepts: 

162 return [] 

163 related: list[str] = [] 

164 seen = set(concepts) 

165 for concept in concepts: 

166 for neighbor in self.get_related_concepts(concept): 

167 if neighbor not in seen: 

168 related.append(neighbor) 

169 seen.add(neighbor) 

170 return related 

171 

172 def get_related_concepts(self, concept: str, depth: int = 1) -> list[str]: 

173 """Find concepts related to *concept* via graph edges, up to *depth* hops. 

174 

175 One batched query per depth level: O(depth) DB round-trips, 

176 independent of frontier size. 

177 """ 

178 table = self._store.open_table(CONCEPT_EDGES_TABLE) 

179 if table is None: 

180 return [] 

181 visited: set[str] = {concept} 

182 frontier: list[str] = [concept] 

183 for _ in range(depth): 

184 if not frontier: 

185 break 

186 escaped_list = ", ".join(f"'{escape_sql_string(n)}'" for n in frontier) 

187 try: 

188 rows = ( 

189 table.search() 

190 .where(f"source IN ({escaped_list}) OR target IN ({escaped_list})") 

191 .to_list() 

192 ) 

193 except Exception: 

194 log.debug( 

195 "concept expand batch failed at frontier size %d", 

196 len(frontier), 

197 exc_info=True, 

198 ) 

199 break 

200 next_frontier: list[str] = [] 

201 for row in rows: 

202 for endpoint in (row["source"], row["target"]): 

203 if endpoint not in visited: 

204 visited.add(endpoint) 

205 next_frontier.append(endpoint) 

206 frontier = next_frontier 

207 return [c for c in visited if c != concept] 

208 

209 def top_communities(self, k: int = 10) -> list[Community]: 

210 """Return the *k* largest concept communities. 

211 

212 Uses ``pyarrow.compute.value_counts`` to pick the top-k 

213 cluster_ids in columnar memory, then materializes only those 

214 clusters' members. Peak Python memory scales with members of 

215 the top *k* clusters, not the total node count. 

216 """ 

217 table = self._store.open_table(CONCEPT_NODES_TABLE) 

218 if table is None: 

219 return [] 

220 arrow_tbl = table.to_arrow() 

221 if arrow_tbl.num_rows == 0: 

222 return [] 

223 counts = pc.value_counts(arrow_tbl["cluster_id"]).to_pylist() 

224 top = sorted(counts, key=lambda entry: entry["counts"], reverse=True)[:k] 

225 top_ids = [entry["values"] for entry in top if entry["values"] is not None] 

226 if not top_ids: 

227 return [] 

228 member_rows = arrow_tbl.filter( 

229 pc.is_in(arrow_tbl["cluster_id"], value_set=pa.array(top_ids)) 

230 ).to_pylist() 

231 by_cluster: dict[int, list[str]] = {} 

232 for row in member_rows: 

233 by_cluster.setdefault(row["cluster_id"], []).append(row["concept"]) 

234 return [ 

235 Community( 

236 cluster_id=cid, 

237 size=len(by_cluster.get(cid, [])), 

238 concepts=by_cluster.get(cid, []), 

239 ) 

240 for cid in top_ids 

241 if by_cluster.get(cid) 

242 ] 

243 

244 def rebuild_clusters(self) -> None: 

245 """Re-run Leiden clustering on the existing edge table.""" 

246 from lilbee.data.store import ensure_table 

247 from lilbee.runtime.lock import write_lock 

248 

249 edges_table = self._store.open_table(CONCEPT_EDGES_TABLE) 

250 if edges_table is None: 

251 return 

252 edge_rows = edges_table.to_arrow().to_pylist() 

253 if not edge_rows: 

254 return 

255 

256 partition, degree_map = _leiden_partition(edge_rows) 

257 

258 node_records = [ 

259 { 

260 "concept": node, 

261 "cluster_id": cluster_id, 

262 "degree": degree_map.get(node, 0), 

263 } 

264 for node, cluster_id in partition.items() 

265 ] 

266 

267 self._store.clear_table(CONCEPT_NODES_TABLE, "concept IS NOT NULL") 

268 if node_records: 

269 with write_lock(): 

270 db = self._store.get_db() 

271 nodes_table = ensure_table(db, CONCEPT_NODES_TABLE, _concept_nodes_schema()) 

272 nodes_table.add(node_records) 

273 

274 def get_cluster_sources(self, min_sources: int = 3) -> dict[int, set[str]]: 

275 """Return clusters that span at least *min_sources* distinct sources. 

276 Joins concept_nodes (concept -> cluster_id) with chunk_concepts 

277 (concept -> chunk_source) to find which document sources each 

278 cluster touches. 

279 """ 

280 nodes_table = self._store.open_table(CONCEPT_NODES_TABLE) 

281 cc_table = self._store.open_table(CHUNK_CONCEPTS_TABLE) 

282 if nodes_table is None or cc_table is None: 

283 return {} 

284 

285 node_rows = nodes_table.to_arrow().to_pylist() 

286 concept_to_cluster: dict[str, int] = {r["concept"]: r["cluster_id"] for r in node_rows} 

287 

288 cc_rows = cc_table.to_arrow().to_pylist() 

289 cluster_sources: dict[int, set[str]] = {} 

290 for row in cc_rows: 

291 cid = concept_to_cluster.get(row["concept"]) 

292 if cid is None: 

293 continue 

294 cluster_sources.setdefault(cid, set()).add(row["chunk_source"]) 

295 

296 return { 

297 cid: sources for cid, sources in cluster_sources.items() if len(sources) >= min_sources 

298 } 

299 

300 def get_cluster_label(self, cluster_id: int) -> str: 

301 """Return a human-readable label for *cluster_id* (highest-degree concept).""" 

302 table = self._store.open_table(CONCEPT_NODES_TABLE) 

303 if table is None: 

304 return f"cluster-{cluster_id}" 

305 try: 

306 rows = table.search().where(f"cluster_id = {int(cluster_id)}").to_list() 

307 except Exception: 

308 log.debug("get_cluster_label query failed", exc_info=True) 

309 return f"cluster-{cluster_id}" 

310 if not rows: 

311 return f"cluster-{cluster_id}" 

312 best = max(rows, key=lambda r: r["degree"]) 

313 return str(best["concept"]) 

314 

315 def get_graph(self) -> bool: 

316 """Check whether a concept graph exists in the store.""" 

317 if not self._config.concept_graph: 

318 return False 

319 return self._store.open_table(CONCEPT_NODES_TABLE) is not None 

320 

321 def reset_nlp_cache(self) -> None: 

322 """Clear the spaCy model cache. For testing only.""" 

323 self._nlp = None 

324 self._nlp_unavailable = False