Coverage for src / lilbee / providers / worker / rerank_worker.py: 100%

46 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-15 20:55 +0000

1"""Long-lived rerank worker subprocess body.""" 

2 

3from __future__ import annotations 

4 

5import contextlib 

6from typing import Any 

7 

8from lilbee.providers.worker.transport import RerankPayload, RoleConfig 

9from lilbee.providers.worker.transport_pipe import _serialize_exception 

10from lilbee.providers.worker.wire_kinds import WireKind 

11from lilbee.providers.worker.worker_runtime import Reply, WorkerLoopState, run_worker 

12 

13 

14class _RerankSession: 

15 """Lazy-loaded Llama reranker, kept alive for the worker's lifetime.""" 

16 

17 def __init__(self, role_config: RoleConfig) -> None: 

18 self._role_config = role_config 

19 self._llm: Any = None 

20 

21 def score(self, query: str, candidates: list[str]) -> list[float]: 

22 """Score *candidates* against *query*, loading the model on first call.""" 

23 if self._llm is None: 

24 self._llm = self._load() 

25 return self._compute(self._llm, query, candidates) 

26 

27 def _load(self) -> Any: 

28 from lilbee.providers.llama_cpp.provider import load_llama 

29 from lilbee.providers.model_cache import LoaderMode 

30 

31 return load_llama(self._role_config.model_path, mode=LoaderMode.RERANK) 

32 

33 @staticmethod 

34 def _compute(llm: Any, query: str, candidates: list[str]) -> list[float]: 

35 # circular: lilbee.providers.llama_cpp.__init__ eagerly imports 

36 # provider.py, which imports this worker module. Function-local 

37 # import keeps that cycle from firing at module-load time. 

38 from lilbee.providers.llama_cpp.batching import compute_rerank_scores 

39 

40 return compute_rerank_scores(llm, query, candidates) 

41 

42 def close(self) -> None: 

43 """Release the loaded model, if any. Idempotent.""" 

44 if self._llm is None: 

45 return 

46 with contextlib.suppress(Exception): 

47 self._llm.close() 

48 self._llm = None 

49 

50 

51def _handle_rerank(reply: Reply, payload: Any, state: WorkerLoopState) -> None: 

52 """Run one rerank request and send the typed reply (or error).""" 

53 if not isinstance(payload, RerankPayload): 

54 try: 

55 raise TypeError(f"rerank payload must be RerankPayload, got {type(payload).__name__}") 

56 except TypeError as exc: 

57 reply.send(WireKind.ERROR, _serialize_exception(exc)) 

58 return 

59 session: _RerankSession = state.session 

60 try: 

61 scores = session.score(payload.query, payload.candidates) 

62 except Exception as exc: 

63 reply.send(WireKind.ERROR, _serialize_exception(exc)) 

64 return 

65 reply.send(WireKind.RESULT, scores) 

66 

67 

68def rerank_worker_main( 

69 data_conn: Any, health_conn: Any, abort_flag: Any, role_config: RoleConfig 

70) -> None: 

71 """Rerank worker entrypoint: load llama-cpp lazily, serve until shutdown.""" 

72 run_worker( 

73 data_conn, 

74 health_conn, 

75 abort_flag, 

76 role_config, 

77 session_factory=lambda role_cfg, _abort: _RerankSession(role_cfg), 

78 kind_handlers={WireKind.RERANK: _handle_rerank}, 

79 ) 

80 

81 

82__all__ = ["rerank_worker_main"]