Coverage for src / lilbee / data / store / core.py: 100%

331 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-15 20:55 +0000

1"""The ``Store`` class: high-level LanceDB read/write API used across lilbee.""" 

2 

3from __future__ import annotations 

4 

5import logging 

6from datetime import UTC, datetime 

7from pathlib import Path 

8from typing import TYPE_CHECKING 

9 

10import pyarrow as pa 

11import pyarrow.compute as pc 

12 

13from lilbee.core.config import ( 

14 CHUNKS_TABLE, 

15 CITATIONS_TABLE, 

16 META_TABLE, 

17 SOURCES_TABLE, 

18 Config, 

19) 

20from lilbee.core.security import validate_path_within 

21from lilbee.runtime.lock import write_lock 

22 

23from .lance_helpers import ( 

24 _chunk_type_predicate, 

25 _embedding_mismatch_message, 

26 _has_fts_index, 

27 _safe_delete_unlocked, 

28 _sources_search_filter, 

29 _table_names, 

30 ensure_table, 

31 escape_sql_string, 

32 refs_compatible, 

33) 

34from .ranking import mmr_rerank 

35from .schema import _citations_schema, _meta_schema, _sources_schema 

36from .types import ( 

37 META_DELETE_ALL_PREDICATE, 

38 META_SCHEMA_VERSION, 

39 READ_CONSISTENCY_INTERVAL, 

40 CitationRecord, 

41 EmbeddingModelMismatchError, 

42 RemoveResult, 

43 SearchChunk, 

44 SourceRecord, 

45 StoreMeta, 

46) 

47 

48if TYPE_CHECKING: 

49 import lancedb 

50 import lancedb.table 

51 

52log = logging.getLogger(__name__) 

53 

54 

55def _hybrid_search( 

56 table: lancedb.table.Table, 

57 query_text: str, 

58 query_vector: list[float], 

59 top_k: int, 

60 chunk_type: str | None = None, 

61) -> list[SearchChunk]: 

62 """Run hybrid (vector + FTS) search with RRF reranking. 

63 

64 When ``chunk_type`` is set, the predicate is pushed into the query so 

65 the limit applies *after* the type filter. Post-filtering would 

66 silently starve wiki-only queries whose matches live past the top-K 

67 hybrid window. 

68 """ 

69 from lancedb.rerankers import RRFReranker 

70 

71 query = ( 

72 table.search(query_type="hybrid") 

73 .vector(query_vector) 

74 .text(query_text) 

75 .rerank(RRFReranker()) 

76 ) 

77 if chunk_type: 

78 query = query.where(_chunk_type_predicate(chunk_type)) 

79 rows = query.limit(top_k).to_list() 

80 return [SearchChunk(**r) for r in rows] 

81 

82 

83_MAX_THRESHOLD = 1.0 

84_MAX_FILTER_ITERATIONS = 20 # safety cap to prevent runaway loops 

85 

86 

87def _get_distance(chunk: SearchChunk) -> float: 

88 """Extract distance as a sortable float (inf for None).""" 

89 return chunk.distance if chunk.distance is not None else float("inf") 

90 

91 

92def _count_within_threshold(sorted_results: list[SearchChunk], threshold: float) -> int: 

93 """Count results whose distance is within the given threshold.""" 

94 for i, r in enumerate(sorted_results): 

95 if _get_distance(r) > threshold: 

96 return i 

97 return len(sorted_results) 

98 

99 

100class Store: 

101 """LanceDB vector store: wraps all DB operations with config-driven defaults.""" 

102 

103 def __init__(self, config: Config) -> None: 

104 self._config = config 

105 self._fts_ready: bool = False 

106 self._db: lancedb.DBConnection | None = None 

107 # Cache of {filename: ingested_at} rebuilt only when sources 

108 # mutate; callers (temporal filter) hit it per-query. 

109 self._source_ingested_cache: dict[str, str] | None = None 

110 

111 def _invalidate_source_cache(self) -> None: 

