Coverage for src / lilbee / retrieval / query / searcher.py: 100%

298 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-15 20:55 +0000

1"""RAG search pipeline -- embed, search, expand, rerank, generate.""" 

2 

3from __future__ import annotations 

4 

5import logging 

6from collections.abc import Generator 

7from datetime import datetime 

8from typing import TYPE_CHECKING, Any, cast 

9 

10from pydantic import BaseModel 

11from typing_extensions import TypedDict 

12 

13from lilbee.core.config import Config 

14from lilbee.core.config.enums import ChatMode 

15from lilbee.data.store import ( 

16 CHUNK_TYPE_RAW, 

17 CHUNK_TYPE_WIKI, 

18 SearchChunk, 

19 Store, 

20 cosine_sim, 

21) 

22from lilbee.providers.base import LLMProvider 

23from lilbee.retrieval.embedder import Embedder 

24from lilbee.retrieval.query.dedup import ( 

25 _greedy_cover, 

26 _relevance_weight, 

27 deduplicate_sources, 

28 filter_results, 

29 prepare_results, 

30) 

31from lilbee.retrieval.query.expansion import EXPANSION_MAX_TOKENS, EXPANSION_PROMPT 

32from lilbee.retrieval.query.formatting import ( 

33 CONTEXT_TEMPLATE, 

34 _extract_cited_indices, 

35 build_context, 

36 strip_llm_citations, 

37) 

38from lilbee.retrieval.query.tokenize import _idf_weights, _tokenize 

39from lilbee.retrieval.reasoning import ( 

40 StreamToken, 

41 cap_events_as_stream_tokens, 

42 effective_reasoning_cap, 

43 stream_chat_with_cap, 

44 strip_reasoning, 

45) 

46 

47if TYPE_CHECKING: 

48 from lilbee.retrieval.concepts import ConceptGraph 

49 from lilbee.retrieval.reranker import Reranker 

50 

51log = logging.getLogger(__name__) 

52 

53# BM25 probe needs at least this many hits to compare top vs. runner-up 

54# scores for the expansion-skip heuristic. 

55_MIN_BM25_PROBE_RESULTS = 2 

56 

57 

58class ChatMessage(TypedDict): 

59 """A single chat message with role and content.""" 

60 

61 role: str 

62 content: str 

63 

64 

65class AskResult(BaseModel): 

66 """Structured result from ask_raw -- answer text + raw search results.""" 

67 

68 answer: str 

69 sources: list[SearchChunk] 

70 

71 

72class Searcher: 

73 """RAG search pipeline -- embed, search, expand, rerank, generate. 

74 All search and answer operations go through this class. 

75 Constructed with injected dependencies via the Services container. 

76 """ 

77 

78 def __init__( 

79 self, 

80 config: Config, 

81 provider: LLMProvider, 

82 store: Store, 

83 embedder: Embedder, 

84 reranker: Reranker, 

85 concepts: ConceptGraph, 

86 ) -> None: 

87 self._config = config 

88 self._provider = provider 

89 self._store = store 

90 self._embedder = embedder 

91 self._reranker = reranker 

92 self._concepts = concepts 

93 

94 def _apply_temporal_filter( 

95 self, results: list[SearchChunk], question: str 

96 ) -> list[SearchChunk]: 

97 if not self._config.temporal_filtering: 

98 return results 

99 from lilbee.runtime.temporal import detect_temporal, resolve_date_range 

100 

101 keyword = detect_temporal(question) 

102 if keyword is None: 

103 return results 

104 date_range = resolve_date_range(keyword) 

105 source_dates = self._store.source_ingested_at_map() 

106 filtered: list[SearchChunk] = [] 

107 for r in results: 

108 ingested_at = source_dates.get(r.source, "") 

109 if not ingested_at: 

110 filtered.append(r) 

111 continue 

112 try: 

113 doc_date = datetime.fromisoformat(ingested_at) 

