Coverage for src / lilbee / data / store / core.py: 100%

486 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-06-28 01:01 +0000

1"""The ``Store`` class: high-level LanceDB read/write API used across lilbee.""" 

2 

3from __future__ import annotations 

4 

5import logging 

6from collections.abc import Callable 

7from datetime import UTC, datetime 

8from pathlib import Path 

9from typing import TYPE_CHECKING 

10 

11import pyarrow as pa 

12import pyarrow.compute as pc 

13 

14from lilbee.core.config import ( 

15 CHUNKS_TABLE, 

16 CITATIONS_TABLE, 

17 MEMORIES_TABLE, 

18 META_TABLE, 

19 PAGE_TEXTS_TABLE, 

20 SOURCES_TABLE, 

21 Config, 

22) 

23from lilbee.core.security import validate_path_within 

24from lilbee.runtime.lock import write_lock 

25 

26from .lance_helpers import ( 

27 _chunk_type_predicate, 

28 _has_fts_index, 

29 _has_vector_index, 

30 _safe_delete_unlocked, 

31 _sources_search_filter, 

32 _table_names, 

33 ensure_table, 

34 escape_sql_string, 

35 refs_compatible, 

36) 

37from .ranking import mmr_rerank 

38from .schema import _citations_schema, _meta_schema, _page_texts_schema, _sources_schema 

39from .types import ( 

40 META_DELETE_ALL_PREDICATE, 

41 META_SCHEMA_VERSION, 

42 READ_CONSISTENCY_INTERVAL, 

43 ChunkType, 

44 CitationRecord, 

45 EmbeddingModelMismatchError, 

46 MemoryKind, 

47 MemoryRow, 

48 PageTextRecord, 

49 RemoveResult, 

50 SearchChunk, 

51 SourceRecord, 

52 SourceType, 

53 StoreMeta, 

54) 

55 

56if TYPE_CHECKING: 

57 import lancedb 

58 import lancedb.table 

59 

60log = logging.getLogger(__name__) 

61 

62 

63def _hybrid_search( 

64 table: lancedb.table.Table, 

65 query_text: str, 

66 query_vector: list[float], 

67 top_k: int, 

68 chunk_type: ChunkType | None = None, 

69) -> list[SearchChunk]: 

70 """Run hybrid (vector + FTS) search with RRF reranking. 

71 

72 When ``chunk_type`` is set, the predicate is pushed into the query so 

73 the limit applies *after* the type filter. Post-filtering would 

74 silently starve wiki-only queries whose matches live past the top-K 

75 hybrid window. 

76 """ 

77 from lancedb.rerankers import RRFReranker 

78 

79 query = ( 

80 table.search(query_type="hybrid") 

81 .vector(query_vector) 

82 .text(query_text) 

83 .rerank(RRFReranker()) 

84 ) 

85 if chunk_type: 

86 query = query.where(_chunk_type_predicate(chunk_type)) 

87 rows = query.limit(top_k).to_list() 

88 return [SearchChunk(**r) for r in rows] 

89 

90 

91_MAX_THRESHOLD = 1.0 

92_MAX_FILTER_ITERATIONS = 20 # safety cap to prevent runaway loops 

93 

94# Vector ANN index. IVF_PQ compresses vectors so search scales to millions; 

95# refine_factor re-ranks the PQ candidates against full vectors to recover recall. 

96_VECTOR_METRIC = "cosine" 

97_ANN_INDEX_TYPE = "IVF_PQ" 

98_ANN_NPROBES = 20 

99_ANN_REFINE_FACTOR = 10 

100 

101 

102def _get_distance(chunk: SearchChunk) -> float: 

103 """Extract distance as a sortable float (inf for None).""" 

104 return chunk.distance if chunk.distance is not None else float("inf") 

105 

106 

107def _count_within_threshold(sorted_results: list[SearchChunk], threshold: float) -> int: 

108 """Count results whose distance is within the given threshold.""" 

109 for i, r in enumerate(sorted_results): 

110 if _get_distance(r) > threshold: 

111 return i 

112 return len(sorted_results) 

113 

114 

115class Store: 

116 """LanceDB vector store: wraps all DB operations with config-driven defaults.""" 

117 

118 def __init__(self, config: Config) -> None: 

119 self._config = config 

120 self._fts_ready: bool = False 

121 self._db: lancedb.DBConnection | None = None 

122 # Cache of {filename: ingested_at} rebuilt only when sources 

123 # mutate; callers (temporal filter) hit it per-query. 

124 self._source_ingested_cache: dict[str, str] | None = None 

125 

126 def _invalidate_source_cache(self) -> None: 

127 """Drop the cached {filename: ingested_at} map.""" 

128 self._source_ingested_cache = None 

129 

130 def source_ingested_at_map(self) -> dict[str, str]: 

131 """Return {filename: ingested_at} for every source, cached until mutation. 

132 

133 Best-effort: a reader racing a concurrent invalidation can store a 

134 pre-mutation snapshot. The only consumer (temporal query filter) 

135 treats a missing/stale entry as "do not filter," so staleness 

136 degrades ranking precision, never correctness. 

137 """ 

