Coverage for src / lilbee / providers / worker / transport_pipe.py: 100%

239 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-06-28 01:01 +0000

1"""``multiprocessing.Pipe``-backed worker channel and spawner. 

2 

3Concrete impl of the ``WorkerChannel`` / ``WorkerSpawner`` Protocols 

4from :mod:`lilbee.providers.worker.transport`. Pipe-specific discipline 

5rules are documented in ``docs/architecture.md``. 

6""" 

7 

8from __future__ import annotations 

9 

10import asyncio 

11import contextlib 

12import logging 

13import multiprocessing 

14import pickle 

15import threading 

16import traceback 

17from collections.abc import AsyncIterator 

18from concurrent.futures import ThreadPoolExecutor 

19from dataclasses import dataclass 

20from typing import Any 

21 

22from lilbee.providers.worker.transport import ( 

23 RoleConfig, 

24 WorkerChannel, 

25 WorkerEntrypoint, 

26 WorkerHandle, 

27 WorkerRole, 

28) 

29from lilbee.providers.worker.wire_kinds import WireKind 

30 

31log = logging.getLogger(__name__) 

32 

33 

34_PICKLE_MAX_BYTES = 32 * 1024 * 1024 

35"""``Connection.send`` raises past about 32 MiB on POSIX.""" 

36 

37 

38_STREAM_CHUNK_TIMEOUT_S = 300.0 

39"""Max wait for the next frame of an in-flight stream before declaring a stall. 

40 

41A stalled worker keeps pinging alive on the health pipe, so without this the 

42consumer would hang forever on a silent data pipe. The ceiling sits above the 

43slowest legitimate single-chunk gap: PDF-OCR streams one page per chunk and a 

44page can take up to ``LILBEE_OCR_TIMEOUT`` (120s default) on a cold model, so 

45300s leaves headroom while still releasing the caller on a genuine hang. 

46""" 

47 

48 

49@dataclass(frozen=True) 

50class _SerializedException: 

51 """Pickle-friendly stand-in for an exception that crossed the wire.""" 

52 

53 type_name: str 

54 message: str 

55 traceback_str: str 

56 

57 

58class WorkerError(RuntimeError): 

59 """Raised on the parent side when a worker reports an exception.""" 

60 

61 def __init__(self, original_type: str, message: str, traceback_str: str) -> None: 

62 super().__init__(f"{original_type}: {message}") 

63 self.original_type = original_type 

64 self.traceback_str = traceback_str 

65 

66 

67class WorkerCrashError(WorkerError): 

68 """Raised when a worker process dies mid-request (EOF on the pipe). 

69 

70 The error message embeds both ``log_path`` and the tail of that log so 

71 a native-side crash (no Python exception to serialize) still surfaces 

72 a diagnostic trail. 

73 """ 

74 

75 def __init__(self, role: WorkerRole, *, log_path: str | None = None) -> None: 

76 tail = _read_log_tail(log_path) if log_path else "" 

77 suffix_parts: list[str] = [] 

78 if log_path: 

79 suffix_parts.append(f" See {log_path} for details.") 

80 if tail: 

81 suffix_parts.append(f"\nLast log lines:\n{tail}") 

82 super().__init__( 

83 "WorkerCrashError", 

84 f"Worker '{role}' subprocess exited unexpectedly.{''.join(suffix_parts)}", 

85 "", 

86 ) 

87 self.role = role 

88 self.log_path = log_path 

89 self.log_tail = tail 

90 

91 

92_LOG_TAIL_BYTES = 4096 

93"""Read at most this many bytes from the end of a worker log on crash. 

94 

95A llama.cpp init that aborts emits a few dozen lines tops; 4 KiB leaves 

96enough room for a stack-like trace without pulling a multi-megabyte log 

97file into a single error message. 

98""" 

99 

100 

101def _read_log_tail(log_path: str) -> str: 

102 """Return the last ``_LOG_TAIL_BYTES`` of *log_path*, or ``""`` on any error. 

103 

104 Best-effort: a missing or unreadable log file must not turn a worker 

105 crash into a different unrelated exception. Decoded as utf-8 with 

106 ``errors="replace"`` so the Tesseract diacritic bytes that leak into 

107 fd 2 on Windows do not crash the error path itself. The caller has 

108 already verified ``log_path`` is non-empty. 

109 """ 