114 if date_range.start <= doc_date <= date_range.end: 

115 filtered.append(r) 

116 except (ValueError, TypeError): 

117 filtered.append(r) 

118 return filtered if filtered else results 

119 

120 def _apply_guardrails( 

121 self, 

122 variants: list[tuple[str, list[float]]], 

123 question_vec: list[float], 

124 ) -> list[tuple[str, list[float]]]: 

125 """Drop expansion variants whose embedding drifts too far from the question.""" 

126 if not self._config.expansion_guardrails: 

127 return variants 

128 threshold = self._config.expansion_similarity_threshold 

129 return [(text, vec) for text, vec in variants if cosine_sim(question_vec, vec) >= threshold] 

130 

131 def _concept_query_expansion(self, question: str) -> list[str]: 

132 if not self._config.concept_graph: 

133 return [] 

134 try: 

135 if not self._concepts.get_graph(): 

136 return [] 

137 return self._concepts.expand_query(question) 

138 except Exception: 

139 log.debug("Concept query expansion failed", exc_info=True) 

140 return [] 

141 

142 def _llm_expand(self, question: str, count: int) -> list[str]: 

143 """Call the LLM to produce ``count`` alternative phrasings.""" 

144 prompt = EXPANSION_PROMPT.format(count=count, question=question) 

145 messages = [{"role": "user", "content": prompt}] 

146 response = self._provider.chat( 

147 messages, stream=False, options={"num_predict": EXPANSION_MAX_TOKENS} 

148 ) 

149 if not isinstance(response, str): 

150 return [] 

151 variants = [line.strip() for line in response.strip().split("\n") if line.strip()] 

152 return variants[:count] 

153 

154 def _expand_query( 

155 self, question: str, question_vec: list[float] 

156 ) -> list[tuple[str, list[float]]]: 

157 """Return ``(variant, variant_vec)`` pairs for downstream search. 

158 

159 LLM variants run through ``_apply_guardrails``; concept-graph 

160 variants bypass it since they come from deterministic traversal. 

161 Embeddings batch per source: one provider round-trip per source. 

162 """ 

163 count = self._config.query_expansion_count 

164 if count <= 0 and not self._config.concept_graph: 

165 return [] 

166 # Short queries skip LLM expansion: BM25/vector signal is already strong 

167 # and the LLM round-trip dominates latency on small local models. 

168 # Concept-graph expansion still runs. 

169 short_threshold = self._config.expansion_short_query_tokens 

170 skip_llm = short_threshold > 0 and len(_tokenize(question)) <= short_threshold 

171 try: 

172 llm_variants: list[tuple[str, list[float]]] = [] 

173 if count > 0 and not skip_llm: 

174 llm_texts = list(self._llm_expand(question, count)) 

175 if llm_texts: 

176 llm_vectors = self._embedder.embed_batch(llm_texts) 

177 llm_variants = list(zip(llm_texts, llm_vectors, strict=True)) 

178 llm_variants = self._apply_guardrails(llm_variants, question_vec) 

179 

180 concept_texts = list(self._concept_query_expansion(question)) 

181 if concept_texts: 

182 concept_vectors = self._embedder.embed_batch(concept_texts) 

183 llm_variants.extend(zip(concept_texts, concept_vectors, strict=True)) 

184 

185 return llm_variants 

186 except Exception as exc: 

187 log.warning("Query expansion disabled for this call: %s", exc) 

188 log.debug("Query expansion exception", exc_info=True) 

189 return [] 

190 

191 def _should_skip_expansion(self, question: str) -> bool: 

192 if self._config.expansion_skip_threshold <= 0: 

193 return False 

194 results = self._store.bm25_probe(question, top_k=2) 

195 if not results: 

196 return False 

197 top_score = results[0].relevance_score or 0 

198 if top_score < self._config.expansion_skip_threshold: 

199 return False 

200 if len(results) < _MIN_BM25_PROBE_RESULTS: 

201 return True 