112 """Drop the cached {filename: ingested_at} map.""" 

113 self._source_ingested_cache = None 

114 

115 def source_ingested_at_map(self) -> dict[str, str]: 

116 """Return {filename: ingested_at} for every source, cached until mutation. 

117 

118 Best-effort: a reader racing a concurrent invalidation can store a 

119 pre-mutation snapshot. The only consumer (temporal query filter) 

120 treats a missing/stale entry as "do not filter," so staleness 

121 degrades ranking precision, never correctness. 

122 """ 

123 if self._source_ingested_cache is not None: 

124 return self._source_ingested_cache 

125 mapping = {s["filename"]: s.get("ingested_at", "") for s in self.get_sources()} 

126 self._source_ingested_cache = mapping 

127 return mapping 

128 

129 def _chunks_schema(self) -> pa.Schema: 

130 return pa.schema( 

131 [ 

132 pa.field("source", pa.utf8()), 

133 pa.field("content_type", pa.utf8()), 

134 pa.field("chunk_type", pa.utf8()), 

135 pa.field("page_start", pa.int32()), 

136 pa.field("page_end", pa.int32()), 

137 pa.field("line_start", pa.int32()), 

138 pa.field("line_end", pa.int32()), 

139 pa.field("chunk", pa.utf8()), 

140 pa.field("chunk_index", pa.int32()), 

141 pa.field("vector", pa.list_(pa.float32(), self._config.embedding_dim)), 

142 ] 

143 ) 

144 

145 def get_meta(self) -> StoreMeta | None: 

146 """Return the persisted store metadata row, or ``None`` if unset.""" 

147 table = self.open_table(META_TABLE) 

148 if table is None: 

149 return None 

150 rows = table.search().limit(1).to_list() 

151 if not rows: 

152 return None 

153 row = rows[0] 

154 return StoreMeta( 

155 embedding_model=row["embedding_model"], 

156 embedding_dim=int(row["embedding_dim"]), 

157 schema_version=int(row["schema_version"]), 

158 updated_at=row["updated_at"], 

159 ) 

160 

161 def _write_meta_unlocked(self, *, embedding_model: str, embedding_dim: int) -> None: 

162 """Overwrite the single ``_meta`` row with the supplied identity. 

163 

164 Caller must hold ``write_lock()``. Args are passed explicitly rather than 

165 re-read from ``self._config`` so the caller can snapshot cfg at a coherent 

166 instant and not race with a concurrent ``set_embedding_model``. 

167 """ 

168 db = self.get_db() 

169 table = ensure_table(db, META_TABLE, _meta_schema()) 

170 _safe_delete_unlocked(table, META_DELETE_ALL_PREDICATE) 

171 table.add( 

172 [ 

173 { 

174 "embedding_model": embedding_model, 

175 "embedding_dim": embedding_dim, 

176 "schema_version": META_SCHEMA_VERSION, 

177 "updated_at": datetime.now(UTC).isoformat(), 

178 } 

179 ] 

180 ) 

181 

182 def _has_chunks(self) -> bool: 

183 """Return True when the chunks table exists and has at least one row.""" 

184 chunks = self.open_table(CHUNKS_TABLE) 

185 return chunks is not None and chunks.count_rows() > 0 

186 

187 def has_chunks(self) -> bool: 

188 """Public predicate: True iff the store currently holds at least one chunk.""" 

189 return self._has_chunks() 

190 

191 def initialize_meta_if_legacy(self) -> bool: 

192 """Pin a legacy store's identity to the current cfg if not already set. 

193 

194 Returns ``True`` when a meta row was just written. No-op when meta already 

195 exists or no chunks are present. This is the path that converts a 

196 pre-upgrade store (chunks present, no ``_meta``) into a gated store. It 

197 snapshots cfg under the write lock to keep the recorded identity coherent 

198 with what the gate is comparing against. 

199 """ 

200 if self.get_meta() is not None: 

201 return False 

202 if not self._has_chunks(): 

