Coverage for src / lilbee / retrieval / query / searcher.py: 100%

316 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-06-28 01:01 +0000

1"""RAG search pipeline -- embed, search, expand, rerank, generate.""" 

2 

3from __future__ import annotations 

4 

5import logging 

6from collections.abc import Generator 

7from datetime import datetime 

8from typing import TYPE_CHECKING, Any, cast 

9 

10from pydantic import BaseModel 

11from typing_extensions import TypedDict 

12 

13from lilbee.core.config import Config 

14from lilbee.core.config.enums import ChatMode 

15from lilbee.data.store import ( 

16 ChunkType, 

17 MemoryKind, 

18 MemoryRow, 

19 SearchChunk, 

20 Store, 

21 cosine_sim, 

22 local_owner_predicate, 

23) 

24from lilbee.providers.base import LLMProvider 

25from lilbee.retrieval.embedder import Embedder 

26from lilbee.retrieval.query.dedup import ( 

27 _greedy_cover, 

28 _relevance_weight, 

29 deduplicate_sources, 

30 filter_results, 

31 prepare_results, 

32) 

33from lilbee.retrieval.query.expansion import EXPANSION_MAX_TOKENS, EXPANSION_PROMPT 

34from lilbee.retrieval.query.formatting import ( 

35 CONTEXT_TEMPLATE, 

36 _extract_cited_indices, 

37 build_context, 

38 strip_llm_citations, 

39) 

40from lilbee.retrieval.query.memory import format_memory_block 

41from lilbee.retrieval.query.tokenize import _idf_weights, _tokenize 

42from lilbee.retrieval.reasoning import ( 

43 StreamToken, 

44 cap_events_as_stream_tokens, 

45 effective_reasoning_cap, 

46 stream_chat_with_cap, 

47 strip_reasoning, 

48) 

49 

50if TYPE_CHECKING: 

51 from lilbee.retrieval.concepts import ConceptGraph 

52 from lilbee.retrieval.reranker import Reranker 

53 

54log = logging.getLogger(__name__) 

55 

56# BM25 probe needs at least this many hits to compare top vs. runner-up 

57# scores for the expansion-skip heuristic. 

58_MIN_BM25_PROBE_RESULTS = 2 

59 

60 

61class ChatMessage(TypedDict): 

62 """A single chat message with role and content.""" 

63 

64 role: str 

65 content: str 

66 

67 

68class AskResult(BaseModel): 

69 """Structured result from ask_raw -- answer text + raw search results.""" 

70 

71 answer: str 

72 sources: list[SearchChunk] 

73 

74 

75class Searcher: 

76 """RAG search pipeline -- embed, search, expand, rerank, generate. 

77 All search and answer operations go through this class. 

78 Constructed with injected dependencies via the Services container. 

79 """ 

80 

81 def __init__( 

82 self, 

83 config: Config, 

84 provider: LLMProvider, 

85 store: Store, 

86 embedder: Embedder, 

87 reranker: Reranker, 

88 concepts: ConceptGraph, 

89 ) -> None: 

90 self._config = config 

91 self._provider = provider 

92 self._store = store 

93 self._embedder = embedder 

94 self._reranker = reranker 

95 self._concepts = concepts 

96 

97 def _apply_temporal_filter( 

98 self, results: list[SearchChunk], question: str 

99 ) -> list[SearchChunk]: 

100 if not self._config.temporal_filtering: 

101 return results 

102 from lilbee.runtime.temporal import detect_temporal, resolve_date_range 

103 

104 keyword = detect_temporal(question) 

105 if keyword is None: 

106 return results 

107 date_range = resolve_date_range(keyword) 

108 source_dates = self._store.source_ingested_at_map() 

109 filtered: list[SearchChunk] = [] 

110 for r in results: 

111 ingested_at = source_dates.get(r.source, "") 

112 if not ingested_at: 

113 filtered.append(r) 

114 continue 

115 try: 

116 doc_date = datetime.fromisoformat(ingested_at) 

117 if date_range.start <= doc_date <= date_range.end: 

118 filtered.append(r) 

119 except (ValueError, TypeError): 

120 filtered.append(r) 

121 return filtered if filtered else results 

122 

