Coverage for src / lilbee / server / handlers / sse.py: 100%

95 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-06-28 01:01 +0000

1"""SSE stream primitives shared by every streaming handler.""" 

2 

3from __future__ import annotations 

4 

5import asyncio 

6import json 

7import logging 

8import threading 

9import time 

10from collections.abc import AsyncGenerator 

11from typing import Any 

12 

13from pydantic import BaseModel 

14 

15from lilbee.core.config import cfg 

16from lilbee.providers.base import ProviderErrorKind, filter_options 

17from lilbee.runtime.progress import ( 

18 DetailedProgressCallback, 

19 EventType, 

20 ProgressEvent, 

21 SseErrorCode, 

22 SseEvent, 

23) 

24 

25log = logging.getLogger(__name__) 

26 

27# Machine-readable ``code`` on an SSE error event. Load-time failures use 

28# SseErrorCode; failed provider calls reuse ProviderErrorKind directly rather 

29# than mirroring it into a second enum. 

30SseErrorCodeValue = SseErrorCode | ProviderErrorKind 

31 

32 

33def sse_event(event: str, data: Any) -> str: 

34 """Format a single Server-Sent Event string.""" 

35 return f"event: {event}\ndata: {json.dumps(data)}\n\n" 

36 

37 

38def sse_error( 

39 message: str, *, code: SseErrorCodeValue | None = None, detail: str | None = None 

40) -> str: 

41 """Format an SSE error event with optional structured ``code`` / ``detail``.""" 

42 payload: dict[str, Any] = {"message": message} 

43 if code is not None: 

44 payload["code"] = code 

45 if detail is not None: 

46 payload["detail"] = detail 

47 return sse_event(SseEvent.ERROR, payload) 

48 

49 

50_OOM_MARKERS = ("failed to load", "free ram", "try a smaller model", "llama_context") 

51_NOT_INSTALLED_MARKERS = ("not found in registry", "is not available", "pull it first") 

52 

53 

54def classify_load_error(message: str) -> tuple[SseErrorCode | None, str]: 

55 """Return ``(code, user_message)`` for an SSE error event. 

56 

57 Maps the llama.cpp out-of-memory diagnostic and the "configured model 

58 isn't installed" failure to stable codes. Anything else returns a generic 

59 code-less message. 

60 """ 

61 lowered = message.lower() 

62 if any(marker in lowered for marker in _OOM_MARKERS): 

63 return SseErrorCode.MODEL_TOO_LARGE, "Model too large for available RAM" 

64 if any(marker in lowered for marker in _NOT_INSTALLED_MARKERS): 

65 return ( 

66 SseErrorCode.MODEL_NOT_INSTALLED, 

67 "Active model isn't installed. Pull it from the catalog.", 

68 ) 

69 return None, "Internal error" 

70 

71 

72def sse_done(data: dict[str, Any]) -> str: 

73 """Format an SSE done event.""" 

74 return sse_event(SseEvent.DONE, data) 

75 

76 

77def _resolve_generation_options(options: dict[str, Any] | None) -> dict[str, Any] | None: 

78 """Merge HTTP-supplied options with config, allowlisting sampling keys only. 

79 

80 ``filter_options`` is the validation boundary for untrusted callers: it 

81 drops anything outside the sampling allowlist (e.g. injected ``api_base`` / 

82 ``api_key``) before the values reach a provider. 

83 """ 

84 return cfg.generation_options(**filter_options(options)) if options else None 

85 

86 

87class SseStream: 

88 """Context object for SSE streaming with cancellation support. 

89 Bundles the queue, cancel event, and progress callback that every SSE 

90 endpoint needs. Call :meth:`drain` to yield events until the task 

91 completes or the client disconnects. 

92 """ 

93 

94 def __init__(self) -> None: 

95 self.queue: asyncio.Queue[str | None] = asyncio.Queue() 

96 self.cancel = threading.Event() 

97 self.loop = asyncio.get_running_loop() 

98 self.callback: DetailedProgressCallback = self._build_callback() 

99 

100 def _build_callback(self) -> DetailedProgressCallback: 

101 """Create a progress callback that serializes events into the queue. 

102 Safe to call from both the event-loop thread and worker threads. 

103 """ 

104 loop = self.loop 

105 queue = self.queue 

106 

107 def _callback(event_type: EventType, data: ProgressEvent) -> None: 

108 serialized = data.model_dump() if isinstance(data, BaseModel) else data 

109 payload = f"event: {event_type}\ndata: {json.dumps(serialized)}\n\n" 

110 try: 

111 running = asyncio.get_running_loop() 

112 except RuntimeError: 

113 running = None 

114 if running is loop: 

115 queue.put_nowait(payload) 

116 else: 

117 loop.call_soon_threadsafe(queue.put_nowait, payload) 

118 

119 return _callback 

120 

121 async def _flush_pending(self) -> AsyncGenerator[str, None]: 

122 """Events left behind the sentinel by a producer that outran the consumer. 

123 

124 A fast producer can enqueue its sentinel before its threadsafe progress 

125 callbacks run; one loop tick lets them land. 

126 """ 

127 await asyncio.sleep(0) 

128 while not self.queue.empty(): 

129 leftover = self.queue.get_nowait() 

130 if leftover is not None: 

131 yield leftover 

132 

133 async def drain( 

134 self, task: asyncio.Task[Any] | asyncio.Future[Any], label: str 

135 ) -> AsyncGenerator[str, None]: 

136 """Yield SSE strings until a sentinel arrives; cancel *task* on client disconnect. 

137 

138 Emits a ``heartbeat`` event whenever the producer queue stays 

139 idle longer than ``cfg.sse_heartbeat_interval`` seconds so 

140 clients that enforce a stream-idle timeout don't abort. 

141 

142 The pending ``queue.get`` survives across poll rounds (``asyncio.wait``, 

143 not ``wait_for``): cancelling a completed get on the timeout boundary 

144 would drop the event it already popped from the queue. 

145 """ 

146 last_yielded = time.monotonic() 

147 getter: asyncio.Future[str | None] | None = None 

148 try: 

149 while True: 

150 if getter is None: 

151 getter = asyncio.ensure_future(self.queue.get()) 

152 done, _ = await asyncio.wait({getter}, timeout=0.1) 

153 if not done: 

154 now = time.monotonic() 

155 heartbeat_interval = cfg.sse_heartbeat_interval 

156 if heartbeat_interval > 0 and now - last_yielded >= heartbeat_interval: 

157 last_yielded = now 

158 yield sse_event(SseEvent.HEARTBEAT, {"ts": time.time()}) 

159 # Fallback for producers that die without a sentinel. 

160 if task.done() and self.queue.empty(): 

161 getter.cancel() 

162 break 

163 continue 

164 item = getter.result() 

165 getter = None 

166 if item is None: 

167 async for leftover in self._flush_pending(): 

168 last_yielded = time.monotonic() 

169 yield leftover 

170 break 

171 last_yielded = time.monotonic() 

172 yield item 

173 except (asyncio.CancelledError, GeneratorExit): 

174 log.info("%s cancelled by client", label) 

175 self.cancel.set() 

176 task.cancel() 

177 if getter is not None: 

178 getter.cancel()