202 second_score = results[1].relevance_score or 0 

203 return (top_score - second_score) >= self._config.expansion_skip_gap 

204 

205 def _apply_concept_boost(self, results: list[SearchChunk], question: str) -> list[SearchChunk]: 

206 if not self._config.concept_graph or not results: 

207 return results 

208 try: 

209 if not self._concepts.get_graph(): 

210 return results 

211 query_concepts = self._concepts.extract_concepts(question) 

212 if not query_concepts: 

213 return results 

214 return self._concepts.boost_results(results, query_concepts) 

215 except Exception: 

216 log.debug("Concept boost failed", exc_info=True) 

217 return results 

218 

219 def _hyde_search(self, question: str, top_k: int) -> list[SearchChunk]: 

220 """Hypothetical Document Embedding search. 

221 Gao et al. 2022, "Precise Zero-Shot Dense Retrieval without 

222 Relevance Labels" -- generates a hypothetical answer passage, 

223 embeds it, and uses the embedding to search for real documents. 

224 """ 

225 try: 

226 response = self._provider.chat( 

227 [{"role": "user", "content": self._config.hyde_prompt.format(question=question)}], 

228 stream=False, 

229 options={"num_predict": EXPANSION_MAX_TOKENS}, 

230 ) 

231 if not isinstance(response, str) or not response.strip(): 

232 return [] 

233 hyde_vec = self._embedder.embed(response.strip()) 

234 return self._store.search(hyde_vec, top_k=top_k, query_text=None) 

235 except Exception: 

236 log.debug("HyDE search failed", exc_info=True) 

237 return [] 

238 

239 def _normalize_chunk_type(self, chunk_type: str | None) -> str | None: 

240 """Drop ``chunk_type="wiki"`` when wiki generation is disabled. 

241 

242 With wiki off the chunks table contains only raw rows, so the 

243 filter would return empty. Logging once keeps the surprise out 

244 of the user's way while surfacing the misuse in logs. 

245 """ 

246 if chunk_type == CHUNK_TYPE_WIKI and not self._config.wiki: 

247 log.warning( 

248 "wiki scope requested but wiki is disabled; searching the full pool instead" 

249 ) 

250 return None 

251 return chunk_type 

252 

253 def _parse_structured_query(self, question: str) -> tuple[str | None, str]: 

254 for prefix in ("term:", "vec:", "hyde:", "wiki:", "raw:"): 

255 if question.strip().lower().startswith(prefix): 

256 return prefix[:-1], question.strip()[len(prefix) :].strip() 

257 return None, question 

258 

259 def _search_structured( 

260 self, 

261 mode: str, 

262 query: str, 

263 top_k: int, 

264 chunk_type: str | None = None, 

265 ) -> list[SearchChunk]: 

266 if mode == "term": 

267 return self._store.bm25_probe(query, top_k=top_k) 

268 if mode == "vec": 

269 query_vec = self._embedder.embed(query) 

270 return self._store.search(query_vec, top_k=top_k, query_text=None) 

271 if mode == "hyde": 

272 return self._hyde_search(query, top_k) 

273 if mode in (CHUNK_TYPE_WIKI, CHUNK_TYPE_RAW): 

274 # Explicit ``chunk_type`` arg beats the prefix shortcut. 

275 effective = chunk_type if chunk_type is not None else mode 

276 query_vec = self._embedder.embed(query) 

277 return self._store.search( 

278 query_vec, top_k=top_k, query_text=query, chunk_type=effective 

279 ) 

280 return [] 

281 

282 def select_context( 

283 self, results: list[SearchChunk], question: str, max_sources: int | None = None 

284 ) -> list[SearchChunk]: 

285 """Pick ``max_sources`` chunks by greedy IDF-weighted set cover.""" 

286 if max_sources is None: 

287 max_sources = self._config.max_context_sources 

288 if len(results) <= max_sources: 

289 return results 

290 

291 question_terms = set(_tokenize(question)) 