110 try: 

111 import os as _os 

112 

113 size = _os.path.getsize(log_path) 

114 offset = max(0, size - _LOG_TAIL_BYTES) 

115 with open(log_path, "rb") as handle: 

116 handle.seek(offset) 

117 data = handle.read() 

118 except OSError: 

119 return "" 

120 return data.decode("utf-8", errors="replace") 

121 

122 

123def _serialize_exception(exc: BaseException) -> _SerializedException: 

124 """Reduce an exception to a pickle-safe ``(type_name, message, traceback)`` triple.""" 

125 tb_str = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)) 

126 return _SerializedException( 

127 type_name=type(exc).__name__, 

128 message=str(exc), 

129 traceback_str=tb_str, 

130 ) 

131 

132 

133def _deserialize_exception(payload: _SerializedException) -> WorkerError: 

134 """Rebuild a parent-side exception from a serialized worker exception.""" 

135 return WorkerError(payload.type_name, payload.message, payload.traceback_str) 

136 

137 

138def _check_pickle_size(payload: Any, kind: WireKind) -> None: 

139 """Raise ``WorkerError`` early if *payload* would exceed the pipe send cap.""" 

140 try: 

141 size = len(pickle.dumps((kind, payload))) 

142 except Exception as exc: 

143 raise WorkerError("PickleError", f"Failed to pickle {kind!r} payload: {exc}", "") from exc 

144 if size > _PICKLE_MAX_BYTES: 

145 raise WorkerError( 

146 "PayloadTooLarge", 

147 f"{kind!r} payload is {size} bytes; pipe send cap is {_PICKLE_MAX_BYTES}.", 

148 "", 

149 ) 

150 

151 

152def _worker_log_path(role: WorkerRole) -> str | None: 

153 """Return the worker's log file path, or ``None`` if no data root is set. 

154 

155 Mirrors :func:`lilbee.providers.worker.worker_runtime.configure_worker_logging` 

156 so the parent's :class:`WorkerCrashError` points at the file the worker 

157 wrote. ``LILBEE_DATA`` is canonicalized at cfg construction, so this 

158 only returns ``None`` in tests that explicitly clear the env. 

159 """ 

160 import os 

161 

162 # circular: worker_runtime imports transport_pipe._serialize_exception at 

163 # module top, so the constant import lives inline at the one call site. 

164 from lilbee.providers.worker.worker_runtime import WORKER_LOGS_DIR_NAME 

165 

166 data_dir = os.environ.get("LILBEE_DATA") 

167 if not data_dir: 

168 return None 

169 return os.path.join(data_dir, WORKER_LOGS_DIR_NAME, f"worker-{role}.log") 

170 

171 

172class PipeChannel: 

173 """One worker process talked to via a duplex :class:`multiprocessing.Pipe`. 

174 

175 Owns the parent end of two pipes (data and health), the abort flag, 

176 and a per-channel :class:`ThreadPoolExecutor`. The data pipe carries 

177 one call at a time: ``call`` and ``stream`` acquire ``_call_lock`` for 

178 their full request/reply or request/stream lifetime, so a reply (or 

179 stream chunk) can only ever belong to the call currently holding the 

180 lock. The health pipe carries ping/pong and shutdown/ack and is 

181 served by a dedicated daemon thread on the worker side, so a long 

182 inference never starves liveness or shutdown. 

183 """ 

184 

185 def __init__( 

186 self, 

187 *, 

188 role: WorkerRole, 

189 process: multiprocessing.process.BaseProcess, 

190 parent_conn: Any, 

191 health_conn: Any, 

192 abort_flag: Any, 

193 ) -> None: 

194 self._role = role 

195 self._process = process 

196 self._conn = parent_conn 

197 self._health_conn = health_conn 

198 self._abort = abort_flag 

199 self._executor = ThreadPoolExecutor( 

200 max_workers=2, 

201 thread_name_prefix=f"pipechan-{role}", 

202 ) 

