Coverage for src / lilbee / providers / worker / rerank_worker.py: 100%
46 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
1"""Long-lived rerank worker subprocess body."""
3from __future__ import annotations
5import contextlib
6from typing import Any
8from lilbee.providers.worker.transport import RerankPayload, RoleConfig
9from lilbee.providers.worker.transport_pipe import _serialize_exception
10from lilbee.providers.worker.wire_kinds import WireKind
11from lilbee.providers.worker.worker_runtime import Reply, WorkerLoopState, run_worker
14class _RerankSession:
15 """Lazy-loaded Llama reranker, kept alive for the worker's lifetime."""
17 def __init__(self, role_config: RoleConfig) -> None:
18 self._role_config = role_config
19 self._llm: Any = None
21 def score(self, query: str, candidates: list[str]) -> list[float]:
22 """Score *candidates* against *query*, loading the model on first call."""
23 if self._llm is None:
24 self._llm = self._load()
25 return self._compute(self._llm, query, candidates)
27 def _load(self) -> Any:
28 from lilbee.providers.llama_cpp.provider import load_llama
29 from lilbee.providers.model_cache import LoaderMode
31 return load_llama(self._role_config.model_path, mode=LoaderMode.RERANK)
33 @staticmethod
34 def _compute(llm: Any, query: str, candidates: list[str]) -> list[float]:
35 # circular: lilbee.providers.llama_cpp.__init__ eagerly imports
36 # provider.py, which imports this worker module. Function-local
37 # import keeps that cycle from firing at module-load time.
38 from lilbee.providers.llama_cpp.batching import compute_rerank_scores
40 return compute_rerank_scores(llm, query, candidates)
42 def close(self) -> None:
43 """Release the loaded model, if any. Idempotent."""
44 if self._llm is None:
45 return
46 with contextlib.suppress(Exception):
47 self._llm.close()
48 self._llm = None
51def _handle_rerank(reply: Reply, payload: Any, state: WorkerLoopState) -> None:
52 """Run one rerank request and send the typed reply (or error)."""
53 if not isinstance(payload, RerankPayload):
54 try:
55 raise TypeError(f"rerank payload must be RerankPayload, got {type(payload).__name__}")
56 except TypeError as exc:
57 reply.send(WireKind.ERROR, _serialize_exception(exc))
58 return
59 session: _RerankSession = state.session
60 try:
61 scores = session.score(payload.query, payload.candidates)
62 except Exception as exc:
63 reply.send(WireKind.ERROR, _serialize_exception(exc))
64 return
65 reply.send(WireKind.RESULT, scores)
68def rerank_worker_main(
69 data_conn: Any, health_conn: Any, abort_flag: Any, role_config: RoleConfig
70) -> None:
71 """Rerank worker entrypoint: load llama-cpp lazily, serve until shutdown."""
72 run_worker(
73 data_conn,
74 health_conn,
75 abort_flag,
76 role_config,
77 session_factory=lambda role_cfg, _abort: _RerankSession(role_cfg),
78 kind_handlers={WireKind.RERANK: _handle_rerank},
79 )
82__all__ = ["rerank_worker_main"]