Coverage for src / lilbee / providers / worker / chat_worker.py: 100%

146 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-15 20:55 +0000

1"""Long-lived chat worker subprocess body, with token streaming.""" 

2 

3from __future__ import annotations 

4 

5import contextlib 

6import threading 

7import time 

8from typing import Any 

9 

10from lilbee.providers.worker.transport import ChatRequest, RoleConfig 

11from lilbee.providers.worker.transport_pipe import _serialize_exception 

12from lilbee.providers.worker.wire_kinds import WireKind 

13from lilbee.providers.worker.worker_runtime import Reply, WorkerLoopState, run_worker 

14 

15_ABORT_BRIDGE_POLL_S = 0.025 

16"""How often the abort bridge polls the parent's mp.Value flag. 

17 

1825 ms is the budget between the user pressing Esc and ggml's next 

19abort_callback poll on a slow-token chat. Faster than this stops being 

20visible to a human; slower than this leaves the user staring at an 

21unresponsive UI for the duration of one stuck token. 

22""" 

23 

24_STREAM_BATCH_MAX_CHUNKS = 16 

25"""Flush a streaming-chat batch once this many tokens have queued. 

26 

27Bounded so a long answer at high tok/s doesn't grow the buffer without 

28limit. Sized so each flush write batches ~16 syscalls into 1. 

29""" 

30 

31_STREAM_BATCH_MAX_INTERVAL_S = 0.05 

32"""Flush a streaming-chat batch at least this often (50 ms). 

33 

34The check fires only when the *next* token arrives (we never wake on a 

35timer), so a generator that stalls after token N keeps token N+1's 

36buffer parked until the next token lands. The eager-first-flush at 

37the top of :func:`_handle_chat_streaming` guarantees the user sees 

38something within the very first token, so the parked-tail case only 

39delays subsequent batches, not initial output. 

40""" 

41 

42 

43class _ChatSession: 

44 """Lazy-loaded Llama chat handle, kept alive for the worker's lifetime. 

45 

46 Reloads in place when the parent passes a per-call ``model`` override 

47 different from the currently loaded one. 

48 """ 

49 

50 def __init__(self, role_config: RoleConfig, abort_flag: Any) -> None: 

51 self._role_config = role_config 

52 self._abort_flag = abort_flag 

53 self._llm: Any = None 

54 self._model_path: str = "" 

55 

56 def chat( 

57 self, 

58 *, 

59 messages: list[dict[str, str]], 

60 stream: bool, 

61 options: dict[str, Any] | None, 

62 model: str | None, 

63 ) -> Any: 

64 """Run one chat completion and return the llama-cpp response.""" 

65 llm = self._ensure_loaded(model) 

66 kwargs: dict[str, Any] = dict(options) if options else {} 

67 return llm.create_chat_completion(messages=messages, stream=stream, **kwargs) 

68 

69 def _ensure_loaded(self, model_override: str | None) -> Any: 

70 from lilbee.providers.llama_cpp.provider import load_llama, resolve_model_path 

71 from lilbee.providers.model_cache import LoaderMode 

72 

73 target_path = ( 

74 resolve_model_path(model_override) if model_override else self._role_config.model_path 

75 ) 

76 target_str = str(target_path) 

77 if self._llm is None or target_str != self._model_path: 

78 self._close_model() 

79 # No abort_callback_override: routing the cancel signal through 

80 # ggml's mid-token abort path crashes the worker on macOS Metal. 

81 # Cancel is enforced one token boundary later by the Python-side 

82 # polling loop in _handle_chat_streaming. 

83 self._llm = load_llama(target_path, mode=LoaderMode.CHAT) 

84 self._model_path = target_str 

85 return self._llm 

86 

87 def _close_model(self) -> None: 

88 if self._llm is not None: 

89 with contextlib.suppress(Exception): 

90 self._llm.close() 

91 self._llm = None 

92 

93 def close(self) -> None: 