203 self._call_lock = asyncio.Lock() 

204 self._health_lock = asyncio.Lock() 

205 self._in_flight = 0 

206 self._in_flight_lock = threading.Lock() 

207 self._closed = False 

208 self._closed_lock = threading.Lock() 

209 

210 @property 

211 def role(self) -> WorkerRole: 

212 """Worker role this channel addresses.""" 

213 return self._role 

214 

215 @property 

216 def is_alive(self) -> bool: 

217 """Return True iff the underlying process is still running.""" 

218 return self._process.is_alive() 

219 

220 @property 

221 def pid(self) -> int | None: 

222 """Worker process id (``None`` until ``start`` returns).""" 

223 return self._process.pid 

224 

225 @property 

226 def in_flight(self) -> int: 

227 """Number of requests sent but not yet fully replied to.""" 

228 with self._in_flight_lock: 

229 return self._in_flight 

230 

231 def _bump_in_flight(self, delta: int) -> None: 

232 with self._in_flight_lock: 

233 self._in_flight += delta 

234 

235 def _ensure_open(self) -> None: 

236 with self._closed_lock: 

237 if self._closed: 

238 raise WorkerError( 

239 "PoolShutdownError", 

240 f"Channel for worker '{self._role}' is closed.", 

241 "", 

242 ) 

243 

244 def _crash(self) -> WorkerCrashError: 

245 return WorkerCrashError(self._role, log_path=_worker_log_path(self._role)) 

246 

247 async def _send_data(self, kind: WireKind, payload: Any) -> None: 

248 """Pickle-pre-check + thread-bounded ``conn.send`` on the data pipe.""" 

249 _check_pickle_size(payload, kind) 

250 loop = asyncio.get_running_loop() 

251 try: 

252 await loop.run_in_executor(self._executor, self._conn.send, (kind, payload)) 

253 except (BrokenPipeError, ConnectionResetError, EOFError, OSError) as exc: 

254 raise self._crash() from exc 

255 

256 async def _recv_data(self) -> tuple[WireKind, Any]: 

257 """Block in a worker thread until the next ``(kind, payload)`` frame arrives.""" 

258 loop = asyncio.get_running_loop() 

259 try: 

260 frame = await loop.run_in_executor(self._executor, self._conn.recv) 

261 except (EOFError, OSError, ConnectionResetError, BrokenPipeError) as exc: 

262 raise self._crash() from exc 

263 return frame # type: ignore[no-any-return] 

264 

265 async def call(self, kind: WireKind, payload: Any, *, timeout: float) -> Any: 

266 """Send one request, await one reply on the data pipe. 

267 

268 The ``_call_lock`` is held for the full request/reply window so a 

269 reply that lands on the pipe can only belong to the call holding 

270 the lock. New callers queue on the lock behind the in-flight one. 

271 """ 

272 self._ensure_open() 

273 async with self._call_lock: 

274 self._bump_in_flight(1) 

275 try: 

276 await self._send_data(kind, payload) 

277 msg_kind, value = await asyncio.wait_for(self._recv_data(), timeout=timeout) 

278 if msg_kind == WireKind.ERROR: 

279 raise _deserialize_exception(value) 

280 if msg_kind != WireKind.RESULT: 

281 raise WorkerError( 

282 "ProtocolError", 

283 f"Worker '{self._role}' replied with unexpected kind {msg_kind!r}.", 

284 "", 

285 ) 

286 return value 

287 finally: 

288 self._bump_in_flight(-1) 

289 

290 async def _recv_chunk(self, timeout: float) -> tuple[WireKind, Any]: 

291 """Recv the next stream frame within *timeout*; a stall is a crash. 

292 

293 A worker that stops emitting frames keeps the health pipe alive, so 

294 a bare ``recv`` would hang the consumer forever. Bounding each frame 

295 turns that silent stall into a :class:`WorkerCrashError` the pool can 

296 recycle from. 

297 """ 

298 try: 

299 return await asyncio.wait_for(self._recv_data(), timeout=timeout) 

300 except TimeoutError as exc: 

301 raise self._crash() from exc 

302 