138 if self._source_ingested_cache is not None: 

139 return self._source_ingested_cache 

140 mapping = {s["filename"]: s.get("ingested_at", "") for s in self.get_sources()} 

141 self._source_ingested_cache = mapping 

142 return mapping 

143 

144 def _chunks_schema(self) -> pa.Schema: 

145 return pa.schema( 

146 [ 

147 pa.field("source", pa.utf8()), 

148 pa.field("content_type", pa.utf8()), 

149 pa.field("chunk_type", pa.utf8()), 

150 pa.field("page_start", pa.int32()), 

151 pa.field("page_end", pa.int32()), 

152 pa.field("line_start", pa.int32()), 

153 pa.field("line_end", pa.int32()), 

154 pa.field("chunk", pa.utf8()), 

155 pa.field("chunk_index", pa.int32()), 

156 pa.field("vector", pa.list_(pa.float32(), self._config.embedding_dim)), 

157 ] 

158 ) 

159 

160 def get_meta(self) -> StoreMeta | None: 

161 """Return the persisted store metadata row, or ``None`` if unset.""" 

162 table = self.open_table(META_TABLE) 

163 if table is None: 

164 return None 

165 rows = table.search().limit(1).to_list() 

166 if not rows: 

167 return None 

168 row = rows[0] 

169 return StoreMeta( 

170 embedding_model=row["embedding_model"], 

171 embedding_dim=int(row["embedding_dim"]), 

172 schema_version=int(row["schema_version"]), 

173 updated_at=row["updated_at"], 

174 ) 

175 

176 def _write_meta_unlocked(self, *, embedding_model: str, embedding_dim: int) -> None: 

177 """Overwrite the single ``_meta`` row with the supplied identity. 

178 

179 Caller must hold ``write_lock()``. Args are passed explicitly rather than 

180 re-read from ``self._config`` so the caller can snapshot cfg at a coherent 

181 instant and not race with a concurrent ``set_embedding_model``. 

182 """ 

183 db = self.get_db() 

184 table = ensure_table(db, META_TABLE, _meta_schema()) 

185 _safe_delete_unlocked(table, META_DELETE_ALL_PREDICATE) 

186 table.add( 

187 [ 

188 { 

189 "embedding_model": embedding_model, 

190 "embedding_dim": embedding_dim, 

191 "schema_version": META_SCHEMA_VERSION, 

192 "updated_at": datetime.now(UTC).isoformat(), 

193 } 

194 ] 

195 ) 

196 

197 def _has_chunks(self) -> bool: 

198 """Return True when the chunks table exists and has at least one row.""" 

199 chunks = self.open_table(CHUNKS_TABLE) 

200 return chunks is not None and chunks.count_rows() > 0 

201 

202 def has_chunks(self) -> bool: 

203 """Public predicate: True iff the store currently holds at least one chunk.""" 

204 return self._has_chunks() 

205 

206 def initialize_meta_if_legacy(self) -> bool: 

207 """Pin a legacy store's identity to the current cfg if not already set. 

208 

209 Returns ``True`` when a meta row was just written. No-op when meta already 

210 exists or no chunks are present. This is the path that converts a 

211 pre-upgrade store (chunks present, no ``_meta``) into a gated store. It 

212 snapshots cfg under the write lock to keep the recorded identity coherent 

213 with what the gate is comparing against. 

214 """ 

215 if self.get_meta() is not None: 

216 return False 

217 if not self._has_chunks(): 

218 return False 

219 embedding_model = self._config.embedding_model 

220 embedding_dim = self._config.embedding_dim 

221 with write_lock(): 

222 # Re-check under the lock so two callers do not both warn-and-write. 

223 if self.get_meta() is not None: 

224 return False 

225 log.warning( 

226 "Legacy store has chunks but no _meta row. Initializing _meta from " 

227 "the current configuration (embedding_model=%s, embedding_dim=%d). " 

228 "If you changed embedding_model before upgrading, run `lilbee rebuild` " 

229 "to ensure the store is consistent.", 

230 embedding_model, 

231 embedding_dim, 

232 ) 

233 self._write_meta_unlocked(embedding_model=embedding_model, embedding_dim=embedding_dim) 

234 return True 

235 

236 def _ensure_embedding_compat(self) -> None: 

237 """Raise when the persisted embedding identity drifts from cfg. 

238 

239 Pure check, no side effects. Migration of legacy stores (chunks present, 

240 no ``_meta``) is the caller's responsibility via ``initialize_meta_if_legacy``; 

241 rewriting a legacy bare-repo ``_meta`` row to the canonical full ref is 

242 the caller's responsibility via ``canonicalize_meta_if_legacy``. This 

243 method stays safe to call from inside an existing ``write_lock()`` (no 

244 recursive lock attempt). cfg fields are snapshotted at entry so the 

245 comparison is coherent even if another thread mutates them mid-call. 

246 """ 