94 """Release the loaded model. Idempotent.""" 

95 self._close_model() 

96 

97 

98def _extract_stream_content(chunk: Any) -> str | None: 

99 """Pull the text content out of one llama-cpp streaming chunk.""" 

100 choices = chunk.get("choices") if isinstance(chunk, dict) else None 

101 if not choices: 

102 return None 

103 delta = choices[0].get("delta") if isinstance(choices[0], dict) else None 

104 if not isinstance(delta, dict): 

105 return None 

106 content = delta.get("content") 

107 return content if isinstance(content, str) and content else None 

108 

109 

110def _handle_chat_streaming(reply: Reply, response_iter: Any, state: WorkerLoopState) -> None: 

111 """Drain *response_iter* and emit batched stream_chunk frames on the data pipe. 

112 

113 Polls ``state.session._abort_flag`` between chunks so a cancel from the 

114 parent flushes a clean ``stream_end`` at the next token boundary. 

115 Tokens are accumulated and flushed every ``_STREAM_BATCH_MAX_CHUNKS`` 

116 or ``_STREAM_BATCH_MAX_INTERVAL_S``, whichever comes first, so the 

117 pipe sees ~one syscall per batch instead of one per token. 

118 

119 Cancel path: ``break`` exits the for loop normally, control falls 

120 through to ``completed_cleanly = True``, the finally clause flushes 

121 any buffered tail, and ``stream_end`` fires. The parent's 

122 ``stream()`` reader then returns cleanly without a hang. 

123 """ 

124 abort_flag = state.session._abort_flag 

125 buffer: list[str] = [] 

126 last_flush = time.monotonic() 

127 seen_first_token = False 

128 completed_cleanly = False 

129 try: 

130 for raw_chunk in response_iter: 

131 if abort_flag.value: 

132 with contextlib.suppress(Exception): 

133 response_iter.close() 

134 break 

135 content = _extract_stream_content(raw_chunk) 

136 if content is None: 

137 continue 

138 buffer.append(content) 

139 now = time.monotonic() 

140 # Flush the very first token immediately so a generator that 

141 # stalls after one token still surfaces something to the user. 

142 should_flush = ( 

143 not seen_first_token 

144 or len(buffer) >= _STREAM_BATCH_MAX_CHUNKS 

145 or (now - last_flush) >= _STREAM_BATCH_MAX_INTERVAL_S 

146 ) 

147 if should_flush: 

148 reply.send(WireKind.STREAM_CHUNK, "".join(buffer)) 

149 buffer.clear() 

150 last_flush = now 

151 seen_first_token = True 

152 completed_cleanly = True 

153 finally: 

154 # Flush any buffered tokens regardless of how the loop exited so 

155 # the user sees partial output before the error frame the outer 

156 # handler may emit. 

157 if buffer: 

158 reply.send(WireKind.STREAM_CHUNK, "".join(buffer)) 

159 if completed_cleanly: 

160 reply.send(WireKind.STREAM_END, None) 

161 

162 

163def _extract_non_streaming_content(response: Any) -> str: 

164 """Pull the assistant text out of one llama-cpp non-streaming response.""" 

165 if not isinstance(response, dict): 

166 raise TypeError(f"chat response must be dict, got {type(response).__name__}") 

167 choices = response.get("choices") 

168 if not isinstance(choices, list) or not choices: 

169 raise TypeError("chat response missing 'choices' list") 

170 first = choices[0] 

171 if not isinstance(first, dict): 

172 raise TypeError(f"chat choices[0] must be dict, got {type(first).__name__}") 

173 message = first.get("message") 

174 if not isinstance(message, dict): 

175 raise TypeError("chat choices[0].message missing or not dict") 

176 content = message.get("content") 

177 return content if isinstance(content, str) else "" 

178 

179 

180def _handle_chat_non_streaming(reply: Reply, response: Any) -> None: 

181 """Emit one result frame with the full assistant message text.""" 