203 return False 

204 embedding_model = self._config.embedding_model 

205 embedding_dim = self._config.embedding_dim 

206 with write_lock(): 

207 # Re-check under the lock so two callers do not both warn-and-write. 

208 if self.get_meta() is not None: 

209 return False 

210 log.warning( 

211 "Legacy store has chunks but no _meta row. Initializing _meta from " 

212 "the current configuration (embedding_model=%s, embedding_dim=%d). " 

213 "If you changed embedding_model before upgrading, run `lilbee rebuild` " 

214 "to ensure the store is consistent.", 

215 embedding_model, 

216 embedding_dim, 

217 ) 

218 self._write_meta_unlocked(embedding_model=embedding_model, embedding_dim=embedding_dim) 

219 return True 

220 

221 def _ensure_embedding_compat(self) -> None: 

222 """Raise when the persisted embedding identity drifts from cfg. 

223 

224 Pure check, no side effects. Migration of legacy stores (chunks present, 

225 no ``_meta``) is the caller's responsibility via ``initialize_meta_if_legacy``; 

226 rewriting a legacy bare-repo ``_meta`` row to the canonical full ref is 

227 the caller's responsibility via ``canonicalize_meta_if_legacy``. This 

228 method stays safe to call from inside an existing ``write_lock()`` (no 

229 recursive lock attempt). cfg fields are snapshotted at entry so the 

230 comparison is coherent even if another thread mutates them mid-call. 

231 """ 

232 current_model = self._config.embedding_model 

233 current_dim = self._config.embedding_dim 

234 meta = self.get_meta() 

235 if meta is None: 

236 return 

237 if refs_compatible( 

238 meta["embedding_model"], current_model, meta["embedding_dim"], current_dim 

239 ): 

240 return 

241 raise EmbeddingModelMismatchError( 

242 _embedding_mismatch_message( 

243 persisted_model=meta["embedding_model"], 

244 persisted_dim=meta["embedding_dim"], 

245 current_model=current_model, 

246 current_dim=current_dim, 

247 ) 

248 ) 

249 

250 def _needs_canonical_meta_rewrite( 

251 self, meta: StoreMeta | None, current_model: str, current_dim: int 

252 ) -> bool: 

253 """True iff *meta* is the legacy form and refs-compatible with current cfg.""" 

254 if meta is None or meta["embedding_model"] == current_model: 

255 return False 

256 return refs_compatible( 

257 meta["embedding_model"], current_model, meta["embedding_dim"], current_dim 

258 ) 

259 

260 def canonicalize_meta_if_legacy(self) -> bool: 

261 """Rewrite a legacy bare-repo ``_meta`` row to the canonical full ref. 

262 

263 Pre-canonical lilbee persisted only ``<org>/<repo>`` in 

264 ``_meta.embedding_model``. The current code persists the full 

265 ``<org>/<repo>/<filename>.gguf``. When the two refer to the same 

266 model under :func:`refs_compatible` but differ as raw strings, the 

267 meta row is rewritten so the legacy name never surfaces. Returns 

268 ``True`` on write; ``False`` when missing, already canonical, or 

269 incompatible (the gate handles incompatibility). 

270 """ 

271 current_model = self._config.embedding_model 

272 current_dim = self._config.embedding_dim 

273 if not self._needs_canonical_meta_rewrite(self.get_meta(), current_model, current_dim): 

274 return False 

275 with write_lock(): 

276 meta = self.get_meta() # re-read under the lock for racing callers 

277 if not self._needs_canonical_meta_rewrite(meta, current_model, current_dim): 

278 return False 

279 assert meta is not None # filtered above # noqa: S101 

280 log.info( 

281 "Migrating legacy embedding ref in store meta: %r -> %r", 

282 meta["embedding_model"], 

283 current_model, 

284 ) 

285 self._write_meta_unlocked(embedding_model=current_model, embedding_dim=current_dim) 

286 return True 

287 