292 if not question_terms: 

293 return results[:max_sources] 

294 

295 chunk_tokens = [set(_tokenize(r.chunk)) for r in results] 

296 term_weights = _idf_weights(question_terms, chunk_tokens) 

297 if not any(term_weights.values()): 

298 return results[:max_sources] 

299 

300 weights = [_relevance_weight(r) for r in results] 

301 selected = _greedy_cover(chunk_tokens, question_terms, term_weights, max_sources, weights) 

302 selected.sort() 

303 return [results[i] for i in selected] 

304 

305 def _merge_variant_results( 

306 self, 

307 question: str, 

308 query_vec: list[float], 

309 results: list[SearchChunk], 

310 seen: set[tuple[str, int]], 

311 top_k: int, 

312 chunk_type: str | None, 

313 ) -> None: 

314 """Append unseen variant-search hits to ``results`` (in place).""" 

315 for variant, variant_vec in self._expand_query(question, query_vec): 

316 variant_results = self._store.search( 

317 variant_vec, 

318 top_k=top_k, 

319 query_text=variant, 

320 chunk_type=chunk_type, 

321 ) 

322 for r in variant_results: 

323 key = (r.source, r.chunk_index) 

324 if key not in seen: 

325 results.append(r) 

326 seen.add(key) 

327 

328 def _merge_hyde_results( 

329 self, 

330 question: str, 

331 results: list[SearchChunk], 

332 seen: set[tuple[str, int]], 

333 top_k: int, 

334 ) -> None: 

335 """Append unseen HyDE hits to ``results`` (in place), reweighted by ``hyde_weight``.""" 

336 for r in self._hyde_search(question, top_k): 

337 key = (r.source, r.chunk_index) 

338 if key in seen: 

339 continue 

340 if r.distance is not None and self._config.hyde_weight > 0: 

341 r = r.model_copy(update={"distance": r.distance / self._config.hyde_weight}) 

342 results.append(r) 

343 seen.add(key) 

344 

345 def search( 

346 self, 

347 question: str, 

348 top_k: int = 0, 

349 *, 

350 chunk_type: str | None = None, 

351 ) -> list[SearchChunk]: 

352 """Embed question and search with expansion, HyDE, and concept boost. 

353 Returns up to top_k*2 candidates for downstream filtering. 

354 

355 When *chunk_type* is set (``"raw"`` or ``"wiki"``), only chunks of 

356 that type are returned. An explicit ``chunk_type`` always wins 

357 over the ``wiki:``/``raw:`` prefix shortcut in *question* so the 

358 user-facing scope choice has the final say. 

359 

360 When ``chunk_type="wiki"`` but wiki generation is disabled on the 

361 config, the filter is normalized to ``None`` (mixed pool) and a 

362 warning is logged: with wiki off the chunks table has no wiki rows, 

363 so honouring the filter would silently return zero results. 

364 """ 

365 if top_k == 0: 

366 top_k = self._config.top_k 

367 chunk_type = self._normalize_chunk_type(chunk_type) 

368 mode, clean_query = self._parse_structured_query(question) 

369 if mode is not None: 

370 return self._search_structured(mode, clean_query, top_k, chunk_type=chunk_type) 

371 query_vec = self._embedder.embed(question) 

372 results = self._store.search( 

373 query_vec, 

374 top_k=top_k, 

375 query_text=question, 

376 chunk_type=chunk_type, 

377 ) 

378 if self._should_skip_expansion(question): 

379 return results[: top_k * 2] 

380 seen = {(r.source, r.chunk_index) for r in results} 

381 self._merge_variant_results(question, query_vec, results, seen, top_k, chunk_type) 

382 if self._config.hyde: 

383 self._merge_hyde_results(question, results, seen, top_k) 

384 results = self._apply_concept_boost(results, question) 

385 return results[: top_k * 2] 

386 