247 current_model = self._config.embedding_model 

248 current_dim = self._config.embedding_dim 

249 meta = self.get_meta() 

250 if meta is None: 

251 return 

252 if refs_compatible( 

253 meta["embedding_model"], current_model, meta["embedding_dim"], current_dim 

254 ): 

255 return 

256 raise EmbeddingModelMismatchError( 

257 persisted_model=meta["embedding_model"], 

258 persisted_dim=meta["embedding_dim"], 

259 current_model=current_model, 

260 current_dim=current_dim, 

261 ) 

262 

263 def assert_embedding_compatible(self) -> None: 

264 """Run the full embedding-identity gate (legacy init, canonicalize, check). 

265 

266 Mirrors the gate ``search`` applies. Callers that write under a fresh 

267 embedder (import) use this to fail before any destructive work when the 

268 store was built by a different model. 

269 """ 

270 self.initialize_meta_if_legacy() 

271 self.canonicalize_meta_if_legacy() 

272 self._ensure_embedding_compat() 

273 

274 def _needs_canonical_meta_rewrite( 

275 self, meta: StoreMeta | None, current_model: str, current_dim: int 

276 ) -> bool: 

277 """True iff *meta* is the legacy form and refs-compatible with current cfg.""" 

278 if meta is None or meta["embedding_model"] == current_model: 

279 return False 

280 return refs_compatible( 

281 meta["embedding_model"], current_model, meta["embedding_dim"], current_dim 

282 ) 

283 

284 def canonicalize_meta_if_legacy(self) -> bool: 

285 """Rewrite a legacy bare-repo ``_meta`` row to the canonical full ref. 

286 

287 Pre-canonical lilbee persisted only ``<org>/<repo>`` in 

288 ``_meta.embedding_model``. The current code persists the full 

289 ``<org>/<repo>/<filename>.gguf``. When the two refer to the same 

290 model under :func:`refs_compatible` but differ as raw strings, the 

291 meta row is rewritten so the legacy name never surfaces. Returns 

292 ``True`` on write; ``False`` when missing, already canonical, or 

293 incompatible (the gate handles incompatibility). 

294 """ 

295 current_model = self._config.embedding_model 

296 current_dim = self._config.embedding_dim 

297 if not self._needs_canonical_meta_rewrite(self.get_meta(), current_model, current_dim): 

298 return False 

299 with write_lock(): 

300 meta = self.get_meta() # re-read under the lock for racing callers 

301 if not self._needs_canonical_meta_rewrite(meta, current_model, current_dim): 

302 return False 

303 assert meta is not None # filtered above # noqa: S101 

304 log.info( 

305 "Migrating legacy embedding ref in store meta: %r -> %r", 

306 meta["embedding_model"], 

307 current_model, 

308 ) 

309 self._write_meta_unlocked(embedding_model=current_model, embedding_dim=current_dim) 

310 return True 

311 

312 def get_db(self) -> lancedb.DBConnection: 

313 if self._db is None: 

314 import lancedb as _lancedb 

315 

316 self._config.lancedb_dir.mkdir(parents=True, exist_ok=True) 

317 self._db = _lancedb.connect( 

318 str(self._config.lancedb_dir), 

319 read_consistency_interval=READ_CONSISTENCY_INTERVAL, 

320 ) 

321 return self._db 

322 

323 def open_table(self, name: str) -> lancedb.table.Table | None: 

324 """Open a table if it exists, otherwise return None.""" 

325 db = self.get_db() 

326 if name not in _table_names(db): 

327 return None 

328 return db.open_table(name) 

329 

330 def ensure_fts_index(self) -> None: 

331 """Create the chunks FTS index, or run ``optimize()`` once it exists. 

332 

333 ``optimize()`` folds newly added rows into the FTS index and also 

334 runs LanceDB's default compaction + version pruning (default prune 

335 window: 7 days). Work scales with recent deltas rather than total 

336 chunk count, so large corpora no longer pay the full 

337 ``create_fts_index(replace=True)`` rebuild cost on every sync. 

338 """ 

339 with write_lock(): 

340 table = self.open_table(CHUNKS_TABLE) 

341 if table is None: 

342 return 

343 try: 

344 if _has_fts_index(table): 

345 table.optimize() 

346 log.debug("FTS index optimized on '%s'", CHUNKS_TABLE) 

347 else: 

348 table.create_fts_index("chunk", replace=False) 

349 log.debug("FTS index created on '%s'", CHUNKS_TABLE) 

350 self._fts_ready = True 

351 except Exception: 

352 log.debug("FTS index ensure failed (empty table?)", exc_info=True) 

353 

354 def ensure_vector_index(self, *, force: bool = False) -> bool: 

355 """Build or refresh the ANN vector index when the corpus is large enough. 

356 

357 Below ``cfg.ann_index_threshold`` (or when it is 0) the store keeps exact 

358 flat search, which is faster and exact for small vaults and is all a 

359 laptop needs. Once an index exists, ``optimize()`` folds new rows in. 

360 Pass ``force=True`` to build regardless of the threshold (publish flow). 

361 Returns True when an index was created or refreshed. 

362 """ 