123 def _apply_guardrails( 

124 self, 

125 variants: list[tuple[str, list[float]]], 

126 question_vec: list[float], 

127 ) -> list[tuple[str, list[float]]]: 

128 """Drop expansion variants whose embedding drifts too far from the question.""" 

129 if not self._config.expansion_guardrails: 

130 return variants 

131 threshold = self._config.expansion_similarity_threshold 

132 return [(text, vec) for text, vec in variants if cosine_sim(question_vec, vec) >= threshold] 

133 

134 def _concept_query_expansion(self, question: str) -> list[str]: 

135 if not self._config.concept_graph: 

136 return [] 

137 try: 

138 if not self._concepts.get_graph(): 

139 return [] 

140 return self._concepts.expand_query(question) 

141 except Exception: 

142 log.debug("Concept query expansion failed", exc_info=True) 

143 return [] 

144 

145 def _llm_expand(self, question: str, count: int) -> list[str]: 

146 """Call the LLM to produce ``count`` alternative phrasings.""" 

147 prompt = EXPANSION_PROMPT.format(count=count, question=question) 

148 messages = [{"role": "user", "content": prompt}] 

149 response = self._provider.chat( 

150 messages, stream=False, options={"num_predict": EXPANSION_MAX_TOKENS} 

151 ) 

152 if not isinstance(response, str): 

153 return [] 

154 variants = [line.strip() for line in response.strip().split("\n") if line.strip()] 

155 return variants[:count] 

156 

157 def _expand_query( 

158 self, question: str, question_vec: list[float] 

159 ) -> list[tuple[str, list[float]]]: 

160 """Return ``(variant, variant_vec)`` pairs for downstream search. 

161 

162 LLM variants run through ``_apply_guardrails``; concept-graph 

163 variants bypass it since they come from deterministic traversal. 

164 Embeddings batch per source: one provider round-trip per source. 

165 """ 

166 count = self._config.query_expansion_count 

167 if count <= 0 and not self._config.concept_graph: 

168 return [] 

169 # Short queries skip LLM expansion: BM25/vector signal is already strong 

170 # and the LLM round-trip dominates latency on small local models. 

171 # Concept-graph expansion still runs. 

172 short_threshold = self._config.expansion_short_query_tokens 

173 skip_llm = short_threshold > 0 and len(_tokenize(question)) <= short_threshold 

174 try: 

175 llm_variants: list[tuple[str, list[float]]] = [] 

176 if count > 0 and not skip_llm: 

177 llm_texts = list(self._llm_expand(question, count)) 

178 if llm_texts: 

179 llm_vectors = self._embedder.embed_batch(llm_texts) 

180 llm_variants = list(zip(llm_texts, llm_vectors, strict=True)) 

181 llm_variants = self._apply_guardrails(llm_variants, question_vec) 

182 

183 concept_texts = list(self._concept_query_expansion(question)) 

184 if concept_texts: 

185 concept_vectors = self._embedder.embed_batch(concept_texts) 

186 llm_variants.extend(zip(concept_texts, concept_vectors, strict=True)) 

187 

188 return llm_variants 

189 except Exception as exc: 

190 log.warning("Query expansion disabled for this call: %s", exc) 

191 log.debug("Query expansion exception", exc_info=True) 

192 return [] 

193 

194 def _should_skip_expansion(self, question: str) -> bool: 

195 if self._config.expansion_skip_threshold <= 0: 

196 return False 

197 results = self._store.bm25_probe(question, top_k=2) 

198 if not results: 

199 return False 

200 top_score = results[0].relevance_score or 0 

201 if top_score < self._config.expansion_skip_threshold: 

202 return False 

203 if len(results) < _MIN_BM25_PROBE_RESULTS: 

204 return True 

205 second_score = results[1].relevance_score or 0 

206 return (top_score - second_score) >= self._config.expansion_skip_gap 

207 

208 def _apply_concept_boost(self, results: list[SearchChunk], question: str) -> list[SearchChunk]: 

209 if not self._config.concept_graph or not results: 

210 return results 

211 try: 

212 if not self._concepts.get_graph(): 

213 return results 

214 query_concepts = self._concepts.extract_concepts(question) 

215 if not query_concepts: 

216 return results 

217 return self._concepts.boost_results(results, query_concepts) 

218 except Exception: 