387 def build_rag_context( 

388 self, 

389 question: str, 

390 top_k: int = 0, 

391 history: list[ChatMessage] | None = None, 

392 *, 

393 chunk_type: str | None = None, 

394 ) -> tuple[list[SearchChunk], list[ChatMessage]] | None: 

395 """Build RAG context from search results. 

396 

397 ``chunk_type`` restricts the pool to ``"raw"`` or ``"wiki"`` rows; 

398 ``None`` (default) searches the mixed pool. 

399 """ 

400 results = self.search(question, top_k=top_k, chunk_type=chunk_type) 

401 results = filter_results( 

402 results, self._config.max_distance, self._config.min_relevance_score 

403 ) 

404 if not results: 

405 return None 

406 results = prepare_results(results) 

407 if self._config.reranker_model: 

408 results = self._reranker.rerank(question, results) 

409 results = self._apply_temporal_filter(results, question) 

410 results = self.select_context(results, question) 

411 context = build_context(results) 

412 prompt = CONTEXT_TEMPLATE.format(context=context, question=question) 

413 messages: list[ChatMessage] = [ 

414 {"role": "system", "content": self._config.rag_system_prompt} 

415 ] 

416 if history: 

417 messages.extend(history) 

418 messages.append({"role": "user", "content": prompt}) 

419 return results, messages 

420 

421 def _direct_messages( 

422 self, question: str, history: list[ChatMessage] | None = None 

423 ) -> list[ChatMessage]: 

424 """Build messages for direct LLM chat (no RAG context).""" 

425 messages: list[ChatMessage] = [ 

426 {"role": "system", "content": self._config.general_system_prompt} 

427 ] 

428 if history: 

429 messages.extend(history) 

430 messages.append({"role": "user", "content": question}) 

431 return messages 

432 

433 def _messages_for_provider(self, messages: list[ChatMessage]) -> list[dict[str, str]]: 

434 """Convert ChatMessage list to provider-expected format.""" 

435 return [{"role": m["role"], "content": m["content"]} for m in messages] 

436 

437 def _direct_chat( 

438 self, 

439 question: str, 

440 history: list[ChatMessage] | None, 

441 options: dict[str, Any] | None, 

442 ) -> str: 

443 """Run a no-RAG chat turn and return the cleaned response.""" 

444 messages = self._direct_messages(question, history) 

445 provider_messages = self._messages_for_provider(messages) 

446 opts = options if options is not None else self._config.generation_options() 

447 raw = str(self._provider.chat(provider_messages, options=opts or None) or "") 

448 return raw if self._config.show_reasoning else strip_reasoning(raw) 

449 

450 def ask_raw( 

451 self, 

452 question: str, 

453 top_k: int = 0, 

454 history: list[ChatMessage] | None = None, 

455 options: dict[str, Any] | None = None, 

456 *, 

457 chunk_type: str | None = None, 

458 ) -> AskResult: 

459 """Ask a question. Skips retrieval when chat_mode is 'chat' or 

460 when no embedding model is configured.""" 

461 if ( 

462 self._config.chat_mode == ChatMode.CHAT.value 

463 or not self._embedder.embedding_available() 

464 ): 

465 return AskResult(answer=self._direct_chat(question, history, options), sources=[]) 

466 rag = self.build_rag_context(question, top_k=top_k, history=history, chunk_type=chunk_type) 

467 if rag is None: 

468 return AskResult(answer=self._direct_chat(question, history, options), sources=[]) 

469 results, messages = rag 

470 provider_messages = self._messages_for_provider(messages) 

471 opts = options if options is not None else self._config.generation_options() 

472 raw = str(self._provider.chat(provider_messages, options=opts or None) or "") 

473 clean = raw if self._config.show_reasoning else strip_reasoning(raw) 

474 return AskResult(answer=clean, sources=results) 

475 

476 def ask( 

477 self, 

478 question: str, 

479 top_k: int = 0, 

480 history: list[ChatMessage] | None = None, 

481 options: dict[str, Any] | None = None, 

482 *, 

483 chunk_type: str | None = None, 

484 ) -> str: 

