Coverage for src / lilbee / server / handlers / rag.py: 100%
105 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-28 01:01 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-28 01:01 +0000
1"""Search, ask, and chat handlers (one-shot and streaming)."""
3from __future__ import annotations
5import asyncio
6import logging
7import threading
8from collections.abc import AsyncGenerator
9from typing import TYPE_CHECKING, Any, cast
11from lilbee.app.memory import auto_extract, auto_extract_enabled
12from lilbee.app.search import clean_result
13from lilbee.app.services import get_services
14from lilbee.core.config import cfg
15from lilbee.core.results import DocumentResult, group
16from lilbee.data.store import ChunkType, EmbeddingModelMismatchError, SearchChunk
17from lilbee.providers.base import ProviderError, ProviderErrorKind
18from lilbee.retrieval.reasoning import (
19 CAP_NOTICE_TEMPLATE,
20 CapNotice,
21 effective_reasoning_cap,
22 stream_chat_with_cap,
23)
24from lilbee.runtime.progress import SseErrorCode, SseEvent
25from lilbee.server.handlers.sse import (
26 SseStream,
27 _resolve_generation_options,
28 classify_load_error,
29 sse_done,
30 sse_error,
31 sse_event,
32)
33from lilbee.server.models import (
34 AskResponse,
35 CleanedChunk,
36 MemoryExtractedEvent,
37 MemoryExtractedItem,
38)
40if TYPE_CHECKING:
41 from lilbee.retrieval.query import ChatMessage
43log = logging.getLogger(__name__)
46async def search(
47 q: str, top_k: int = 5, chunk_type: ChunkType | None = None
48) -> list[DocumentResult]:
49 """Search and return grouped DocumentResults."""
50 if not q or not q.strip():
51 raise ValueError("query must not be empty")
52 results = get_services().searcher.search(q, top_k=top_k, chunk_type=chunk_type)
53 results = [r for r in results if r.distance is None or r.distance <= cfg.max_distance]
54 return group(results)
57async def ask(
58 question: str,
59 top_k: int = 0,
60 options: dict[str, Any] | None = None,
61 chunk_type: ChunkType | None = None,
62) -> AskResponse:
63 """One-shot RAG answer. Returns answer and sources."""
64 if not question or not question.strip():
65 raise ValueError("question must not be empty")
66 opts = _resolve_generation_options(options)
67 result = get_services().searcher.ask_raw(
68 question, top_k=top_k, options=opts, chunk_type=chunk_type
69 )
70 return AskResponse(
71 answer=result.answer,
72 sources=[CleanedChunk(**clean_result(s)) for s in result.sources],
73 )
76def _run_llm_stream(
77 messages: list[ChatMessage],
78 opts: dict[str, Any] | None,
79 queue: asyncio.Queue[str | None],
80 cancel: threading.Event,
81 error_holder: list[Exception],
82 answer_parts: list[str],
83) -> None:
84 """Forward tokens from the cap-aware chat orchestrator into the SSE queue.
86 Answer tokens (not reasoning) are also accumulated into *answer_parts* so the
87 caller can feed the finished answer to auto-extraction.
88 """
89 try:
90 events = stream_chat_with_cap(
91 get_services().provider,
92 cast("list[dict[str, Any]]", messages),
93 options=opts,
94 model=cfg.chat_model,
95 show_reasoning=cfg.show_reasoning,
96 cap_chars=effective_reasoning_cap(),
97 )
98 for event in events:
99 if cancel.is_set():
100 events.close()
101 break
102 if isinstance(event, CapNotice):
103 queue.put_nowait(
104 sse_event(
105 SseEvent.REASONING,
106 {"token": CAP_NOTICE_TEMPLATE.format(chars=event.cap_chars)},
107 )
108 )
109 elif event.content:
110 kind = SseEvent.REASONING if event.is_reasoning else SseEvent.TOKEN
111 if kind is SseEvent.TOKEN:
112 answer_parts.append(event.content)
113 queue.put_nowait(sse_event(kind, {"token": event.content}))
114 except Exception as exc:
115 error_holder.append(exc)
116 finally:
117 queue.put_nowait(None)
120async def _emit_extracted_memories(question: str, answer: str) -> AsyncGenerator[str, None]:
121 """Yield a ``memory_extracted`` SSE event if the turn auto-saved any memories.
123 Runs the extraction LLM pass off the event loop. Silent (yields nothing)
124 when the answer is empty, auto-extraction is off, or nothing was extracted,
125 so existing consumers are unaffected.
126 """
127 if not answer or not auto_extract_enabled():
128 return
129 stored = await asyncio.to_thread(auto_extract, question, answer)
130 if not stored:
131 return
132 event = MemoryExtractedEvent(
133 count=len(stored),
134 items=[MemoryExtractedItem(id=m.id, kind=m.kind, text=m.text) for m in stored],
135 )
136 yield sse_event(SseEvent.MEMORY_EXTRACTED, event.model_dump(mode="json"))
139def _error_event(exc: Exception) -> str:
140 """Build the SSE error event for a stream failure.
142 Provider errors already carry a user-facing message (rate limits, auth,
143 bad model), so surface it verbatim. Everything else goes through the
144 llama.cpp OOM classifier and otherwise collapses to a generic message.
145 """
146 if isinstance(exc, ProviderError):
147 log.warning("Provider error during stream: %s", exc)
148 kind_code = exc.kind if exc.kind is not ProviderErrorKind.UNKNOWN else None
149 return sse_error(str(exc), code=kind_code)
150 raw = str(exc)
151 code, user_message = classify_load_error(raw)
152 log.warning("Stream error: %s", raw)
153 return sse_error(user_message, code=code, detail=raw if code else None)
156async def _stream_rag_response(
157 question: str,
158 history: list[ChatMessage] | None = None,
159 top_k: int = 0,
160 options: dict[str, Any] | None = None,
161 chunk_type: ChunkType | None = None,
162 *,
163 honor_chat_mode: bool = False,
164) -> AsyncGenerator[str, None]:
165 """Shared SSE streaming for ask_stream and chat_stream.
167 With ``honor_chat_mode`` (chat_stream), a chat-only mode or a missing embedder
168 streams a direct answer with no retrieval and no sources.
169 """
170 yield "" # force generator
172 searcher = get_services().searcher
173 results: list[SearchChunk] = []
174 if honor_chat_mode and searcher.skip_retrieval():
175 messages = searcher.direct_messages(question, history)
176 else:
177 try:
178 rag = searcher.build_rag_context(
179 question, top_k=top_k, history=history, chunk_type=chunk_type
180 )
181 except EmbeddingModelMismatchError as exc:
182 # detail carries the index's embedder so the client can offer to adopt it.
183 detail = exc.persisted_model if exc.dims_match else None
184 yield sse_error(str(exc), code=SseErrorCode.INDEX_EMBEDDER_MISMATCH, detail=detail)
185 return
186 if rag is None:
187 yield sse_error("No relevant documents found.")
188 return
189 results, messages = rag
191 opts = _resolve_generation_options(options) or cfg.generation_options()
193 sse = SseStream()
194 error_holder: list[Exception] = []
195 answer_parts: list[str] = []
197 executor_fut = sse.loop.run_in_executor(
198 None, _run_llm_stream, messages, opts, sse.queue, sse.cancel, error_holder, answer_parts
199 )
200 task = asyncio.ensure_future(executor_fut)
201 async for event in sse.drain(task, "RAG stream"):
202 yield event
204 if error_holder:
205 yield _error_event(error_holder[0])
206 sse.cancel.set()
207 return
209 # Ensure executor thread has finished before yielding final events
210 await executor_fut
212 yield sse_event(SseEvent.SOURCES, [clean_result(s) for s in results])
213 yield sse_done({})
215 # Auto-extraction (and its notification) trails ``done`` so clients that stop
216 # at ``done`` are unaffected; the memories are stored regardless.
217 async for event in _emit_extracted_memories(question, "".join(answer_parts)):
218 yield event
221def ask_stream(
222 question: str,
223 top_k: int = 0,
224 options: dict[str, Any] | None = None,
225 chunk_type: ChunkType | None = None,
226) -> AsyncGenerator[str, None]:
227 """Yield SSE events: token, sources, done."""
228 return _stream_rag_response(question, top_k=top_k, options=options, chunk_type=chunk_type)
231async def chat(
232 question: str,
233 history: list[ChatMessage],
234 top_k: int = 0,
235 options: dict[str, Any] | None = None,
236 chunk_type: ChunkType | None = None,
237) -> AskResponse:
238 """Chat with history. Returns answer and sources."""
239 opts = _resolve_generation_options(options)
240 result = get_services().searcher.ask_raw(
241 question, top_k=top_k, history=history, options=opts, chunk_type=chunk_type
242 )
243 return AskResponse(
244 answer=result.answer,
245 sources=[CleanedChunk(**clean_result(s)) for s in result.sources],
246 )
249def chat_stream(
250 question: str,
251 history: list[ChatMessage],
252 top_k: int = 0,
253 options: dict[str, Any] | None = None,
254 chunk_type: ChunkType | None = None,
255) -> AsyncGenerator[str, None]:
256 """Yield SSE events with chat history support."""
257 return _stream_rag_response(
258 question,
259 history=history,
260 top_k=top_k,
261 options=options,
262 chunk_type=chunk_type,
263 honor_chat_mode=True,
264 )