Coverage for src / lilbee / server / handlers / sse.py: 100%

74 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-15 20:55 +0000

1"""SSE stream primitives shared by every streaming handler.""" 

2 

3from __future__ import annotations 

4 

5import asyncio 

6import json 

7import logging 

8import threading 

9import time 

10from collections.abc import AsyncGenerator 

11from typing import Any 

12 

13from pydantic import BaseModel 

14 

15from lilbee.core.config import cfg 

16from lilbee.runtime.progress import DetailedProgressCallback, EventType, ProgressEvent, SseEvent 

17 

18log = logging.getLogger(__name__) 

19 

20 

21def sse_event(event: str, data: Any) -> str: 

22 """Format a single Server-Sent Event string.""" 

23 return f"event: {event}\ndata: {json.dumps(data)}\n\n" 

24 

25 

26def sse_error(message: str, *, code: str | None = None, detail: str | None = None) -> str: 

27 """Format an SSE error event with optional structured ``code`` / ``detail``.""" 

28 payload: dict[str, Any] = {"message": message} 

29 if code is not None: 

30 payload["code"] = code 

31 if detail is not None: 

32 payload["detail"] = detail 

33 return sse_event(SseEvent.ERROR, payload) 

34 

35 

36_OOM_MARKERS = ("failed to load", "free ram", "try a smaller model", "llama_context") 

37 

38 

39def classify_load_error(message: str) -> tuple[str | None, str]: 

40 """Return ``(code, user_message)`` for an SSE error event. 

41 

42 Recognises the llama.cpp OOM diagnostic and maps it to a stable code; any 

43 other input falls back to the legacy generic shape. 

44 """ 

45 lowered = message.lower() 

46 if any(marker in lowered for marker in _OOM_MARKERS): 

47 return "model_too_large", "Model too large for available RAM" 

48 return None, "Internal error" 

49 

50 

51def sse_done(data: dict[str, Any]) -> str: 

52 """Format an SSE done event.""" 

53 return sse_event(SseEvent.DONE, data) 

54 

55 

56def _resolve_generation_options(options: dict[str, Any] | None) -> dict[str, Any] | None: 

57 """Convert raw options dict to GenerationOptions, or None.""" 

58 return cfg.generation_options(**options) if options else None 

59 

60 

61class SseStream: 

62 """Context object for SSE streaming with cancellation support. 

63 Bundles the queue, cancel event, and progress callback that every SSE 

64 endpoint needs. Call :meth:`drain` to yield events until the task 

65 completes or the client disconnects. 

66 """ 

67 

68 def __init__(self) -> None: 

69 self.queue: asyncio.Queue[str | None] = asyncio.Queue() 

70 self.cancel = threading.Event() 

71 self.loop = asyncio.get_running_loop() 

72 self.callback: DetailedProgressCallback = self._build_callback() 

73 

74 def _build_callback(self) -> DetailedProgressCallback: 

75 """Create a progress callback that serializes events into the queue. 

76 Safe to call from both the event-loop thread and worker threads. 

77 """ 

78 loop = self.loop 

79 queue = self.queue 

80 

81 def _callback(event_type: EventType, data: ProgressEvent) -> None: 

82 serialized = data.model_dump() if isinstance(data, BaseModel) else data 

83 payload = f"event: {event_type}\ndata: {json.dumps(serialized)}\n\n" 

84 try: 

85 running = asyncio.get_running_loop() 

86 except RuntimeError: 

87 running = None 

88 if running is loop: 

89 queue.put_nowait(payload) 

90 else: 

91 loop.call_soon_threadsafe(queue.put_nowait, payload) 

92 

93 return _callback 

94 

95 async def drain( 

96 self, task: asyncio.Task[Any] | asyncio.Future[Any], label: str 

97 ) -> AsyncGenerator[str, None]: 

98 """Yield SSE strings until a sentinel arrives; cancel *task* on client disconnect. 

99 

100 Emits a ``heartbeat`` event whenever the producer queue stays 

101 idle longer than ``cfg.sse_heartbeat_interval`` seconds so 

102 clients that enforce a stream-idle timeout don't abort. 

103 """ 

104 last_yielded = time.monotonic() 

105 try: 

106 while True: 

107 try: 

108 item = await asyncio.wait_for(self.queue.get(), timeout=0.1) 

109 except TimeoutError: 

110 now = time.monotonic() 

111 heartbeat_interval = cfg.sse_heartbeat_interval 

112 if heartbeat_interval > 0 and now - last_yielded >= heartbeat_interval: 

113 last_yielded = now 

114 yield sse_event(SseEvent.HEARTBEAT, {"ts": time.time()}) 

115 # Fallback for producers that die without a sentinel. 

116 if task.done() and self.queue.empty(): 

117 break 

118 continue 

119 if item is None: 

120 break 

121 last_yielded = time.monotonic() 

122 yield item 

123 except (asyncio.CancelledError, GeneratorExit): 

124 log.info("%s cancelled by client", label) 

125 self.cancel.set() 

126 task.cancel()