Coverage for src / lilbee / data / store / core.py: 100%
486 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-28 01:01 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-28 01:01 +0000
1"""The ``Store`` class: high-level LanceDB read/write API used across lilbee."""
3from __future__ import annotations
5import logging
6from collections.abc import Callable
7from datetime import UTC, datetime
8from pathlib import Path
9from typing import TYPE_CHECKING
11import pyarrow as pa
12import pyarrow.compute as pc
14from lilbee.core.config import (
15 CHUNKS_TABLE,
16 CITATIONS_TABLE,
17 MEMORIES_TABLE,
18 META_TABLE,
19 PAGE_TEXTS_TABLE,
20 SOURCES_TABLE,
21 Config,
22)
23from lilbee.core.security import validate_path_within
24from lilbee.runtime.lock import write_lock
26from .lance_helpers import (
27 _chunk_type_predicate,
28 _has_fts_index,
29 _has_vector_index,
30 _safe_delete_unlocked,
31 _sources_search_filter,
32 _table_names,
33 ensure_table,
34 escape_sql_string,
35 refs_compatible,
36)
37from .ranking import mmr_rerank
38from .schema import _citations_schema, _meta_schema, _page_texts_schema, _sources_schema
39from .types import (
40 META_DELETE_ALL_PREDICATE,
41 META_SCHEMA_VERSION,
42 READ_CONSISTENCY_INTERVAL,
43 ChunkType,
44 CitationRecord,
45 EmbeddingModelMismatchError,
46 MemoryKind,
47 MemoryRow,
48 PageTextRecord,
49 RemoveResult,
50 SearchChunk,
51 SourceRecord,
52 SourceType,
53 StoreMeta,
54)
56if TYPE_CHECKING:
57 import lancedb
58 import lancedb.table
60log = logging.getLogger(__name__)
63def _hybrid_search(
64 table: lancedb.table.Table,
65 query_text: str,
66 query_vector: list[float],
67 top_k: int,
68 chunk_type: ChunkType | None = None,
69) -> list[SearchChunk]:
70 """Run hybrid (vector + FTS) search with RRF reranking.
72 When ``chunk_type`` is set, the predicate is pushed into the query so
73 the limit applies *after* the type filter. Post-filtering would
74 silently starve wiki-only queries whose matches live past the top-K
75 hybrid window.
76 """
77 from lancedb.rerankers import RRFReranker
79 query = (
80 table.search(query_type="hybrid")
81 .vector(query_vector)
82 .text(query_text)
83 .rerank(RRFReranker())
84 )
85 if chunk_type:
86 query = query.where(_chunk_type_predicate(chunk_type))
87 rows = query.limit(top_k).to_list()
88 return [SearchChunk(**r) for r in rows]
91_MAX_THRESHOLD = 1.0
92_MAX_FILTER_ITERATIONS = 20 # safety cap to prevent runaway loops
94# Vector ANN index. IVF_PQ compresses vectors so search scales to millions;
95# refine_factor re-ranks the PQ candidates against full vectors to recover recall.
96_VECTOR_METRIC = "cosine"
97_ANN_INDEX_TYPE = "IVF_PQ"
98_ANN_NPROBES = 20
99_ANN_REFINE_FACTOR = 10
102def _get_distance(chunk: SearchChunk) -> float:
103 """Extract distance as a sortable float (inf for None)."""
104 return chunk.distance if chunk.distance is not None else float("inf")
107def _count_within_threshold(sorted_results: list[SearchChunk], threshold: float) -> int:
108 """Count results whose distance is within the given threshold."""
109 for i, r in enumerate(sorted_results):
110 if _get_distance(r) > threshold:
111 return i
112 return len(sorted_results)
115class Store:
116 """LanceDB vector store: wraps all DB operations with config-driven defaults."""
118 def __init__(self, config: Config) -> None:
119 self._config = config
120 self._fts_ready: bool = False
121 self._db: lancedb.DBConnection | None = None
122 # Cache of {filename: ingested_at} rebuilt only when sources
123 # mutate; callers (temporal filter) hit it per-query.
124 self._source_ingested_cache: dict[str, str] | None = None
126 def _invalidate_source_cache(self) -> None:
127 """Drop the cached {filename: ingested_at} map."""
128 self._source_ingested_cache = None
130 def source_ingested_at_map(self) -> dict[str, str]:
131 """Return {filename: ingested_at} for every source, cached until mutation.
133 Best-effort: a reader racing a concurrent invalidation can store a
134 pre-mutation snapshot. The only consumer (temporal query filter)
135 treats a missing/stale entry as "do not filter," so staleness
136 degrades ranking precision, never correctness.
137 """
138 if self._source_ingested_cache is not None:
139 return self._source_ingested_cache
140 mapping = {s["filename"]: s.get("ingested_at", "") for s in self.get_sources()}
141 self._source_ingested_cache = mapping
142 return mapping
144 def _chunks_schema(self) -> pa.Schema:
145 return pa.schema(
146 [
147 pa.field("source", pa.utf8()),
148 pa.field("content_type", pa.utf8()),
149 pa.field("chunk_type", pa.utf8()),
150 pa.field("page_start", pa.int32()),
151 pa.field("page_end", pa.int32()),
152 pa.field("line_start", pa.int32()),
153 pa.field("line_end", pa.int32()),
154 pa.field("chunk", pa.utf8()),
155 pa.field("chunk_index", pa.int32()),
156 pa.field("vector", pa.list_(pa.float32(), self._config.embedding_dim)),
157 ]
158 )
160 def get_meta(self) -> StoreMeta | None:
161 """Return the persisted store metadata row, or ``None`` if unset."""
162 table = self.open_table(META_TABLE)
163 if table is None:
164 return None
165 rows = table.search().limit(1).to_list()
166 if not rows:
167 return None
168 row = rows[0]
169 return StoreMeta(
170 embedding_model=row["embedding_model"],
171 embedding_dim=int(row["embedding_dim"]),
172 schema_version=int(row["schema_version"]),
173 updated_at=row["updated_at"],
174 )
176 def _write_meta_unlocked(self, *, embedding_model: str, embedding_dim: int) -> None:
177 """Overwrite the single ``_meta`` row with the supplied identity.
179 Caller must hold ``write_lock()``. Args are passed explicitly rather than
180 re-read from ``self._config`` so the caller can snapshot cfg at a coherent
181 instant and not race with a concurrent ``set_embedding_model``.
182 """
183 db = self.get_db()
184 table = ensure_table(db, META_TABLE, _meta_schema())
185 _safe_delete_unlocked(table, META_DELETE_ALL_PREDICATE)
186 table.add(
187 [
188 {
189 "embedding_model": embedding_model,
190 "embedding_dim": embedding_dim,
191 "schema_version": META_SCHEMA_VERSION,
192 "updated_at": datetime.now(UTC).isoformat(),
193 }
194 ]
195 )
197 def _has_chunks(self) -> bool:
198 """Return True when the chunks table exists and has at least one row."""
199 chunks = self.open_table(CHUNKS_TABLE)
200 return chunks is not None and chunks.count_rows() > 0
202 def has_chunks(self) -> bool:
203 """Public predicate: True iff the store currently holds at least one chunk."""
204 return self._has_chunks()
206 def initialize_meta_if_legacy(self) -> bool:
207 """Pin a legacy store's identity to the current cfg if not already set.
209 Returns ``True`` when a meta row was just written. No-op when meta already
210 exists or no chunks are present. This is the path that converts a
211 pre-upgrade store (chunks present, no ``_meta``) into a gated store. It
212 snapshots cfg under the write lock to keep the recorded identity coherent
213 with what the gate is comparing against.
214 """
215 if self.get_meta() is not None:
216 return False
217 if not self._has_chunks():
218 return False
219 embedding_model = self._config.embedding_model
220 embedding_dim = self._config.embedding_dim
221 with write_lock():
222 # Re-check under the lock so two callers do not both warn-and-write.
223 if self.get_meta() is not None:
224 return False
225 log.warning(
226 "Legacy store has chunks but no _meta row. Initializing _meta from "
227 "the current configuration (embedding_model=%s, embedding_dim=%d). "
228 "If you changed embedding_model before upgrading, run `lilbee rebuild` "
229 "to ensure the store is consistent.",
230 embedding_model,
231 embedding_dim,
232 )
233 self._write_meta_unlocked(embedding_model=embedding_model, embedding_dim=embedding_dim)
234 return True
236 def _ensure_embedding_compat(self) -> None:
237 """Raise when the persisted embedding identity drifts from cfg.
239 Pure check, no side effects. Migration of legacy stores (chunks present,
240 no ``_meta``) is the caller's responsibility via ``initialize_meta_if_legacy``;
241 rewriting a legacy bare-repo ``_meta`` row to the canonical full ref is
242 the caller's responsibility via ``canonicalize_meta_if_legacy``. This
243 method stays safe to call from inside an existing ``write_lock()`` (no
244 recursive lock attempt). cfg fields are snapshotted at entry so the
245 comparison is coherent even if another thread mutates them mid-call.
246 """
247 current_model = self._config.embedding_model
248 current_dim = self._config.embedding_dim
249 meta = self.get_meta()
250 if meta is None:
251 return
252 if refs_compatible(
253 meta["embedding_model"], current_model, meta["embedding_dim"], current_dim
254 ):
255 return
256 raise EmbeddingModelMismatchError(
257 persisted_model=meta["embedding_model"],
258 persisted_dim=meta["embedding_dim"],
259 current_model=current_model,
260 current_dim=current_dim,
261 )
263 def assert_embedding_compatible(self) -> None:
264 """Run the full embedding-identity gate (legacy init, canonicalize, check).
266 Mirrors the gate ``search`` applies. Callers that write under a fresh
267 embedder (import) use this to fail before any destructive work when the
268 store was built by a different model.
269 """
270 self.initialize_meta_if_legacy()
271 self.canonicalize_meta_if_legacy()
272 self._ensure_embedding_compat()
274 def _needs_canonical_meta_rewrite(
275 self, meta: StoreMeta | None, current_model: str, current_dim: int
276 ) -> bool:
277 """True iff *meta* is the legacy form and refs-compatible with current cfg."""
278 if meta is None or meta["embedding_model"] == current_model:
279 return False
280 return refs_compatible(
281 meta["embedding_model"], current_model, meta["embedding_dim"], current_dim
282 )
284 def canonicalize_meta_if_legacy(self) -> bool:
285 """Rewrite a legacy bare-repo ``_meta`` row to the canonical full ref.
287 Pre-canonical lilbee persisted only ``<org>/<repo>`` in
288 ``_meta.embedding_model``. The current code persists the full
289 ``<org>/<repo>/<filename>.gguf``. When the two refer to the same
290 model under :func:`refs_compatible` but differ as raw strings, the
291 meta row is rewritten so the legacy name never surfaces. Returns
292 ``True`` on write; ``False`` when missing, already canonical, or
293 incompatible (the gate handles incompatibility).
294 """
295 current_model = self._config.embedding_model
296 current_dim = self._config.embedding_dim
297 if not self._needs_canonical_meta_rewrite(self.get_meta(), current_model, current_dim):
298 return False
299 with write_lock():
300 meta = self.get_meta() # re-read under the lock for racing callers
301 if not self._needs_canonical_meta_rewrite(meta, current_model, current_dim):
302 return False
303 assert meta is not None # filtered above # noqa: S101
304 log.info(
305 "Migrating legacy embedding ref in store meta: %r -> %r",
306 meta["embedding_model"],
307 current_model,
308 )
309 self._write_meta_unlocked(embedding_model=current_model, embedding_dim=current_dim)
310 return True
312 def get_db(self) -> lancedb.DBConnection:
313 if self._db is None:
314 import lancedb as _lancedb
316 self._config.lancedb_dir.mkdir(parents=True, exist_ok=True)
317 self._db = _lancedb.connect(
318 str(self._config.lancedb_dir),
319 read_consistency_interval=READ_CONSISTENCY_INTERVAL,
320 )
321 return self._db
323 def open_table(self, name: str) -> lancedb.table.Table | None:
324 """Open a table if it exists, otherwise return None."""
325 db = self.get_db()
326 if name not in _table_names(db):
327 return None
328 return db.open_table(name)
330 def ensure_fts_index(self) -> None:
331 """Create the chunks FTS index, or run ``optimize()`` once it exists.
333 ``optimize()`` folds newly added rows into the FTS index and also
334 runs LanceDB's default compaction + version pruning (default prune
335 window: 7 days). Work scales with recent deltas rather than total
336 chunk count, so large corpora no longer pay the full
337 ``create_fts_index(replace=True)`` rebuild cost on every sync.
338 """
339 with write_lock():
340 table = self.open_table(CHUNKS_TABLE)
341 if table is None:
342 return
343 try:
344 if _has_fts_index(table):
345 table.optimize()
346 log.debug("FTS index optimized on '%s'", CHUNKS_TABLE)
347 else:
348 table.create_fts_index("chunk", replace=False)
349 log.debug("FTS index created on '%s'", CHUNKS_TABLE)
350 self._fts_ready = True
351 except Exception:
352 log.debug("FTS index ensure failed (empty table?)", exc_info=True)
354 def ensure_vector_index(self, *, force: bool = False) -> bool:
355 """Build or refresh the ANN vector index when the corpus is large enough.
357 Below ``cfg.ann_index_threshold`` (or when it is 0) the store keeps exact
358 flat search, which is faster and exact for small vaults and is all a
359 laptop needs. Once an index exists, ``optimize()`` folds new rows in.
360 Pass ``force=True`` to build regardless of the threshold (publish flow).
361 Returns True when an index was created or refreshed.
362 """
363 threshold = self._config.ann_index_threshold
364 with write_lock():
365 table = self.open_table(CHUNKS_TABLE)
366 if table is None:
367 return False
368 if _has_vector_index(table):
369 table.optimize()
370 log.debug("Vector index optimized on '%s'", CHUNKS_TABLE)
371 return True
372 if not force and (threshold <= 0 or table.count_rows() < threshold):
373 return False
374 try:
375 table.create_index(metric=_VECTOR_METRIC, index_type=_ANN_INDEX_TYPE)
376 log.info("Vector ANN index created on '%s'", CHUNKS_TABLE)
377 return True
378 except Exception:
379 log.debug("Vector index build failed (too few rows?)", exc_info=True)
380 return False
382 def add_chunks(self, records: list[dict]) -> int:
383 """Add chunk records to the store. Returns count added.
385 Raises ``EmbeddingModelMismatchError`` if the persisted ``_meta`` row was
386 written under a different embedding model than the current ``cfg``. On the
387 first write to a fresh store, ``_meta`` is initialized from the current cfg.
389 The gate runs inside the write lock and uses a single cfg snapshot so a
390 concurrent ``set_embedding_model`` cannot slip a write in past a stale
391 compatibility check.
392 """
393 with write_lock():
394 embedding_model = self._config.embedding_model
395 embedding_dim = self._config.embedding_dim
396 self._ensure_embedding_compat()
397 self._fts_ready = False
398 if not records:
399 return 0
400 for rec in records:
401 vec = rec.get("vector", [])
402 if len(vec) != embedding_dim:
403 raise ValueError(
404 f"Vector dimension mismatch: expected {embedding_dim}, "
405 f"got {len(vec)} (source={rec.get('source', '?')})"
406 )
407 db = self.get_db()
408 table = ensure_table(db, CHUNKS_TABLE, self._chunks_schema())
409 table.add(records)
410 if self.get_meta() is None:
411 self._write_meta_unlocked(
412 embedding_model=embedding_model, embedding_dim=embedding_dim
413 )
414 return len(records)
416 def bm25_probe(self, query_text: str, top_k: int = 5) -> list[SearchChunk]:
417 """Quick BM25-only search for confidence checking. Returns up to top_k results."""
418 table = self.open_table(CHUNKS_TABLE)
419 if table is None:
420 return []
421 if not self._fts_ready:
422 self.ensure_fts_index()
423 if not self._fts_ready:
424 return []
425 try:
426 rows = table.search(query_text, query_type="fts").limit(top_k).to_list()
427 return [SearchChunk(**r) for r in rows]
428 except Exception:
429 log.debug("BM25 probe failed", exc_info=True)
430 return []
432 def search(
433 self,
434 query_vector: list[float],
435 top_k: int | None = None,
436 max_distance: float | None = None,
437 query_text: str | None = None,
438 chunk_type: ChunkType | None = None,
439 ) -> list[SearchChunk]:
440 """Search for similar chunks. Hybrid when FTS available, else vector-only.
442 Results with distance > max_distance are filtered out (vector-only path).
443 Pass max_distance=0 to disable filtering.
444 When *chunk_type* is set, only chunks of that type ("raw" or "wiki") are returned.
446 Raises ``EmbeddingModelMismatchError`` if the persisted ``_meta`` row was
447 written under a different embedding model than the current ``cfg``.
448 """
449 if top_k is None:
450 top_k = self._config.top_k
451 if max_distance is None:
452 max_distance = self._config.max_distance
453 table = self.open_table(CHUNKS_TABLE)
454 if table is None:
455 return []
456 self.initialize_meta_if_legacy()
457 self.canonicalize_meta_if_legacy()
458 self._ensure_embedding_compat()
460 if query_text and not self._fts_ready:
461 self.ensure_fts_index()
463 if query_text and self._fts_ready:
464 try:
465 return _hybrid_search(table, query_text, query_vector, top_k, chunk_type)
466 except Exception:
467 log.debug("Hybrid search failed, falling back to vector-only", exc_info=True)
469 candidate_k = top_k * self._config.candidate_multiplier
470 query = table.search(query_vector).metric(_VECTOR_METRIC).limit(candidate_k)
471 if _has_vector_index(table):
472 # IVF_PQ is lossy; probe more partitions and refine against full
473 # vectors so recall stays close to the exact flat scan.
474 query = query.nprobes(_ANN_NPROBES).refine_factor(_ANN_REFINE_FACTOR)
475 if chunk_type:
476 query = query.where(_chunk_type_predicate(chunk_type))
477 rows = query.to_list()
478 log.debug(
479 "Vector search: query=%r, candidates=%d, max_distance=%.2f",
480 query_text or "vector-only",
481 len(rows),
482 max_distance,
483 )
484 if rows:
485 distances = [r.get("distance", 0) for r in rows[:5]]
486 log.debug("Top 5 distances: %s", distances)
487 results = [SearchChunk(**r) for r in rows]
488 return self._filter_and_rerank(results, query_vector, top_k, max_distance)
490 def _filter_and_rerank(
491 self,
492 results: list[SearchChunk],
493 query_vector: list[float],
494 top_k: int,
495 max_distance: float,
496 ) -> list[SearchChunk]:
497 """Apply the configured distance filter, then MMR-rerank down to top_k."""
498 if max_distance > 0:
499 before = len(results)
500 if self._config.adaptive_threshold:
501 results = self._adaptive_filter(results, top_k, max_distance)
502 filter_name = "adaptive"
503 else:
504 results = self._fixed_filter(results, max_distance)
505 filter_name = "fixed"
506 log.debug(
507 "After %s filter: %d/%d results, threshold=%.2f",
508 filter_name,
509 len(results),
510 before,
511 max_distance,
512 )
513 if len(results) > top_k:
514 results = mmr_rerank(query_vector, results, top_k, self._config.mmr_lambda)
515 return results
517 def _adaptive_filter(
518 self, results: list[SearchChunk], top_k: int, initial_threshold: float
519 ) -> list[SearchChunk]:
520 """Widen cosine distance threshold when too few results.
521 Inspired by grantflow's (grantflow-ai/grantflow) adaptive retrieval
522 pattern which widens thresholds on recursive retry. Step size and
523 cap are configurable via ``self._config.adaptive_threshold_step``.
525 Pre-sorts results by distance for a single-pass cutoff search.
526 Step size is ``self._config.adaptive_threshold_step`` (default 0.2).
527 """
528 cap = max(initial_threshold, _MAX_THRESHOLD)
529 step = self._config.adaptive_threshold_step
531 sorted_results = sorted(results, key=_get_distance)
533 threshold = initial_threshold
534 for _ in range(_MAX_FILTER_ITERATIONS):
535 if threshold > cap:
536 break
537 cutoff = _count_within_threshold(sorted_results, threshold)
538 if cutoff >= top_k:
539 return sorted_results[:cutoff]
540 threshold += step
541 # Final pass at cap
542 cutoff = _count_within_threshold(sorted_results, cap)
543 return sorted_results[:cutoff]
545 def _fixed_filter(self, results: list[SearchChunk], threshold: float) -> list[SearchChunk]:
546 """Simple fixed threshold filter - keep only results within distance threshold."""
547 return [r for r in results if _get_distance(r) <= threshold]
549 def get_chunks_by_source(self, source: str) -> list[SearchChunk]:
550 """Return every chunk whose ``source`` equals *source*."""
551 table = self.open_table(CHUNKS_TABLE)
552 if table is None:
553 return []
554 escaped = escape_sql_string(source)
555 try:
556 rows = table.search().where(f"source = '{escaped}'").limit(None).to_list()
557 except Exception:
558 # FTS-enabled tables return a query builder that cannot
559 # handle .where() on arbitrary columns; fall through to a
560 # pyarrow.compute filter on the Arrow table so the source
561 # match runs in C++ without materializing non-matching rows.
562 log.debug("get_chunks_by_source search() failed, using Arrow fallback", exc_info=True)
563 arrow_tbl = table.to_arrow()
564 filtered = arrow_tbl.filter(pc.equal(arrow_tbl["source"], source))
565 rows = filtered.to_pylist()
566 return [SearchChunk(**r) for r in rows]
568 def _delete_by_source_unlocked(self, source: str) -> None:
569 """Delete a source's chunks and page texts. Caller must hold ``write_lock()``."""
570 predicate = f"source = '{escape_sql_string(source)}'"
571 for name in (CHUNKS_TABLE, PAGE_TEXTS_TABLE):
572 table = self.open_table(name)
573 if table is not None:
574 _safe_delete_unlocked(table, predicate)
576 def delete_by_source(self, source: str) -> None:
577 """Delete a source's chunks and page texts."""
578 with write_lock():
579 self._delete_by_source_unlocked(source)
580 self._invalidate_source_cache()
582 def add_page_texts(self, records: list[dict]) -> int:
583 """Add per-page text rows (no vectors). Returns count added."""
584 if not records:
585 return 0
586 with write_lock():
587 db = self.get_db()
588 table = ensure_table(db, PAGE_TEXTS_TABLE, _page_texts_schema())
589 table.add(records)
590 return len(records)
592 def get_page_texts(self, source: str | None = None) -> list[PageTextRecord]:
593 """Return per-page text rows, all or for a single *source*."""
594 table = self.open_table(PAGE_TEXTS_TABLE)
595 if table is None:
596 return []
597 query = table.search()
598 if source is not None:
599 query = query.where(f"source = '{escape_sql_string(source)}'")
600 rows: list[PageTextRecord] = query.limit(None).to_list()
601 return rows
603 def page_text_sources(self) -> set[str]:
604 """Return the distinct sources present in the page-text table."""
605 table = self.open_table(PAGE_TEXTS_TABLE)
606 if table is None:
607 return set()
608 return {row["source"] for row in table.search().select(["source"]).limit(None).to_list()}
610 def get_sources(
611 self,
612 *,
613 search: str | None = None,
614 limit: int | None = None,
615 offset: int = 0,
616 ) -> list[SourceRecord]:
617 """Return source records, filtered by *search* and sliced by offset/limit."""
618 table = self.open_table(SOURCES_TABLE)
619 if table is None:
620 return []
621 query = table.search()
622 where = _sources_search_filter(search)
623 if where is not None:
624 query = query.where(where)
625 if offset:
626 query = query.offset(offset)
627 query = query.limit(limit)
628 result: list[SourceRecord] = query.to_list() # type: ignore[assignment]
629 return result
631 def count_sources(self, *, search: str | None = None) -> int:
632 """Count tracked sources matching *search* without materializing rows."""
633 table = self.open_table(SOURCES_TABLE)
634 if table is None:
635 return 0
636 where = _sources_search_filter(search)
637 count: int = table.count_rows() if where is None else table.count_rows(filter=where)
638 return count
640 def upsert_source(
641 self,
642 filename: str,
643 file_hash: str,
644 chunk_count: int,
645 source_type: SourceType = SourceType.DOCUMENT,
646 ) -> None:
647 """Add or update a source tracking record."""
648 with write_lock():
649 db = self.get_db()
650 table = ensure_table(db, SOURCES_TABLE, _sources_schema())
651 _safe_delete_unlocked(table, f"filename = '{escape_sql_string(filename)}'")
652 table.add(
653 [
654 {
655 "filename": filename,
656 "file_hash": file_hash,
657 "ingested_at": datetime.now(UTC).isoformat(),
658 "chunk_count": chunk_count,
659 "source_type": str(source_type),
660 }
661 ]
662 )
663 self._invalidate_source_cache()
665 def _delete_source_unlocked(self, filename: str) -> None:
666 """Remove the *filename* source record. Caller must hold ``write_lock()``."""
667 table = self.open_table(SOURCES_TABLE)
668 if table is not None:
669 _safe_delete_unlocked(table, f"filename = '{escape_sql_string(filename)}'")
671 def delete_source(self, filename: str) -> None:
672 """Remove a source file tracking record."""
673 with write_lock():
674 self._delete_source_unlocked(filename)
675 self._invalidate_source_cache()
677 def _remove_one_unlocked(self, name: str) -> None:
678 """Delete a document's chunks and its source record together.
680 Both deletes run under the caller's single ``write_lock()`` so no
681 reader can observe chunks whose source record is already gone.
682 """
683 self._delete_by_source_unlocked(name)
684 self._delete_source_unlocked(name)
686 def remove_documents(
687 self,
688 names: list[str],
689 *,
690 delete_files: bool = False,
691 documents_dir: Path | None = None,
692 ) -> RemoveResult:
693 """Remove documents from the knowledge base by source name.
694 Looks up known sources, deletes chunks and source records for each.
695 If *delete_files* is True, resolves the path and verifies it is
696 contained within *documents_dir* before unlinking (path traversal guard).
698 Returns a RemoveResult with removed and not_found lists.
699 """
700 if documents_dir is None:
701 documents_dir = self._config.documents_dir
703 known = {s["filename"] for s in self.get_sources()}
704 removed: list[str] = []
705 not_found: list[str] = []
707 for name in names:
708 if name not in known:
709 not_found.append(name)
710 continue
711 with write_lock():
712 self._remove_one_unlocked(name)
713 self._invalidate_source_cache()
714 removed.append(name)
715 if delete_files:
716 try:
717 path = validate_path_within(documents_dir / name, documents_dir)
718 except ValueError:
719 log.warning("Path traversal blocked: %s escapes %s", name, documents_dir)
720 continue
721 if path.exists():
722 path.unlink()
724 return RemoveResult(removed=removed, not_found=not_found)
726 def clear_table(self, name: str, predicate: str) -> None:
727 """Delete rows matching *predicate* from *name*. Acquires write lock."""
728 with write_lock():
729 table = self.open_table(name)
730 if table is not None:
731 _safe_delete_unlocked(table, predicate)
733 def add_citations(self, records: list[CitationRecord]) -> int:
734 """Add citation records to the store. Returns count added."""
735 if not records:
736 return 0
737 with write_lock():
738 db = self.get_db()
739 table = ensure_table(db, CITATIONS_TABLE, _citations_schema())
740 table.add(records)
741 return len(records)
743 def get_citations_for_wiki(self, wiki_source: str) -> list[CitationRecord]:
744 """Get all citations for a wiki page."""
745 table = self.open_table(CITATIONS_TABLE)
746 if table is None:
747 return []
748 escaped = escape_sql_string(wiki_source)
749 rows: list[CitationRecord] = table.search().where(f"wiki_source = '{escaped}'").to_list()
750 return rows
752 def get_citations_for_source(self, source_filename: str) -> list[CitationRecord]:
753 """Get all citations that reference a source document (reverse lookup)."""
754 table = self.open_table(CITATIONS_TABLE)
755 if table is None:
756 return []
757 escaped = escape_sql_string(source_filename)
758 rows: list[CitationRecord] = (
759 table.search().where(f"source_filename = '{escaped}'").to_list()
760 )
761 return rows
763 def delete_citations_for_wiki(self, wiki_source: str) -> None:
764 """Delete all citations for a wiki page (used before regeneration)."""
765 self.clear_table(
766 CITATIONS_TABLE,
767 f"wiki_source = '{escape_sql_string(wiki_source)}'",
768 )
770 def _memories_schema(self) -> pa.Schema:
771 return pa.schema(
772 [
773 pa.field("id", pa.utf8()),
774 pa.field("owner", pa.utf8()),
775 pa.field("shared", pa.bool_()),
776 pa.field("kind", pa.utf8()),
777 pa.field("source", pa.utf8()),
778 pa.field("text", pa.utf8()),
779 pa.field("vector", pa.list_(pa.float32(), self._config.embedding_dim)),
780 pa.field("created_at", pa.utf8()),
781 pa.field("updated_at", pa.utf8()),
782 ]
783 )
785 def _duplicate_memory_id_unlocked(
786 self, table: lancedb.table.Table, record: MemoryRow
787 ) -> str | None:
788 """Return the id of a near-duplicate same-owner, same-kind memory, if any."""
789 if table.count_rows() == 0:
790 return None
791 predicate = (
792 f"owner = '{escape_sql_string(record.owner)}' "
793 f"AND kind = '{escape_sql_string(record.kind)}'"
794 )
795 rows = table.search(record.vector).metric("cosine").where(predicate).limit(1).to_list()
796 if rows and rows[0].get("_distance", 1.0) <= self._config.memory_dedup_distance:
797 return str(rows[0]["id"])
798 return None
800 def _evict_overflow_unlocked(self, table: lancedb.table.Table, owner: str) -> None:
801 """Delete oldest memories for *owner* so an incoming insert stays within the cap."""
802 cap = self._config.memory_max_per_owner
803 predicate = f"owner = '{escape_sql_string(owner)}'"
804 rows = table.search().where(predicate).limit(None).to_list()
805 if len(rows) < cap:
806 return
807 rows.sort(key=lambda r: r.get("created_at", ""))
808 for row in rows[: len(rows) - (cap - 1)]:
809 _safe_delete_unlocked(table, f"id = '{escape_sql_string(str(row['id']))}'")
811 def add_memory(self, record: MemoryRow) -> str:
812 """Insert *record*, or update the nearest same-owner duplicate in place.
814 Returns the stored id. Raises ``EmbeddingModelMismatchError`` when the store
815 was built under a different embedding model, and ``ValueError`` on a vector
816 dimension mismatch.
817 """
818 if len(record.vector) != self._config.embedding_dim:
819 raise ValueError(
820 f"Memory vector dimension mismatch: expected "
821 f"{self._config.embedding_dim}, got {len(record.vector)}"
822 )
823 with write_lock():
824 embedding_model = self._config.embedding_model
825 embedding_dim = self._config.embedding_dim
826 self._ensure_embedding_compat()
827 db = self.get_db()
828 table = ensure_table(db, MEMORIES_TABLE, self._memories_schema())
829 duplicate_id = self._duplicate_memory_id_unlocked(table, record)
830 if duplicate_id is not None:
831 _safe_delete_unlocked(table, f"id = '{escape_sql_string(duplicate_id)}'")
832 record.id = duplicate_id
833 self._evict_overflow_unlocked(table, record.owner)
834 table.add([record.model_dump(mode="json")])
835 if self.get_meta() is None:
836 self._write_meta_unlocked(
837 embedding_model=embedding_model, embedding_dim=embedding_dim
838 )
839 return record.id
841 def get_memories(
842 self,
843 *,
844 owner_predicate: str,
845 kind: MemoryKind | None = None,
846 ) -> list[MemoryRow]:
847 """Return memories matching *owner_predicate* and optional *kind*, newest first."""
848 table = self.open_table(MEMORIES_TABLE)
849 if table is None:
850 return []
851 clauses = [f"({owner_predicate})"]
852 if kind is not None:
853 clauses.append(f"kind = '{escape_sql_string(kind)}'")
854 rows = table.search().where(" AND ".join(clauses)).limit(None).to_list()
855 memories = [MemoryRow(**r) for r in rows]
856 memories.sort(key=lambda m: m.created_at, reverse=True)
857 return memories
859 def search_memories(
860 self,
861 query_vector: list[float],
862 *,
863 owner_predicate: str,
864 top_k: int,
865 max_distance: float,
866 ) -> list[MemoryRow]:
867 """Vector-recall FACT memories within *max_distance*, best first."""
868 table = self.open_table(MEMORIES_TABLE)
869 if table is None or top_k <= 0:
870 return []
871 self._ensure_embedding_compat()
872 predicate = f"({owner_predicate}) AND kind = '{MemoryKind.FACT}'"
873 rows = table.search(query_vector).metric("cosine").where(predicate).limit(top_k).to_list()
874 return [MemoryRow(**r) for r in rows if r.get("_distance", 1.0) <= max_distance]
876 def update_memory(self, memory_id: str, *, shared: bool) -> bool:
877 """Set the *shared* flag on a memory by id. Returns True when found."""
878 with write_lock():
879 table = self.open_table(MEMORIES_TABLE)
880 if table is None:
881 return False
882 escaped = escape_sql_string(memory_id)
883 rows = table.search().where(f"id = '{escaped}'").limit(1).to_list()
884 if not rows:
885 return False
886 record = MemoryRow(**rows[0])
887 record.shared = shared
888 record.updated_at = datetime.now(UTC).isoformat()
889 _safe_delete_unlocked(table, f"id = '{escaped}'")
890 table.add([record.model_dump(mode="json")])
891 return True
893 def delete_memory(self, memory_id: str) -> None:
894 """Delete a memory by id."""
895 self.clear_table(MEMORIES_TABLE, f"id = '{escape_sql_string(memory_id)}'")
897 def rebuild_memory_embeddings(self, embed: Callable[[list[str]], list[list[float]]]) -> int:
898 """Re-embed every memory under the current model, recreating the table.
900 The vector column dimension is immutable, so a different-dim model needs a
901 fresh table; recreating unconditionally also covers the same-dim case. Memory
902 text is human-authored and re-embeddable, so no data is lost. Returns the count.
903 """
904 table = self.open_table(MEMORIES_TABLE)
905 if table is None:
906 return 0
907 rows = table.search().limit(None).to_list()
908 if not rows:
909 return 0
910 memories = [MemoryRow(**r) for r in rows]
911 vectors = embed([m.text for m in memories])
912 for memory, vector in zip(memories, vectors, strict=True):
913 memory.vector = vector
914 with write_lock():
915 db = self.get_db()
916 db.drop_table(MEMORIES_TABLE)
917 new_table = ensure_table(db, MEMORIES_TABLE, self._memories_schema())
918 new_table.add([m.model_dump(mode="json") for m in memories])
919 return len(memories)
921 def close(self) -> None:
922 """Release the database connection and reset state."""
923 self._db = None
924 self._fts_ready = False
926 def drop_all(self) -> None:
927 """Drop every table except ``_memories`` -- used by rebuild.
929 Memory is user-authored data with no on-disk source, not derived from
930 documents, so a rebuild preserves it. Only a factory reset (which deletes
931 the data directory) clears it.
932 """
933 with write_lock():
934 self._fts_ready = False
935 db = self.get_db()
936 for name in _table_names(db):
937 if name == MEMORIES_TABLE:
938 continue
939 db.drop_table(name)
940 self._invalidate_source_cache()