Coverage for src / lilbee / data / store / core.py: 100%
331 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
1"""The ``Store`` class: high-level LanceDB read/write API used across lilbee."""
3from __future__ import annotations
5import logging
6from datetime import UTC, datetime
7from pathlib import Path
8from typing import TYPE_CHECKING
10import pyarrow as pa
11import pyarrow.compute as pc
13from lilbee.core.config import (
14 CHUNKS_TABLE,
15 CITATIONS_TABLE,
16 META_TABLE,
17 SOURCES_TABLE,
18 Config,
19)
20from lilbee.core.security import validate_path_within
21from lilbee.runtime.lock import write_lock
23from .lance_helpers import (
24 _chunk_type_predicate,
25 _embedding_mismatch_message,
26 _has_fts_index,
27 _safe_delete_unlocked,
28 _sources_search_filter,
29 _table_names,
30 ensure_table,
31 escape_sql_string,
32 refs_compatible,
33)
34from .ranking import mmr_rerank
35from .schema import _citations_schema, _meta_schema, _sources_schema
36from .types import (
37 META_DELETE_ALL_PREDICATE,
38 META_SCHEMA_VERSION,
39 READ_CONSISTENCY_INTERVAL,
40 CitationRecord,
41 EmbeddingModelMismatchError,
42 RemoveResult,
43 SearchChunk,
44 SourceRecord,
45 StoreMeta,
46)
48if TYPE_CHECKING:
49 import lancedb
50 import lancedb.table
52log = logging.getLogger(__name__)
55def _hybrid_search(
56 table: lancedb.table.Table,
57 query_text: str,
58 query_vector: list[float],
59 top_k: int,
60 chunk_type: str | None = None,
61) -> list[SearchChunk]:
62 """Run hybrid (vector + FTS) search with RRF reranking.
64 When ``chunk_type`` is set, the predicate is pushed into the query so
65 the limit applies *after* the type filter. Post-filtering would
66 silently starve wiki-only queries whose matches live past the top-K
67 hybrid window.
68 """
69 from lancedb.rerankers import RRFReranker
71 query = (
72 table.search(query_type="hybrid")
73 .vector(query_vector)
74 .text(query_text)
75 .rerank(RRFReranker())
76 )
77 if chunk_type:
78 query = query.where(_chunk_type_predicate(chunk_type))
79 rows = query.limit(top_k).to_list()
80 return [SearchChunk(**r) for r in rows]
83_MAX_THRESHOLD = 1.0
84_MAX_FILTER_ITERATIONS = 20 # safety cap to prevent runaway loops
87def _get_distance(chunk: SearchChunk) -> float:
88 """Extract distance as a sortable float (inf for None)."""
89 return chunk.distance if chunk.distance is not None else float("inf")
92def _count_within_threshold(sorted_results: list[SearchChunk], threshold: float) -> int:
93 """Count results whose distance is within the given threshold."""
94 for i, r in enumerate(sorted_results):
95 if _get_distance(r) > threshold:
96 return i
97 return len(sorted_results)
100class Store:
101 """LanceDB vector store: wraps all DB operations with config-driven defaults."""
103 def __init__(self, config: Config) -> None:
104 self._config = config
105 self._fts_ready: bool = False
106 self._db: lancedb.DBConnection | None = None
107 # Cache of {filename: ingested_at} rebuilt only when sources
108 # mutate; callers (temporal filter) hit it per-query.
109 self._source_ingested_cache: dict[str, str] | None = None
111 def _invalidate_source_cache(self) -> None:
112 """Drop the cached {filename: ingested_at} map."""
113 self._source_ingested_cache = None
115 def source_ingested_at_map(self) -> dict[str, str]:
116 """Return {filename: ingested_at} for every source, cached until mutation.
118 Best-effort: a reader racing a concurrent invalidation can store a
119 pre-mutation snapshot. The only consumer (temporal query filter)
120 treats a missing/stale entry as "do not filter," so staleness
121 degrades ranking precision, never correctness.
122 """
123 if self._source_ingested_cache is not None:
124 return self._source_ingested_cache
125 mapping = {s["filename"]: s.get("ingested_at", "") for s in self.get_sources()}
126 self._source_ingested_cache = mapping
127 return mapping
129 def _chunks_schema(self) -> pa.Schema:
130 return pa.schema(
131 [
132 pa.field("source", pa.utf8()),
133 pa.field("content_type", pa.utf8()),
134 pa.field("chunk_type", pa.utf8()),
135 pa.field("page_start", pa.int32()),
136 pa.field("page_end", pa.int32()),
137 pa.field("line_start", pa.int32()),
138 pa.field("line_end", pa.int32()),
139 pa.field("chunk", pa.utf8()),
140 pa.field("chunk_index", pa.int32()),
141 pa.field("vector", pa.list_(pa.float32(), self._config.embedding_dim)),
142 ]
143 )
145 def get_meta(self) -> StoreMeta | None:
146 """Return the persisted store metadata row, or ``None`` if unset."""
147 table = self.open_table(META_TABLE)
148 if table is None:
149 return None
150 rows = table.search().limit(1).to_list()
151 if not rows:
152 return None
153 row = rows[0]
154 return StoreMeta(
155 embedding_model=row["embedding_model"],
156 embedding_dim=int(row["embedding_dim"]),
157 schema_version=int(row["schema_version"]),
158 updated_at=row["updated_at"],
159 )
161 def _write_meta_unlocked(self, *, embedding_model: str, embedding_dim: int) -> None:
162 """Overwrite the single ``_meta`` row with the supplied identity.
164 Caller must hold ``write_lock()``. Args are passed explicitly rather than
165 re-read from ``self._config`` so the caller can snapshot cfg at a coherent
166 instant and not race with a concurrent ``set_embedding_model``.
167 """
168 db = self.get_db()
169 table = ensure_table(db, META_TABLE, _meta_schema())
170 _safe_delete_unlocked(table, META_DELETE_ALL_PREDICATE)
171 table.add(
172 [
173 {
174 "embedding_model": embedding_model,
175 "embedding_dim": embedding_dim,
176 "schema_version": META_SCHEMA_VERSION,
177 "updated_at": datetime.now(UTC).isoformat(),
178 }
179 ]
180 )
182 def _has_chunks(self) -> bool:
183 """Return True when the chunks table exists and has at least one row."""
184 chunks = self.open_table(CHUNKS_TABLE)
185 return chunks is not None and chunks.count_rows() > 0
187 def has_chunks(self) -> bool:
188 """Public predicate: True iff the store currently holds at least one chunk."""
189 return self._has_chunks()
191 def initialize_meta_if_legacy(self) -> bool:
192 """Pin a legacy store's identity to the current cfg if not already set.
194 Returns ``True`` when a meta row was just written. No-op when meta already
195 exists or no chunks are present. This is the path that converts a
196 pre-upgrade store (chunks present, no ``_meta``) into a gated store. It
197 snapshots cfg under the write lock to keep the recorded identity coherent
198 with what the gate is comparing against.
199 """
200 if self.get_meta() is not None:
201 return False
202 if not self._has_chunks():
203 return False
204 embedding_model = self._config.embedding_model
205 embedding_dim = self._config.embedding_dim
206 with write_lock():
207 # Re-check under the lock so two callers do not both warn-and-write.
208 if self.get_meta() is not None:
209 return False
210 log.warning(
211 "Legacy store has chunks but no _meta row. Initializing _meta from "
212 "the current configuration (embedding_model=%s, embedding_dim=%d). "
213 "If you changed embedding_model before upgrading, run `lilbee rebuild` "
214 "to ensure the store is consistent.",
215 embedding_model,
216 embedding_dim,
217 )
218 self._write_meta_unlocked(embedding_model=embedding_model, embedding_dim=embedding_dim)
219 return True
221 def _ensure_embedding_compat(self) -> None:
222 """Raise when the persisted embedding identity drifts from cfg.
224 Pure check, no side effects. Migration of legacy stores (chunks present,
225 no ``_meta``) is the caller's responsibility via ``initialize_meta_if_legacy``;
226 rewriting a legacy bare-repo ``_meta`` row to the canonical full ref is
227 the caller's responsibility via ``canonicalize_meta_if_legacy``. This
228 method stays safe to call from inside an existing ``write_lock()`` (no
229 recursive lock attempt). cfg fields are snapshotted at entry so the
230 comparison is coherent even if another thread mutates them mid-call.
231 """
232 current_model = self._config.embedding_model
233 current_dim = self._config.embedding_dim
234 meta = self.get_meta()
235 if meta is None:
236 return
237 if refs_compatible(
238 meta["embedding_model"], current_model, meta["embedding_dim"], current_dim
239 ):
240 return
241 raise EmbeddingModelMismatchError(
242 _embedding_mismatch_message(
243 persisted_model=meta["embedding_model"],
244 persisted_dim=meta["embedding_dim"],
245 current_model=current_model,
246 current_dim=current_dim,
247 )
248 )
250 def _needs_canonical_meta_rewrite(
251 self, meta: StoreMeta | None, current_model: str, current_dim: int
252 ) -> bool:
253 """True iff *meta* is the legacy form and refs-compatible with current cfg."""
254 if meta is None or meta["embedding_model"] == current_model:
255 return False
256 return refs_compatible(
257 meta["embedding_model"], current_model, meta["embedding_dim"], current_dim
258 )
260 def canonicalize_meta_if_legacy(self) -> bool:
261 """Rewrite a legacy bare-repo ``_meta`` row to the canonical full ref.
263 Pre-canonical lilbee persisted only ``<org>/<repo>`` in
264 ``_meta.embedding_model``. The current code persists the full
265 ``<org>/<repo>/<filename>.gguf``. When the two refer to the same
266 model under :func:`refs_compatible` but differ as raw strings, the
267 meta row is rewritten so the legacy name never surfaces. Returns
268 ``True`` on write; ``False`` when missing, already canonical, or
269 incompatible (the gate handles incompatibility).
270 """
271 current_model = self._config.embedding_model
272 current_dim = self._config.embedding_dim
273 if not self._needs_canonical_meta_rewrite(self.get_meta(), current_model, current_dim):
274 return False
275 with write_lock():
276 meta = self.get_meta() # re-read under the lock for racing callers
277 if not self._needs_canonical_meta_rewrite(meta, current_model, current_dim):
278 return False
279 assert meta is not None # filtered above # noqa: S101
280 log.info(
281 "Migrating legacy embedding ref in store meta: %r -> %r",
282 meta["embedding_model"],
283 current_model,
284 )
285 self._write_meta_unlocked(embedding_model=current_model, embedding_dim=current_dim)
286 return True
288 def get_db(self) -> lancedb.DBConnection:
289 if self._db is None:
290 import lancedb as _lancedb
292 self._config.lancedb_dir.mkdir(parents=True, exist_ok=True)
293 self._db = _lancedb.connect(
294 str(self._config.lancedb_dir),
295 read_consistency_interval=READ_CONSISTENCY_INTERVAL,
296 )
297 return self._db
299 def open_table(self, name: str) -> lancedb.table.Table | None:
300 """Open a table if it exists, otherwise return None."""
301 db = self.get_db()
302 if name not in _table_names(db):
303 return None
304 return db.open_table(name)
306 def ensure_fts_index(self) -> None:
307 """Create the chunks FTS index, or run ``optimize()`` once it exists.
309 ``optimize()`` folds newly added rows into the FTS index and also
310 runs LanceDB's default compaction + version pruning (default prune
311 window: 7 days). Work scales with recent deltas rather than total
312 chunk count, so large corpora no longer pay the full
313 ``create_fts_index(replace=True)`` rebuild cost on every sync.
314 """
315 with write_lock():
316 table = self.open_table(CHUNKS_TABLE)
317 if table is None:
318 return
319 try:
320 if _has_fts_index(table):
321 table.optimize()
322 log.debug("FTS index optimized on '%s'", CHUNKS_TABLE)
323 else:
324 table.create_fts_index("chunk", replace=False)
325 log.debug("FTS index created on '%s'", CHUNKS_TABLE)
326 self._fts_ready = True
327 except Exception:
328 log.debug("FTS index ensure failed (empty table?)", exc_info=True)
330 def add_chunks(self, records: list[dict]) -> int:
331 """Add chunk records to the store. Returns count added.
333 Raises ``EmbeddingModelMismatchError`` if the persisted ``_meta`` row was
334 written under a different embedding model than the current ``cfg``. On the
335 first write to a fresh store, ``_meta`` is initialized from the current cfg.
337 The gate runs inside the write lock and uses a single cfg snapshot so a
338 concurrent ``set_embedding_model`` cannot slip a write in past a stale
339 compatibility check.
340 """
341 with write_lock():
342 embedding_model = self._config.embedding_model
343 embedding_dim = self._config.embedding_dim
344 self._ensure_embedding_compat()
345 self._fts_ready = False
346 if not records:
347 return 0
348 for rec in records:
349 vec = rec.get("vector", [])
350 if len(vec) != embedding_dim:
351 raise ValueError(
352 f"Vector dimension mismatch: expected {embedding_dim}, "
353 f"got {len(vec)} (source={rec.get('source', '?')})"
354 )
355 db = self.get_db()
356 table = ensure_table(db, CHUNKS_TABLE, self._chunks_schema())
357 table.add(records)
358 if self.get_meta() is None:
359 self._write_meta_unlocked(
360 embedding_model=embedding_model, embedding_dim=embedding_dim
361 )
362 return len(records)
364 def bm25_probe(self, query_text: str, top_k: int = 5) -> list[SearchChunk]:
365 """Quick BM25-only search for confidence checking. Returns up to top_k results."""
366 table = self.open_table(CHUNKS_TABLE)
367 if table is None:
368 return []
369 if not self._fts_ready:
370 self.ensure_fts_index()
371 if not self._fts_ready:
372 return [] # pragma: no cover
373 try:
374 rows = table.search(query_text, query_type="fts").limit(top_k).to_list()
375 return [SearchChunk(**r) for r in rows]
376 except Exception:
377 log.debug("BM25 probe failed", exc_info=True)
378 return []
380 def search(
381 self,
382 query_vector: list[float],
383 top_k: int | None = None,
384 max_distance: float | None = None,
385 query_text: str | None = None,
386 chunk_type: str | None = None,
387 ) -> list[SearchChunk]:
388 """Search for similar chunks. Hybrid when FTS available, else vector-only.
390 Results with distance > max_distance are filtered out (vector-only path).
391 Pass max_distance=0 to disable filtering.
392 When *chunk_type* is set, only chunks of that type ("raw" or "wiki") are returned.
394 Raises ``EmbeddingModelMismatchError`` if the persisted ``_meta`` row was
395 written under a different embedding model than the current ``cfg``.
396 """
397 if top_k is None:
398 top_k = self._config.top_k
399 if max_distance is None:
400 max_distance = self._config.max_distance
401 table = self.open_table(CHUNKS_TABLE)
402 if table is None:
403 return []
404 self.initialize_meta_if_legacy()
405 self.canonicalize_meta_if_legacy()
406 self._ensure_embedding_compat()
408 if query_text and not self._fts_ready:
409 self.ensure_fts_index()
411 if query_text and self._fts_ready:
412 try:
413 return _hybrid_search(table, query_text, query_vector, top_k, chunk_type)
414 except Exception:
415 log.debug("Hybrid search failed, falling back to vector-only", exc_info=True)
417 candidate_k = top_k * self._config.candidate_multiplier
418 query = table.search(query_vector).metric("cosine").limit(candidate_k)
419 if chunk_type:
420 query = query.where(_chunk_type_predicate(chunk_type))
421 rows = query.to_list()
422 log.debug(
423 "Vector search: query=%r, candidates=%d, max_distance=%.2f",
424 query_text or "vector-only",
425 len(rows),
426 max_distance,
427 )
428 if rows:
429 distances = [r.get("distance", 0) for r in rows[:5]]
430 log.debug("Top 5 distances: %s", distances)
431 results = [SearchChunk(**r) for r in rows]
432 return self._filter_and_rerank(results, query_vector, top_k, max_distance)
434 def _filter_and_rerank(
435 self,
436 results: list[SearchChunk],
437 query_vector: list[float],
438 top_k: int,
439 max_distance: float,
440 ) -> list[SearchChunk]:
441 """Apply the configured distance filter, then MMR-rerank down to top_k."""
442 if max_distance > 0:
443 before = len(results)
444 if self._config.adaptive_threshold:
445 results = self._adaptive_filter(results, top_k, max_distance)
446 filter_name = "adaptive"
447 else:
448 results = self._fixed_filter(results, max_distance)
449 filter_name = "fixed"
450 log.debug(
451 "After %s filter: %d/%d results, threshold=%.2f",
452 filter_name,
453 len(results),
454 before,
455 max_distance,
456 )
457 if len(results) > top_k:
458 results = mmr_rerank(query_vector, results, top_k, self._config.mmr_lambda)
459 return results
461 def _adaptive_filter(
462 self, results: list[SearchChunk], top_k: int, initial_threshold: float
463 ) -> list[SearchChunk]:
464 """Widen cosine distance threshold when too few results.
465 Inspired by grantflow's (grantflow-ai/grantflow) adaptive retrieval
466 pattern which widens thresholds on recursive retry. Step size and
467 cap are configurable via ``self._config.adaptive_threshold_step``.
469 Pre-sorts results by distance for a single-pass cutoff search.
470 Step size is ``self._config.adaptive_threshold_step`` (default 0.2).
471 """
472 cap = max(initial_threshold, _MAX_THRESHOLD)
473 step = self._config.adaptive_threshold_step
475 sorted_results = sorted(results, key=_get_distance)
477 threshold = initial_threshold
478 for _ in range(_MAX_FILTER_ITERATIONS):
479 if threshold > cap:
480 break
481 cutoff = _count_within_threshold(sorted_results, threshold)
482 if cutoff >= top_k:
483 return sorted_results[:cutoff]
484 threshold += step
485 # Final pass at cap
486 cutoff = _count_within_threshold(sorted_results, cap)
487 return sorted_results[:cutoff]
489 def _fixed_filter(self, results: list[SearchChunk], threshold: float) -> list[SearchChunk]:
490 """Simple fixed threshold filter - keep only results within distance threshold."""
491 return [r for r in results if _get_distance(r) <= threshold]
493 def get_chunks_by_source(self, source: str) -> list[SearchChunk]:
494 """Return every chunk whose ``source`` equals *source*."""
495 table = self.open_table(CHUNKS_TABLE)
496 if table is None:
497 return []
498 escaped = escape_sql_string(source)
499 try:
500 rows = table.search().where(f"source = '{escaped}'").limit(None).to_list()
501 except Exception:
502 # FTS-enabled tables return a query builder that cannot
503 # handle .where() on arbitrary columns; fall through to a
504 # pyarrow.compute filter on the Arrow table so the source
505 # match runs in C++ without materializing non-matching rows.
506 log.debug("get_chunks_by_source search() failed, using Arrow fallback", exc_info=True)
507 arrow_tbl = table.to_arrow()
508 filtered = arrow_tbl.filter(pc.equal(arrow_tbl["source"], source))
509 rows = filtered.to_pylist()
510 return [SearchChunk(**r) for r in rows]
512 def delete_by_source(self, source: str) -> None:
513 """Delete all chunks from a given source file."""
514 with write_lock():
515 table = self.open_table(CHUNKS_TABLE)
516 if table is not None:
517 _safe_delete_unlocked(table, f"source = '{escape_sql_string(source)}'")
518 self._invalidate_source_cache()
520 def get_sources(
521 self,
522 *,
523 search: str | None = None,
524 limit: int | None = None,
525 offset: int = 0,
526 ) -> list[SourceRecord]:
527 """Return source records, filtered by *search* and sliced by offset/limit."""
528 table = self.open_table(SOURCES_TABLE)
529 if table is None:
530 return []
531 query = table.search()
532 where = _sources_search_filter(search)
533 if where is not None:
534 query = query.where(where)
535 if offset:
536 query = query.offset(offset)
537 query = query.limit(limit)
538 result: list[SourceRecord] = query.to_list() # type: ignore[assignment]
539 return result
541 def count_sources(self, *, search: str | None = None) -> int:
542 """Count tracked sources matching *search* without materializing rows."""
543 table = self.open_table(SOURCES_TABLE)
544 if table is None:
545 return 0
546 where = _sources_search_filter(search)
547 count: int = table.count_rows() if where is None else table.count_rows(filter=where)
548 return count
550 def upsert_source(
551 self,
552 filename: str,
553 file_hash: str,
554 chunk_count: int,
555 source_type: str = "document",
556 ) -> None:
557 """Add or update a source file tracking record."""
558 with write_lock():
559 db = self.get_db()
560 table = ensure_table(db, SOURCES_TABLE, _sources_schema())
561 _safe_delete_unlocked(table, f"filename = '{escape_sql_string(filename)}'")
562 table.add(
563 [
564 {
565 "filename": filename,
566 "file_hash": file_hash,
567 "ingested_at": datetime.now(UTC).isoformat(),
568 "chunk_count": chunk_count,
569 "source_type": source_type,
570 }
571 ]
572 )
573 self._invalidate_source_cache()
575 def delete_source(self, filename: str) -> None:
576 """Remove a source file tracking record."""
577 with write_lock():
578 table = self.open_table(SOURCES_TABLE)
579 if table is not None:
580 _safe_delete_unlocked(table, f"filename = '{escape_sql_string(filename)}'")
581 self._invalidate_source_cache()
583 def remove_documents(
584 self,
585 names: list[str],
586 *,
587 delete_files: bool = False,
588 documents_dir: Path | None = None,
589 ) -> RemoveResult:
590 """Remove documents from the knowledge base by source name.
591 Looks up known sources, deletes chunks and source records for each.
592 If *delete_files* is True, resolves the path and verifies it is
593 contained within *documents_dir* before unlinking (path traversal guard).
595 Returns a RemoveResult with removed and not_found lists.
596 """
597 if documents_dir is None:
598 documents_dir = self._config.documents_dir
600 known = {s["filename"] for s in self.get_sources()}
601 removed: list[str] = []
602 not_found: list[str] = []
604 for name in names:
605 if name not in known:
606 not_found.append(name)
607 continue
608 self.delete_by_source(name)
609 self.delete_source(name)
610 removed.append(name)
611 if delete_files:
612 try:
613 path = validate_path_within(documents_dir / name, documents_dir)
614 except ValueError:
615 log.warning("Path traversal blocked: %s escapes %s", name, documents_dir)
616 continue
617 if path.exists():
618 path.unlink()
620 return RemoveResult(removed=removed, not_found=not_found)
622 def clear_table(self, name: str, predicate: str) -> None:
623 """Delete rows matching *predicate* from *name*. Acquires write lock."""
624 with write_lock():
625 table = self.open_table(name)
626 if table is not None:
627 _safe_delete_unlocked(table, predicate)
629 def add_citations(self, records: list[CitationRecord]) -> int:
630 """Add citation records to the store. Returns count added."""
631 if not records:
632 return 0
633 with write_lock():
634 db = self.get_db()
635 table = ensure_table(db, CITATIONS_TABLE, _citations_schema())
636 table.add(records)
637 return len(records)
639 def get_citations_for_wiki(self, wiki_source: str) -> list[CitationRecord]:
640 """Get all citations for a wiki page."""
641 table = self.open_table(CITATIONS_TABLE)
642 if table is None:
643 return []
644 escaped = escape_sql_string(wiki_source)
645 rows: list[CitationRecord] = table.search().where(f"wiki_source = '{escaped}'").to_list()
646 return rows
648 def get_citations_for_source(self, source_filename: str) -> list[CitationRecord]:
649 """Get all citations that reference a source document (reverse lookup)."""
650 table = self.open_table(CITATIONS_TABLE)
651 if table is None:
652 return []
653 escaped = escape_sql_string(source_filename)
654 rows: list[CitationRecord] = (
655 table.search().where(f"source_filename = '{escaped}'").to_list()
656 )
657 return rows
659 def delete_citations_for_wiki(self, wiki_source: str) -> None:
660 """Delete all citations for a wiki page (used before regeneration)."""
661 self.clear_table(
662 CITATIONS_TABLE,
663 f"wiki_source = '{escape_sql_string(wiki_source)}'",
664 )
666 def close(self) -> None:
667 """Release the database connection and reset state."""
668 self._db = None
669 self._fts_ready = False
671 def drop_all(self) -> None:
672 """Drop all tables -- used by rebuild."""
673 with write_lock():
674 self._fts_ready = False
675 db = self.get_db()
676 for name in _table_names(db):
677 db.drop_table(name)
678 self._invalidate_source_cache()