Coverage for src / lilbee / providers / worker / transport_pipe.py: 100%

232 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-15 20:55 +0000

1"""``multiprocessing.Pipe``-backed worker channel and spawner. 

2 

3Concrete impl of the ``WorkerChannel`` / ``WorkerSpawner`` Protocols 

4from :mod:`lilbee.providers.worker.transport`. Pipe-specific discipline 

5rules are documented in ``docs/architecture.md``. 

6""" 

7 

8from __future__ import annotations 

9 

10import asyncio 

11import contextlib 

12import logging 

13import multiprocessing 

14import pickle 

15import threading 

16import traceback 

17from collections.abc import AsyncIterator 

18from concurrent.futures import ThreadPoolExecutor 

19from dataclasses import dataclass 

20from typing import Any 

21 

22from lilbee.providers.worker.transport import ( 

23 RoleConfig, 

24 WorkerChannel, 

25 WorkerEntrypoint, 

26 WorkerHandle, 

27 WorkerRole, 

28) 

29from lilbee.providers.worker.wire_kinds import WireKind 

30 

31log = logging.getLogger(__name__) 

32 

33 

34_PICKLE_MAX_BYTES = 32 * 1024 * 1024 

35"""``Connection.send`` raises past about 32 MiB on POSIX.""" 

36 

37 

38@dataclass(frozen=True) 

39class _SerializedException: 

40 """Pickle-friendly stand-in for an exception that crossed the wire.""" 

41 

42 type_name: str 

43 message: str 

44 traceback_str: str 

45 

46 

47class WorkerError(RuntimeError): 

48 """Raised on the parent side when a worker reports an exception.""" 

49 

50 def __init__(self, original_type: str, message: str, traceback_str: str) -> None: 

51 super().__init__(f"{original_type}: {message}") 

52 self.original_type = original_type 

53 self.traceback_str = traceback_str 

54 

55 

56class WorkerCrashError(WorkerError): 

57 """Raised when a worker process dies mid-request (EOF on the pipe). 

58 

59 The error message embeds both ``log_path`` and the tail of that log so 

60 a native-side crash (no Python exception to serialize) still surfaces 

61 a diagnostic trail. 

62 """ 

63 

64 def __init__(self, role: WorkerRole, *, log_path: str | None = None) -> None: 

65 tail = _read_log_tail(log_path) if log_path else "" 

66 suffix_parts: list[str] = [] 

67 if log_path: 

68 suffix_parts.append(f" See {log_path} for details.") 

69 if tail: 

70 suffix_parts.append(f"\nLast log lines:\n{tail}") 

71 super().__init__( 

72 "WorkerCrashError", 

73 f"Worker '{role}' subprocess exited unexpectedly.{''.join(suffix_parts)}", 

74 "", 

75 ) 

76 self.role = role 

77 self.log_path = log_path 

78 self.log_tail = tail 

79 

80 

81_LOG_TAIL_BYTES = 4096 

82"""Read at most this many bytes from the end of a worker log on crash. 

83 

84A llama.cpp init that aborts emits a few dozen lines tops; 4 KiB leaves 

85enough room for a stack-like trace without pulling a multi-megabyte log 

86file into a single error message. 

87""" 

88 

89 

90def _read_log_tail(log_path: str) -> str: 

91 """Return the last ``_LOG_TAIL_BYTES`` of *log_path*, or ``""`` on any error. 

92 

93 Best-effort: a missing or unreadable log file must not turn a worker 

94 crash into a different unrelated exception. Decoded as utf-8 with 

95 ``errors="replace"`` so the Tesseract diacritic bytes that leak into 

96 fd 2 on Windows do not crash the error path itself. The caller has 

97 already verified ``log_path`` is non-empty. 

98 """ 

99 try: 

100 import os as _os 

101 

102 size = _os.path.getsize(log_path) 

103 offset = max(0, size - _LOG_TAIL_BYTES) 

104 with open(log_path, "rb") as handle: 

105 handle.seek(offset) 

106 data = handle.read() 

107 except OSError: 

108 return "" 

109 return data.decode("utf-8", errors="replace") 

110 

111 

112def _serialize_exception(exc: BaseException) -> _SerializedException: 

113 """Reduce an exception to a pickle-safe ``(type_name, message, traceback)`` triple.""" 

114 tb_str = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)) 

115 return _SerializedException( 

116 type_name=type(exc).__name__, 

117 message=str(exc), 

118 traceback_str=tb_str, 

119 ) 

120 

121 

122def _deserialize_exception(payload: _SerializedException) -> WorkerError: 

123 """Rebuild a parent-side exception from a serialized worker exception.""" 

124 return WorkerError(payload.type_name, payload.message, payload.traceback_str) 

125 

