Coverage for src / lilbee / providers / worker / chat_worker.py: 100%
146 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
1"""Long-lived chat worker subprocess body, with token streaming."""
3from __future__ import annotations
5import contextlib
6import threading
7import time
8from typing import Any
10from lilbee.providers.worker.transport import ChatRequest, RoleConfig
11from lilbee.providers.worker.transport_pipe import _serialize_exception
12from lilbee.providers.worker.wire_kinds import WireKind
13from lilbee.providers.worker.worker_runtime import Reply, WorkerLoopState, run_worker
15_ABORT_BRIDGE_POLL_S = 0.025
16"""How often the abort bridge polls the parent's mp.Value flag.
1825 ms is the budget between the user pressing Esc and ggml's next
19abort_callback poll on a slow-token chat. Faster than this stops being
20visible to a human; slower than this leaves the user staring at an
21unresponsive UI for the duration of one stuck token.
22"""
24_STREAM_BATCH_MAX_CHUNKS = 16
25"""Flush a streaming-chat batch once this many tokens have queued.
27Bounded so a long answer at high tok/s doesn't grow the buffer without
28limit. Sized so each flush write batches ~16 syscalls into 1.
29"""
31_STREAM_BATCH_MAX_INTERVAL_S = 0.05
32"""Flush a streaming-chat batch at least this often (50 ms).
34The check fires only when the *next* token arrives (we never wake on a
35timer), so a generator that stalls after token N keeps token N+1's
36buffer parked until the next token lands. The eager-first-flush at
37the top of :func:`_handle_chat_streaming` guarantees the user sees
38something within the very first token, so the parked-tail case only
39delays subsequent batches, not initial output.
40"""
43class _ChatSession:
44 """Lazy-loaded Llama chat handle, kept alive for the worker's lifetime.
46 Reloads in place when the parent passes a per-call ``model`` override
47 different from the currently loaded one.
48 """
50 def __init__(self, role_config: RoleConfig, abort_flag: Any) -> None:
51 self._role_config = role_config
52 self._abort_flag = abort_flag
53 self._llm: Any = None
54 self._model_path: str = ""
56 def chat(
57 self,
58 *,
59 messages: list[dict[str, str]],
60 stream: bool,
61 options: dict[str, Any] | None,
62 model: str | None,
63 ) -> Any:
64 """Run one chat completion and return the llama-cpp response."""
65 llm = self._ensure_loaded(model)
66 kwargs: dict[str, Any] = dict(options) if options else {}
67 return llm.create_chat_completion(messages=messages, stream=stream, **kwargs)
69 def _ensure_loaded(self, model_override: str | None) -> Any:
70 from lilbee.providers.llama_cpp.provider import load_llama, resolve_model_path
71 from lilbee.providers.model_cache import LoaderMode
73 target_path = (
74 resolve_model_path(model_override) if model_override else self._role_config.model_path
75 )
76 target_str = str(target_path)
77 if self._llm is None or target_str != self._model_path:
78 self._close_model()
79 # No abort_callback_override: routing the cancel signal through
80 # ggml's mid-token abort path crashes the worker on macOS Metal.
81 # Cancel is enforced one token boundary later by the Python-side
82 # polling loop in _handle_chat_streaming.
83 self._llm = load_llama(target_path, mode=LoaderMode.CHAT)
84 self._model_path = target_str
85 return self._llm
87 def _close_model(self) -> None:
88 if self._llm is not None:
89 with contextlib.suppress(Exception):
90 self._llm.close()
91 self._llm = None
93 def close(self) -> None:
94 """Release the loaded model. Idempotent."""
95 self._close_model()
98def _extract_stream_content(chunk: Any) -> str | None:
99 """Pull the text content out of one llama-cpp streaming chunk."""
100 choices = chunk.get("choices") if isinstance(chunk, dict) else None
101 if not choices:
102 return None
103 delta = choices[0].get("delta") if isinstance(choices[0], dict) else None
104 if not isinstance(delta, dict):
105 return None
106 content = delta.get("content")
107 return content if isinstance(content, str) and content else None
110def _handle_chat_streaming(reply: Reply, response_iter: Any, state: WorkerLoopState) -> None:
111 """Drain *response_iter* and emit batched stream_chunk frames on the data pipe.
113 Polls ``state.session._abort_flag`` between chunks so a cancel from the
114 parent flushes a clean ``stream_end`` at the next token boundary.
115 Tokens are accumulated and flushed every ``_STREAM_BATCH_MAX_CHUNKS``
116 or ``_STREAM_BATCH_MAX_INTERVAL_S``, whichever comes first, so the
117 pipe sees ~one syscall per batch instead of one per token.
119 Cancel path: ``break`` exits the for loop normally, control falls
120 through to ``completed_cleanly = True``, the finally clause flushes
121 any buffered tail, and ``stream_end`` fires. The parent's
122 ``stream()`` reader then returns cleanly without a hang.
123 """
124 abort_flag = state.session._abort_flag
125 buffer: list[str] = []
126 last_flush = time.monotonic()
127 seen_first_token = False
128 completed_cleanly = False
129 try:
130 for raw_chunk in response_iter:
131 if abort_flag.value:
132 with contextlib.suppress(Exception):
133 response_iter.close()
134 break
135 content = _extract_stream_content(raw_chunk)
136 if content is None:
137 continue
138 buffer.append(content)
139 now = time.monotonic()
140 # Flush the very first token immediately so a generator that
141 # stalls after one token still surfaces something to the user.
142 should_flush = (
143 not seen_first_token
144 or len(buffer) >= _STREAM_BATCH_MAX_CHUNKS
145 or (now - last_flush) >= _STREAM_BATCH_MAX_INTERVAL_S
146 )
147 if should_flush:
148 reply.send(WireKind.STREAM_CHUNK, "".join(buffer))
149 buffer.clear()
150 last_flush = now
151 seen_first_token = True
152 completed_cleanly = True
153 finally:
154 # Flush any buffered tokens regardless of how the loop exited so
155 # the user sees partial output before the error frame the outer
156 # handler may emit.
157 if buffer:
158 reply.send(WireKind.STREAM_CHUNK, "".join(buffer))
159 if completed_cleanly:
160 reply.send(WireKind.STREAM_END, None)
163def _extract_non_streaming_content(response: Any) -> str:
164 """Pull the assistant text out of one llama-cpp non-streaming response."""
165 if not isinstance(response, dict):
166 raise TypeError(f"chat response must be dict, got {type(response).__name__}")
167 choices = response.get("choices")
168 if not isinstance(choices, list) or not choices:
169 raise TypeError("chat response missing 'choices' list")
170 first = choices[0]
171 if not isinstance(first, dict):
172 raise TypeError(f"chat choices[0] must be dict, got {type(first).__name__}")
173 message = first.get("message")
174 if not isinstance(message, dict):
175 raise TypeError("chat choices[0].message missing or not dict")
176 content = message.get("content")
177 return content if isinstance(content, str) else ""
180def _handle_chat_non_streaming(reply: Reply, response: Any) -> None:
181 """Emit one result frame with the full assistant message text."""
182 text = _extract_non_streaming_content(response)
183 reply.send(WireKind.RESULT, text)
186class _AbortBridge:
187 """Mirror the parent's mp.Value abort flag into ggml's threading.Event.
189 Without this, cancel only takes effect at Python-loop boundaries
190 between yielded tokens. A token that takes 30+ seconds inside the
191 ggml decode (full context, slow GPU, big buffer) keeps generating
192 because the Python loop never gets a chance to read the flag.
193 Calling ``request_abort()`` flips the threading.Event that the
194 loaded llama's ``abort_callback`` polls inside ggml, so cancel
195 takes effect at the next ggml poll point (every few tokens) instead
196 of waiting for the next Python yield.
197 """
199 def __init__(self, abort_flag: Any) -> None:
200 self._abort_flag = abort_flag
201 self._stop = threading.Event()
202 self._thread: threading.Thread | None = None
204 def __enter__(self) -> _AbortBridge:
205 from lilbee.providers.llama_cpp.abort_signal import clear_abort
207 # Reset both flags before the chat starts: a stale parent-side
208 # cancel from a prior call must not abort the new request.
209 clear_abort()
210 self._abort_flag.value = 0
211 self._stop.clear()
212 self._thread = threading.Thread(target=self._poll, name="chat-abort-bridge", daemon=True)
213 self._thread.start()
214 return self
216 def __exit__(self, *_exc_info: Any) -> None:
217 from lilbee.providers.llama_cpp.abort_signal import clear_abort
219 self._stop.set()
220 thread = self._thread
221 if thread is not None:
222 thread.join(timeout=1.0)
223 # Reset for the next request so a cancelled prior call doesn't
224 # latch onto the next inference.
225 clear_abort()
226 self._abort_flag.value = 0
228 def _poll(self) -> None:
229 from lilbee.providers.llama_cpp.abort_signal import request_abort
231 while not self._stop.wait(_ABORT_BRIDGE_POLL_S):
232 if self._abort_flag.value:
233 request_abort()
234 return
237def _handle_chat(reply: Reply, payload: Any, state: WorkerLoopState) -> None:
238 """Run one chat request and dispatch to the streaming/non-streaming handler."""
239 if not isinstance(payload, ChatRequest):
240 try:
241 raise TypeError(f"chat payload must be ChatRequest, got {type(payload).__name__}")
242 except TypeError as exc:
243 reply.send(WireKind.ERROR, _serialize_exception(exc))
244 return
245 session: _ChatSession = state.session
246 with _AbortBridge(session._abort_flag):
247 try:
248 response = session.chat(
249 messages=payload.messages,
250 stream=payload.stream,
251 options=payload.options,
252 model=payload.model,
253 )
254 except Exception as exc:
255 reply.send(WireKind.ERROR, _serialize_exception(exc))
256 return
257 try:
258 if payload.stream:
259 _handle_chat_streaming(reply, response, state)
260 else:
261 _handle_chat_non_streaming(reply, response)
262 except Exception as exc:
263 reply.send(WireKind.ERROR, _serialize_exception(exc))
266def chat_worker_main(
267 data_conn: Any, health_conn: Any, abort_flag: Any, role_config: RoleConfig
268) -> None:
269 """Chat worker entrypoint: load llama-cpp lazily, serve until shutdown."""
270 run_worker(
271 data_conn,
272 health_conn,
273 abort_flag,
274 role_config,
275 session_factory=_ChatSession,
276 kind_handlers={WireKind.CHAT: _handle_chat},
277 )
280__all__ = ["chat_worker_main"]