303 async def stream( 

304 self, 

305 kind: WireKind, 

306 payload: Any, 

307 *, 

308 stream_chunk_timeout: float = _STREAM_CHUNK_TIMEOUT_S, 

309 ) -> AsyncIterator[Any]: 

310 """Send one request, yield streamed chunks until the terminator arrives. 

311 

312 The ``_call_lock`` is held for the full stream lifetime, so frames 

313 recv'd by this coroutine belong to this stream by construction. 

314 New callers queue behind the active stream. Each frame is bounded by 

315 ``stream_chunk_timeout`` so a mid-stream stall releases the caller. 

316 """ 

317 self._ensure_open() 

318 async with self._call_lock: 

319 self._bump_in_flight(1) 

320 try: 

321 await self._send_data(kind, payload) 

322 while True: 

323 msg_kind, value = await self._recv_chunk(stream_chunk_timeout) 

324 if msg_kind == WireKind.STREAM_CHUNK: 

325 yield value 

326 elif msg_kind == WireKind.STREAM_END: 

327 return 

328 elif msg_kind == WireKind.ERROR: 

329 raise _deserialize_exception(value) 

330 else: 

331 raise WorkerError( 

332 "ProtocolError", 

333 f"Worker '{self._role}' streamed unexpected kind {msg_kind!r}.", 

334 "", 

335 ) 

336 finally: 

337 self._bump_in_flight(-1) 

338 

339 async def ping(self, *, timeout: float) -> None: 

340 """Round-trip a ping over the health pipe; raise on timeout / crash. 

341 

342 The worker dedicates a daemon thread to the health pipe so pings 

343 and shutdown can be served while the data loop is busy in 

344 ``session.embed`` / ``create_chat_completion``. 

345 """ 

346 self._ensure_open() 

347 kind = await self._health_round_trip(WireKind.PING, None, timeout=timeout) 

348 if kind != WireKind.PONG: 

349 raise WorkerError( 

350 "ProtocolError", 

351 f"Worker '{self._role}' ping reply was {kind!r}, want 'pong'.", 

352 "", 

353 ) 

354 

355 async def _health_round_trip( 

356 self, send_kind: WireKind, send_payload: Any, *, timeout: float 

357 ) -> WireKind: 

358 """Send one frame on the health pipe and await its reply within *timeout*.""" 

359 loop = asyncio.get_running_loop() 

360 async with self._health_lock: 

361 try: 

362 await loop.run_in_executor( 

363 self._executor, self._health_conn.send, (send_kind, send_payload) 

364 ) 

365 except (BrokenPipeError, ConnectionResetError, EOFError, OSError) as exc: 

366 raise self._crash() from exc 

367 try: 

368 frame = await asyncio.wait_for( 

369 loop.run_in_executor(self._executor, self._health_conn.recv), 

370 timeout=timeout, 

371 ) 

372 except (EOFError, OSError, ConnectionResetError, BrokenPipeError) as exc: 

373 raise self._crash() from exc 

374 reply_kind, _ = frame 

375 return reply_kind # type: ignore[no-any-return] 

376 

377 def cancel(self) -> None: 

378 """Flip the abort flag to 1; in-flight tokens may still drain.""" 

379 self._abort.value = 1 

380 

381 def clear_abort(self) -> None: 

382 """Reset the abort flag to 0 before the next request.""" 

383 self._abort.value = 0 

384 

385 async def close(self, *, timeout: float) -> None: 

386 """Send shutdown on the health pipe, await ack, then join the process. 

387 

388 The data pipe is never used for shutdown: a long in-flight call 

389 on the data pipe would otherwise serialize behind a shutdown 

390 request. Health-pipe shutdown is served by the worker's 

391 dedicated heartbeat thread, so this returns within the timeout 

392 regardless of what the data loop is doing. Any in-flight data 

393 call sees the process exit and surfaces as :class:`WorkerCrashError`. 

394 """ 

395 with self._closed_lock: 

396 if self._closed: 

397 return 

398 self._closed = True 

399 try: 

400 with contextlib.suppress(TimeoutError, WorkerError): 

401 await self._health_round_trip(WireKind.SHUTDOWN, None, timeout=timeout) 