126 

127def _check_pickle_size(payload: Any, kind: WireKind) -> None: 

128 """Raise ``WorkerError`` early if *payload* would exceed the pipe send cap.""" 

129 try: 

130 size = len(pickle.dumps((kind, payload))) 

131 except Exception as exc: 

132 raise WorkerError("PickleError", f"Failed to pickle {kind!r} payload: {exc}", "") from exc 

133 if size > _PICKLE_MAX_BYTES: 

134 raise WorkerError( 

135 "PayloadTooLarge", 

136 f"{kind!r} payload is {size} bytes; pipe send cap is {_PICKLE_MAX_BYTES}.", 

137 "", 

138 ) 

139 

140 

141def _worker_log_path(role: WorkerRole) -> str | None: 

142 """Return the worker's log file path, or ``None`` if no data root is set. 

143 

144 Mirrors :func:`lilbee.providers.worker.worker_runtime.configure_worker_logging` 

145 so the parent's :class:`WorkerCrashError` points at the file the worker 

146 wrote. ``LILBEE_DATA`` is canonicalized at cfg construction, so this 

147 only returns ``None`` in tests that explicitly clear the env. 

148 """ 

149 import os 

150 

151 # circular: worker_runtime imports transport_pipe._serialize_exception at 

152 # module top, so the constant import lives inline at the one call site. 

153 from lilbee.providers.worker.worker_runtime import WORKER_LOGS_DIR_NAME 

154 

155 data_dir = os.environ.get("LILBEE_DATA") 

156 if not data_dir: 

157 return None 

158 return os.path.join(data_dir, WORKER_LOGS_DIR_NAME, f"worker-{role}.log") 

159 

160 

161class PipeChannel: 

162 """One worker process talked to via a duplex :class:`multiprocessing.Pipe`. 

163 

164 Owns the parent end of two pipes (data and health), the abort flag, 

165 and a per-channel :class:`ThreadPoolExecutor`. The data pipe carries 

166 one call at a time: ``call`` and ``stream`` acquire ``_call_lock`` for 

167 their full request/reply or request/stream lifetime, so a reply (or 

168 stream chunk) can only ever belong to the call currently holding the 

169 lock. The health pipe carries ping/pong and shutdown/ack and is 

170 served by a dedicated daemon thread on the worker side, so a long 

171 inference never starves liveness or shutdown. 

172 """ 

173 

174 def __init__( 

175 self, 

176 *, 

177 role: WorkerRole, 

178 process: multiprocessing.process.BaseProcess, 

179 parent_conn: Any, 

180 health_conn: Any, 

181 abort_flag: Any, 

182 ) -> None: 

183 self._role = role 

184 self._process = process 

185 self._conn = parent_conn 

186 self._health_conn = health_conn 

187 self._abort = abort_flag 

188 self._executor = ThreadPoolExecutor( 

189 max_workers=2, 

190 thread_name_prefix=f"pipechan-{role}", 

191 ) 

192 self._call_lock = asyncio.Lock() 

193 self._health_lock = asyncio.Lock() 

194 self._in_flight = 0 

195 self._in_flight_lock = threading.Lock() 

196 self._closed = False 

197 self._closed_lock = threading.Lock() 

198 

199 @property 

200 def role(self) -> WorkerRole: 

201 """Worker role this channel addresses.""" 

202 return self._role 

203 

204 @property 

205 def is_alive(self) -> bool: 

206 """Return True iff the underlying process is still running.""" 

207 return self._process.is_alive() 

208 

209 @property 

210 def pid(self) -> int | None: 

211 """Worker process id (``None`` until ``start`` returns).""" 

212 return self._process.pid 

213 

214 @property 

215 def in_flight(self) -> int: 

216 """Number of requests sent but not yet fully replied to.""" 

217 with self._in_flight_lock: 

218 return self._in_flight 

219 

220 def _bump_in_flight(self, delta: int) -> None: 

221 with self._in_flight_lock: 

222 self._in_flight += delta 

223 

224 def _ensure_open(self) -> None: 

225 with self._closed_lock: 

226 if self._closed: 

227 raise WorkerError( 

228 "PoolShutdownError", 

229 f"Channel for worker '{self._role}' is closed.", 

230 "", 

231 ) 

232 

233 def _crash(self) -> WorkerCrashError: 

234 return WorkerCrashError(self._role, log_path=_worker_log_path(self._role)) 

235 

236 async def _send_data(self, kind: WireKind, payload: Any) -> None: 

237 """Pickle-pre-check + thread-bounded ``conn.send`` on the data pipe.""" 

238 _check_pickle_size(payload, kind) 

239 loop = asyncio.get_running_loop() 

240 try: 