219 log.debug("Concept boost failed", exc_info=True) 

220 return results 

221 

222 def _hyde_search(self, question: str, top_k: int) -> list[SearchChunk]: 

223 """Hypothetical Document Embedding search. 

224 Gao et al. 2022, "Precise Zero-Shot Dense Retrieval without 

225 Relevance Labels" -- generates a hypothetical answer passage, 

226 embeds it, and uses the embedding to search for real documents. 

227 """ 

228 try: 

229 response = self._provider.chat( 

230 [{"role": "user", "content": self._config.hyde_prompt.format(question=question)}], 

231 stream=False, 

232 options={"num_predict": EXPANSION_MAX_TOKENS}, 

233 ) 

234 if not isinstance(response, str) or not response.strip(): 

235 return [] 

236 hyde_vec = self._embedder.embed(response.strip()) 

237 return self._store.search(hyde_vec, top_k=top_k, query_text=None) 

238 except Exception: 

239 log.debug("HyDE search failed", exc_info=True) 

240 return [] 

241 

242 def _normalize_chunk_type(self, chunk_type: ChunkType | None) -> ChunkType | None: 

243 """Drop ``chunk_type=ChunkType.WIKI`` when wiki generation is disabled. 

244 

245 With wiki off the chunks table contains only raw rows, so the 

246 filter would return empty. Logging once keeps the surprise out 

247 of the user's way while surfacing the misuse in logs. 

248 """ 

249 if chunk_type == ChunkType.WIKI and not self._config.wiki: 

250 log.warning( 

251 "wiki scope requested but wiki is disabled; searching the full pool instead" 

252 ) 

253 return None 

254 return chunk_type 

255 

256 def _parse_structured_query(self, question: str) -> tuple[str | None, str]: 

257 for prefix in ("term:", "vec:", "hyde:", "wiki:", "raw:"): 

258 if question.strip().lower().startswith(prefix): 

259 return prefix[:-1], question.strip()[len(prefix) :].strip() 

260 return None, question 

261 

262 def _search_structured( 

263 self, 

264 mode: str, 

265 query: str, 

266 top_k: int, 

267 chunk_type: ChunkType | None = None, 

268 ) -> list[SearchChunk]: 

269 if mode == "term": 

270 return self._store.bm25_probe(query, top_k=top_k) 

271 if mode == "vec": 

272 query_vec = self._embedder.embed(query) 

273 return self._store.search(query_vec, top_k=top_k, query_text=None) 

274 if mode == "hyde": 

275 return self._hyde_search(query, top_k) 

276 if mode in (ChunkType.WIKI, ChunkType.RAW): 

277 # Explicit ``chunk_type`` arg beats the ``wiki:``/``raw:`` prefix shortcut. 

278 effective = chunk_type if chunk_type is not None else ChunkType(mode) 

279 query_vec = self._embedder.embed(query) 

280 return self._store.search( 

281 query_vec, top_k=top_k, query_text=query, chunk_type=effective 

282 ) 

283 return [] 

284 

285 def select_context( 

286 self, results: list[SearchChunk], question: str, max_sources: int | None = None 

287 ) -> list[SearchChunk]: 

288 """Pick ``max_sources`` chunks. 

289 

290 Results carrying ``rerank_score`` keep the cross-encoder order (top 

291 ``max_sources``); otherwise greedy IDF-weighted set cover. 

292 """ 

293 if max_sources is None: 

294 max_sources = self._config.max_context_sources 

295 if len(results) <= max_sources: 

296 return results 

297 if any(r.rerank_score is not None for r in results): 

298 return results[:max_sources] 

299 

300 question_terms = set(_tokenize(question)) 

301 if not question_terms: 

302 return results[:max_sources] 

303 

304 chunk_tokens = [set(_tokenize(r.chunk)) for r in results] 

305 term_weights = _idf_weights(question_terms, chunk_tokens) 

306 if not any(term_weights.values()): 

307 return results[:max_sources] 

308 

309 weights = [_relevance_weight(r) for r in results] 

310 selected = _greedy_cover(chunk_tokens, question_terms, term_weights, max_sources, weights) 

311 selected.sort() 

312 return [results[i] for i in selected] 

313 

