Coverage for src / lilbee / retrieval / query / searcher.py: 100%
316 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-28 01:01 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-28 01:01 +0000
1"""RAG search pipeline -- embed, search, expand, rerank, generate."""
3from __future__ import annotations
5import logging
6from collections.abc import Generator
7from datetime import datetime
8from typing import TYPE_CHECKING, Any, cast
10from pydantic import BaseModel
11from typing_extensions import TypedDict
13from lilbee.core.config import Config
14from lilbee.core.config.enums import ChatMode
15from lilbee.data.store import (
16 ChunkType,
17 MemoryKind,
18 MemoryRow,
19 SearchChunk,
20 Store,
21 cosine_sim,
22 local_owner_predicate,
23)
24from lilbee.providers.base import LLMProvider
25from lilbee.retrieval.embedder import Embedder
26from lilbee.retrieval.query.dedup import (
27 _greedy_cover,
28 _relevance_weight,
29 deduplicate_sources,
30 filter_results,
31 prepare_results,
32)
33from lilbee.retrieval.query.expansion import EXPANSION_MAX_TOKENS, EXPANSION_PROMPT
34from lilbee.retrieval.query.formatting import (
35 CONTEXT_TEMPLATE,
36 _extract_cited_indices,
37 build_context,
38 strip_llm_citations,
39)
40from lilbee.retrieval.query.memory import format_memory_block
41from lilbee.retrieval.query.tokenize import _idf_weights, _tokenize
42from lilbee.retrieval.reasoning import (
43 StreamToken,
44 cap_events_as_stream_tokens,
45 effective_reasoning_cap,
46 stream_chat_with_cap,
47 strip_reasoning,
48)
50if TYPE_CHECKING:
51 from lilbee.retrieval.concepts import ConceptGraph
52 from lilbee.retrieval.reranker import Reranker
54log = logging.getLogger(__name__)
56# BM25 probe needs at least this many hits to compare top vs. runner-up
57# scores for the expansion-skip heuristic.
58_MIN_BM25_PROBE_RESULTS = 2
61class ChatMessage(TypedDict):
62 """A single chat message with role and content."""
64 role: str
65 content: str
68class AskResult(BaseModel):
69 """Structured result from ask_raw -- answer text + raw search results."""
71 answer: str
72 sources: list[SearchChunk]
75class Searcher:
76 """RAG search pipeline -- embed, search, expand, rerank, generate.
77 All search and answer operations go through this class.
78 Constructed with injected dependencies via the Services container.
79 """
81 def __init__(
82 self,
83 config: Config,
84 provider: LLMProvider,
85 store: Store,
86 embedder: Embedder,
87 reranker: Reranker,
88 concepts: ConceptGraph,
89 ) -> None:
90 self._config = config
91 self._provider = provider
92 self._store = store
93 self._embedder = embedder
94 self._reranker = reranker
95 self._concepts = concepts
97 def _apply_temporal_filter(
98 self, results: list[SearchChunk], question: str
99 ) -> list[SearchChunk]:
100 if not self._config.temporal_filtering:
101 return results
102 from lilbee.runtime.temporal import detect_temporal, resolve_date_range
104 keyword = detect_temporal(question)
105 if keyword is None:
106 return results
107 date_range = resolve_date_range(keyword)
108 source_dates = self._store.source_ingested_at_map()
109 filtered: list[SearchChunk] = []
110 for r in results:
111 ingested_at = source_dates.get(r.source, "")
112 if not ingested_at:
113 filtered.append(r)
114 continue
115 try:
116 doc_date = datetime.fromisoformat(ingested_at)
117 if date_range.start <= doc_date <= date_range.end:
118 filtered.append(r)
119 except (ValueError, TypeError):
120 filtered.append(r)
121 return filtered if filtered else results
123 def _apply_guardrails(
124 self,
125 variants: list[tuple[str, list[float]]],
126 question_vec: list[float],
127 ) -> list[tuple[str, list[float]]]:
128 """Drop expansion variants whose embedding drifts too far from the question."""
129 if not self._config.expansion_guardrails:
130 return variants
131 threshold = self._config.expansion_similarity_threshold
132 return [(text, vec) for text, vec in variants if cosine_sim(question_vec, vec) >= threshold]
134 def _concept_query_expansion(self, question: str) -> list[str]:
135 if not self._config.concept_graph:
136 return []
137 try:
138 if not self._concepts.get_graph():
139 return []
140 return self._concepts.expand_query(question)
141 except Exception:
142 log.debug("Concept query expansion failed", exc_info=True)
143 return []
145 def _llm_expand(self, question: str, count: int) -> list[str]:
146 """Call the LLM to produce ``count`` alternative phrasings."""
147 prompt = EXPANSION_PROMPT.format(count=count, question=question)
148 messages = [{"role": "user", "content": prompt}]
149 response = self._provider.chat(
150 messages, stream=False, options={"num_predict": EXPANSION_MAX_TOKENS}
151 )
152 if not isinstance(response, str):
153 return []
154 variants = [line.strip() for line in response.strip().split("\n") if line.strip()]
155 return variants[:count]
157 def _expand_query(
158 self, question: str, question_vec: list[float]
159 ) -> list[tuple[str, list[float]]]:
160 """Return ``(variant, variant_vec)`` pairs for downstream search.
162 LLM variants run through ``_apply_guardrails``; concept-graph
163 variants bypass it since they come from deterministic traversal.
164 Embeddings batch per source: one provider round-trip per source.
165 """
166 count = self._config.query_expansion_count
167 if count <= 0 and not self._config.concept_graph:
168 return []
169 # Short queries skip LLM expansion: BM25/vector signal is already strong
170 # and the LLM round-trip dominates latency on small local models.
171 # Concept-graph expansion still runs.
172 short_threshold = self._config.expansion_short_query_tokens
173 skip_llm = short_threshold > 0 and len(_tokenize(question)) <= short_threshold
174 try:
175 llm_variants: list[tuple[str, list[float]]] = []
176 if count > 0 and not skip_llm:
177 llm_texts = list(self._llm_expand(question, count))
178 if llm_texts:
179 llm_vectors = self._embedder.embed_batch(llm_texts)
180 llm_variants = list(zip(llm_texts, llm_vectors, strict=True))
181 llm_variants = self._apply_guardrails(llm_variants, question_vec)
183 concept_texts = list(self._concept_query_expansion(question))
184 if concept_texts:
185 concept_vectors = self._embedder.embed_batch(concept_texts)
186 llm_variants.extend(zip(concept_texts, concept_vectors, strict=True))
188 return llm_variants
189 except Exception as exc:
190 log.warning("Query expansion disabled for this call: %s", exc)
191 log.debug("Query expansion exception", exc_info=True)
192 return []
194 def _should_skip_expansion(self, question: str) -> bool:
195 if self._config.expansion_skip_threshold <= 0:
196 return False
197 results = self._store.bm25_probe(question, top_k=2)
198 if not results:
199 return False
200 top_score = results[0].relevance_score or 0
201 if top_score < self._config.expansion_skip_threshold:
202 return False
203 if len(results) < _MIN_BM25_PROBE_RESULTS:
204 return True
205 second_score = results[1].relevance_score or 0
206 return (top_score - second_score) >= self._config.expansion_skip_gap
208 def _apply_concept_boost(self, results: list[SearchChunk], question: str) -> list[SearchChunk]:
209 if not self._config.concept_graph or not results:
210 return results
211 try:
212 if not self._concepts.get_graph():
213 return results
214 query_concepts = self._concepts.extract_concepts(question)
215 if not query_concepts:
216 return results
217 return self._concepts.boost_results(results, query_concepts)
218 except Exception:
219 log.debug("Concept boost failed", exc_info=True)
220 return results
222 def _hyde_search(self, question: str, top_k: int) -> list[SearchChunk]:
223 """Hypothetical Document Embedding search.
224 Gao et al. 2022, "Precise Zero-Shot Dense Retrieval without
225 Relevance Labels" -- generates a hypothetical answer passage,
226 embeds it, and uses the embedding to search for real documents.
227 """
228 try:
229 response = self._provider.chat(
230 [{"role": "user", "content": self._config.hyde_prompt.format(question=question)}],
231 stream=False,
232 options={"num_predict": EXPANSION_MAX_TOKENS},
233 )
234 if not isinstance(response, str) or not response.strip():
235 return []
236 hyde_vec = self._embedder.embed(response.strip())
237 return self._store.search(hyde_vec, top_k=top_k, query_text=None)
238 except Exception:
239 log.debug("HyDE search failed", exc_info=True)
240 return []
242 def _normalize_chunk_type(self, chunk_type: ChunkType | None) -> ChunkType | None:
243 """Drop ``chunk_type=ChunkType.WIKI`` when wiki generation is disabled.
245 With wiki off the chunks table contains only raw rows, so the
246 filter would return empty. Logging once keeps the surprise out
247 of the user's way while surfacing the misuse in logs.
248 """
249 if chunk_type == ChunkType.WIKI and not self._config.wiki:
250 log.warning(
251 "wiki scope requested but wiki is disabled; searching the full pool instead"
252 )
253 return None
254 return chunk_type
256 def _parse_structured_query(self, question: str) -> tuple[str | None, str]:
257 for prefix in ("term:", "vec:", "hyde:", "wiki:", "raw:"):
258 if question.strip().lower().startswith(prefix):
259 return prefix[:-1], question.strip()[len(prefix) :].strip()
260 return None, question
262 def _search_structured(
263 self,
264 mode: str,
265 query: str,
266 top_k: int,
267 chunk_type: ChunkType | None = None,
268 ) -> list[SearchChunk]:
269 if mode == "term":
270 return self._store.bm25_probe(query, top_k=top_k)
271 if mode == "vec":
272 query_vec = self._embedder.embed(query)
273 return self._store.search(query_vec, top_k=top_k, query_text=None)
274 if mode == "hyde":
275 return self._hyde_search(query, top_k)
276 if mode in (ChunkType.WIKI, ChunkType.RAW):
277 # Explicit ``chunk_type`` arg beats the ``wiki:``/``raw:`` prefix shortcut.
278 effective = chunk_type if chunk_type is not None else ChunkType(mode)
279 query_vec = self._embedder.embed(query)
280 return self._store.search(
281 query_vec, top_k=top_k, query_text=query, chunk_type=effective
282 )
283 return []
285 def select_context(
286 self, results: list[SearchChunk], question: str, max_sources: int | None = None
287 ) -> list[SearchChunk]:
288 """Pick ``max_sources`` chunks.
290 Results carrying ``rerank_score`` keep the cross-encoder order (top
291 ``max_sources``); otherwise greedy IDF-weighted set cover.
292 """
293 if max_sources is None:
294 max_sources = self._config.max_context_sources
295 if len(results) <= max_sources:
296 return results
297 if any(r.rerank_score is not None for r in results):
298 return results[:max_sources]
300 question_terms = set(_tokenize(question))
301 if not question_terms:
302 return results[:max_sources]
304 chunk_tokens = [set(_tokenize(r.chunk)) for r in results]
305 term_weights = _idf_weights(question_terms, chunk_tokens)
306 if not any(term_weights.values()):
307 return results[:max_sources]
309 weights = [_relevance_weight(r) for r in results]
310 selected = _greedy_cover(chunk_tokens, question_terms, term_weights, max_sources, weights)
311 selected.sort()
312 return [results[i] for i in selected]
314 def _merge_variant_results(
315 self,
316 question: str,
317 query_vec: list[float],
318 results: list[SearchChunk],
319 seen: set[tuple[str, int]],
320 top_k: int,
321 chunk_type: ChunkType | None,
322 ) -> None:
323 """Append unseen variant-search hits to ``results`` (in place)."""
324 for variant, variant_vec in self._expand_query(question, query_vec):
325 variant_results = self._store.search(
326 variant_vec,
327 top_k=top_k,
328 query_text=variant,
329 chunk_type=chunk_type,
330 )
331 for r in variant_results:
332 key = (r.source, r.chunk_index)
333 if key not in seen:
334 results.append(r)
335 seen.add(key)
337 def _merge_hyde_results(
338 self,
339 question: str,
340 results: list[SearchChunk],
341 seen: set[tuple[str, int]],
342 top_k: int,
343 ) -> None:
344 """Append unseen HyDE hits to ``results`` (in place), reweighted by ``hyde_weight``."""
345 for r in self._hyde_search(question, top_k):
346 key = (r.source, r.chunk_index)
347 if key in seen:
348 continue
349 if r.distance is not None and self._config.hyde_weight > 0:
350 r = r.model_copy(update={"distance": r.distance / self._config.hyde_weight})
351 results.append(r)
352 seen.add(key)
354 def search(
355 self,
356 question: str,
357 top_k: int = 0,
358 *,
359 chunk_type: ChunkType | None = None,
360 ) -> list[SearchChunk]:
361 """Embed question and search with expansion, HyDE, and concept boost.
362 Returns up to top_k*2 candidates for downstream filtering.
364 When *chunk_type* is set (``"raw"`` or ``"wiki"``), only chunks of
365 that type are returned. An explicit ``chunk_type`` always wins
366 over the ``wiki:``/``raw:`` prefix shortcut in *question* so the
367 user-facing scope choice has the final say.
369 When ``chunk_type="wiki"`` but wiki generation is disabled on the
370 config, the filter is normalized to ``None`` (mixed pool) and a
371 warning is logged: with wiki off the chunks table has no wiki rows,
372 so honouring the filter would silently return zero results.
373 """
374 if top_k == 0:
375 top_k = self._config.top_k
376 chunk_type = self._normalize_chunk_type(chunk_type)
377 mode, clean_query = self._parse_structured_query(question)
378 if mode is not None:
379 return self._search_structured(mode, clean_query, top_k, chunk_type=chunk_type)
380 query_vec = self._embedder.embed(question)
381 results = self._store.search(
382 query_vec,
383 top_k=top_k,
384 query_text=question,
385 chunk_type=chunk_type,
386 )
387 if self._should_skip_expansion(question):
388 return results[: top_k * 2]
389 seen = {(r.source, r.chunk_index) for r in results}
390 self._merge_variant_results(question, query_vec, results, seen, top_k, chunk_type)
391 if self._config.hyde:
392 self._merge_hyde_results(question, results, seen, top_k)
393 results = self._apply_concept_boost(results, question)
394 return results[: top_k * 2]
396 def build_rag_context(
397 self,
398 question: str,
399 top_k: int = 0,
400 history: list[ChatMessage] | None = None,
401 *,
402 chunk_type: ChunkType | None = None,
403 ) -> tuple[list[SearchChunk], list[ChatMessage]] | None:
404 """Build RAG context from search results.
406 ``chunk_type`` restricts the pool to ``"raw"`` or ``"wiki"`` rows;
407 ``None`` (default) searches the mixed pool.
408 """
409 results = self.search(question, top_k=top_k, chunk_type=chunk_type)
410 results = filter_results(
411 results, self._config.max_distance, self._config.min_relevance_score
412 )
413 if not results:
414 return None
415 results = prepare_results(results)
416 if self._config.reranker_model:
417 results = self._reranker.rerank(question, results)
418 results = self._apply_temporal_filter(results, question)
419 results = self.select_context(results, question)
420 context = build_context(results)
421 prompt = CONTEXT_TEMPLATE.format(context=context, question=question)
422 messages: list[ChatMessage] = [
423 {
424 "role": "system",
425 "content": self._system_with_memory(self._config.rag_system_prompt, question),
426 }
427 ]
428 if history:
429 messages.extend(history)
430 messages.append({"role": "user", "content": prompt})
431 return results, messages
433 def _system_with_memory(self, base_prompt: str, question: str) -> str:
434 """Append the local-owner memory block to *base_prompt* when memory is enabled."""
435 block = self._memory_block(question)
436 return f"{base_prompt}\n\n{block}" if block else base_prompt
438 def _memory_block(self, question: str) -> str:
439 """Recall the local human's preferences and relevant facts as a system block.
441 Preferences are always included; facts are recalled by similarity. Empty
442 when memory is disabled or nothing matches. MCP agents never reach this path
443 (their tools recall explicitly under their own owner).
444 """
445 if not self._config.memory_enabled:
446 return ""
447 owner_predicate = local_owner_predicate()
448 preferences = self._store.get_memories(
449 owner_predicate=owner_predicate,
450 kind=MemoryKind.PREFERENCE,
451 )
452 facts: list[MemoryRow] = []
453 if self._config.memory_top_k > 0 and self._embedder.embedding_available():
454 vector = self._embedder.embed(question)
455 facts = self._store.search_memories(
456 vector,
457 owner_predicate=owner_predicate,
458 top_k=self._config.memory_top_k,
459 max_distance=self._config.memory_max_distance,
460 )
461 return format_memory_block(preferences, facts, self._config.memory_token_budget)
463 def skip_retrieval(self) -> bool:
464 """Whether this turn should bypass RAG: chat-only mode or no embedder."""
465 return (
466 self._config.chat_mode == ChatMode.CHAT.value
467 or not self._embedder.embedding_available()
468 )
470 def direct_messages(
471 self, question: str, history: list[ChatMessage] | None = None
472 ) -> list[ChatMessage]:
473 """Build messages for direct LLM chat (no RAG context)."""
474 messages: list[ChatMessage] = [
475 {
476 "role": "system",
477 "content": self._system_with_memory(self._config.general_system_prompt, question),
478 }
479 ]
480 if history:
481 messages.extend(history)
482 messages.append({"role": "user", "content": question})
483 return messages
485 def _messages_for_provider(self, messages: list[ChatMessage]) -> list[dict[str, str]]:
486 """Convert ChatMessage list to provider-expected format."""
487 return [{"role": m["role"], "content": m["content"]} for m in messages]
489 def _direct_chat(
490 self,
491 question: str,
492 history: list[ChatMessage] | None,
493 options: dict[str, Any] | None,
494 ) -> str:
495 """Run a no-RAG chat turn and return the cleaned response."""
496 messages = self.direct_messages(question, history)
497 provider_messages = self._messages_for_provider(messages)
498 opts = options if options is not None else self._config.generation_options()
499 raw = str(self._provider.chat(provider_messages, options=opts or None) or "")
500 return raw if self._config.show_reasoning else strip_reasoning(raw)
502 def ask_raw(
503 self,
504 question: str,
505 top_k: int = 0,
506 history: list[ChatMessage] | None = None,
507 options: dict[str, Any] | None = None,
508 *,
509 chunk_type: ChunkType | None = None,
510 ) -> AskResult:
511 """Ask a question. Skips retrieval when chat_mode is 'chat' or
512 when no embedding model is configured."""
513 if self.skip_retrieval():
514 return AskResult(answer=self._direct_chat(question, history, options), sources=[])
515 rag = self.build_rag_context(question, top_k=top_k, history=history, chunk_type=chunk_type)
516 if rag is None:
517 return AskResult(answer=self._direct_chat(question, history, options), sources=[])
518 results, messages = rag
519 provider_messages = self._messages_for_provider(messages)
520 opts = options if options is not None else self._config.generation_options()
521 raw = str(self._provider.chat(provider_messages, options=opts or None) or "")
522 clean = raw if self._config.show_reasoning else strip_reasoning(raw)
523 return AskResult(answer=clean, sources=results)
525 def ask(
526 self,
527 question: str,
528 top_k: int = 0,
529 history: list[ChatMessage] | None = None,
530 options: dict[str, Any] | None = None,
531 *,
532 chunk_type: ChunkType | None = None,
533 ) -> str:
534 """Ask a question and get a formatted answer with citations."""
535 result = self.ask_raw(
536 question, top_k=top_k, history=history, options=options, chunk_type=chunk_type
537 )
538 if not result.sources:
539 return result.answer
540 cited = _extract_cited_indices(result.answer)
541 used = [result.sources[i - 1] for i in sorted(cited) if 1 <= i <= len(result.sources)]
542 answer = strip_llm_citations(result.answer)
543 source_list = used if used else result.sources
544 citations = deduplicate_sources(source_list)
545 return f"{answer}\n\nSources:\n" + "\n".join(citations)
547 def _stream_direct(
548 self,
549 question: str,
550 history: list[ChatMessage] | None,
551 options: dict[str, Any] | None,
552 ) -> Generator[StreamToken, None, None]:
553 """Streaming branch with the general system prompt (no RAG context)."""
554 messages = self.direct_messages(question, history)
555 provider_messages = self._messages_for_provider(messages)
556 opts = options if options is not None else self._config.generation_options()
557 events = stream_chat_with_cap(
558 self._provider,
559 cast("list[dict[str, Any]]", provider_messages),
560 options=opts,
561 model=self._config.chat_model,
562 show_reasoning=self._config.show_reasoning,
563 cap_chars=effective_reasoning_cap(),
564 )
565 try:
566 yield from cap_events_as_stream_tokens(events)
567 except (ConnectionError, OSError) as exc:
568 yield StreamToken(content=f"\n\n[Connection lost: {exc}]", is_reasoning=False)
570 def ask_stream(
571 self,
572 question: str,
573 top_k: int = 0,
574 history: list[ChatMessage] | None = None,
575 options: dict[str, Any] | None = None,
576 *,
577 chunk_type: ChunkType | None = None,
578 ) -> Generator[StreamToken, None, None]:
579 """Stream answer tokens with citations appended at the end."""
580 if self.skip_retrieval():
581 yield from self._stream_direct(question, history, options)
582 return
584 rag = self.build_rag_context(question, top_k=top_k, history=history, chunk_type=chunk_type)
585 if rag is None:
586 yield from self._stream_direct(question, history, options)
587 return
588 results, messages = rag
589 provider_messages = self._messages_for_provider(messages)
590 opts = options if options is not None else self._config.generation_options()
591 answer_parts: list[str] = []
592 events = stream_chat_with_cap(
593 self._provider,
594 cast("list[dict[str, Any]]", provider_messages),
595 options=opts,
596 model=self._config.chat_model,
597 show_reasoning=self._config.show_reasoning,
598 cap_chars=effective_reasoning_cap(),
599 )
600 try:
601 for token in cap_events_as_stream_tokens(events):
602 if not token.is_reasoning:
603 answer_parts.append(token.content)
604 yield token
605 except (ConnectionError, OSError) as exc:
606 yield StreamToken(content=f"\n\n[Connection lost: {exc}]", is_reasoning=False)
607 # LLM-generated citation blocks in streamed tokens cannot be
608 # retroactively stripped. The system prompt discourages them; this
609 # only filters the code-appended Sources block to cited chunks.
610 full_answer = "".join(answer_parts)
611 cited = _extract_cited_indices(full_answer)
612 used = [results[i - 1] for i in sorted(cited) if 1 <= i <= len(results)]
613 source_list = used if used else results
614 citations = deduplicate_sources(source_list)
615 yield StreamToken(content="\n\nSources:\n" + "\n".join(citations), is_reasoning=False)