Coverage for src / lilbee / retrieval / query / searcher.py: 100%
298 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
1"""RAG search pipeline -- embed, search, expand, rerank, generate."""
3from __future__ import annotations
5import logging
6from collections.abc import Generator
7from datetime import datetime
8from typing import TYPE_CHECKING, Any, cast
10from pydantic import BaseModel
11from typing_extensions import TypedDict
13from lilbee.core.config import Config
14from lilbee.core.config.enums import ChatMode
15from lilbee.data.store import (
16 CHUNK_TYPE_RAW,
17 CHUNK_TYPE_WIKI,
18 SearchChunk,
19 Store,
20 cosine_sim,
21)
22from lilbee.providers.base import LLMProvider
23from lilbee.retrieval.embedder import Embedder
24from lilbee.retrieval.query.dedup import (
25 _greedy_cover,
26 _relevance_weight,
27 deduplicate_sources,
28 filter_results,
29 prepare_results,
30)
31from lilbee.retrieval.query.expansion import EXPANSION_MAX_TOKENS, EXPANSION_PROMPT
32from lilbee.retrieval.query.formatting import (
33 CONTEXT_TEMPLATE,
34 _extract_cited_indices,
35 build_context,
36 strip_llm_citations,
37)
38from lilbee.retrieval.query.tokenize import _idf_weights, _tokenize
39from lilbee.retrieval.reasoning import (
40 StreamToken,
41 cap_events_as_stream_tokens,
42 effective_reasoning_cap,
43 stream_chat_with_cap,
44 strip_reasoning,
45)
47if TYPE_CHECKING:
48 from lilbee.retrieval.concepts import ConceptGraph
49 from lilbee.retrieval.reranker import Reranker
51log = logging.getLogger(__name__)
53# BM25 probe needs at least this many hits to compare top vs. runner-up
54# scores for the expansion-skip heuristic.
55_MIN_BM25_PROBE_RESULTS = 2
58class ChatMessage(TypedDict):
59 """A single chat message with role and content."""
61 role: str
62 content: str
65class AskResult(BaseModel):
66 """Structured result from ask_raw -- answer text + raw search results."""
68 answer: str
69 sources: list[SearchChunk]
72class Searcher:
73 """RAG search pipeline -- embed, search, expand, rerank, generate.
74 All search and answer operations go through this class.
75 Constructed with injected dependencies via the Services container.
76 """
78 def __init__(
79 self,
80 config: Config,
81 provider: LLMProvider,
82 store: Store,
83 embedder: Embedder,
84 reranker: Reranker,
85 concepts: ConceptGraph,
86 ) -> None:
87 self._config = config
88 self._provider = provider
89 self._store = store
90 self._embedder = embedder
91 self._reranker = reranker
92 self._concepts = concepts
94 def _apply_temporal_filter(
95 self, results: list[SearchChunk], question: str
96 ) -> list[SearchChunk]:
97 if not self._config.temporal_filtering:
98 return results
99 from lilbee.runtime.temporal import detect_temporal, resolve_date_range
101 keyword = detect_temporal(question)
102 if keyword is None:
103 return results
104 date_range = resolve_date_range(keyword)
105 source_dates = self._store.source_ingested_at_map()
106 filtered: list[SearchChunk] = []
107 for r in results:
108 ingested_at = source_dates.get(r.source, "")
109 if not ingested_at:
110 filtered.append(r)
111 continue
112 try:
113 doc_date = datetime.fromisoformat(ingested_at)
114 if date_range.start <= doc_date <= date_range.end:
115 filtered.append(r)
116 except (ValueError, TypeError):
117 filtered.append(r)
118 return filtered if filtered else results
120 def _apply_guardrails(
121 self,
122 variants: list[tuple[str, list[float]]],
123 question_vec: list[float],
124 ) -> list[tuple[str, list[float]]]:
125 """Drop expansion variants whose embedding drifts too far from the question."""
126 if not self._config.expansion_guardrails:
127 return variants
128 threshold = self._config.expansion_similarity_threshold
129 return [(text, vec) for text, vec in variants if cosine_sim(question_vec, vec) >= threshold]
131 def _concept_query_expansion(self, question: str) -> list[str]:
132 if not self._config.concept_graph:
133 return []
134 try:
135 if not self._concepts.get_graph():
136 return []
137 return self._concepts.expand_query(question)
138 except Exception:
139 log.debug("Concept query expansion failed", exc_info=True)
140 return []
142 def _llm_expand(self, question: str, count: int) -> list[str]:
143 """Call the LLM to produce ``count`` alternative phrasings."""
144 prompt = EXPANSION_PROMPT.format(count=count, question=question)
145 messages = [{"role": "user", "content": prompt}]
146 response = self._provider.chat(
147 messages, stream=False, options={"num_predict": EXPANSION_MAX_TOKENS}
148 )
149 if not isinstance(response, str):
150 return []
151 variants = [line.strip() for line in response.strip().split("\n") if line.strip()]
152 return variants[:count]
154 def _expand_query(
155 self, question: str, question_vec: list[float]
156 ) -> list[tuple[str, list[float]]]:
157 """Return ``(variant, variant_vec)`` pairs for downstream search.
159 LLM variants run through ``_apply_guardrails``; concept-graph
160 variants bypass it since they come from deterministic traversal.
161 Embeddings batch per source: one provider round-trip per source.
162 """
163 count = self._config.query_expansion_count
164 if count <= 0 and not self._config.concept_graph:
165 return []
166 # Short queries skip LLM expansion: BM25/vector signal is already strong
167 # and the LLM round-trip dominates latency on small local models.
168 # Concept-graph expansion still runs.
169 short_threshold = self._config.expansion_short_query_tokens
170 skip_llm = short_threshold > 0 and len(_tokenize(question)) <= short_threshold
171 try:
172 llm_variants: list[tuple[str, list[float]]] = []
173 if count > 0 and not skip_llm:
174 llm_texts = list(self._llm_expand(question, count))
175 if llm_texts:
176 llm_vectors = self._embedder.embed_batch(llm_texts)
177 llm_variants = list(zip(llm_texts, llm_vectors, strict=True))
178 llm_variants = self._apply_guardrails(llm_variants, question_vec)
180 concept_texts = list(self._concept_query_expansion(question))
181 if concept_texts:
182 concept_vectors = self._embedder.embed_batch(concept_texts)
183 llm_variants.extend(zip(concept_texts, concept_vectors, strict=True))
185 return llm_variants
186 except Exception as exc:
187 log.warning("Query expansion disabled for this call: %s", exc)
188 log.debug("Query expansion exception", exc_info=True)
189 return []
191 def _should_skip_expansion(self, question: str) -> bool:
192 if self._config.expansion_skip_threshold <= 0:
193 return False
194 results = self._store.bm25_probe(question, top_k=2)
195 if not results:
196 return False
197 top_score = results[0].relevance_score or 0
198 if top_score < self._config.expansion_skip_threshold:
199 return False
200 if len(results) < _MIN_BM25_PROBE_RESULTS:
201 return True
202 second_score = results[1].relevance_score or 0
203 return (top_score - second_score) >= self._config.expansion_skip_gap
205 def _apply_concept_boost(self, results: list[SearchChunk], question: str) -> list[SearchChunk]:
206 if not self._config.concept_graph or not results:
207 return results
208 try:
209 if not self._concepts.get_graph():
210 return results
211 query_concepts = self._concepts.extract_concepts(question)
212 if not query_concepts:
213 return results
214 return self._concepts.boost_results(results, query_concepts)
215 except Exception:
216 log.debug("Concept boost failed", exc_info=True)
217 return results
219 def _hyde_search(self, question: str, top_k: int) -> list[SearchChunk]:
220 """Hypothetical Document Embedding search.
221 Gao et al. 2022, "Precise Zero-Shot Dense Retrieval without
222 Relevance Labels" -- generates a hypothetical answer passage,
223 embeds it, and uses the embedding to search for real documents.
224 """
225 try:
226 response = self._provider.chat(
227 [{"role": "user", "content": self._config.hyde_prompt.format(question=question)}],
228 stream=False,
229 options={"num_predict": EXPANSION_MAX_TOKENS},
230 )
231 if not isinstance(response, str) or not response.strip():
232 return []
233 hyde_vec = self._embedder.embed(response.strip())
234 return self._store.search(hyde_vec, top_k=top_k, query_text=None)
235 except Exception:
236 log.debug("HyDE search failed", exc_info=True)
237 return []
239 def _normalize_chunk_type(self, chunk_type: str | None) -> str | None:
240 """Drop ``chunk_type="wiki"`` when wiki generation is disabled.
242 With wiki off the chunks table contains only raw rows, so the
243 filter would return empty. Logging once keeps the surprise out
244 of the user's way while surfacing the misuse in logs.
245 """
246 if chunk_type == CHUNK_TYPE_WIKI and not self._config.wiki:
247 log.warning(
248 "wiki scope requested but wiki is disabled; searching the full pool instead"
249 )
250 return None
251 return chunk_type
253 def _parse_structured_query(self, question: str) -> tuple[str | None, str]:
254 for prefix in ("term:", "vec:", "hyde:", "wiki:", "raw:"):
255 if question.strip().lower().startswith(prefix):
256 return prefix[:-1], question.strip()[len(prefix) :].strip()
257 return None, question
259 def _search_structured(
260 self,
261 mode: str,
262 query: str,
263 top_k: int,
264 chunk_type: str | None = None,
265 ) -> list[SearchChunk]:
266 if mode == "term":
267 return self._store.bm25_probe(query, top_k=top_k)
268 if mode == "vec":
269 query_vec = self._embedder.embed(query)
270 return self._store.search(query_vec, top_k=top_k, query_text=None)
271 if mode == "hyde":
272 return self._hyde_search(query, top_k)
273 if mode in (CHUNK_TYPE_WIKI, CHUNK_TYPE_RAW):
274 # Explicit ``chunk_type`` arg beats the prefix shortcut.
275 effective = chunk_type if chunk_type is not None else mode
276 query_vec = self._embedder.embed(query)
277 return self._store.search(
278 query_vec, top_k=top_k, query_text=query, chunk_type=effective
279 )
280 return []
282 def select_context(
283 self, results: list[SearchChunk], question: str, max_sources: int | None = None
284 ) -> list[SearchChunk]:
285 """Pick ``max_sources`` chunks by greedy IDF-weighted set cover."""
286 if max_sources is None:
287 max_sources = self._config.max_context_sources
288 if len(results) <= max_sources:
289 return results
291 question_terms = set(_tokenize(question))
292 if not question_terms:
293 return results[:max_sources]
295 chunk_tokens = [set(_tokenize(r.chunk)) for r in results]
296 term_weights = _idf_weights(question_terms, chunk_tokens)
297 if not any(term_weights.values()):
298 return results[:max_sources]
300 weights = [_relevance_weight(r) for r in results]
301 selected = _greedy_cover(chunk_tokens, question_terms, term_weights, max_sources, weights)
302 selected.sort()
303 return [results[i] for i in selected]
305 def _merge_variant_results(
306 self,
307 question: str,
308 query_vec: list[float],
309 results: list[SearchChunk],
310 seen: set[tuple[str, int]],
311 top_k: int,
312 chunk_type: str | None,
313 ) -> None:
314 """Append unseen variant-search hits to ``results`` (in place)."""
315 for variant, variant_vec in self._expand_query(question, query_vec):
316 variant_results = self._store.search(
317 variant_vec,
318 top_k=top_k,
319 query_text=variant,
320 chunk_type=chunk_type,
321 )
322 for r in variant_results:
323 key = (r.source, r.chunk_index)
324 if key not in seen:
325 results.append(r)
326 seen.add(key)
328 def _merge_hyde_results(
329 self,
330 question: str,
331 results: list[SearchChunk],
332 seen: set[tuple[str, int]],
333 top_k: int,
334 ) -> None:
335 """Append unseen HyDE hits to ``results`` (in place), reweighted by ``hyde_weight``."""
336 for r in self._hyde_search(question, top_k):
337 key = (r.source, r.chunk_index)
338 if key in seen:
339 continue
340 if r.distance is not None and self._config.hyde_weight > 0:
341 r = r.model_copy(update={"distance": r.distance / self._config.hyde_weight})
342 results.append(r)
343 seen.add(key)
345 def search(
346 self,
347 question: str,
348 top_k: int = 0,
349 *,
350 chunk_type: str | None = None,
351 ) -> list[SearchChunk]:
352 """Embed question and search with expansion, HyDE, and concept boost.
353 Returns up to top_k*2 candidates for downstream filtering.
355 When *chunk_type* is set (``"raw"`` or ``"wiki"``), only chunks of
356 that type are returned. An explicit ``chunk_type`` always wins
357 over the ``wiki:``/``raw:`` prefix shortcut in *question* so the
358 user-facing scope choice has the final say.
360 When ``chunk_type="wiki"`` but wiki generation is disabled on the
361 config, the filter is normalized to ``None`` (mixed pool) and a
362 warning is logged: with wiki off the chunks table has no wiki rows,
363 so honouring the filter would silently return zero results.
364 """
365 if top_k == 0:
366 top_k = self._config.top_k
367 chunk_type = self._normalize_chunk_type(chunk_type)
368 mode, clean_query = self._parse_structured_query(question)
369 if mode is not None:
370 return self._search_structured(mode, clean_query, top_k, chunk_type=chunk_type)
371 query_vec = self._embedder.embed(question)
372 results = self._store.search(
373 query_vec,
374 top_k=top_k,
375 query_text=question,
376 chunk_type=chunk_type,
377 )
378 if self._should_skip_expansion(question):
379 return results[: top_k * 2]
380 seen = {(r.source, r.chunk_index) for r in results}
381 self._merge_variant_results(question, query_vec, results, seen, top_k, chunk_type)
382 if self._config.hyde:
383 self._merge_hyde_results(question, results, seen, top_k)
384 results = self._apply_concept_boost(results, question)
385 return results[: top_k * 2]
387 def build_rag_context(
388 self,
389 question: str,
390 top_k: int = 0,
391 history: list[ChatMessage] | None = None,
392 *,
393 chunk_type: str | None = None,
394 ) -> tuple[list[SearchChunk], list[ChatMessage]] | None:
395 """Build RAG context from search results.
397 ``chunk_type`` restricts the pool to ``"raw"`` or ``"wiki"`` rows;
398 ``None`` (default) searches the mixed pool.
399 """
400 results = self.search(question, top_k=top_k, chunk_type=chunk_type)
401 results = filter_results(
402 results, self._config.max_distance, self._config.min_relevance_score
403 )
404 if not results:
405 return None
406 results = prepare_results(results)
407 if self._config.reranker_model:
408 results = self._reranker.rerank(question, results)
409 results = self._apply_temporal_filter(results, question)
410 results = self.select_context(results, question)
411 context = build_context(results)
412 prompt = CONTEXT_TEMPLATE.format(context=context, question=question)
413 messages: list[ChatMessage] = [
414 {"role": "system", "content": self._config.rag_system_prompt}
415 ]
416 if history:
417 messages.extend(history)
418 messages.append({"role": "user", "content": prompt})
419 return results, messages
421 def _direct_messages(
422 self, question: str, history: list[ChatMessage] | None = None
423 ) -> list[ChatMessage]:
424 """Build messages for direct LLM chat (no RAG context)."""
425 messages: list[ChatMessage] = [
426 {"role": "system", "content": self._config.general_system_prompt}
427 ]
428 if history:
429 messages.extend(history)
430 messages.append({"role": "user", "content": question})
431 return messages
433 def _messages_for_provider(self, messages: list[ChatMessage]) -> list[dict[str, str]]:
434 """Convert ChatMessage list to provider-expected format."""
435 return [{"role": m["role"], "content": m["content"]} for m in messages]
437 def _direct_chat(
438 self,
439 question: str,
440 history: list[ChatMessage] | None,
441 options: dict[str, Any] | None,
442 ) -> str:
443 """Run a no-RAG chat turn and return the cleaned response."""
444 messages = self._direct_messages(question, history)
445 provider_messages = self._messages_for_provider(messages)
446 opts = options if options is not None else self._config.generation_options()
447 raw = str(self._provider.chat(provider_messages, options=opts or None) or "")
448 return raw if self._config.show_reasoning else strip_reasoning(raw)
450 def ask_raw(
451 self,
452 question: str,
453 top_k: int = 0,
454 history: list[ChatMessage] | None = None,
455 options: dict[str, Any] | None = None,
456 *,
457 chunk_type: str | None = None,
458 ) -> AskResult:
459 """Ask a question. Skips retrieval when chat_mode is 'chat' or
460 when no embedding model is configured."""
461 if (
462 self._config.chat_mode == ChatMode.CHAT.value
463 or not self._embedder.embedding_available()
464 ):
465 return AskResult(answer=self._direct_chat(question, history, options), sources=[])
466 rag = self.build_rag_context(question, top_k=top_k, history=history, chunk_type=chunk_type)
467 if rag is None:
468 return AskResult(answer=self._direct_chat(question, history, options), sources=[])
469 results, messages = rag
470 provider_messages = self._messages_for_provider(messages)
471 opts = options if options is not None else self._config.generation_options()
472 raw = str(self._provider.chat(provider_messages, options=opts or None) or "")
473 clean = raw if self._config.show_reasoning else strip_reasoning(raw)
474 return AskResult(answer=clean, sources=results)
476 def ask(
477 self,
478 question: str,
479 top_k: int = 0,
480 history: list[ChatMessage] | None = None,
481 options: dict[str, Any] | None = None,
482 *,
483 chunk_type: str | None = None,
484 ) -> str:
485 """Ask a question and get a formatted answer with citations."""
486 result = self.ask_raw(
487 question, top_k=top_k, history=history, options=options, chunk_type=chunk_type
488 )
489 if not result.sources:
490 return result.answer
491 cited = _extract_cited_indices(result.answer)
492 used = [result.sources[i - 1] for i in sorted(cited) if 1 <= i <= len(result.sources)]
493 answer = strip_llm_citations(result.answer)
494 source_list = used if used else result.sources
495 citations = deduplicate_sources(source_list)
496 return f"{answer}\n\nSources:\n" + "\n".join(citations)
498 def _stream_direct(
499 self,
500 question: str,
501 history: list[ChatMessage] | None,
502 options: dict[str, Any] | None,
503 ) -> Generator[StreamToken, None, None]:
504 """Streaming branch with the general system prompt (no RAG context)."""
505 messages = self._direct_messages(question, history)
506 provider_messages = self._messages_for_provider(messages)
507 opts = options if options is not None else self._config.generation_options()
508 events = stream_chat_with_cap(
509 self._provider,
510 cast("list[dict[str, Any]]", provider_messages),
511 options=opts,
512 model=self._config.chat_model,
513 show_reasoning=self._config.show_reasoning,
514 cap_chars=effective_reasoning_cap(),
515 )
516 try:
517 yield from cap_events_as_stream_tokens(events)
518 except (ConnectionError, OSError) as exc:
519 yield StreamToken(content=f"\n\n[Connection lost: {exc}]", is_reasoning=False)
521 def ask_stream(
522 self,
523 question: str,
524 top_k: int = 0,
525 history: list[ChatMessage] | None = None,
526 options: dict[str, Any] | None = None,
527 *,
528 chunk_type: str | None = None,
529 ) -> Generator[StreamToken, None, None]:
530 """Stream answer tokens with citations appended at the end."""
531 if (
532 self._config.chat_mode == ChatMode.CHAT.value
533 or not self._embedder.embedding_available()
534 ):
535 yield from self._stream_direct(question, history, options)
536 return
538 rag = self.build_rag_context(question, top_k=top_k, history=history, chunk_type=chunk_type)
539 if rag is None:
540 yield from self._stream_direct(question, history, options)
541 return
542 results, messages = rag
543 provider_messages = self._messages_for_provider(messages)
544 opts = options if options is not None else self._config.generation_options()
545 answer_parts: list[str] = []
546 events = stream_chat_with_cap(
547 self._provider,
548 cast("list[dict[str, Any]]", provider_messages),
549 options=opts,
550 model=self._config.chat_model,
551 show_reasoning=self._config.show_reasoning,
552 cap_chars=effective_reasoning_cap(),
553 )
554 try:
555 for token in cap_events_as_stream_tokens(events):
556 if not token.is_reasoning:
557 answer_parts.append(token.content)
558 yield token
559 except (ConnectionError, OSError) as exc:
560 yield StreamToken(content=f"\n\n[Connection lost: {exc}]", is_reasoning=False)
561 # LLM-generated citation blocks in streamed tokens cannot be
562 # retroactively stripped. The system prompt discourages them; this
563 # only filters the code-appended Sources block to cited chunks.
564 full_answer = "".join(answer_parts)
565 cited = _extract_cited_indices(full_answer)
566 used = [results[i - 1] for i in sorted(cited) if 1 <= i <= len(results)]
567 source_list = used if used else results
568 citations = deduplicate_sources(source_list)
569 yield StreamToken(content="\n\nSources:\n" + "\n".join(citations), is_reasoning=False)