314 def _merge_variant_results( 

315 self, 

316 question: str, 

317 query_vec: list[float], 

318 results: list[SearchChunk], 

319 seen: set[tuple[str, int]], 

320 top_k: int, 

321 chunk_type: ChunkType | None, 

322 ) -> None: 

323 """Append unseen variant-search hits to ``results`` (in place).""" 

324 for variant, variant_vec in self._expand_query(question, query_vec): 

325 variant_results = self._store.search( 

326 variant_vec, 

327 top_k=top_k, 

328 query_text=variant, 

329 chunk_type=chunk_type, 

330 ) 

331 for r in variant_results: 

332 key = (r.source, r.chunk_index) 

333 if key not in seen: 

334 results.append(r) 

335 seen.add(key) 

336 

337 def _merge_hyde_results( 

338 self, 

339 question: str, 

340 results: list[SearchChunk], 

341 seen: set[tuple[str, int]], 

342 top_k: int, 

343 ) -> None: 

344 """Append unseen HyDE hits to ``results`` (in place), reweighted by ``hyde_weight``.""" 

345 for r in self._hyde_search(question, top_k): 

346 key = (r.source, r.chunk_index) 

347 if key in seen: 

348 continue 

349 if r.distance is not None and self._config.hyde_weight > 0: 

350 r = r.model_copy(update={"distance": r.distance / self._config.hyde_weight}) 

351 results.append(r) 

352 seen.add(key) 

353 

354 def search( 

355 self, 

356 question: str, 

357 top_k: int = 0, 

358 *, 

359 chunk_type: ChunkType | None = None, 

360 ) -> list[SearchChunk]: 

361 """Embed question and search with expansion, HyDE, and concept boost. 

362 Returns up to top_k*2 candidates for downstream filtering. 

363 

364 When *chunk_type* is set (``"raw"`` or ``"wiki"``), only chunks of 

365 that type are returned. An explicit ``chunk_type`` always wins 

366 over the ``wiki:``/``raw:`` prefix shortcut in *question* so the 

367 user-facing scope choice has the final say. 

368 

369 When ``chunk_type="wiki"`` but wiki generation is disabled on the 

370 config, the filter is normalized to ``None`` (mixed pool) and a 

371 warning is logged: with wiki off the chunks table has no wiki rows, 

372 so honouring the filter would silently return zero results. 

373 """ 

374 if top_k == 0: 

375 top_k = self._config.top_k 

376 chunk_type = self._normalize_chunk_type(chunk_type) 

377 mode, clean_query = self._parse_structured_query(question) 

378 if mode is not None: 

379 return self._search_structured(mode, clean_query, top_k, chunk_type=chunk_type) 

380 query_vec = self._embedder.embed(question) 

381 results = self._store.search( 

382 query_vec, 

383 top_k=top_k, 

384 query_text=question, 

385 chunk_type=chunk_type, 

386 ) 

387 if self._should_skip_expansion(question): 

388 return results[: top_k * 2] 

389 seen = {(r.source, r.chunk_index) for r in results} 

390 self._merge_variant_results(question, query_vec, results, seen, top_k, chunk_type) 

391 if self._config.hyde: 

392 self._merge_hyde_results(question, results, seen, top_k) 

393 results = self._apply_concept_boost(results, question) 

394 return results[: top_k * 2] 

395 

396 def build_rag_context( 

397 self, 

398 question: str, 

399 top_k: int = 0, 

400 history: list[ChatMessage] | None = None, 

401 *, 

402 chunk_type: ChunkType | None = None, 

403 ) -> tuple[list[SearchChunk], list[ChatMessage]] | None: 

404 """Build RAG context from search results. 

405 

406 ``chunk_type`` restricts the pool to ``"raw"`` or ``"wiki"`` rows; 

407 ``None`` (default) searches the mixed pool. 

408 """ 

409 results = self.search(question, top_k=top_k, chunk_type=chunk_type) 

410 results = filter_results( 

411 results, self._config.max_distance, self._config.min_relevance_score 

412 ) 

413 if not results: 

414 return None 

415 results = prepare_results(results) 

416 if self._config.reranker_model: 

417 results = self._reranker.rerank(question, results) 

418 results = self._apply_temporal_filter(results, question) 

419 results = self.select_context(results, question) 

420 context = build_context(results) 

