Coverage for src / lilbee / server / handlers / rag.py: 100%

105 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-06-28 01:01 +0000

1"""Search, ask, and chat handlers (one-shot and streaming).""" 

2 

3from __future__ import annotations 

4 

5import asyncio 

6import logging 

7import threading 

8from collections.abc import AsyncGenerator 

9from typing import TYPE_CHECKING, Any, cast 

10 

11from lilbee.app.memory import auto_extract, auto_extract_enabled 

12from lilbee.app.search import clean_result 

13from lilbee.app.services import get_services 

14from lilbee.core.config import cfg 

15from lilbee.core.results import DocumentResult, group 

16from lilbee.data.store import ChunkType, EmbeddingModelMismatchError, SearchChunk 

17from lilbee.providers.base import ProviderError, ProviderErrorKind 

18from lilbee.retrieval.reasoning import ( 

19 CAP_NOTICE_TEMPLATE, 

20 CapNotice, 

21 effective_reasoning_cap, 

22 stream_chat_with_cap, 

23) 

24from lilbee.runtime.progress import SseErrorCode, SseEvent 

25from lilbee.server.handlers.sse import ( 

26 SseStream, 

27 _resolve_generation_options, 

28 classify_load_error, 

29 sse_done, 

30 sse_error, 

31 sse_event, 

32) 

33from lilbee.server.models import ( 

34 AskResponse, 

35 CleanedChunk, 

36 MemoryExtractedEvent, 

37 MemoryExtractedItem, 

38) 

39 

40if TYPE_CHECKING: 

41 from lilbee.retrieval.query import ChatMessage 

42 

43log = logging.getLogger(__name__) 

44 

45 

46async def search( 

47 q: str, top_k: int = 5, chunk_type: ChunkType | None = None 

48) -> list[DocumentResult]: 

49 """Search and return grouped DocumentResults.""" 

50 if not q or not q.strip(): 

51 raise ValueError("query must not be empty") 

52 results = get_services().searcher.search(q, top_k=top_k, chunk_type=chunk_type) 

53 results = [r for r in results if r.distance is None or r.distance <= cfg.max_distance] 

54 return group(results) 

55 

56 

57async def ask( 

58 question: str, 

59 top_k: int = 0, 

60 options: dict[str, Any] | None = None, 

61 chunk_type: ChunkType | None = None, 

62) -> AskResponse: 

63 """One-shot RAG answer. Returns answer and sources.""" 

64 if not question or not question.strip(): 

65 raise ValueError("question must not be empty") 

66 opts = _resolve_generation_options(options) 

67 result = get_services().searcher.ask_raw( 

68 question, top_k=top_k, options=opts, chunk_type=chunk_type 

69 ) 

70 return AskResponse( 

71 answer=result.answer, 

72 sources=[CleanedChunk(**clean_result(s)) for s in result.sources], 

73 ) 

74 

75 

76def _run_llm_stream( 

77 messages: list[ChatMessage], 

78 opts: dict[str, Any] | None, 

79 queue: asyncio.Queue[str | None], 

80 cancel: threading.Event, 

81 error_holder: list[Exception], 

82 answer_parts: list[str], 

83) -> None: 

84 """Forward tokens from the cap-aware chat orchestrator into the SSE queue. 

85 

86 Answer tokens (not reasoning) are also accumulated into *answer_parts* so the 

87 caller can feed the finished answer to auto-extraction. 

88 """ 

89 try: 

90 events = stream_chat_with_cap( 

91 get_services().provider, 

92 cast("list[dict[str, Any]]", messages), 

93 options=opts, 

94 model=cfg.chat_model, 

95 show_reasoning=cfg.show_reasoning, 

96 cap_chars=effective_reasoning_cap(), 

97 ) 

98 for event in events: 

99 if cancel.is_set(): 

100 events.close() 

101 break 

102 if isinstance(event, CapNotice): 

103 queue.put_nowait( 

104 sse_event( 

105 SseEvent.REASONING, 

106 {"token": CAP_NOTICE_TEMPLATE.format(chars=event.cap_chars)}, 

107 ) 

108 ) 

109 elif event.content: 

110 kind = SseEvent.REASONING if event.is_reasoning else SseEvent.TOKEN 

111 if kind is SseEvent.TOKEN: 

112 answer_parts.append(event.content) 

113 queue.put_nowait(sse_event(kind, {"token": event.content})) 

114 except Exception as exc: 

115 error_holder.append(exc) 

116 finally: 

117 queue.put_nowait(None) 

118 

119 

120async def _emit_extracted_memories(question: str, answer: str) -> AsyncGenerator[str, None]: 

121 """Yield a ``memory_extracted`` SSE event if the turn auto-saved any memories. 

122 

123 Runs the extraction LLM pass off the event loop. Silent (yields nothing) 

124 when the answer is empty, auto-extraction is off, or nothing was extracted, 

125 so existing consumers are unaffected. 

126 """ 

127 if not answer or not auto_extract_enabled(): 

128 return 

129 stored = await asyncio.to_thread(auto_extract, question, answer) 

130 if not stored: 

131 return 

