Coverage for src / lilbee / server / handlers / rag.py: 100%
74 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
1"""Search, ask, and chat handlers (one-shot and streaming)."""
3from __future__ import annotations
5import asyncio
6import logging
7import threading
8from collections.abc import AsyncGenerator
9from typing import TYPE_CHECKING, Any, cast
11from lilbee.app.search import clean_result
12from lilbee.app.services import get_services
13from lilbee.core.config import cfg
14from lilbee.core.results import DocumentResult, group
15from lilbee.retrieval.reasoning import (
16 CAP_NOTICE_TEMPLATE,
17 CapNotice,
18 effective_reasoning_cap,
19 stream_chat_with_cap,
20)
21from lilbee.runtime.progress import SseEvent
22from lilbee.server.handlers.sse import (
23 SseStream,
24 _resolve_generation_options,
25 classify_load_error,
26 sse_done,
27 sse_error,
28 sse_event,
29)
30from lilbee.server.models import AskResponse, CleanedChunk
32if TYPE_CHECKING:
33 from lilbee.retrieval.query import ChatMessage
35log = logging.getLogger(__name__)
38async def search(q: str, top_k: int = 5, chunk_type: str | None = None) -> list[DocumentResult]:
39 """Search and return grouped DocumentResults."""
40 if not q or not q.strip():
41 raise ValueError("query must not be empty")
42 results = get_services().searcher.search(q, top_k=top_k, chunk_type=chunk_type)
43 results = [r for r in results if r.distance is None or r.distance <= cfg.max_distance]
44 return group(results)
47async def ask(
48 question: str,
49 top_k: int = 0,
50 options: dict[str, Any] | None = None,
51 chunk_type: str | None = None,
52) -> AskResponse:
53 """One-shot RAG answer. Returns answer and sources."""
54 if not question or not question.strip():
55 raise ValueError("question must not be empty")
56 opts = _resolve_generation_options(options)
57 result = get_services().searcher.ask_raw(
58 question, top_k=top_k, options=opts, chunk_type=chunk_type
59 )
60 return AskResponse(
61 answer=result.answer,
62 sources=[CleanedChunk(**clean_result(s)) for s in result.sources],
63 )
66def _run_llm_stream(
67 messages: list[ChatMessage],
68 opts: dict[str, Any] | None,
69 queue: asyncio.Queue[str | None],
70 cancel: threading.Event,
71 error_holder: list[str],
72) -> None:
73 """Forward tokens from the cap-aware chat orchestrator into the SSE queue."""
74 try:
75 events = stream_chat_with_cap(
76 get_services().provider,
77 cast("list[dict[str, Any]]", messages),
78 options=opts,
79 model=cfg.chat_model,
80 show_reasoning=cfg.show_reasoning,
81 cap_chars=effective_reasoning_cap(),
82 )
83 for event in events:
84 if cancel.is_set():
85 events.close()
86 break
87 if isinstance(event, CapNotice):
88 queue.put_nowait(
89 sse_event(
90 SseEvent.REASONING,
91 {"token": CAP_NOTICE_TEMPLATE.format(chars=event.cap_chars)},
92 )
93 )
94 elif event.content:
95 kind = SseEvent.REASONING if event.is_reasoning else SseEvent.TOKEN
96 queue.put_nowait(sse_event(kind, {"token": event.content}))
97 except Exception as exc:
98 error_holder.append(str(exc))
99 finally:
100 queue.put_nowait(None)
103async def _stream_rag_response(
104 question: str,
105 history: list[ChatMessage] | None = None,
106 top_k: int = 0,
107 options: dict[str, Any] | None = None,
108 chunk_type: str | None = None,
109) -> AsyncGenerator[str, None]:
110 """Shared SSE streaming for ask_stream and chat_stream."""
111 yield "" # force generator
113 rag = get_services().searcher.build_rag_context(
114 question, top_k=top_k, history=history, chunk_type=chunk_type
115 )
116 if rag is None:
117 yield sse_error("No relevant documents found.")
118 return
120 results, messages = rag
121 opts = _resolve_generation_options(options) or cfg.generation_options()
123 sse = SseStream()
124 error_holder: list[str] = []
126 executor_fut = sse.loop.run_in_executor(
127 None, _run_llm_stream, messages, opts, sse.queue, sse.cancel, error_holder
128 )
129 task = asyncio.ensure_future(executor_fut)
130 async for event in sse.drain(task, "RAG stream"):
131 yield event
133 if error_holder:
134 raw = error_holder[0]
135 code, user_message = classify_load_error(raw)
136 log.warning("Stream error: %s", raw)
137 yield sse_error(user_message, code=code, detail=raw if code else None)
138 sse.cancel.set()
139 return
141 # Ensure executor thread has finished before yielding final events
142 await executor_fut
144 yield sse_event(SseEvent.SOURCES, [clean_result(s) for s in results])
145 yield sse_done({})
148def ask_stream(
149 question: str,
150 top_k: int = 0,
151 options: dict[str, Any] | None = None,
152 chunk_type: str | None = None,
153) -> AsyncGenerator[str, None]:
154 """Yield SSE events: token, sources, done."""
155 return _stream_rag_response(question, top_k=top_k, options=options, chunk_type=chunk_type)
158async def chat(
159 question: str,
160 history: list[ChatMessage],
161 top_k: int = 0,
162 options: dict[str, Any] | None = None,
163 chunk_type: str | None = None,
164) -> AskResponse:
165 """Chat with history. Returns answer and sources."""
166 opts = _resolve_generation_options(options)
167 result = get_services().searcher.ask_raw(
168 question, top_k=top_k, history=history, options=opts, chunk_type=chunk_type
169 )
170 return AskResponse(
171 answer=result.answer,
172 sources=[CleanedChunk(**clean_result(s)) for s in result.sources],
173 )
176def chat_stream(
177 question: str,
178 history: list[ChatMessage],
179 top_k: int = 0,
180 options: dict[str, Any] | None = None,
181 chunk_type: str | None = None,
182) -> AsyncGenerator[str, None]:
183 """Yield SSE events with chat history support."""
184 return _stream_rag_response(
185 question, history=history, top_k=top_k, options=options, chunk_type=chunk_type
186 )