421 prompt = CONTEXT_TEMPLATE.format(context=context, question=question) 

422 messages: list[ChatMessage] = [ 

423 { 

424 "role": "system", 

425 "content": self._system_with_memory(self._config.rag_system_prompt, question), 

426 } 

427 ] 

428 if history: 

429 messages.extend(history) 

430 messages.append({"role": "user", "content": prompt}) 

431 return results, messages 

432 

433 def _system_with_memory(self, base_prompt: str, question: str) -> str: 

434 """Append the local-owner memory block to *base_prompt* when memory is enabled.""" 

435 block = self._memory_block(question) 

436 return f"{base_prompt}\n\n{block}" if block else base_prompt 

437 

438 def _memory_block(self, question: str) -> str: 

439 """Recall the local human's preferences and relevant facts as a system block. 

440 

441 Preferences are always included; facts are recalled by similarity. Empty 

442 when memory is disabled or nothing matches. MCP agents never reach this path 

443 (their tools recall explicitly under their own owner). 

444 """ 

445 if not self._config.memory_enabled: 

446 return "" 

447 owner_predicate = local_owner_predicate() 

448 preferences = self._store.get_memories( 

449 owner_predicate=owner_predicate, 

450 kind=MemoryKind.PREFERENCE, 

451 ) 

452 facts: list[MemoryRow] = [] 

453 if self._config.memory_top_k > 0 and self._embedder.embedding_available(): 

454 vector = self._embedder.embed(question) 

455 facts = self._store.search_memories( 

456 vector, 

457 owner_predicate=owner_predicate, 

458 top_k=self._config.memory_top_k, 

459 max_distance=self._config.memory_max_distance, 

460 ) 

461 return format_memory_block(preferences, facts, self._config.memory_token_budget) 

462 

463 def skip_retrieval(self) -> bool: 

464 """Whether this turn should bypass RAG: chat-only mode or no embedder.""" 

465 return ( 

466 self._config.chat_mode == ChatMode.CHAT.value 

467 or not self._embedder.embedding_available() 

468 ) 

469 

470 def direct_messages( 

471 self, question: str, history: list[ChatMessage] | None = None 

472 ) -> list[ChatMessage]: 

473 """Build messages for direct LLM chat (no RAG context).""" 

474 messages: list[ChatMessage] = [ 

475 { 

476 "role": "system", 

477 "content": self._system_with_memory(self._config.general_system_prompt, question), 

478 } 

479 ] 

480 if history: 

481 messages.extend(history) 

482 messages.append({"role": "user", "content": question}) 

483 return messages 

484 

485 def _messages_for_provider(self, messages: list[ChatMessage]) -> list[dict[str, str]]: 

486 """Convert ChatMessage list to provider-expected format.""" 

487 return [{"role": m["role"], "content": m["content"]} for m in messages] 

488 

489 def _direct_chat( 

490 self, 

491 question: str, 

492 history: list[ChatMessage] | None, 

493 options: dict[str, Any] | None, 

494 ) -> str: 

495 """Run a no-RAG chat turn and return the cleaned response.""" 

496 messages = self.direct_messages(question, history) 

497 provider_messages = self._messages_for_provider(messages) 

498 opts = options if options is not None else self._config.generation_options() 

499 raw = str(self._provider.chat(provider_messages, options=opts or None) or "") 

500 return raw if self._config.show_reasoning else strip_reasoning(raw) 

501 

502 def ask_raw( 

503 self, 

504 question: str, 

505 top_k: int = 0, 

506 history: list[ChatMessage] | None = None, 

507 options: dict[str, Any] | None = None, 

508 *, 

509 chunk_type: ChunkType | None = None, 

510 ) -> AskResult: 

511 """Ask a question. Skips retrieval when chat_mode is 'chat' or 

512 when no embedding model is configured.""" 

513 if self.skip_retrieval(): 

514 return AskResult(answer=self._direct_chat(question, history, options), sources=[]) 

515 rag = self.build_rag_context(question, top_k=top_k, history=history, chunk_type=chunk_type) 

516 if rag is None: 

517 return AskResult(answer=self._direct_chat(question, history, options), sources=[]) 

518 results, messages = rag 

519 provider_messages = self._messages_for_provider(messages) 

520 opts = options if options is not None else self._config.generation_options() 