402 # Join on the loop's default executor, never the channel's own: 

403 # a stalled stream / timed-out health recv leaves the channel 

404 # executor's threads blocked in conn.recv until the process dies, 

405 # and terminate() is what unblocks them. Scheduling the join there 

406 # too would deadlock when both worker threads are saturated. 

407 await asyncio.get_running_loop().run_in_executor(None, self._join_process, timeout) 

408 finally: 

409 with contextlib.suppress(Exception): 

410 self._conn.close() 

411 with contextlib.suppress(Exception): 

412 self._health_conn.close() 

413 self._executor.shutdown(wait=False, cancel_futures=True) 

414 

415 def _join_process(self, timeout: float) -> None: 

416 """Wait *timeout* seconds for the process; terminate if still alive. 

417 

418 On non-clean exit (signal, non-zero code) record the exit reason in 

419 the worker log so the user has something to attach to a bug report. 

420 """ 

421 self._process.join(timeout=timeout) 

422 if self._process.is_alive(): 

423 log.warning("Worker '%s' did not exit gracefully; terminating", self._role) 

424 self._process.terminate() 

425 self._process.join(timeout=2.0) 

426 self._record_exit_reason() 

427 

428 def _record_exit_reason(self) -> None: 

429 """Append worker exit reason (signal or non-zero code) to the worker log.""" 

430 code = self._process.exitcode 

431 if code is None or code == 0: 

432 return 

433 log_path = _worker_log_path(self._role) 

434 message = self._format_exit_reason(code) 

435 log.warning("Worker '%s' %s", self._role, message) 

436 if log_path is None: 

437 return 

438 with contextlib.suppress(OSError), open(log_path, "a") as handle: 

439 handle.write(f"\n[supervisor] {message}\n") 

440 

441 @staticmethod 

442 def _format_exit_reason(exit_code: int) -> str: 

443 if exit_code >= 0: 

444 return f"exited with code {exit_code}" 

445 import signal 

446 

447 signum = -exit_code 

448 try: 

449 name = signal.Signals(signum).name 

450 except ValueError: 

451 name = f"SIG{signum}" 

452 return f"killed by signal {name} ({signum})" 

453 

454 

455class PipeSpawner: 

456 """Spawns worker subprocesses connected to the parent via :class:`multiprocessing.Pipe`.""" 

457 

458 def __init__(self, *, daemon: bool = True) -> None: 

459 self._ctx = multiprocessing.get_context("spawn") 

460 self._daemon = daemon 

461 

462 def spawn( 

463 self, 

464 worker_main: WorkerEntrypoint, 

465 role_config: RoleConfig, 

466 ) -> tuple[WorkerChannel, WorkerHandle]: 

467 """Start a worker subprocess and return its channel + handle. 

468 

469 Two pipes per worker: ``data_pipe`` carries call/stream traffic, 

470 ``health_pipe`` carries ping/pong and shutdown/ack. The worker 

471 dedicates a daemon thread to the health pipe so heartbeats and 

472 shutdown stay live during long inference. 

473 """ 

474 parent_data, child_data = self._ctx.Pipe(duplex=True) 

475 parent_health, child_health = self._ctx.Pipe(duplex=True) 

476 abort_flag = self._ctx.Value("b", 0, lock=True) 

477 process = self._ctx.Process( 

478 target=worker_main, 

479 args=(child_data, child_health, abort_flag, role_config), 

480 daemon=self._daemon, 

481 name=f"lilbee-worker-{role_config.role}", 

482 ) 

483 process.start() 

484 child_data.close() 

485 child_health.close() 

486 channel = PipeChannel( 

487 role=role_config.role, 

488 process=process, 

489 parent_conn=parent_data, 

490 health_conn=parent_health, 

491 abort_flag=abort_flag, 

492 ) 

493 handle = WorkerHandle(pid=process.pid, role=role_config.role) 

494 log.info("Spawned worker role=%s pid=%s", role_config.role, process.pid) 

495 return channel, handle 

496 

497 

498__all__ = [ 

499 "PipeChannel", 

500 "PipeSpawner", 

501 "WorkerCrashError", 

502 "WorkerError", 

503]