288 def get_db(self) -> lancedb.DBConnection: 

289 if self._db is None: 

290 import lancedb as _lancedb 

291 

292 self._config.lancedb_dir.mkdir(parents=True, exist_ok=True) 

293 self._db = _lancedb.connect( 

294 str(self._config.lancedb_dir), 

295 read_consistency_interval=READ_CONSISTENCY_INTERVAL, 

296 ) 

297 return self._db 

298 

299 def open_table(self, name: str) -> lancedb.table.Table | None: 

300 """Open a table if it exists, otherwise return None.""" 

301 db = self.get_db() 

302 if name not in _table_names(db): 

303 return None 

304 return db.open_table(name) 

305 

306 def ensure_fts_index(self) -> None: 

307 """Create the chunks FTS index, or run ``optimize()`` once it exists. 

308 

309 ``optimize()`` folds newly added rows into the FTS index and also 

310 runs LanceDB's default compaction + version pruning (default prune 

311 window: 7 days). Work scales with recent deltas rather than total 

312 chunk count, so large corpora no longer pay the full 

313 ``create_fts_index(replace=True)`` rebuild cost on every sync. 

314 """ 

315 with write_lock(): 

316 table = self.open_table(CHUNKS_TABLE) 

317 if table is None: 

318 return 

319 try: 

320 if _has_fts_index(table): 

321 table.optimize() 

322 log.debug("FTS index optimized on '%s'", CHUNKS_TABLE) 

323 else: 

324 table.create_fts_index("chunk", replace=False) 

325 log.debug("FTS index created on '%s'", CHUNKS_TABLE) 

326 self._fts_ready = True 

327 except Exception: 

328 log.debug("FTS index ensure failed (empty table?)", exc_info=True) 

329 

330 def add_chunks(self, records: list[dict]) -> int: 

331 """Add chunk records to the store. Returns count added. 

332 

333 Raises ``EmbeddingModelMismatchError`` if the persisted ``_meta`` row was 

334 written under a different embedding model than the current ``cfg``. On the 

335 first write to a fresh store, ``_meta`` is initialized from the current cfg. 

336 

337 The gate runs inside the write lock and uses a single cfg snapshot so a 

338 concurrent ``set_embedding_model`` cannot slip a write in past a stale 

339 compatibility check. 

340 """ 

341 with write_lock(): 

342 embedding_model = self._config.embedding_model 

343 embedding_dim = self._config.embedding_dim 

344 self._ensure_embedding_compat() 

345 self._fts_ready = False 

346 if not records: 

347 return 0 

348 for rec in records: 

349 vec = rec.get("vector", []) 

350 if len(vec) != embedding_dim: 

351 raise ValueError( 

352 f"Vector dimension mismatch: expected {embedding_dim}, " 

353 f"got {len(vec)} (source={rec.get('source', '?')})" 

354 ) 

355 db = self.get_db() 

356 table = ensure_table(db, CHUNKS_TABLE, self._chunks_schema()) 

357 table.add(records) 

358 if self.get_meta() is None: 

359 self._write_meta_unlocked( 

360 embedding_model=embedding_model, embedding_dim=embedding_dim 

361 ) 

362 return len(records) 

363 

364 def bm25_probe(self, query_text: str, top_k: int = 5) -> list[SearchChunk]: 

365 """Quick BM25-only search for confidence checking. Returns up to top_k results.""" 

366 table = self.open_table(CHUNKS_TABLE) 

367 if table is None: 

368 return [] 

369 if not self._fts_ready: 

370 self.ensure_fts_index() 

371 if not self._fts_ready: 

372 return [] # pragma: no cover 

373 try: 

374 rows = table.search(query_text, query_type="fts").limit(top_k).to_list() 

375 return [SearchChunk(**r) for r in rows] 

376 except Exception: 

377 log.debug("BM25 probe failed", exc_info=True) 

378 return [] 

379 

