Coverage for src / lilbee / retrieval / concepts / graph.py: 100%
208 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
1"""ConceptGraph: extracts, stores, and queries concept relationships."""
3from __future__ import annotations
5import logging
6from collections import Counter
7from typing import Any
9import pyarrow as pa
10import pyarrow.compute as pc
12from lilbee.core.config import (
13 CHUNK_CONCEPTS_TABLE,
14 CONCEPT_EDGES_TABLE,
15 CONCEPT_NODES_TABLE,
16 Config,
17)
18from lilbee.data.store import Store, escape_sql_string
19from lilbee.retrieval.concepts.community import Community, _compute_pmi, _leiden_partition
20from lilbee.retrieval.concepts.nlp import _ensure_spacy_model, _filter_noun_chunks
21from lilbee.retrieval.concepts.schema import (
22 _chunk_concepts_schema,
23 _concept_edges_schema,
24 _concept_nodes_schema,
25)
27log = logging.getLogger(__name__)
30class ConceptGraph:
31 """Concept graph -- extracts, stores, and queries concept relationships."""
33 def __init__(self, config: Config, store: Store) -> None:
34 self._config = config
35 self._store = store
36 self._nlp: Any = None
37 self._nlp_unavailable: bool = False
39 def _ensure_nlp(self) -> Any | None:
40 """Lazy-load and cache the spaCy model. Returns None if unavailable."""
41 if self._nlp_unavailable:
42 return None
43 if self._nlp is None:
44 try:
45 self._nlp = _ensure_spacy_model()
46 except ImportError:
47 log.warning("Concept graph disabled: spaCy model unavailable")
48 self._nlp_unavailable = True
49 return None
50 return self._nlp
52 def extract_concepts(self, text: str, max_concepts: int | None = None) -> list[str]:
53 """Extract noun-phrase concepts from text via spaCy."""
54 if max_concepts is None:
55 max_concepts = self._config.concept_max_per_chunk
56 if not text.strip():
57 return []
58 nlp = self._ensure_nlp()
59 if nlp is None:
60 return []
61 doc = nlp(text)
62 return _filter_noun_chunks(doc, max_concepts)
64 def extract_concepts_batch(self, texts: list[str]) -> list[list[str]]:
65 """Batch-extract concepts from multiple texts."""
66 if not texts:
67 return []
68 nlp = self._ensure_nlp()
69 if nlp is None:
70 return [[] for _ in texts]
71 max_concepts = self._config.concept_max_per_chunk
72 return [_filter_noun_chunks(doc, max_concepts) for doc in nlp.pipe(texts)]
74 def build_from_chunks(
75 self, chunk_ids: list[tuple[str, int]], concept_lists: list[list[str]]
76 ) -> None:
77 """Build co-occurrence graph from chunk concepts, compute PMI, store tables."""
78 from lilbee.data.store import ensure_table
79 from lilbee.runtime.lock import write_lock
81 if not chunk_ids:
82 return
84 cooccurrences: Counter[tuple[str, str]] = Counter()
85 concept_counts: Counter[str] = Counter()
86 chunk_concept_records: list[dict[str, Any]] = []
88 for (source, idx), concepts in zip(chunk_ids, concept_lists, strict=True):
89 for c in concepts:
90 concept_counts[c] += 1
91 chunk_concept_records.append(
92 {"chunk_source": source, "chunk_index": idx, "concept": c}
93 )
94 for i, a in enumerate(concepts):
95 for b in concepts[i + 1 :]:
96 pair = (min(a, b), max(a, b))
97 cooccurrences[pair] += 1
99 pmi_weights = _compute_pmi(cooccurrences, concept_counts, len(chunk_ids))
101 edge_records = [
102 {"source": a, "target": b, "weight": w} for (a, b), w in pmi_weights.items()
103 ]
105 node_records = [
106 {"concept": c, "cluster_id": 0, "degree": count} for c, count in concept_counts.items()
107 ]
109 with write_lock():
110 db = self._store.get_db()
111 # Always create tables so get_graph() returns True even when
112 # concept extraction yields no results for the current corpus.
113 nodes_tbl = ensure_table(db, CONCEPT_NODES_TABLE, _concept_nodes_schema())
114 edges_tbl = ensure_table(db, CONCEPT_EDGES_TABLE, _concept_edges_schema())
115 cc_tbl = ensure_table(db, CHUNK_CONCEPTS_TABLE, _chunk_concepts_schema())
116 if node_records:
117 nodes_tbl.add(node_records)
118 if edge_records:
119 edges_tbl.add(edge_records)
120 if chunk_concept_records:
121 cc_tbl.add(chunk_concept_records)
123 def boost_results(self, results: list[Any], query_concepts: list[str]) -> list[Any]:
124 """Boost search results whose chunks overlap with query concepts."""
125 if not query_concepts or not results:
126 return results
127 query_set = set(query_concepts)
128 boosted: list[Any] = []
129 for r in results:
130 chunk_concepts = set(self.get_chunk_concepts(r.source, r.chunk_index))
131 overlap = len(query_set & chunk_concepts)
132 if overlap > 0:
133 boost = (overlap / len(query_set)) * self._config.concept_boost_weight
134 r = r.model_copy()
135 if r.relevance_score is not None:
136 r.relevance_score = r.relevance_score + boost
137 elif r.distance is not None:
138 r.distance = max(self._config.concept_boost_floor, r.distance - boost)
139 boosted.append(r)
140 return boosted
142 def get_chunk_concepts(self, source: str, chunk_index: int) -> list[str]:
143 """Get concepts associated with a specific chunk."""
144 table = self._store.open_table(CHUNK_CONCEPTS_TABLE)
145 if table is None:
146 return []
147 escaped = escape_sql_string(source)
148 try:
149 rows = (
150 table.search()
151 .where(f"chunk_source = '{escaped}' AND chunk_index = {chunk_index}")
152 .to_list()
153 )
154 return [r["concept"] for r in rows]
155 except Exception:
156 return []
158 def expand_query(self, query: str) -> list[str]:
159 """Expand a query with related concepts from the graph."""
160 concepts = self.extract_concepts(query)
161 if not concepts:
162 return []
163 related: list[str] = []
164 seen = set(concepts)
165 for concept in concepts:
166 for neighbor in self.get_related_concepts(concept):
167 if neighbor not in seen:
168 related.append(neighbor)
169 seen.add(neighbor)
170 return related
172 def get_related_concepts(self, concept: str, depth: int = 1) -> list[str]:
173 """Find concepts related to *concept* via graph edges, up to *depth* hops.
175 One batched query per depth level: O(depth) DB round-trips,
176 independent of frontier size.
177 """
178 table = self._store.open_table(CONCEPT_EDGES_TABLE)
179 if table is None:
180 return []
181 visited: set[str] = {concept}
182 frontier: list[str] = [concept]
183 for _ in range(depth):
184 if not frontier:
185 break
186 escaped_list = ", ".join(f"'{escape_sql_string(n)}'" for n in frontier)
187 try:
188 rows = (
189 table.search()
190 .where(f"source IN ({escaped_list}) OR target IN ({escaped_list})")
191 .to_list()
192 )
193 except Exception:
194 log.debug(
195 "concept expand batch failed at frontier size %d",
196 len(frontier),
197 exc_info=True,
198 )
199 break
200 next_frontier: list[str] = []
201 for row in rows:
202 for endpoint in (row["source"], row["target"]):
203 if endpoint not in visited:
204 visited.add(endpoint)
205 next_frontier.append(endpoint)
206 frontier = next_frontier
207 return [c for c in visited if c != concept]
209 def top_communities(self, k: int = 10) -> list[Community]:
210 """Return the *k* largest concept communities.
212 Uses ``pyarrow.compute.value_counts`` to pick the top-k
213 cluster_ids in columnar memory, then materializes only those
214 clusters' members. Peak Python memory scales with members of
215 the top *k* clusters, not the total node count.
216 """
217 table = self._store.open_table(CONCEPT_NODES_TABLE)
218 if table is None:
219 return []
220 arrow_tbl = table.to_arrow()
221 if arrow_tbl.num_rows == 0:
222 return []
223 counts = pc.value_counts(arrow_tbl["cluster_id"]).to_pylist()
224 top = sorted(counts, key=lambda entry: entry["counts"], reverse=True)[:k]
225 top_ids = [entry["values"] for entry in top if entry["values"] is not None]
226 if not top_ids:
227 return []
228 member_rows = arrow_tbl.filter(
229 pc.is_in(arrow_tbl["cluster_id"], value_set=pa.array(top_ids))
230 ).to_pylist()
231 by_cluster: dict[int, list[str]] = {}
232 for row in member_rows:
233 by_cluster.setdefault(row["cluster_id"], []).append(row["concept"])
234 return [
235 Community(
236 cluster_id=cid,
237 size=len(by_cluster.get(cid, [])),
238 concepts=by_cluster.get(cid, []),
239 )
240 for cid in top_ids
241 if by_cluster.get(cid)
242 ]
244 def rebuild_clusters(self) -> None:
245 """Re-run Leiden clustering on the existing edge table."""
246 from lilbee.data.store import ensure_table
247 from lilbee.runtime.lock import write_lock
249 edges_table = self._store.open_table(CONCEPT_EDGES_TABLE)
250 if edges_table is None:
251 return
252 edge_rows = edges_table.to_arrow().to_pylist()
253 if not edge_rows:
254 return
256 partition, degree_map = _leiden_partition(edge_rows)
258 node_records = [
259 {
260 "concept": node,
261 "cluster_id": cluster_id,
262 "degree": degree_map.get(node, 0),
263 }
264 for node, cluster_id in partition.items()
265 ]
267 self._store.clear_table(CONCEPT_NODES_TABLE, "concept IS NOT NULL")
268 if node_records:
269 with write_lock():
270 db = self._store.get_db()
271 nodes_table = ensure_table(db, CONCEPT_NODES_TABLE, _concept_nodes_schema())
272 nodes_table.add(node_records)
274 def get_cluster_sources(self, min_sources: int = 3) -> dict[int, set[str]]:
275 """Return clusters that span at least *min_sources* distinct sources.
276 Joins concept_nodes (concept -> cluster_id) with chunk_concepts
277 (concept -> chunk_source) to find which document sources each
278 cluster touches.
279 """
280 nodes_table = self._store.open_table(CONCEPT_NODES_TABLE)
281 cc_table = self._store.open_table(CHUNK_CONCEPTS_TABLE)
282 if nodes_table is None or cc_table is None:
283 return {}
285 node_rows = nodes_table.to_arrow().to_pylist()
286 concept_to_cluster: dict[str, int] = {r["concept"]: r["cluster_id"] for r in node_rows}
288 cc_rows = cc_table.to_arrow().to_pylist()
289 cluster_sources: dict[int, set[str]] = {}
290 for row in cc_rows:
291 cid = concept_to_cluster.get(row["concept"])
292 if cid is None:
293 continue
294 cluster_sources.setdefault(cid, set()).add(row["chunk_source"])
296 return {
297 cid: sources for cid, sources in cluster_sources.items() if len(sources) >= min_sources
298 }
300 def get_cluster_label(self, cluster_id: int) -> str:
301 """Return a human-readable label for *cluster_id* (highest-degree concept)."""
302 table = self._store.open_table(CONCEPT_NODES_TABLE)
303 if table is None:
304 return f"cluster-{cluster_id}"
305 try:
306 rows = table.search().where(f"cluster_id = {int(cluster_id)}").to_list()
307 except Exception:
308 log.debug("get_cluster_label query failed", exc_info=True)
309 return f"cluster-{cluster_id}"
310 if not rows:
311 return f"cluster-{cluster_id}"
312 best = max(rows, key=lambda r: r["degree"])
313 return str(best["concept"])
315 def get_graph(self) -> bool:
316 """Check whether a concept graph exists in the store."""
317 if not self._config.concept_graph:
318 return False
319 return self._store.open_table(CONCEPT_NODES_TABLE) is not None
321 def reset_nlp_cache(self) -> None:
322 """Clear the spaCy model cache. For testing only."""
323 self._nlp = None
324 self._nlp_unavailable = False