132 event = MemoryExtractedEvent( 

133 count=len(stored), 

134 items=[MemoryExtractedItem(id=m.id, kind=m.kind, text=m.text) for m in stored], 

135 ) 

136 yield sse_event(SseEvent.MEMORY_EXTRACTED, event.model_dump(mode="json")) 

137 

138 

139def _error_event(exc: Exception) -> str: 

140 """Build the SSE error event for a stream failure. 

141 

142 Provider errors already carry a user-facing message (rate limits, auth, 

143 bad model), so surface it verbatim. Everything else goes through the 

144 llama.cpp OOM classifier and otherwise collapses to a generic message. 

145 """ 

146 if isinstance(exc, ProviderError): 

147 log.warning("Provider error during stream: %s", exc) 

148 kind_code = exc.kind if exc.kind is not ProviderErrorKind.UNKNOWN else None 

149 return sse_error(str(exc), code=kind_code) 

150 raw = str(exc) 

151 code, user_message = classify_load_error(raw) 

152 log.warning("Stream error: %s", raw) 

153 return sse_error(user_message, code=code, detail=raw if code else None) 

154 

155 

156async def _stream_rag_response( 

157 question: str, 

158 history: list[ChatMessage] | None = None, 

159 top_k: int = 0, 

160 options: dict[str, Any] | None = None, 

161 chunk_type: ChunkType | None = None, 

162 *, 

163 honor_chat_mode: bool = False, 

164) -> AsyncGenerator[str, None]: 

165 """Shared SSE streaming for ask_stream and chat_stream. 

166 

167 With ``honor_chat_mode`` (chat_stream), a chat-only mode or a missing embedder 

168 streams a direct answer with no retrieval and no sources. 

169 """ 

170 yield "" # force generator 

171 

172 searcher = get_services().searcher 

173 results: list[SearchChunk] = [] 

174 if honor_chat_mode and searcher.skip_retrieval(): 

175 messages = searcher.direct_messages(question, history) 

176 else: 

177 try: 

178 rag = searcher.build_rag_context( 

179 question, top_k=top_k, history=history, chunk_type=chunk_type 

180 ) 

181 except EmbeddingModelMismatchError as exc: 

182 # detail carries the index's embedder so the client can offer to adopt it. 

183 detail = exc.persisted_model if exc.dims_match else None 

184 yield sse_error(str(exc), code=SseErrorCode.INDEX_EMBEDDER_MISMATCH, detail=detail) 

185 return 

186 if rag is None: 

187 yield sse_error("No relevant documents found.") 

188 return 

189 results, messages = rag 

190 

191 opts = _resolve_generation_options(options) or cfg.generation_options() 

192 

193 sse = SseStream() 

194 error_holder: list[Exception] = [] 

195 answer_parts: list[str] = [] 

196 

197 executor_fut = sse.loop.run_in_executor( 

198 None, _run_llm_stream, messages, opts, sse.queue, sse.cancel, error_holder, answer_parts 

199 ) 

200 task = asyncio.ensure_future(executor_fut) 

201 async for event in sse.drain(task, "RAG stream"): 

202 yield event 

203 

204 if error_holder: 

205 yield _error_event(error_holder[0]) 

206 sse.cancel.set() 

207 return 

208 

209 # Ensure executor thread has finished before yielding final events 

210 await executor_fut 

211 

212 yield sse_event(SseEvent.SOURCES, [clean_result(s) for s in results]) 

213 yield sse_done({}) 

214 

215 # Auto-extraction (and its notification) trails ``done`` so clients that stop 

216 # at ``done`` are unaffected; the memories are stored regardless. 

217 async for event in _emit_extracted_memories(question, "".join(answer_parts)): 

218 yield event 

219 

220 

221def ask_stream( 

222 question: str, 

223 top_k: int = 0, 

224 options: dict[str, Any] | None = None, 

225 chunk_type: ChunkType | None = None, 

226) -> AsyncGenerator[str, None]: 

227 """Yield SSE events: token, sources, done.""" 

228 return _stream_rag_response(question, top_k=top_k, options=options, chunk_type=chunk_type) 

229 

230 

231async def chat( 

232 question: str, 

233 history: list[ChatMessage], 

234 top_k: int = 0, 

235 options: dict[str, Any] | None = None, 

236 chunk_type: ChunkType | None = None, 

237) -> AskResponse: 

238 """Chat with history. Returns answer and sources.""" 

239 opts = _resolve_generation_options(options) 

240 result = get_services().searcher.ask_raw( 

241 question, top_k=top_k, history=history, options=opts, chunk_type=chunk_type 

242 ) 

243 return AskResponse( 

244 answer=result.answer, 

245 sources=[CleanedChunk(**clean_result(s)) for s in result.sources], 

246 ) 

247 

248 

249def chat_stream( 

250 question: str, 

251 history: list[ChatMessage], 

252 top_k: int = 0, 

253 options: dict[str, Any] | None = None, 

254 chunk_type: ChunkType | None = None, 

255) -> AsyncGenerator[str, None]: 

256 """Yield SSE events with chat history support.""" 

257 return _stream_rag_response( 

258 question, 

259 history=history, 

260 top_k=top_k, 

261 options=options, 

262 chunk_type=chunk_type, 

263 honor_chat_mode=True, 

264 )