363 threshold = self._config.ann_index_threshold 

364 with write_lock(): 

365 table = self.open_table(CHUNKS_TABLE) 

366 if table is None: 

367 return False 

368 if _has_vector_index(table): 

369 table.optimize() 

370 log.debug("Vector index optimized on '%s'", CHUNKS_TABLE) 

371 return True 

372 if not force and (threshold <= 0 or table.count_rows() < threshold): 

373 return False 

374 try: 

375 table.create_index(metric=_VECTOR_METRIC, index_type=_ANN_INDEX_TYPE) 

376 log.info("Vector ANN index created on '%s'", CHUNKS_TABLE) 

377 return True 

378 except Exception: 

379 log.debug("Vector index build failed (too few rows?)", exc_info=True) 

380 return False 

381 

382 def add_chunks(self, records: list[dict]) -> int: 

383 """Add chunk records to the store. Returns count added. 

384 

385 Raises ``EmbeddingModelMismatchError`` if the persisted ``_meta`` row was 

386 written under a different embedding model than the current ``cfg``. On the 

387 first write to a fresh store, ``_meta`` is initialized from the current cfg. 

388 

389 The gate runs inside the write lock and uses a single cfg snapshot so a 

390 concurrent ``set_embedding_model`` cannot slip a write in past a stale 

391 compatibility check. 

392 """ 

393 with write_lock(): 

394 embedding_model = self._config.embedding_model 

395 embedding_dim = self._config.embedding_dim 

396 self._ensure_embedding_compat() 

397 self._fts_ready = False 

398 if not records: 

399 return 0 

400 for rec in records: 

401 vec = rec.get("vector", []) 

402 if len(vec) != embedding_dim: 

403 raise ValueError( 

404 f"Vector dimension mismatch: expected {embedding_dim}, " 

405 f"got {len(vec)} (source={rec.get('source', '?')})" 

406 ) 

407 db = self.get_db() 

408 table = ensure_table(db, CHUNKS_TABLE, self._chunks_schema()) 

409 table.add(records) 

410 if self.get_meta() is None: 

411 self._write_meta_unlocked( 

412 embedding_model=embedding_model, embedding_dim=embedding_dim 

413 ) 

414 return len(records) 

415 

416 def bm25_probe(self, query_text: str, top_k: int = 5) -> list[SearchChunk]: 

417 """Quick BM25-only search for confidence checking. Returns up to top_k results.""" 

418 table = self.open_table(CHUNKS_TABLE) 

419 if table is None: 

420 return [] 

421 if not self._fts_ready: 

422 self.ensure_fts_index() 

423 if not self._fts_ready: 

424 return [] 

425 try: 

426 rows = table.search(query_text, query_type="fts").limit(top_k).to_list() 

427 return [SearchChunk(**r) for r in rows] 

428 except Exception: 

429 log.debug("BM25 probe failed", exc_info=True) 

430 return [] 

431 

432 def search( 

433 self, 

434 query_vector: list[float], 

435 top_k: int | None = None, 

436 max_distance: float | None = None, 

437 query_text: str | None = None, 

438 chunk_type: ChunkType | None = None, 

439 ) -> list[SearchChunk]: 

440 """Search for similar chunks. Hybrid when FTS available, else vector-only. 

441 

442 Results with distance > max_distance are filtered out (vector-only path). 

443 Pass max_distance=0 to disable filtering. 

444 When *chunk_type* is set, only chunks of that type ("raw" or "wiki") are returned. 

445 

446 Raises ``EmbeddingModelMismatchError`` if the persisted ``_meta`` row was 

447 written under a different embedding model than the current ``cfg``. 

448 """ 

449 if top_k is None: 

450 top_k = self._config.top_k 

451 if max_distance is None: 

452 max_distance = self._config.max_distance 

453 table = self.open_table(CHUNKS_TABLE) 

454 if table is None: 

455 return [] 

456 self.initialize_meta_if_legacy() 

457 self.canonicalize_meta_if_legacy() 

458 self._ensure_embedding_compat() 

459 

460 if query_text and not self._fts_ready: 

461 self.ensure_fts_index() 

462 

463 if query_text and self._fts_ready: 

464 try: 

465 return _hybrid_search(table, query_text, query_vector, top_k, chunk_type) 

466 except Exception: 

467 log.debug("Hybrid search failed, falling back to vector-only", exc_info=True) 

468 

469 candidate_k = top_k * self._config.candidate_multiplier 

470 query = table.search(query_vector).metric(_VECTOR_METRIC).limit(candidate_k) 

471 if _has_vector_index(table): 

472 # IVF_PQ is lossy; probe more partitions and refine against full 

473 # vectors so recall stays close to the exact flat scan. 

