Coverage for src / lilbee / server / routes / search.py: 100%

77 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-06-28 01:01 +0000

1"""Search, ask, ask_stream, chat, and chat_stream route handlers.""" 

2 

3from __future__ import annotations 

4 

5import asyncio 

6import logging 

7from collections.abc import AsyncGenerator 

8 

9from litestar import get, post 

10from litestar.exceptions import HTTPException, ValidationException 

11from litestar.params import Parameter 

12from litestar.response import Stream 

13 

14from lilbee.core.results import DocumentResult 

15from lilbee.data.store import EmbeddingModelMismatchError, scope_to_chunk_type 

16from lilbee.retrieval.query import ChatMessage as ChatMessageDict 

17from lilbee.server import handlers 

18from lilbee.server.auth import read_only 

19from lilbee.server.handlers.sse import sse_error 

20from lilbee.server.models import ( 

21 AskRequest, 

22 AskResponse, 

23 ChatRequest, 

24) 

25 

26# Process-wide lock that gates the two streaming chat endpoints to one 

27# in-flight request at a time. The llama-cpp provider already serializes 

28# concurrent chat() calls under a thread lock, so a second concurrent 

29# stream blocks the client for many seconds with no feedback. Returning 

30# 429 + Retry-After fast lets clients surface a real error and decide. 

31# The lock binds to the worker's running event loop on first acquire. 

32log = logging.getLogger(__name__) 

33 

34_chat_inflight_lock = asyncio.Lock() 

35 

36 

37def _embedding_mismatch_http(exc: EmbeddingModelMismatchError) -> HTTPException: 

38 """Translate an embedder mismatch into a 409 carrying the facts to adopt. 

39 

40 The client renders its own confirm-to-adopt prompt from ``extra`` and, on 

41 confirm, sets the embedder via ``PUT /api/models/embedding`` then retries. 

42 The server never switches embedder unprompted. 

43 """ 

44 return HTTPException( 

45 status_code=409, 

46 detail=str(exc), 

47 extra={ 

48 "persisted_model": exc.persisted_model, 

49 "persisted_dim": exc.persisted_dim, 

50 "current_model": exc.current_model, 

51 "adoptable": exc.dims_match, 

52 }, 

53 ) 

54 

55 

56def _acquire_chat_lock_or_raise() -> None: 

57 """Non-blocking acquire on the running loop thread; raise 429 on contention. 

58 

59 Race-free because route handlers run on a single event loop thread and 

60 ``Lock.acquire()`` on a free lock returns synchronously without yielding. 

61 The check + acquire is atomic from the loop's perspective, no ``await`` 

62 can intervene between the two calls. 

63 """ 

64 if _chat_inflight_lock.locked(): 

65 raise HTTPException(status_code=429, headers={"Retry-After": "1"}) 

66 

67 

68async def _gated_stream( 

69 generator: AsyncGenerator[str, None], 

70) -> AsyncGenerator[str, None]: 

71 """Wrap *generator* so the chat lock is released when the stream ends. 

72 

73 The lock must already be held when this is called. Release happens on 

74 natural completion, exception, and client-disconnect (GeneratorExit 

75 fires the ``finally`` block). A failure inside the generator becomes an 

76 SSE error event; raising after the 201 headers would drop the connection 

77 with no body for the client to read. 

78 """ 

79 try: 

80 async for chunk in generator: 

81 yield chunk 

82 except Exception as exc: 

83 log.exception("streaming chat handler failed") 

84 yield sse_error(str(exc)) 

85 finally: 

86 _chat_inflight_lock.release() 

87 

88 

89@get("/api/search") 

90@read_only 

91async def search_route( 

92 q: str = Parameter(query="q"), 

93 top_k: int = Parameter(query="top_k", default=5, le=100), 

94 chunk_type: str | None = Parameter(query="chunk_type", default=None), 

95) -> list[DocumentResult]: 

96 """Search indexed documents by semantic similarity. No LLM call required.""" 

97 try: 

98 chunk_type = scope_to_chunk_type(chunk_type) 

99 except ValueError as exc: 

100 raise ValidationException(str(exc)) from exc 

101 try: 

102 return await handlers.search(q, top_k=top_k, chunk_type=chunk_type) 

103 except EmbeddingModelMismatchError as exc: 

104 raise _embedding_mismatch_http(exc) from exc 

105 except ValueError as exc: 

106 raise ValidationException(str(exc)) from exc 

107 except Exception as exc: 

108 raise HTTPException(status_code=503, detail=str(exc)) from exc 

109 

110 

111@post("/api/ask") 

112async def ask_route(data: AskRequest) -> AskResponse: 

113 """One-shot RAG question returning an answer with source chunks.""" 

114 try: 

115 return await handlers.ask( 

116 question=data.question, 

117 top_k=data.top_k, 

118 options=data.options, 

119 chunk_type=data.chunk_type, 

120 ) 

121 except EmbeddingModelMismatchError as exc: 

122 raise _embedding_mismatch_http(exc) from exc 

123 except ValueError as exc: 

124 raise ValidationException(str(exc)) from exc 

125 except Exception as exc: 

126 raise HTTPException(status_code=503, detail=str(exc)) from exc 

127 

128 

129@post("/api/ask/stream") 

130async def ask_stream_route(data: AskRequest) -> Stream: 

131 """Streaming SSE version of ask, emitting token-by-token answer chunks.""" 

132 _acquire_chat_lock_or_raise() 

133 await _chat_inflight_lock.acquire() 

134 return Stream( 

135 _gated_stream( 

136 handlers.ask_stream( 

137 question=data.question, 

138 top_k=data.top_k, 

139 options=data.options, 

140 chunk_type=data.chunk_type, 

141 ), 

142 ), 

143 media_type="text/event-stream", 

144 ) 

145 

146 

147@post("/api/chat") 

148async def chat_route(data: ChatRequest) -> AskResponse: 

149 """RAG chat with conversation history, returning an answer with sources.""" 

150 history: list[ChatMessageDict] = [ 

151 ChatMessageDict(role=m.role, content=m.content) for m in data.history 

152 ] 

153 try: 

154 return await handlers.chat( 

155 question=data.question, 

156 history=history, 

157 top_k=data.top_k, 

158 options=data.options, 

159 chunk_type=data.chunk_type, 

160 ) 

161 except EmbeddingModelMismatchError as exc: 

162 raise _embedding_mismatch_http(exc) from exc 

163 except ValueError as exc: 

164 raise ValidationException(str(exc)) from exc 

165 except Exception as exc: 

166 raise HTTPException(status_code=503, detail=str(exc)) from exc 

167 

168 

169@post("/api/chat/stream") 

170async def chat_stream_route(data: ChatRequest) -> Stream: 

171 """Streaming SSE version of chat with conversation history.""" 

172 _acquire_chat_lock_or_raise() 

173 await _chat_inflight_lock.acquire() 

174 history: list[ChatMessageDict] = [ 

175 ChatMessageDict(role=m.role, content=m.content) for m in data.history 

176 ] 

177 return Stream( 

178 _gated_stream( 

179 handlers.chat_stream( 

180 question=data.question, 

181 history=history, 

182 top_k=data.top_k, 

183 options=data.options, 

184 chunk_type=data.chunk_type, 

185 ), 

186 ), 

187 media_type="text/event-stream", 

188 )