380 def search( 

381 self, 

382 query_vector: list[float], 

383 top_k: int | None = None, 

384 max_distance: float | None = None, 

385 query_text: str | None = None, 

386 chunk_type: str | None = None, 

387 ) -> list[SearchChunk]: 

388 """Search for similar chunks. Hybrid when FTS available, else vector-only. 

389 

390 Results with distance > max_distance are filtered out (vector-only path). 

391 Pass max_distance=0 to disable filtering. 

392 When *chunk_type* is set, only chunks of that type ("raw" or "wiki") are returned. 

393 

394 Raises ``EmbeddingModelMismatchError`` if the persisted ``_meta`` row was 

395 written under a different embedding model than the current ``cfg``. 

396 """ 

397 if top_k is None: 

398 top_k = self._config.top_k 

399 if max_distance is None: 

400 max_distance = self._config.max_distance 

401 table = self.open_table(CHUNKS_TABLE) 

402 if table is None: 

403 return [] 

404 self.initialize_meta_if_legacy() 

405 self.canonicalize_meta_if_legacy() 

406 self._ensure_embedding_compat() 

407 

408 if query_text and not self._fts_ready: 

409 self.ensure_fts_index() 

410 

411 if query_text and self._fts_ready: 

412 try: 

413 return _hybrid_search(table, query_text, query_vector, top_k, chunk_type) 

414 except Exception: 

415 log.debug("Hybrid search failed, falling back to vector-only", exc_info=True) 

416 

417 candidate_k = top_k * self._config.candidate_multiplier 

418 query = table.search(query_vector).metric("cosine").limit(candidate_k) 

419 if chunk_type: 

420 query = query.where(_chunk_type_predicate(chunk_type)) 

421 rows = query.to_list() 

422 log.debug( 

423 "Vector search: query=%r, candidates=%d, max_distance=%.2f", 

424 query_text or "vector-only", 

425 len(rows), 

426 max_distance, 

427 ) 

428 if rows: 

429 distances = [r.get("distance", 0) for r in rows[:5]] 

430 log.debug("Top 5 distances: %s", distances) 

431 results = [SearchChunk(**r) for r in rows] 

432 return self._filter_and_rerank(results, query_vector, top_k, max_distance) 

433 

434 def _filter_and_rerank( 

435 self, 

436 results: list[SearchChunk], 

437 query_vector: list[float], 

438 top_k: int, 

439 max_distance: float, 

440 ) -> list[SearchChunk]: 

441 """Apply the configured distance filter, then MMR-rerank down to top_k.""" 

442 if max_distance > 0: 

443 before = len(results) 

444 if self._config.adaptive_threshold: 

445 results = self._adaptive_filter(results, top_k, max_distance) 

446 filter_name = "adaptive" 

447 else: 

448 results = self._fixed_filter(results, max_distance) 

449 filter_name = "fixed" 

450 log.debug( 

451 "After %s filter: %d/%d results, threshold=%.2f", 

452 filter_name, 

453 len(results), 

454 before, 

455 max_distance, 

456 ) 

457 if len(results) > top_k: 

458 results = mmr_rerank(query_vector, results, top_k, self._config.mmr_lambda) 

459 return results 

460 

461 def _adaptive_filter( 

462 self, results: list[SearchChunk], top_k: int, initial_threshold: float 

463 ) -> list[SearchChunk]: 

464 """Widen cosine distance threshold when too few results. 

465 Inspired by grantflow's (grantflow-ai/grantflow) adaptive retrieval 

466 pattern which widens thresholds on recursive retry. Step size and 

467 cap are configurable via ``self._config.adaptive_threshold_step``. 

468 

469 Pre-sorts results by distance for a single-pass cutoff search. 

470 Step size is ``self._config.adaptive_threshold_step`` (default 0.2). 

471 """ 

472 cap = max(initial_threshold, _MAX_THRESHOLD) 

473 step = self._config.adaptive_threshold_step 

474 

475 sorted_results = sorted(results, key=_get_distance) 

476 

477 threshold = initial_threshold 