474 query = query.nprobes(_ANN_NPROBES).refine_factor(_ANN_REFINE_FACTOR) 

475 if chunk_type: 

476 query = query.where(_chunk_type_predicate(chunk_type)) 

477 rows = query.to_list() 

478 log.debug( 

479 "Vector search: query=%r, candidates=%d, max_distance=%.2f", 

480 query_text or "vector-only", 

481 len(rows), 

482 max_distance, 

483 ) 

484 if rows: 

485 distances = [r.get("distance", 0) for r in rows[:5]] 

486 log.debug("Top 5 distances: %s", distances) 

487 results = [SearchChunk(**r) for r in rows] 

488 return self._filter_and_rerank(results, query_vector, top_k, max_distance) 

489 

490 def _filter_and_rerank( 

491 self, 

492 results: list[SearchChunk], 

493 query_vector: list[float], 

494 top_k: int, 

495 max_distance: float, 

496 ) -> list[SearchChunk]: 

497 """Apply the configured distance filter, then MMR-rerank down to top_k.""" 

498 if max_distance > 0: 

499 before = len(results) 

500 if self._config.adaptive_threshold: 

501 results = self._adaptive_filter(results, top_k, max_distance) 

502 filter_name = "adaptive" 

503 else: 

504 results = self._fixed_filter(results, max_distance) 

505 filter_name = "fixed" 

506 log.debug( 

507 "After %s filter: %d/%d results, threshold=%.2f", 

508 filter_name, 

509 len(results), 

510 before, 

511 max_distance, 

512 ) 

513 if len(results) > top_k: 

514 results = mmr_rerank(query_vector, results, top_k, self._config.mmr_lambda) 

515 return results 

516 

517 def _adaptive_filter( 

518 self, results: list[SearchChunk], top_k: int, initial_threshold: float 

519 ) -> list[SearchChunk]: 

520 """Widen cosine distance threshold when too few results. 

521 Inspired by grantflow's (grantflow-ai/grantflow) adaptive retrieval 

522 pattern which widens thresholds on recursive retry. Step size and 

523 cap are configurable via ``self._config.adaptive_threshold_step``. 

524 

525 Pre-sorts results by distance for a single-pass cutoff search. 

526 Step size is ``self._config.adaptive_threshold_step`` (default 0.2). 

527 """ 

528 cap = max(initial_threshold, _MAX_THRESHOLD) 

529 step = self._config.adaptive_threshold_step 

530 

531 sorted_results = sorted(results, key=_get_distance) 

532 

533 threshold = initial_threshold 

534 for _ in range(_MAX_FILTER_ITERATIONS): 

535 if threshold > cap: 

536 break 

537 cutoff = _count_within_threshold(sorted_results, threshold) 

538 if cutoff >= top_k: 

539 return sorted_results[:cutoff] 

540 threshold += step 

541 # Final pass at cap 

542 cutoff = _count_within_threshold(sorted_results, cap) 

543 return sorted_results[:cutoff] 

544 

545 def _fixed_filter(self, results: list[SearchChunk], threshold: float) -> list[SearchChunk]: 

546 """Simple fixed threshold filter - keep only results within distance threshold.""" 

547 return [r for r in results if _get_distance(r) <= threshold] 

548 

549 def get_chunks_by_source(self, source: str) -> list[SearchChunk]: 

550 """Return every chunk whose ``source`` equals *source*.""" 

551 table = self.open_table(CHUNKS_TABLE) 

552 if table is None: 

553 return [] 

554 escaped = escape_sql_string(source) 

555 try: 

556 rows = table.search().where(f"source = '{escaped}'").limit(None).to_list() 

557 except Exception: 

558 # FTS-enabled tables return a query builder that cannot 

559 # handle .where() on arbitrary columns; fall through to a 

560 # pyarrow.compute filter on the Arrow table so the source 

561 # match runs in C++ without materializing non-matching rows. 

562 log.debug("get_chunks_by_source search() failed, using Arrow fallback", exc_info=True) 

563 arrow_tbl = table.to_arrow() 

564 filtered = arrow_tbl.filter(pc.equal(arrow_tbl["source"], source)) 

565 rows = filtered.to_pylist() 

566 return [SearchChunk(**r) for r in rows] 

567 

568 def _delete_by_source_unlocked(self, source: str) -> None: 

569 """Delete a source's chunks and page texts. Caller must hold ``write_lock()``.""" 

570 predicate = f"source = '{escape_sql_string(source)}'" 

571 for name in (CHUNKS_TABLE, PAGE_TEXTS_TABLE): 

572 table = self.open_table(name) 

573 if table is not None: 

574 _safe_delete_unlocked(table, predicate) 

575 

576 def delete_by_source(self, source: str) -> None: 

577 """Delete a source's chunks and page texts.""" 

578 with write_lock(): 

579 self._delete_by_source_unlocked(source) 

580 self._invalidate_source_cache() 

581 

582 def add_page_texts(self, records: list[dict]) -> int: 

583 """Add per-page text rows (no vectors). Returns count added.""" 