182 text = _extract_non_streaming_content(response) 

183 reply.send(WireKind.RESULT, text) 

184 

185 

186class _AbortBridge: 

187 """Mirror the parent's mp.Value abort flag into ggml's threading.Event. 

188 

189 Without this, cancel only takes effect at Python-loop boundaries 

190 between yielded tokens. A token that takes 30+ seconds inside the 

191 ggml decode (full context, slow GPU, big buffer) keeps generating 

192 because the Python loop never gets a chance to read the flag. 

193 Calling ``request_abort()`` flips the threading.Event that the 

194 loaded llama's ``abort_callback`` polls inside ggml, so cancel 

195 takes effect at the next ggml poll point (every few tokens) instead 

196 of waiting for the next Python yield. 

197 """ 

198 

199 def __init__(self, abort_flag: Any) -> None: 

200 self._abort_flag = abort_flag 

201 self._stop = threading.Event() 

202 self._thread: threading.Thread | None = None 

203 

204 def __enter__(self) -> _AbortBridge: 

205 from lilbee.providers.llama_cpp.abort_signal import clear_abort 

206 

207 # Reset both flags before the chat starts: a stale parent-side 

208 # cancel from a prior call must not abort the new request. 

209 clear_abort() 

210 self._abort_flag.value = 0 

211 self._stop.clear() 

212 self._thread = threading.Thread(target=self._poll, name="chat-abort-bridge", daemon=True) 

213 self._thread.start() 

214 return self 

215 

216 def __exit__(self, *_exc_info: Any) -> None: 

217 from lilbee.providers.llama_cpp.abort_signal import clear_abort 

218 

219 self._stop.set() 

220 thread = self._thread 

221 if thread is not None: 

222 thread.join(timeout=1.0) 

223 # Reset for the next request so a cancelled prior call doesn't 

224 # latch onto the next inference. 

225 clear_abort() 

226 self._abort_flag.value = 0 

227 

228 def _poll(self) -> None: 

229 from lilbee.providers.llama_cpp.abort_signal import request_abort 

230 

231 while not self._stop.wait(_ABORT_BRIDGE_POLL_S): 

232 if self._abort_flag.value: 

233 request_abort() 

234 return 

235 

236 

237def _handle_chat(reply: Reply, payload: Any, state: WorkerLoopState) -> None: 

238 """Run one chat request and dispatch to the streaming/non-streaming handler.""" 

239 if not isinstance(payload, ChatRequest): 

240 try: 

241 raise TypeError(f"chat payload must be ChatRequest, got {type(payload).__name__}") 

242 except TypeError as exc: 

243 reply.send(WireKind.ERROR, _serialize_exception(exc)) 

244 return 

245 session: _ChatSession = state.session 

246 with _AbortBridge(session._abort_flag): 

247 try: 

248 response = session.chat( 

249 messages=payload.messages, 

250 stream=payload.stream, 

251 options=payload.options, 

252 model=payload.model, 

253 ) 

254 except Exception as exc: 

255 reply.send(WireKind.ERROR, _serialize_exception(exc)) 

256 return 

257 try: 

258 if payload.stream: 

259 _handle_chat_streaming(reply, response, state) 

260 else: 

261 _handle_chat_non_streaming(reply, response) 

262 except Exception as exc: 

263 reply.send(WireKind.ERROR, _serialize_exception(exc)) 

264 

265 

266def chat_worker_main( 

267 data_conn: Any, health_conn: Any, abort_flag: Any, role_config: RoleConfig 

268) -> None: 

269 """Chat worker entrypoint: load llama-cpp lazily, serve until shutdown.""" 

270 run_worker( 

271 data_conn, 

272 health_conn, 

273 abort_flag, 

274 role_config, 

275 session_factory=_ChatSession, 

276 kind_handlers={WireKind.CHAT: _handle_chat}, 

277 ) 

278 

279 

280__all__ = ["chat_worker_main"]