Coverage for src / lilbee / server / handlers / rag.py: 100%

74 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-15 20:55 +0000

1"""Search, ask, and chat handlers (one-shot and streaming).""" 

2 

3from __future__ import annotations 

4 

5import asyncio 

6import logging 

7import threading 

8from collections.abc import AsyncGenerator 

9from typing import TYPE_CHECKING, Any, cast 

10 

11from lilbee.app.search import clean_result 

12from lilbee.app.services import get_services 

13from lilbee.core.config import cfg 

14from lilbee.core.results import DocumentResult, group 

15from lilbee.retrieval.reasoning import ( 

16 CAP_NOTICE_TEMPLATE, 

17 CapNotice, 

18 effective_reasoning_cap, 

19 stream_chat_with_cap, 

20) 

21from lilbee.runtime.progress import SseEvent 

22from lilbee.server.handlers.sse import ( 

23 SseStream, 

24 _resolve_generation_options, 

25 classify_load_error, 

26 sse_done, 

27 sse_error, 

28 sse_event, 

29) 

30from lilbee.server.models import AskResponse, CleanedChunk 

31 

32if TYPE_CHECKING: 

33 from lilbee.retrieval.query import ChatMessage 

34 

35log = logging.getLogger(__name__) 

36 

37 

38async def search(q: str, top_k: int = 5, chunk_type: str | None = None) -> list[DocumentResult]: 

39 """Search and return grouped DocumentResults.""" 

40 if not q or not q.strip(): 

41 raise ValueError("query must not be empty") 

42 results = get_services().searcher.search(q, top_k=top_k, chunk_type=chunk_type) 

43 results = [r for r in results if r.distance is None or r.distance <= cfg.max_distance] 

44 return group(results) 

45 

46 

47async def ask( 

48 question: str, 

49 top_k: int = 0, 

50 options: dict[str, Any] | None = None, 

51 chunk_type: str | None = None, 

52) -> AskResponse: 

53 """One-shot RAG answer. Returns answer and sources.""" 

54 if not question or not question.strip(): 

55 raise ValueError("question must not be empty") 

56 opts = _resolve_generation_options(options) 

57 result = get_services().searcher.ask_raw( 

58 question, top_k=top_k, options=opts, chunk_type=chunk_type 

59 ) 

60 return AskResponse( 

61 answer=result.answer, 

62 sources=[CleanedChunk(**clean_result(s)) for s in result.sources], 

63 ) 

64 

65 

66def _run_llm_stream( 

67 messages: list[ChatMessage], 

68 opts: dict[str, Any] | None, 

69 queue: asyncio.Queue[str | None], 

70 cancel: threading.Event, 

71 error_holder: list[str], 

72) -> None: 

73 """Forward tokens from the cap-aware chat orchestrator into the SSE queue.""" 

74 try: 

75 events = stream_chat_with_cap( 

76 get_services().provider, 

77 cast("list[dict[str, Any]]", messages), 

78 options=opts, 

79 model=cfg.chat_model, 

80 show_reasoning=cfg.show_reasoning, 

81 cap_chars=effective_reasoning_cap(), 

82 ) 

83 for event in events: 

84 if cancel.is_set(): 

85 events.close() 

86 break 

87 if isinstance(event, CapNotice): 

88 queue.put_nowait( 

89 sse_event( 

90 SseEvent.REASONING, 

91 {"token": CAP_NOTICE_TEMPLATE.format(chars=event.cap_chars)}, 

92 ) 

93 ) 

94 elif event.content: 

95 kind = SseEvent.REASONING if event.is_reasoning else SseEvent.TOKEN 

96 queue.put_nowait(sse_event(kind, {"token": event.content})) 

97 except Exception as exc: 

98 error_holder.append(str(exc)) 

99 finally: 

100 queue.put_nowait(None) 

101 

102 

103async def _stream_rag_response( 

104 question: str, 

105 history: list[ChatMessage] | None = None, 

106 top_k: int = 0, 

107 options: dict[str, Any] | None = None, 

108 chunk_type: str | None = None, 

109) -> AsyncGenerator[str, None]: 

110 """Shared SSE streaming for ask_stream and chat_stream.""" 

111 yield "" # force generator 

112 

113 rag = get_services().searcher.build_rag_context( 

114 question, top_k=top_k, history=history, chunk_type=chunk_type 

115 ) 

116 if rag is None: 

117 yield sse_error("No relevant documents found.") 

118 return 

119 

120 results, messages = rag 

121 opts = _resolve_generation_options(options) or cfg.generation_options() 

122 

123 sse = SseStream() 

124 error_holder: list[str] = [] 

125 

126 executor_fut = sse.loop.run_in_executor( 

127 None, _run_llm_stream, messages, opts, sse.queue, sse.cancel, error_holder 

128 ) 

129 task = asyncio.ensure_future(executor_fut) 

130 async for event in sse.drain(task, "RAG stream"): 

131 yield event 

132 

133 if error_holder: 

134 raw = error_holder[0] 

135 code, user_message = classify_load_error(raw) 

136 log.warning("Stream error: %s", raw) 

137 yield sse_error(user_message, code=code, detail=raw if code else None) 

138 sse.cancel.set() 

139 return 

140 

141 # Ensure executor thread has finished before yielding final events 

142 await executor_fut 

143 

144 yield sse_event(SseEvent.SOURCES, [clean_result(s) for s in results]) 

145 yield sse_done({}) 

146 

147 

148def ask_stream( 

149 question: str, 

150 top_k: int = 0, 

151 options: dict[str, Any] | None = None, 

152 chunk_type: str | None = None, 

153) -> AsyncGenerator[str, None]: 

154 """Yield SSE events: token, sources, done.""" 

155 return _stream_rag_response(question, top_k=top_k, options=options, chunk_type=chunk_type) 

156 

157 

158async def chat( 

159 question: str, 

160 history: list[ChatMessage], 

161 top_k: int = 0, 

162 options: dict[str, Any] | None = None, 

163 chunk_type: str | None = None, 

164) -> AskResponse: 

165 """Chat with history. Returns answer and sources.""" 

166 opts = _resolve_generation_options(options) 

167 result = get_services().searcher.ask_raw( 

168 question, top_k=top_k, history=history, options=opts, chunk_type=chunk_type 

169 ) 

170 return AskResponse( 

171 answer=result.answer, 

172 sources=[CleanedChunk(**clean_result(s)) for s in result.sources], 

173 ) 

174 

175 

176def chat_stream( 

177 question: str, 

178 history: list[ChatMessage], 

179 top_k: int = 0, 

180 options: dict[str, Any] | None = None, 

181 chunk_type: str | None = None, 

182) -> AsyncGenerator[str, None]: 

183 """Yield SSE events with chat history support.""" 

184 return _stream_rag_response( 

185 question, history=history, top_k=top_k, options=options, chunk_type=chunk_type 

186 )