584 if not records: 

585 return 0 

586 with write_lock(): 

587 db = self.get_db() 

588 table = ensure_table(db, PAGE_TEXTS_TABLE, _page_texts_schema()) 

589 table.add(records) 

590 return len(records) 

591 

592 def get_page_texts(self, source: str | None = None) -> list[PageTextRecord]: 

593 """Return per-page text rows, all or for a single *source*.""" 

594 table = self.open_table(PAGE_TEXTS_TABLE) 

595 if table is None: 

596 return [] 

597 query = table.search() 

598 if source is not None: 

599 query = query.where(f"source = '{escape_sql_string(source)}'") 

600 rows: list[PageTextRecord] = query.limit(None).to_list() 

601 return rows 

602 

603 def page_text_sources(self) -> set[str]: 

604 """Return the distinct sources present in the page-text table.""" 

605 table = self.open_table(PAGE_TEXTS_TABLE) 

606 if table is None: 

607 return set() 

608 return {row["source"] for row in table.search().select(["source"]).limit(None).to_list()} 

609 

610 def get_sources( 

611 self, 

612 *, 

613 search: str | None = None, 

614 limit: int | None = None, 

615 offset: int = 0, 

616 ) -> list[SourceRecord]: 

617 """Return source records, filtered by *search* and sliced by offset/limit.""" 

618 table = self.open_table(SOURCES_TABLE) 

619 if table is None: 

620 return [] 

621 query = table.search() 

622 where = _sources_search_filter(search) 

623 if where is not None: 

624 query = query.where(where) 

625 if offset: 

626 query = query.offset(offset) 

627 query = query.limit(limit) 

628 result: list[SourceRecord] = query.to_list() # type: ignore[assignment] 

629 return result 

630 

631 def count_sources(self, *, search: str | None = None) -> int: 

632 """Count tracked sources matching *search* without materializing rows.""" 

633 table = self.open_table(SOURCES_TABLE) 

634 if table is None: 

635 return 0 

636 where = _sources_search_filter(search) 

637 count: int = table.count_rows() if where is None else table.count_rows(filter=where) 

638 return count 

639 

640 def upsert_source( 

641 self, 

642 filename: str, 

643 file_hash: str, 

644 chunk_count: int, 

645 source_type: SourceType = SourceType.DOCUMENT, 

646 ) -> None: 

647 """Add or update a source tracking record.""" 

648 with write_lock(): 

649 db = self.get_db() 

650 table = ensure_table(db, SOURCES_TABLE, _sources_schema()) 

651 _safe_delete_unlocked(table, f"filename = '{escape_sql_string(filename)}'") 

652 table.add( 

653 [ 

654 { 

655 "filename": filename, 

656 "file_hash": file_hash, 

657 "ingested_at": datetime.now(UTC).isoformat(), 

658 "chunk_count": chunk_count, 

659 "source_type": str(source_type), 

660 } 

661 ] 

662 ) 

663 self._invalidate_source_cache() 

664 

665 def _delete_source_unlocked(self, filename: str) -> None: 

666 """Remove the *filename* source record. Caller must hold ``write_lock()``.""" 

667 table = self.open_table(SOURCES_TABLE) 

668 if table is not None: 

669 _safe_delete_unlocked(table, f"filename = '{escape_sql_string(filename)}'") 

670 

671 def delete_source(self, filename: str) -> None: 

672 """Remove a source file tracking record.""" 

673 with write_lock(): 

674 self._delete_source_unlocked(filename) 

675 self._invalidate_source_cache() 

676 

677 def _remove_one_unlocked(self, name: str) -> None: 

678 """Delete a document's chunks and its source record together. 

679 

680 Both deletes run under the caller's single ``write_lock()`` so no 

681 reader can observe chunks whose source record is already gone. 

682 """ 

683 self._delete_by_source_unlocked(name) 

684 self._delete_source_unlocked(name) 

685 

686 def remove_documents( 

687 self, 

688 names: list[str], 

689 *, 

690 delete_files: bool = False, 

691 documents_dir: Path | None = None, 

692 ) -> RemoveResult: 

693 """Remove documents from the knowledge base by source name. 

694 Looks up known sources, deletes chunks and source records for each. 

695 If *delete_files* is True, resolves the path and verifies it is 

696 contained within *documents_dir* before unlinking (path traversal guard). 

697 

698 Returns a RemoveResult with removed and not_found lists. 

699 """ 

700 if documents_dir is None: 

701 documents_dir = self._config.documents_dir 

702 

703 known = {s["filename"] for s in self.get_sources()} 

704 removed: list[str] = [] 

705 not_found: list[str] = [] 

706 

707 for name in names: 

708 if name not in known: 

709 not_found.append(name) 

710 continue 

711 with write_lock(): 

712 self._remove_one_unlocked(name) 

713 self._invalidate_source_cache() 

714 removed.append(name) 

715 if delete_files: 

716 try: 

717 path = validate_path_within(documents_dir / name, documents_dir) 