241 await loop.run_in_executor(self._executor, self._conn.send, (kind, payload)) 

242 except (BrokenPipeError, ConnectionResetError, EOFError, OSError) as exc: 

243 raise self._crash() from exc 

244 

245 async def _recv_data(self) -> tuple[WireKind, Any]: 

246 """Block in a worker thread until the next ``(kind, payload)`` frame arrives.""" 

247 loop = asyncio.get_running_loop() 

248 try: 

249 frame = await loop.run_in_executor(self._executor, self._conn.recv) 

250 except (EOFError, OSError, ConnectionResetError, BrokenPipeError) as exc: 

251 raise self._crash() from exc 

252 return frame # type: ignore[no-any-return] 

253 

254 async def call(self, kind: WireKind, payload: Any, *, timeout: float) -> Any: 

255 """Send one request, await one reply on the data pipe. 

256 

257 The ``_call_lock`` is held for the full request/reply window so a 

258 reply that lands on the pipe can only belong to the call holding 

259 the lock. New callers queue on the lock behind the in-flight one. 

260 """ 

261 self._ensure_open() 

262 async with self._call_lock: 

263 self._bump_in_flight(1) 

264 try: 

265 await self._send_data(kind, payload) 

266 msg_kind, value = await asyncio.wait_for(self._recv_data(), timeout=timeout) 

267 if msg_kind == WireKind.ERROR: 

268 raise _deserialize_exception(value) 

269 if msg_kind != WireKind.RESULT: 

270 raise WorkerError( 

271 "ProtocolError", 

272 f"Worker '{self._role}' replied with unexpected kind {msg_kind!r}.", 

273 "", 

274 ) 

275 return value 

276 finally: 

277 self._bump_in_flight(-1) 

278 

279 async def stream(self, kind: WireKind, payload: Any) -> AsyncIterator[Any]: 

280 """Send one request, yield streamed chunks until the terminator arrives. 

281 

282 The ``_call_lock`` is held for the full stream lifetime, so frames 

283 recv'd by this coroutine belong to this stream by construction. 

284 New callers queue behind the active stream. 

285 """ 

286 self._ensure_open() 

287 async with self._call_lock: 

288 self._bump_in_flight(1) 

289 try: 

290 await self._send_data(kind, payload) 

291 while True: 

292 msg_kind, value = await self._recv_data() 

293 if msg_kind == WireKind.STREAM_CHUNK: 

294 yield value 

295 elif msg_kind == WireKind.STREAM_END: 

296 return 

297 elif msg_kind == WireKind.ERROR: 

298 raise _deserialize_exception(value) 

299 else: 

300 raise WorkerError( 

301 "ProtocolError", 

302 f"Worker '{self._role}' streamed unexpected kind {msg_kind!r}.", 

303 "", 

304 ) 

305 finally: 

306 self._bump_in_flight(-1) 

307 

308 async def ping(self, *, timeout: float) -> None: 

309 """Round-trip a ping over the health pipe; raise on timeout / crash. 

310 

311 The worker dedicates a daemon thread to the health pipe so pings 

312 and shutdown can be served while the data loop is busy in 

313 ``session.embed`` / ``create_chat_completion``. 

314 """ 

315 self._ensure_open() 

316 kind = await self._health_round_trip(WireKind.PING, None, timeout=timeout) 

317 if kind != WireKind.PONG: 

318 raise WorkerError( 

319 "ProtocolError", 

320 f"Worker '{self._role}' ping reply was {kind!r}, want 'pong'.", 

321 "", 

322 ) 

323 

324 async def _health_round_trip( 

325 self, send_kind: WireKind, send_payload: Any, *, timeout: float 

326 ) -> WireKind: 

327 """Send one frame on the health pipe and await its reply within *timeout*.""" 

328 loop = asyncio.get_running_loop() 

329 async with self._health_lock: 

330 try: 

331 await loop.run_in_executor( 

332 self._executor, self._health_conn.send, (send_kind, send_payload) 

333 ) 

334 except (BrokenPipeError, ConnectionResetError, EOFError, OSError) as exc: 

335 raise self._crash() from exc 

336 try: 

337 frame = await asyncio.wait_for( 

338 loop.run_in_executor(self._executor, self._health_conn.recv), 

339 timeout=timeout, 

340 ) 

341 except (EOFError, OSError, ConnectionResetError, BrokenPipeError) as exc: 

342 raise self._crash() from exc 

343 reply_kind, _ = frame 

344 return reply_kind # type: ignore[no-any-return] 

345 

346 def cancel(self) -> None: 

347 """Flip the abort flag to 1; in-flight tokens may still drain.""" 

348 self._abort.value = 1 

349 

350 def clear_abort(self) -> None: 

351 """Reset the abort flag to 0 before the next request.""" 