478 for _ in range(_MAX_FILTER_ITERATIONS): 

479 if threshold > cap: 

480 break 

481 cutoff = _count_within_threshold(sorted_results, threshold) 

482 if cutoff >= top_k: 

483 return sorted_results[:cutoff] 

484 threshold += step 

485 # Final pass at cap 

486 cutoff = _count_within_threshold(sorted_results, cap) 

487 return sorted_results[:cutoff] 

488 

489 def _fixed_filter(self, results: list[SearchChunk], threshold: float) -> list[SearchChunk]: 

490 """Simple fixed threshold filter - keep only results within distance threshold.""" 

491 return [r for r in results if _get_distance(r) <= threshold] 

492 

493 def get_chunks_by_source(self, source: str) -> list[SearchChunk]: 

494 """Return every chunk whose ``source`` equals *source*.""" 

495 table = self.open_table(CHUNKS_TABLE) 

496 if table is None: 

497 return [] 

498 escaped = escape_sql_string(source) 

499 try: 

500 rows = table.search().where(f"source = '{escaped}'").limit(None).to_list() 

501 except Exception: 

502 # FTS-enabled tables return a query builder that cannot 

503 # handle .where() on arbitrary columns; fall through to a 

504 # pyarrow.compute filter on the Arrow table so the source 

505 # match runs in C++ without materializing non-matching rows. 

506 log.debug("get_chunks_by_source search() failed, using Arrow fallback", exc_info=True) 

507 arrow_tbl = table.to_arrow() 

508 filtered = arrow_tbl.filter(pc.equal(arrow_tbl["source"], source)) 

509 rows = filtered.to_pylist() 

510 return [SearchChunk(**r) for r in rows] 

511 

512 def delete_by_source(self, source: str) -> None: 

513 """Delete all chunks from a given source file.""" 

514 with write_lock(): 

515 table = self.open_table(CHUNKS_TABLE) 

516 if table is not None: 

517 _safe_delete_unlocked(table, f"source = '{escape_sql_string(source)}'") 

518 self._invalidate_source_cache() 

519 

520 def get_sources( 

521 self, 

522 *, 

523 search: str | None = None, 

524 limit: int | None = None, 

525 offset: int = 0, 

526 ) -> list[SourceRecord]: 

527 """Return source records, filtered by *search* and sliced by offset/limit.""" 

528 table = self.open_table(SOURCES_TABLE) 

529 if table is None: 

530 return [] 

531 query = table.search() 

532 where = _sources_search_filter(search) 

533 if where is not None: 

534 query = query.where(where) 

535 if offset: 

536 query = query.offset(offset) 

537 query = query.limit(limit) 

538 result: list[SourceRecord] = query.to_list() # type: ignore[assignment] 

539 return result 

540 

541 def count_sources(self, *, search: str | None = None) -> int: 

542 """Count tracked sources matching *search* without materializing rows.""" 

543 table = self.open_table(SOURCES_TABLE) 

544 if table is None: 

545 return 0 

546 where = _sources_search_filter(search) 

547 count: int = table.count_rows() if where is None else table.count_rows(filter=where) 

548 return count 

549 

550 def upsert_source( 

551 self, 

552 filename: str, 

553 file_hash: str, 

554 chunk_count: int, 

555 source_type: str = "document", 

556 ) -> None: 

557 """Add or update a source file tracking record.""" 

558 with write_lock(): 

559 db = self.get_db() 

560 table = ensure_table(db, SOURCES_TABLE, _sources_schema()) 

561 _safe_delete_unlocked(table, f"filename = '{escape_sql_string(filename)}'") 

562 table.add( 

563 [ 

564 { 

565 "filename": filename, 

566 "file_hash": file_hash, 

567 "ingested_at": datetime.now(UTC).isoformat(), 

568 "chunk_count": chunk_count, 

569 "source_type": source_type, 

570 } 

571 ] 

572 ) 

573 self._invalidate_source_cache() 

574 