718 except ValueError: 

719 log.warning("Path traversal blocked: %s escapes %s", name, documents_dir) 

720 continue 

721 if path.exists(): 

722 path.unlink() 

723 

724 return RemoveResult(removed=removed, not_found=not_found) 

725 

726 def clear_table(self, name: str, predicate: str) -> None: 

727 """Delete rows matching *predicate* from *name*. Acquires write lock.""" 

728 with write_lock(): 

729 table = self.open_table(name) 

730 if table is not None: 

731 _safe_delete_unlocked(table, predicate) 

732 

733 def add_citations(self, records: list[CitationRecord]) -> int: 

734 """Add citation records to the store. Returns count added.""" 

735 if not records: 

736 return 0 

737 with write_lock(): 

738 db = self.get_db() 

739 table = ensure_table(db, CITATIONS_TABLE, _citations_schema()) 

740 table.add(records) 

741 return len(records) 

742 

743 def get_citations_for_wiki(self, wiki_source: str) -> list[CitationRecord]: 

744 """Get all citations for a wiki page.""" 

745 table = self.open_table(CITATIONS_TABLE) 

746 if table is None: 

747 return [] 

748 escaped = escape_sql_string(wiki_source) 

749 rows: list[CitationRecord] = table.search().where(f"wiki_source = '{escaped}'").to_list() 

750 return rows 

751 

752 def get_citations_for_source(self, source_filename: str) -> list[CitationRecord]: 

753 """Get all citations that reference a source document (reverse lookup).""" 

754 table = self.open_table(CITATIONS_TABLE) 

755 if table is None: 

756 return [] 

757 escaped = escape_sql_string(source_filename) 

758 rows: list[CitationRecord] = ( 

759 table.search().where(f"source_filename = '{escaped}'").to_list() 

760 ) 

761 return rows 

762 

763 def delete_citations_for_wiki(self, wiki_source: str) -> None: 

764 """Delete all citations for a wiki page (used before regeneration).""" 

765 self.clear_table( 

766 CITATIONS_TABLE, 

767 f"wiki_source = '{escape_sql_string(wiki_source)}'", 

768 ) 

769 

770 def _memories_schema(self) -> pa.Schema: 

771 return pa.schema( 

772 [ 

773 pa.field("id", pa.utf8()), 

774 pa.field("owner", pa.utf8()), 

775 pa.field("shared", pa.bool_()), 

776 pa.field("kind", pa.utf8()), 

777 pa.field("source", pa.utf8()), 

778 pa.field("text", pa.utf8()), 

779 pa.field("vector", pa.list_(pa.float32(), self._config.embedding_dim)), 

780 pa.field("created_at", pa.utf8()), 

781 pa.field("updated_at", pa.utf8()), 

782 ] 

783 ) 

784 

785 def _duplicate_memory_id_unlocked( 

786 self, table: lancedb.table.Table, record: MemoryRow 

787 ) -> str | None: 

788 """Return the id of a near-duplicate same-owner, same-kind memory, if any.""" 

789 if table.count_rows() == 0: 

790 return None 

791 predicate = ( 

792 f"owner = '{escape_sql_string(record.owner)}' " 

793 f"AND kind = '{escape_sql_string(record.kind)}'" 

794 ) 

795 rows = table.search(record.vector).metric("cosine").where(predicate).limit(1).to_list() 

796 if rows and rows[0].get("_distance", 1.0) <= self._config.memory_dedup_distance: 

797 return str(rows[0]["id"]) 

798 return None 

799 

800 def _evict_overflow_unlocked(self, table: lancedb.table.Table, owner: str) -> None: 

801 """Delete oldest memories for *owner* so an incoming insert stays within the cap.""" 

802 cap = self._config.memory_max_per_owner 

803 predicate = f"owner = '{escape_sql_string(owner)}'" 

804 rows = table.search().where(predicate).limit(None).to_list() 

805 if len(rows) < cap: 

806 return 

807 rows.sort(key=lambda r: r.get("created_at", "")) 

808 for row in rows[: len(rows) - (cap - 1)]: 

809 _safe_delete_unlocked(table, f"id = '{escape_sql_string(str(row['id']))}'") 

810 

811 def add_memory(self, record: MemoryRow) -> str: 

812 """Insert *record*, or update the nearest same-owner duplicate in place. 

813 

814 Returns the stored id. Raises ``EmbeddingModelMismatchError`` when the store 

815 was built under a different embedding model, and ``ValueError`` on a vector 

816 dimension mismatch. 

817 """ 

818 if len(record.vector) != self._config.embedding_dim: 

819 raise ValueError( 

820 f"Memory vector dimension mismatch: expected " 

821 f"{self._config.embedding_dim}, got {len(record.vector)}" 

822 ) 

823 with write_lock(): 

824 embedding_model = self._config.embedding_model 

825 embedding_dim = self._config.embedding_dim 

826 self._ensure_embedding_compat() 

827 db = self.get_db() 