352 self._abort.value = 0 

353 

354 async def close(self, *, timeout: float) -> None: 

355 """Send shutdown on the health pipe, await ack, then join the process. 

356 

357 The data pipe is never used for shutdown: a long in-flight call 

358 on the data pipe would otherwise serialize behind a shutdown 

359 request. Health-pipe shutdown is served by the worker's 

360 dedicated heartbeat thread, so this returns within the timeout 

361 regardless of what the data loop is doing. Any in-flight data 

362 call sees the process exit and surfaces as :class:`WorkerCrashError`. 

363 """ 

364 with self._closed_lock: 

365 if self._closed: 

366 return 

367 self._closed = True 

368 try: 

369 with contextlib.suppress(TimeoutError, WorkerError): 

370 await self._health_round_trip(WireKind.SHUTDOWN, None, timeout=timeout) 

371 await asyncio.get_running_loop().run_in_executor( 

372 self._executor, self._join_process, timeout 

373 ) 

374 finally: 

375 with contextlib.suppress(Exception): 

376 self._conn.close() 

377 with contextlib.suppress(Exception): 

378 self._health_conn.close() 

379 self._executor.shutdown(wait=False, cancel_futures=True) 

380 

381 def _join_process(self, timeout: float) -> None: 

382 """Wait *timeout* seconds for the process; terminate if still alive. 

383 

384 On non-clean exit (signal, non-zero code) record the exit reason in 

385 the worker log so the user has something to attach to a bug report. 

386 """ 

387 self._process.join(timeout=timeout) 

388 if self._process.is_alive(): 

389 log.warning("Worker '%s' did not exit gracefully; terminating", self._role) 

390 self._process.terminate() 

391 self._process.join(timeout=2.0) 

392 self._record_exit_reason() 

393 

394 def _record_exit_reason(self) -> None: 

395 """Append worker exit reason (signal or non-zero code) to the worker log.""" 

396 code = self._process.exitcode 

397 if code is None or code == 0: 

398 return 

399 log_path = _worker_log_path(self._role) 

400 message = self._format_exit_reason(code) 

401 log.warning("Worker '%s' %s", self._role, message) 

402 if log_path is None: 

403 return 

404 with contextlib.suppress(OSError), open(log_path, "a") as handle: 

405 handle.write(f"\n[supervisor] {message}\n") 

406 

407 @staticmethod 

408 def _format_exit_reason(exit_code: int) -> str: 

409 if exit_code >= 0: 

410 return f"exited with code {exit_code}" 

411 import signal 

412 

413 signum = -exit_code 

414 try: 

415 name = signal.Signals(signum).name 

416 except ValueError: 

417 name = f"SIG{signum}" 

418 return f"killed by signal {name} ({signum})" 

419 

420 

421class PipeSpawner: 

422 """Spawns worker subprocesses connected to the parent via :class:`multiprocessing.Pipe`.""" 

423 

424 def __init__(self, *, daemon: bool = True) -> None: 

425 self._ctx = multiprocessing.get_context("spawn") 

426 self._daemon = daemon 

427 

428 def spawn( 

429 self, 

430 worker_main: WorkerEntrypoint, 

431 role_config: RoleConfig, 

432 ) -> tuple[WorkerChannel, WorkerHandle]: 

433 """Start a worker subprocess and return its channel + handle. 

434 

435 Two pipes per worker: ``data_pipe`` carries call/stream traffic, 

436 ``health_pipe`` carries ping/pong and shutdown/ack. The worker 

437 dedicates a daemon thread to the health pipe so heartbeats and 

438 shutdown stay live during long inference. 

439 """ 

440 parent_data, child_data = self._ctx.Pipe(duplex=True) 

441 parent_health, child_health = self._ctx.Pipe(duplex=True) 

442 abort_flag = self._ctx.Value("b", 0, lock=True) 

443 process = self._ctx.Process( 

444 target=worker_main, 

445 args=(child_data, child_health, abort_flag, role_config), 

446 daemon=self._daemon, 

447 name=f"lilbee-worker-{role_config.role}", 

448 ) 

449 process.start() 

450 child_data.close() 

451 child_health.close() 

452 channel = PipeChannel( 

453 role=role_config.role, 

454 process=process, 

455 parent_conn=parent_data, 

456 health_conn=parent_health, 

457 abort_flag=abort_flag, 

458 ) 

459 handle = WorkerHandle(pid=process.pid, role=role_config.role) 

460 log.info("Spawned worker role=%s pid=%s", role_config.role, process.pid) 

461 return channel, handle 

462 

463 

464__all__ = [ 

465 "PipeChannel", 

466 "PipeSpawner", 

467 "WorkerCrashError", 

468 "WorkerError", 

469]