Coverage for src / lilbee / providers / worker / transport_pipe.py: 100%
232 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
1"""``multiprocessing.Pipe``-backed worker channel and spawner.
3Concrete impl of the ``WorkerChannel`` / ``WorkerSpawner`` Protocols
4from :mod:`lilbee.providers.worker.transport`. Pipe-specific discipline
5rules are documented in ``docs/architecture.md``.
6"""
8from __future__ import annotations
10import asyncio
11import contextlib
12import logging
13import multiprocessing
14import pickle
15import threading
16import traceback
17from collections.abc import AsyncIterator
18from concurrent.futures import ThreadPoolExecutor
19from dataclasses import dataclass
20from typing import Any
22from lilbee.providers.worker.transport import (
23 RoleConfig,
24 WorkerChannel,
25 WorkerEntrypoint,
26 WorkerHandle,
27 WorkerRole,
28)
29from lilbee.providers.worker.wire_kinds import WireKind
31log = logging.getLogger(__name__)
34_PICKLE_MAX_BYTES = 32 * 1024 * 1024
35"""``Connection.send`` raises past about 32 MiB on POSIX."""
38@dataclass(frozen=True)
39class _SerializedException:
40 """Pickle-friendly stand-in for an exception that crossed the wire."""
42 type_name: str
43 message: str
44 traceback_str: str
47class WorkerError(RuntimeError):
48 """Raised on the parent side when a worker reports an exception."""
50 def __init__(self, original_type: str, message: str, traceback_str: str) -> None:
51 super().__init__(f"{original_type}: {message}")
52 self.original_type = original_type
53 self.traceback_str = traceback_str
56class WorkerCrashError(WorkerError):
57 """Raised when a worker process dies mid-request (EOF on the pipe).
59 The error message embeds both ``log_path`` and the tail of that log so
60 a native-side crash (no Python exception to serialize) still surfaces
61 a diagnostic trail.
62 """
64 def __init__(self, role: WorkerRole, *, log_path: str | None = None) -> None:
65 tail = _read_log_tail(log_path) if log_path else ""
66 suffix_parts: list[str] = []
67 if log_path:
68 suffix_parts.append(f" See {log_path} for details.")
69 if tail:
70 suffix_parts.append(f"\nLast log lines:\n{tail}")
71 super().__init__(
72 "WorkerCrashError",
73 f"Worker '{role}' subprocess exited unexpectedly.{''.join(suffix_parts)}",
74 "",
75 )
76 self.role = role
77 self.log_path = log_path
78 self.log_tail = tail
81_LOG_TAIL_BYTES = 4096
82"""Read at most this many bytes from the end of a worker log on crash.
84A llama.cpp init that aborts emits a few dozen lines tops; 4 KiB leaves
85enough room for a stack-like trace without pulling a multi-megabyte log
86file into a single error message.
87"""
90def _read_log_tail(log_path: str) -> str:
91 """Return the last ``_LOG_TAIL_BYTES`` of *log_path*, or ``""`` on any error.
93 Best-effort: a missing or unreadable log file must not turn a worker
94 crash into a different unrelated exception. Decoded as utf-8 with
95 ``errors="replace"`` so the Tesseract diacritic bytes that leak into
96 fd 2 on Windows do not crash the error path itself. The caller has
97 already verified ``log_path`` is non-empty.
98 """
99 try:
100 import os as _os
102 size = _os.path.getsize(log_path)
103 offset = max(0, size - _LOG_TAIL_BYTES)
104 with open(log_path, "rb") as handle:
105 handle.seek(offset)
106 data = handle.read()
107 except OSError:
108 return ""
109 return data.decode("utf-8", errors="replace")
112def _serialize_exception(exc: BaseException) -> _SerializedException:
113 """Reduce an exception to a pickle-safe ``(type_name, message, traceback)`` triple."""
114 tb_str = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__))
115 return _SerializedException(
116 type_name=type(exc).__name__,
117 message=str(exc),
118 traceback_str=tb_str,
119 )
122def _deserialize_exception(payload: _SerializedException) -> WorkerError:
123 """Rebuild a parent-side exception from a serialized worker exception."""
124 return WorkerError(payload.type_name, payload.message, payload.traceback_str)
127def _check_pickle_size(payload: Any, kind: WireKind) -> None:
128 """Raise ``WorkerError`` early if *payload* would exceed the pipe send cap."""
129 try:
130 size = len(pickle.dumps((kind, payload)))
131 except Exception as exc:
132 raise WorkerError("PickleError", f"Failed to pickle {kind!r} payload: {exc}", "") from exc
133 if size > _PICKLE_MAX_BYTES:
134 raise WorkerError(
135 "PayloadTooLarge",
136 f"{kind!r} payload is {size} bytes; pipe send cap is {_PICKLE_MAX_BYTES}.",
137 "",
138 )
141def _worker_log_path(role: WorkerRole) -> str | None:
142 """Return the worker's log file path, or ``None`` if no data root is set.
144 Mirrors :func:`lilbee.providers.worker.worker_runtime.configure_worker_logging`
145 so the parent's :class:`WorkerCrashError` points at the file the worker
146 wrote. ``LILBEE_DATA`` is canonicalized at cfg construction, so this
147 only returns ``None`` in tests that explicitly clear the env.
148 """
149 import os
151 # circular: worker_runtime imports transport_pipe._serialize_exception at
152 # module top, so the constant import lives inline at the one call site.
153 from lilbee.providers.worker.worker_runtime import WORKER_LOGS_DIR_NAME
155 data_dir = os.environ.get("LILBEE_DATA")
156 if not data_dir:
157 return None
158 return os.path.join(data_dir, WORKER_LOGS_DIR_NAME, f"worker-{role}.log")
161class PipeChannel:
162 """One worker process talked to via a duplex :class:`multiprocessing.Pipe`.
164 Owns the parent end of two pipes (data and health), the abort flag,
165 and a per-channel :class:`ThreadPoolExecutor`. The data pipe carries
166 one call at a time: ``call`` and ``stream`` acquire ``_call_lock`` for
167 their full request/reply or request/stream lifetime, so a reply (or
168 stream chunk) can only ever belong to the call currently holding the
169 lock. The health pipe carries ping/pong and shutdown/ack and is
170 served by a dedicated daemon thread on the worker side, so a long
171 inference never starves liveness or shutdown.
172 """
174 def __init__(
175 self,
176 *,
177 role: WorkerRole,
178 process: multiprocessing.process.BaseProcess,
179 parent_conn: Any,
180 health_conn: Any,
181 abort_flag: Any,
182 ) -> None:
183 self._role = role
184 self._process = process
185 self._conn = parent_conn
186 self._health_conn = health_conn
187 self._abort = abort_flag
188 self._executor = ThreadPoolExecutor(
189 max_workers=2,
190 thread_name_prefix=f"pipechan-{role}",
191 )
192 self._call_lock = asyncio.Lock()
193 self._health_lock = asyncio.Lock()
194 self._in_flight = 0
195 self._in_flight_lock = threading.Lock()
196 self._closed = False
197 self._closed_lock = threading.Lock()
199 @property
200 def role(self) -> WorkerRole:
201 """Worker role this channel addresses."""
202 return self._role
204 @property
205 def is_alive(self) -> bool:
206 """Return True iff the underlying process is still running."""
207 return self._process.is_alive()
209 @property
210 def pid(self) -> int | None:
211 """Worker process id (``None`` until ``start`` returns)."""
212 return self._process.pid
214 @property
215 def in_flight(self) -> int:
216 """Number of requests sent but not yet fully replied to."""
217 with self._in_flight_lock:
218 return self._in_flight
220 def _bump_in_flight(self, delta: int) -> None:
221 with self._in_flight_lock:
222 self._in_flight += delta
224 def _ensure_open(self) -> None:
225 with self._closed_lock:
226 if self._closed:
227 raise WorkerError(
228 "PoolShutdownError",
229 f"Channel for worker '{self._role}' is closed.",
230 "",
231 )
233 def _crash(self) -> WorkerCrashError:
234 return WorkerCrashError(self._role, log_path=_worker_log_path(self._role))
236 async def _send_data(self, kind: WireKind, payload: Any) -> None:
237 """Pickle-pre-check + thread-bounded ``conn.send`` on the data pipe."""
238 _check_pickle_size(payload, kind)
239 loop = asyncio.get_running_loop()
240 try:
241 await loop.run_in_executor(self._executor, self._conn.send, (kind, payload))
242 except (BrokenPipeError, ConnectionResetError, EOFError, OSError) as exc:
243 raise self._crash() from exc
245 async def _recv_data(self) -> tuple[WireKind, Any]:
246 """Block in a worker thread until the next ``(kind, payload)`` frame arrives."""
247 loop = asyncio.get_running_loop()
248 try:
249 frame = await loop.run_in_executor(self._executor, self._conn.recv)
250 except (EOFError, OSError, ConnectionResetError, BrokenPipeError) as exc:
251 raise self._crash() from exc
252 return frame # type: ignore[no-any-return]
254 async def call(self, kind: WireKind, payload: Any, *, timeout: float) -> Any:
255 """Send one request, await one reply on the data pipe.
257 The ``_call_lock`` is held for the full request/reply window so a
258 reply that lands on the pipe can only belong to the call holding
259 the lock. New callers queue on the lock behind the in-flight one.
260 """
261 self._ensure_open()
262 async with self._call_lock:
263 self._bump_in_flight(1)
264 try:
265 await self._send_data(kind, payload)
266 msg_kind, value = await asyncio.wait_for(self._recv_data(), timeout=timeout)
267 if msg_kind == WireKind.ERROR:
268 raise _deserialize_exception(value)
269 if msg_kind != WireKind.RESULT:
270 raise WorkerError(
271 "ProtocolError",
272 f"Worker '{self._role}' replied with unexpected kind {msg_kind!r}.",
273 "",
274 )
275 return value
276 finally:
277 self._bump_in_flight(-1)
279 async def stream(self, kind: WireKind, payload: Any) -> AsyncIterator[Any]:
280 """Send one request, yield streamed chunks until the terminator arrives.
282 The ``_call_lock`` is held for the full stream lifetime, so frames
283 recv'd by this coroutine belong to this stream by construction.
284 New callers queue behind the active stream.
285 """
286 self._ensure_open()
287 async with self._call_lock:
288 self._bump_in_flight(1)
289 try:
290 await self._send_data(kind, payload)
291 while True:
292 msg_kind, value = await self._recv_data()
293 if msg_kind == WireKind.STREAM_CHUNK:
294 yield value
295 elif msg_kind == WireKind.STREAM_END:
296 return
297 elif msg_kind == WireKind.ERROR:
298 raise _deserialize_exception(value)
299 else:
300 raise WorkerError(
301 "ProtocolError",
302 f"Worker '{self._role}' streamed unexpected kind {msg_kind!r}.",
303 "",
304 )
305 finally:
306 self._bump_in_flight(-1)
308 async def ping(self, *, timeout: float) -> None:
309 """Round-trip a ping over the health pipe; raise on timeout / crash.
311 The worker dedicates a daemon thread to the health pipe so pings
312 and shutdown can be served while the data loop is busy in
313 ``session.embed`` / ``create_chat_completion``.
314 """
315 self._ensure_open()
316 kind = await self._health_round_trip(WireKind.PING, None, timeout=timeout)
317 if kind != WireKind.PONG:
318 raise WorkerError(
319 "ProtocolError",
320 f"Worker '{self._role}' ping reply was {kind!r}, want 'pong'.",
321 "",
322 )
324 async def _health_round_trip(
325 self, send_kind: WireKind, send_payload: Any, *, timeout: float
326 ) -> WireKind:
327 """Send one frame on the health pipe and await its reply within *timeout*."""
328 loop = asyncio.get_running_loop()
329 async with self._health_lock:
330 try:
331 await loop.run_in_executor(
332 self._executor, self._health_conn.send, (send_kind, send_payload)
333 )
334 except (BrokenPipeError, ConnectionResetError, EOFError, OSError) as exc:
335 raise self._crash() from exc
336 try:
337 frame = await asyncio.wait_for(
338 loop.run_in_executor(self._executor, self._health_conn.recv),
339 timeout=timeout,
340 )
341 except (EOFError, OSError, ConnectionResetError, BrokenPipeError) as exc:
342 raise self._crash() from exc
343 reply_kind, _ = frame
344 return reply_kind # type: ignore[no-any-return]
346 def cancel(self) -> None:
347 """Flip the abort flag to 1; in-flight tokens may still drain."""
348 self._abort.value = 1
350 def clear_abort(self) -> None:
351 """Reset the abort flag to 0 before the next request."""
352 self._abort.value = 0
354 async def close(self, *, timeout: float) -> None:
355 """Send shutdown on the health pipe, await ack, then join the process.
357 The data pipe is never used for shutdown: a long in-flight call
358 on the data pipe would otherwise serialize behind a shutdown
359 request. Health-pipe shutdown is served by the worker's
360 dedicated heartbeat thread, so this returns within the timeout
361 regardless of what the data loop is doing. Any in-flight data
362 call sees the process exit and surfaces as :class:`WorkerCrashError`.
363 """
364 with self._closed_lock:
365 if self._closed:
366 return
367 self._closed = True
368 try:
369 with contextlib.suppress(TimeoutError, WorkerError):
370 await self._health_round_trip(WireKind.SHUTDOWN, None, timeout=timeout)
371 await asyncio.get_running_loop().run_in_executor(
372 self._executor, self._join_process, timeout
373 )
374 finally:
375 with contextlib.suppress(Exception):
376 self._conn.close()
377 with contextlib.suppress(Exception):
378 self._health_conn.close()
379 self._executor.shutdown(wait=False, cancel_futures=True)
381 def _join_process(self, timeout: float) -> None:
382 """Wait *timeout* seconds for the process; terminate if still alive.
384 On non-clean exit (signal, non-zero code) record the exit reason in
385 the worker log so the user has something to attach to a bug report.
386 """
387 self._process.join(timeout=timeout)
388 if self._process.is_alive():
389 log.warning("Worker '%s' did not exit gracefully; terminating", self._role)
390 self._process.terminate()
391 self._process.join(timeout=2.0)
392 self._record_exit_reason()
394 def _record_exit_reason(self) -> None:
395 """Append worker exit reason (signal or non-zero code) to the worker log."""
396 code = self._process.exitcode
397 if code is None or code == 0:
398 return
399 log_path = _worker_log_path(self._role)
400 message = self._format_exit_reason(code)
401 log.warning("Worker '%s' %s", self._role, message)
402 if log_path is None:
403 return
404 with contextlib.suppress(OSError), open(log_path, "a") as handle:
405 handle.write(f"\n[supervisor] {message}\n")
407 @staticmethod
408 def _format_exit_reason(exit_code: int) -> str:
409 if exit_code >= 0:
410 return f"exited with code {exit_code}"
411 import signal
413 signum = -exit_code
414 try:
415 name = signal.Signals(signum).name
416 except ValueError:
417 name = f"SIG{signum}"
418 return f"killed by signal {name} ({signum})"
421class PipeSpawner:
422 """Spawns worker subprocesses connected to the parent via :class:`multiprocessing.Pipe`."""
424 def __init__(self, *, daemon: bool = True) -> None:
425 self._ctx = multiprocessing.get_context("spawn")
426 self._daemon = daemon
428 def spawn(
429 self,
430 worker_main: WorkerEntrypoint,
431 role_config: RoleConfig,
432 ) -> tuple[WorkerChannel, WorkerHandle]:
433 """Start a worker subprocess and return its channel + handle.
435 Two pipes per worker: ``data_pipe`` carries call/stream traffic,
436 ``health_pipe`` carries ping/pong and shutdown/ack. The worker
437 dedicates a daemon thread to the health pipe so heartbeats and
438 shutdown stay live during long inference.
439 """
440 parent_data, child_data = self._ctx.Pipe(duplex=True)
441 parent_health, child_health = self._ctx.Pipe(duplex=True)
442 abort_flag = self._ctx.Value("b", 0, lock=True)
443 process = self._ctx.Process(
444 target=worker_main,
445 args=(child_data, child_health, abort_flag, role_config),
446 daemon=self._daemon,
447 name=f"lilbee-worker-{role_config.role}",
448 )
449 process.start()
450 child_data.close()
451 child_health.close()
452 channel = PipeChannel(
453 role=role_config.role,
454 process=process,
455 parent_conn=parent_data,
456 health_conn=parent_health,
457 abort_flag=abort_flag,
458 )
459 handle = WorkerHandle(pid=process.pid, role=role_config.role)
460 log.info("Spawned worker role=%s pid=%s", role_config.role, process.pid)
461 return channel, handle
464__all__ = [
465 "PipeChannel",
466 "PipeSpawner",
467 "WorkerCrashError",
468 "WorkerError",
469]