828 table = ensure_table(db, MEMORIES_TABLE, self._memories_schema()) 

829 duplicate_id = self._duplicate_memory_id_unlocked(table, record) 

830 if duplicate_id is not None: 

831 _safe_delete_unlocked(table, f"id = '{escape_sql_string(duplicate_id)}'") 

832 record.id = duplicate_id 

833 self._evict_overflow_unlocked(table, record.owner) 

834 table.add([record.model_dump(mode="json")]) 

835 if self.get_meta() is None: 

836 self._write_meta_unlocked( 

837 embedding_model=embedding_model, embedding_dim=embedding_dim 

838 ) 

839 return record.id 

840 

841 def get_memories( 

842 self, 

843 *, 

844 owner_predicate: str, 

845 kind: MemoryKind | None = None, 

846 ) -> list[MemoryRow]: 

847 """Return memories matching *owner_predicate* and optional *kind*, newest first.""" 

848 table = self.open_table(MEMORIES_TABLE) 

849 if table is None: 

850 return [] 

851 clauses = [f"({owner_predicate})"] 

852 if kind is not None: 

853 clauses.append(f"kind = '{escape_sql_string(kind)}'") 

854 rows = table.search().where(" AND ".join(clauses)).limit(None).to_list() 

855 memories = [MemoryRow(**r) for r in rows] 

856 memories.sort(key=lambda m: m.created_at, reverse=True) 

857 return memories 

858 

859 def search_memories( 

860 self, 

861 query_vector: list[float], 

862 *, 

863 owner_predicate: str, 

864 top_k: int, 

865 max_distance: float, 

866 ) -> list[MemoryRow]: 

867 """Vector-recall FACT memories within *max_distance*, best first.""" 

868 table = self.open_table(MEMORIES_TABLE) 

869 if table is None or top_k <= 0: 

870 return [] 

871 self._ensure_embedding_compat() 

872 predicate = f"({owner_predicate}) AND kind = '{MemoryKind.FACT}'" 

873 rows = table.search(query_vector).metric("cosine").where(predicate).limit(top_k).to_list() 

874 return [MemoryRow(**r) for r in rows if r.get("_distance", 1.0) <= max_distance] 

875 

876 def update_memory(self, memory_id: str, *, shared: bool) -> bool: 

877 """Set the *shared* flag on a memory by id. Returns True when found.""" 

878 with write_lock(): 

879 table = self.open_table(MEMORIES_TABLE) 

880 if table is None: 

881 return False 

882 escaped = escape_sql_string(memory_id) 

883 rows = table.search().where(f"id = '{escaped}'").limit(1).to_list() 

884 if not rows: 

885 return False 

886 record = MemoryRow(**rows[0]) 

887 record.shared = shared 

888 record.updated_at = datetime.now(UTC).isoformat() 

889 _safe_delete_unlocked(table, f"id = '{escaped}'") 

890 table.add([record.model_dump(mode="json")]) 

891 return True 

892 

893 def delete_memory(self, memory_id: str) -> None: 

894 """Delete a memory by id.""" 

895 self.clear_table(MEMORIES_TABLE, f"id = '{escape_sql_string(memory_id)}'") 

896 

897 def rebuild_memory_embeddings(self, embed: Callable[[list[str]], list[list[float]]]) -> int: 

898 """Re-embed every memory under the current model, recreating the table. 

899 

900 The vector column dimension is immutable, so a different-dim model needs a 

901 fresh table; recreating unconditionally also covers the same-dim case. Memory 

902 text is human-authored and re-embeddable, so no data is lost. Returns the count. 

903 """ 

904 table = self.open_table(MEMORIES_TABLE) 

905 if table is None: 

906 return 0 

907 rows = table.search().limit(None).to_list() 

908 if not rows: 

909 return 0 

910 memories = [MemoryRow(**r) for r in rows] 

911 vectors = embed([m.text for m in memories]) 

912 for memory, vector in zip(memories, vectors, strict=True): 

913 memory.vector = vector 

914 with write_lock(): 

915 db = self.get_db() 

916 db.drop_table(MEMORIES_TABLE) 

917 new_table = ensure_table(db, MEMORIES_TABLE, self._memories_schema()) 

918 new_table.add([m.model_dump(mode="json") for m in memories]) 

919 return len(memories) 

920 

921 def close(self) -> None: 

922 """Release the database connection and reset state.""" 

923 self._db = None 

924 self._fts_ready = False 

925 

926 def drop_all(self) -> None: 

927 """Drop every table except ``_memories`` -- used by rebuild. 

928 

929 Memory is user-authored data with no on-disk source, not derived from 

930 documents, so a rebuild preserves it. Only a factory reset (which deletes 

931 the data directory) clears it. 

932 """ 

933 with write_lock(): 

934 self._fts_ready = False 

935 db = self.get_db() 

936 for name in _table_names(db): 

937 if name == MEMORIES_TABLE: 

938 continue 

939 db.drop_table(name) 

940 self._invalidate_source_cache()