485 """Ask a question and get a formatted answer with citations.""" 

486 result = self.ask_raw( 

487 question, top_k=top_k, history=history, options=options, chunk_type=chunk_type 

488 ) 

489 if not result.sources: 

490 return result.answer 

491 cited = _extract_cited_indices(result.answer) 

492 used = [result.sources[i - 1] for i in sorted(cited) if 1 <= i <= len(result.sources)] 

493 answer = strip_llm_citations(result.answer) 

494 source_list = used if used else result.sources 

495 citations = deduplicate_sources(source_list) 

496 return f"{answer}\n\nSources:\n" + "\n".join(citations) 

497 

498 def _stream_direct( 

499 self, 

500 question: str, 

501 history: list[ChatMessage] | None, 

502 options: dict[str, Any] | None, 

503 ) -> Generator[StreamToken, None, None]: 

504 """Streaming branch with the general system prompt (no RAG context).""" 

505 messages = self._direct_messages(question, history) 

506 provider_messages = self._messages_for_provider(messages) 

507 opts = options if options is not None else self._config.generation_options() 

508 events = stream_chat_with_cap( 

509 self._provider, 

510 cast("list[dict[str, Any]]", provider_messages), 

511 options=opts, 

512 model=self._config.chat_model, 

513 show_reasoning=self._config.show_reasoning, 

514 cap_chars=effective_reasoning_cap(), 

515 ) 

516 try: 

517 yield from cap_events_as_stream_tokens(events) 

518 except (ConnectionError, OSError) as exc: 

519 yield StreamToken(content=f"\n\n[Connection lost: {exc}]", is_reasoning=False) 

520 

521 def ask_stream( 

522 self, 

523 question: str, 

524 top_k: int = 0, 

525 history: list[ChatMessage] | None = None, 

526 options: dict[str, Any] | None = None, 

527 *, 

528 chunk_type: str | None = None, 

529 ) -> Generator[StreamToken, None, None]: 

530 """Stream answer tokens with citations appended at the end.""" 

531 if ( 

532 self._config.chat_mode == ChatMode.CHAT.value 

533 or not self._embedder.embedding_available() 

534 ): 

535 yield from self._stream_direct(question, history, options) 

536 return 

537 

538 rag = self.build_rag_context(question, top_k=top_k, history=history, chunk_type=chunk_type) 

539 if rag is None: 

540 yield from self._stream_direct(question, history, options) 

541 return 

542 results, messages = rag 

543 provider_messages = self._messages_for_provider(messages) 

544 opts = options if options is not None else self._config.generation_options() 

545 answer_parts: list[str] = [] 

546 events = stream_chat_with_cap( 

547 self._provider, 

548 cast("list[dict[str, Any]]", provider_messages), 

549 options=opts, 

550 model=self._config.chat_model, 

551 show_reasoning=self._config.show_reasoning, 

552 cap_chars=effective_reasoning_cap(), 

553 ) 

554 try: 

555 for token in cap_events_as_stream_tokens(events): 

556 if not token.is_reasoning: 

557 answer_parts.append(token.content) 

558 yield token 

559 except (ConnectionError, OSError) as exc: 

560 yield StreamToken(content=f"\n\n[Connection lost: {exc}]", is_reasoning=False) 

561 # LLM-generated citation blocks in streamed tokens cannot be 

562 # retroactively stripped. The system prompt discourages them; this 

563 # only filters the code-appended Sources block to cited chunks. 

564 full_answer = "".join(answer_parts) 

565 cited = _extract_cited_indices(full_answer) 

566 used = [results[i - 1] for i in sorted(cited) if 1 <= i <= len(results)] 

567 source_list = used if used else results 

568 citations = deduplicate_sources(source_list) 

569 yield StreamToken(content="\n\nSources:\n" + "\n".join(citations), is_reasoning=False)