575 def delete_source(self, filename: str) -> None: 

576 """Remove a source file tracking record.""" 

577 with write_lock(): 

578 table = self.open_table(SOURCES_TABLE) 

579 if table is not None: 

580 _safe_delete_unlocked(table, f"filename = '{escape_sql_string(filename)}'") 

581 self._invalidate_source_cache() 

582 

583 def remove_documents( 

584 self, 

585 names: list[str], 

586 *, 

587 delete_files: bool = False, 

588 documents_dir: Path | None = None, 

589 ) -> RemoveResult: 

590 """Remove documents from the knowledge base by source name. 

591 Looks up known sources, deletes chunks and source records for each. 

592 If *delete_files* is True, resolves the path and verifies it is 

593 contained within *documents_dir* before unlinking (path traversal guard). 

594 

595 Returns a RemoveResult with removed and not_found lists. 

596 """ 

597 if documents_dir is None: 

598 documents_dir = self._config.documents_dir 

599 

600 known = {s["filename"] for s in self.get_sources()} 

601 removed: list[str] = [] 

602 not_found: list[str] = [] 

603 

604 for name in names: 

605 if name not in known: 

606 not_found.append(name) 

607 continue 

608 self.delete_by_source(name) 

609 self.delete_source(name) 

610 removed.append(name) 

611 if delete_files: 

612 try: 

613 path = validate_path_within(documents_dir / name, documents_dir) 

614 except ValueError: 

615 log.warning("Path traversal blocked: %s escapes %s", name, documents_dir) 

616 continue 

617 if path.exists(): 

618 path.unlink() 

619 

620 return RemoveResult(removed=removed, not_found=not_found) 

621 

622 def clear_table(self, name: str, predicate: str) -> None: 

623 """Delete rows matching *predicate* from *name*. Acquires write lock.""" 

624 with write_lock(): 

625 table = self.open_table(name) 

626 if table is not None: 

627 _safe_delete_unlocked(table, predicate) 

628 

629 def add_citations(self, records: list[CitationRecord]) -> int: 

630 """Add citation records to the store. Returns count added.""" 

631 if not records: 

632 return 0 

633 with write_lock(): 

634 db = self.get_db() 

635 table = ensure_table(db, CITATIONS_TABLE, _citations_schema()) 

636 table.add(records) 

637 return len(records) 

638 

639 def get_citations_for_wiki(self, wiki_source: str) -> list[CitationRecord]: 

640 """Get all citations for a wiki page.""" 

641 table = self.open_table(CITATIONS_TABLE) 

642 if table is None: 

643 return [] 

644 escaped = escape_sql_string(wiki_source) 

645 rows: list[CitationRecord] = table.search().where(f"wiki_source = '{escaped}'").to_list() 

646 return rows 

647 

648 def get_citations_for_source(self, source_filename: str) -> list[CitationRecord]: 

649 """Get all citations that reference a source document (reverse lookup).""" 

650 table = self.open_table(CITATIONS_TABLE) 

651 if table is None: 

652 return [] 

653 escaped = escape_sql_string(source_filename) 

654 rows: list[CitationRecord] = ( 

655 table.search().where(f"source_filename = '{escaped}'").to_list() 

656 ) 

657 return rows 

658 

659 def delete_citations_for_wiki(self, wiki_source: str) -> None: 

660 """Delete all citations for a wiki page (used before regeneration).""" 

661 self.clear_table( 

662 CITATIONS_TABLE, 

663 f"wiki_source = '{escape_sql_string(wiki_source)}'", 

664 ) 

665 

666 def close(self) -> None: 

667 """Release the database connection and reset state.""" 

668 self._db = None 

669 self._fts_ready = False 

670 

671 def drop_all(self) -> None: 

672 """Drop all tables -- used by rebuild.""" 

673 with write_lock(): 

674 self._fts_ready = False 

675 db = self.get_db() 

676 for name in _table_names(db): 

677 db.drop_table(name) 

678 self._invalidate_source_cache()