Coverage for src / lilbee / server / handlers / sse.py: 100%
95 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-28 01:01 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-28 01:01 +0000
1"""SSE stream primitives shared by every streaming handler."""
3from __future__ import annotations
5import asyncio
6import json
7import logging
8import threading
9import time
10from collections.abc import AsyncGenerator
11from typing import Any
13from pydantic import BaseModel
15from lilbee.core.config import cfg
16from lilbee.providers.base import ProviderErrorKind, filter_options
17from lilbee.runtime.progress import (
18 DetailedProgressCallback,
19 EventType,
20 ProgressEvent,
21 SseErrorCode,
22 SseEvent,
23)
25log = logging.getLogger(__name__)
27# Machine-readable ``code`` on an SSE error event. Load-time failures use
28# SseErrorCode; failed provider calls reuse ProviderErrorKind directly rather
29# than mirroring it into a second enum.
30SseErrorCodeValue = SseErrorCode | ProviderErrorKind
33def sse_event(event: str, data: Any) -> str:
34 """Format a single Server-Sent Event string."""
35 return f"event: {event}\ndata: {json.dumps(data)}\n\n"
38def sse_error(
39 message: str, *, code: SseErrorCodeValue | None = None, detail: str | None = None
40) -> str:
41 """Format an SSE error event with optional structured ``code`` / ``detail``."""
42 payload: dict[str, Any] = {"message": message}
43 if code is not None:
44 payload["code"] = code
45 if detail is not None:
46 payload["detail"] = detail
47 return sse_event(SseEvent.ERROR, payload)
50_OOM_MARKERS = ("failed to load", "free ram", "try a smaller model", "llama_context")
51_NOT_INSTALLED_MARKERS = ("not found in registry", "is not available", "pull it first")
54def classify_load_error(message: str) -> tuple[SseErrorCode | None, str]:
55 """Return ``(code, user_message)`` for an SSE error event.
57 Maps the llama.cpp out-of-memory diagnostic and the "configured model
58 isn't installed" failure to stable codes. Anything else returns a generic
59 code-less message.
60 """
61 lowered = message.lower()
62 if any(marker in lowered for marker in _OOM_MARKERS):
63 return SseErrorCode.MODEL_TOO_LARGE, "Model too large for available RAM"
64 if any(marker in lowered for marker in _NOT_INSTALLED_MARKERS):
65 return (
66 SseErrorCode.MODEL_NOT_INSTALLED,
67 "Active model isn't installed. Pull it from the catalog.",
68 )
69 return None, "Internal error"
72def sse_done(data: dict[str, Any]) -> str:
73 """Format an SSE done event."""
74 return sse_event(SseEvent.DONE, data)
77def _resolve_generation_options(options: dict[str, Any] | None) -> dict[str, Any] | None:
78 """Merge HTTP-supplied options with config, allowlisting sampling keys only.
80 ``filter_options`` is the validation boundary for untrusted callers: it
81 drops anything outside the sampling allowlist (e.g. injected ``api_base`` /
82 ``api_key``) before the values reach a provider.
83 """
84 return cfg.generation_options(**filter_options(options)) if options else None
87class SseStream:
88 """Context object for SSE streaming with cancellation support.
89 Bundles the queue, cancel event, and progress callback that every SSE
90 endpoint needs. Call :meth:`drain` to yield events until the task
91 completes or the client disconnects.
92 """
94 def __init__(self) -> None:
95 self.queue: asyncio.Queue[str | None] = asyncio.Queue()
96 self.cancel = threading.Event()
97 self.loop = asyncio.get_running_loop()
98 self.callback: DetailedProgressCallback = self._build_callback()
100 def _build_callback(self) -> DetailedProgressCallback:
101 """Create a progress callback that serializes events into the queue.
102 Safe to call from both the event-loop thread and worker threads.
103 """
104 loop = self.loop
105 queue = self.queue
107 def _callback(event_type: EventType, data: ProgressEvent) -> None:
108 serialized = data.model_dump() if isinstance(data, BaseModel) else data
109 payload = f"event: {event_type}\ndata: {json.dumps(serialized)}\n\n"
110 try:
111 running = asyncio.get_running_loop()
112 except RuntimeError:
113 running = None
114 if running is loop:
115 queue.put_nowait(payload)
116 else:
117 loop.call_soon_threadsafe(queue.put_nowait, payload)
119 return _callback
121 async def _flush_pending(self) -> AsyncGenerator[str, None]:
122 """Events left behind the sentinel by a producer that outran the consumer.
124 A fast producer can enqueue its sentinel before its threadsafe progress
125 callbacks run; one loop tick lets them land.
126 """
127 await asyncio.sleep(0)
128 while not self.queue.empty():
129 leftover = self.queue.get_nowait()
130 if leftover is not None:
131 yield leftover
133 async def drain(
134 self, task: asyncio.Task[Any] | asyncio.Future[Any], label: str
135 ) -> AsyncGenerator[str, None]:
136 """Yield SSE strings until a sentinel arrives; cancel *task* on client disconnect.
138 Emits a ``heartbeat`` event whenever the producer queue stays
139 idle longer than ``cfg.sse_heartbeat_interval`` seconds so
140 clients that enforce a stream-idle timeout don't abort.
142 The pending ``queue.get`` survives across poll rounds (``asyncio.wait``,
143 not ``wait_for``): cancelling a completed get on the timeout boundary
144 would drop the event it already popped from the queue.
145 """
146 last_yielded = time.monotonic()
147 getter: asyncio.Future[str | None] | None = None
148 try:
149 while True:
150 if getter is None:
151 getter = asyncio.ensure_future(self.queue.get())
152 done, _ = await asyncio.wait({getter}, timeout=0.1)
153 if not done:
154 now = time.monotonic()
155 heartbeat_interval = cfg.sse_heartbeat_interval
156 if heartbeat_interval > 0 and now - last_yielded >= heartbeat_interval:
157 last_yielded = now
158 yield sse_event(SseEvent.HEARTBEAT, {"ts": time.time()})
159 # Fallback for producers that die without a sentinel.
160 if task.done() and self.queue.empty():
161 getter.cancel()
162 break
163 continue
164 item = getter.result()
165 getter = None
166 if item is None:
167 async for leftover in self._flush_pending():
168 last_yielded = time.monotonic()
169 yield leftover
170 break
171 last_yielded = time.monotonic()
172 yield item
173 except (asyncio.CancelledError, GeneratorExit):
174 log.info("%s cancelled by client", label)
175 self.cancel.set()
176 task.cancel()
177 if getter is not None:
178 getter.cancel()