521 raw = str(self._provider.chat(provider_messages, options=opts or None) or "") 

522 clean = raw if self._config.show_reasoning else strip_reasoning(raw) 

523 return AskResult(answer=clean, sources=results) 

524 

525 def ask( 

526 self, 

527 question: str, 

528 top_k: int = 0, 

529 history: list[ChatMessage] | None = None, 

530 options: dict[str, Any] | None = None, 

531 *, 

532 chunk_type: ChunkType | None = None, 

533 ) -> str: 

534 """Ask a question and get a formatted answer with citations.""" 

535 result = self.ask_raw( 

536 question, top_k=top_k, history=history, options=options, chunk_type=chunk_type 

537 ) 

538 if not result.sources: 

539 return result.answer 

540 cited = _extract_cited_indices(result.answer) 

541 used = [result.sources[i - 1] for i in sorted(cited) if 1 <= i <= len(result.sources)] 

542 answer = strip_llm_citations(result.answer) 

543 source_list = used if used else result.sources 

544 citations = deduplicate_sources(source_list) 

545 return f"{answer}\n\nSources:\n" + "\n".join(citations) 

546 

547 def _stream_direct( 

548 self, 

549 question: str, 

550 history: list[ChatMessage] | None, 

551 options: dict[str, Any] | None, 

552 ) -> Generator[StreamToken, None, None]: 

553 """Streaming branch with the general system prompt (no RAG context).""" 

554 messages = self.direct_messages(question, history) 

555 provider_messages = self._messages_for_provider(messages) 

556 opts = options if options is not None else self._config.generation_options() 

557 events = stream_chat_with_cap( 

558 self._provider, 

559 cast("list[dict[str, Any]]", provider_messages), 

560 options=opts, 

561 model=self._config.chat_model, 

562 show_reasoning=self._config.show_reasoning, 

563 cap_chars=effective_reasoning_cap(), 

564 ) 

565 try: 

566 yield from cap_events_as_stream_tokens(events) 

567 except (ConnectionError, OSError) as exc: 

568 yield StreamToken(content=f"\n\n[Connection lost: {exc}]", is_reasoning=False) 

569 

570 def ask_stream( 

571 self, 

572 question: str, 

573 top_k: int = 0, 

574 history: list[ChatMessage] | None = None, 

575 options: dict[str, Any] | None = None, 

576 *, 

577 chunk_type: ChunkType | None = None, 

578 ) -> Generator[StreamToken, None, None]: 

579 """Stream answer tokens with citations appended at the end.""" 

580 if self.skip_retrieval(): 

581 yield from self._stream_direct(question, history, options) 

582 return 

583 

584 rag = self.build_rag_context(question, top_k=top_k, history=history, chunk_type=chunk_type) 

585 if rag is None: 

586 yield from self._stream_direct(question, history, options) 

587 return 

588 results, messages = rag 

589 provider_messages = self._messages_for_provider(messages) 

590 opts = options if options is not None else self._config.generation_options() 

591 answer_parts: list[str] = [] 

592 events = stream_chat_with_cap( 

593 self._provider, 

594 cast("list[dict[str, Any]]", provider_messages), 

595 options=opts, 

596 model=self._config.chat_model, 

597 show_reasoning=self._config.show_reasoning, 

598 cap_chars=effective_reasoning_cap(), 

599 ) 

600 try: 

601 for token in cap_events_as_stream_tokens(events): 

602 if not token.is_reasoning: 

603 answer_parts.append(token.content) 

604 yield token 

605 except (ConnectionError, OSError) as exc: 

606 yield StreamToken(content=f"\n\n[Connection lost: {exc}]", is_reasoning=False) 

607 # LLM-generated citation blocks in streamed tokens cannot be 

608 # retroactively stripped. The system prompt discourages them; this 

609 # only filters the code-appended Sources block to cited chunks. 

610 full_answer = "".join(answer_parts) 

611 cited = _extract_cited_indices(full_answer) 

612 used = [results[i - 1] for i in sorted(cited) if 1 <= i <= len(results)] 

613 source_list = used if used else results 

614 citations = deduplicate_sources(source_list) 

615 yield StreamToken(content="\n\nSources:\n" + "\n".join(citations), is_reasoning=False)