Coverage for src / lilbee / providers / routing_provider.py: 100%
113 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
1"""Routing provider: prefix-based dispatch between the SDK backend and llama-cpp."""
3from __future__ import annotations
5import contextlib
6import logging
7from collections.abc import Callable
8from pathlib import Path
9from typing import Any, Literal, overload
11from lilbee.catalog import is_rerank_ref
12from lilbee.core.config import cfg
13from lilbee.providers.base import ClosableIterator, LLMProvider, ProviderError
14from lilbee.providers.litellm_sdk import LitellmSdkBackend
15from lilbee.providers.model_ref import ProviderModelRef, parse_model_ref
16from lilbee.providers.sdk_llm_provider import SdkLLMProvider
17from lilbee.providers.worker.transport import OcrBackend
18from lilbee.vision import PageText
20log = logging.getLogger(__name__)
22_NATIVE_GGUF_REF_MIN_SLASHES = 2
23"""``<org>/<repo>/<filename>.gguf`` has at least two slashes."""
26class RoutingProvider(LLMProvider):
27 """Dispatches calls based on the model ref prefix.
29 ``ollama/``, ``openai/``, ``anthropic/``, ``gemini/`` go to the SDK
30 provider. Other refs (the HuggingFace ``<org>/<repo>/<file>.gguf``
31 shape) go to llama-cpp, which resolves them against the native
32 registry. A registry miss surfaces the native ProviderError
33 unchanged, rather than silently falling through to a remote backend.
34 """
36 def __init__(self) -> None:
37 self._llama_cpp: LLMProvider | None = None
38 self._sdk_provider: SdkLLMProvider | None = None
40 def _get_llama_cpp(self) -> LLMProvider:
41 if self._llama_cpp is None:
42 from lilbee.providers.llama_cpp import LlamaCppProvider
44 self._llama_cpp = LlamaCppProvider()
45 return self._llama_cpp
47 def _get_sdk_provider(self) -> SdkLLMProvider:
48 if self._sdk_provider is None:
49 self._sdk_provider = SdkLLMProvider(
50 LitellmSdkBackend(),
51 base_url=cfg.remote_base_url,
52 api_key=cfg.llm_api_key,
53 )
54 return self._sdk_provider
56 def _pick_backend(self, ref: ProviderModelRef) -> LLMProvider:
57 """Pick the backend for *ref* purely by prefix."""
58 if ref.is_remote:
59 return self._get_sdk_provider()
60 return self._get_llama_cpp()
62 def embed(self, texts: list[str]) -> list[list[float]]:
63 ref = parse_model_ref(cfg.embedding_model)
64 return self._pick_backend(ref).embed(texts)
66 @overload
67 def chat(
68 self,
69 messages: list[dict[str, str]],
70 *,
71 stream: Literal[False] = False,
72 options: dict[str, Any] | None = None,
73 model: str | None = None,
74 ) -> str: ...
76 @overload
77 def chat(
78 self,
79 messages: list[dict[str, str]],
80 *,
81 stream: Literal[True],
82 options: dict[str, Any] | None = None,
83 model: str | None = None,
84 ) -> ClosableIterator[str]: ...
86 def chat(
87 self,
88 messages: list[dict[str, str]],
89 *,
90 stream: bool = False,
91 options: dict[str, Any] | None = None,
92 model: str | None = None,
93 ) -> str | ClosableIterator[str]:
94 ref = parse_model_ref(model or cfg.chat_model)
95 backend = self._pick_backend(ref)
96 # Split on stream so each call resolves to a specific overload; the
97 # base impl signature accepts bool but the @overloads on the LLMProvider
98 # Protocol require Literal narrowing at the boundary.
99 if stream:
100 return backend.chat(messages, stream=True, options=options, model=model)
101 return backend.chat(messages, stream=False, options=options, model=model)
103 def vision_ocr(
104 self,
105 png_bytes: bytes,
106 model: str,
107 prompt: str = "",
108 *,
109 timeout: float | None = None,
110 ) -> str:
111 """Dispatch by ``model``'s ref prefix, same rules as :meth:`chat`."""
112 ref = parse_model_ref(model)
113 return self._pick_backend(ref).vision_ocr(png_bytes, model, prompt, timeout=timeout)
115 def pdf_ocr(
116 self,
117 path: Path,
118 *,
119 backend: OcrBackend,
120 model: str = "",
121 per_page_timeout_s: float | None = None,
122 quiet: bool = True,
123 on_progress: Callable[..., None] | None = None,
124 ) -> list[PageText]:
125 """Dispatch by ``model``'s ref prefix, same rules as :meth:`vision_ocr`.
127 Hosted refs reach :class:`SdkLLMProvider`, which raises
128 ``NotImplementedError`` for PDF OCR; native refs reach the
129 llama-cpp pool worker. ``model`` is empty when the caller wants
130 the configured ``cfg.vision_model`` to drive the dispatch.
131 """
132 ref = parse_model_ref(model or cfg.vision_model)
133 return self._pick_backend(ref).pdf_ocr(
134 path,
135 backend=backend,
136 model=model,
137 per_page_timeout_s=per_page_timeout_s,
138 quiet=quiet,
139 on_progress=on_progress,
140 )
142 def list_models(self) -> list[str]:
143 """Return the union of native and SDK-visible models.
145 Both halves are wrapped so an unreachable remote backend or a
146 missing native registry does not mask the other.
147 """
148 native: set[str] = set()
149 with contextlib.suppress(Exception):
150 native = set(self._get_llama_cpp().list_models())
151 sdk = self._get_sdk_provider()
152 if not sdk.available():
153 return sorted(native)
154 try:
155 remote = set(sdk.list_models())
156 except Exception:
157 return sorted(native)
158 return sorted(native | remote)
160 def list_chat_models(self, provider: str) -> list[str]:
161 """Delegate to the SDK backend; native llama-cpp has no catalog."""
162 sdk = self._get_sdk_provider()
163 if not sdk.available():
164 return []
165 return sdk.list_chat_models(provider)
167 def pull_model(self, model: str, *, on_progress: Callable[..., Any] | None = None) -> None:
168 """Pull via the SDK backend if installed, otherwise raise."""
169 sdk = self._get_sdk_provider()
170 if not sdk.available():
171 raise ProviderError(f"Cannot pull model {model!r}: no pull-capable backend available")
172 sdk.pull_model(model, on_progress=on_progress)
174 def show_model(self, model: str) -> dict[str, Any] | None:
175 """Show model info from the backend selected by the ref prefix."""
176 ref = parse_model_ref(model)
177 return self._pick_backend(ref).show_model(model)
179 def get_capabilities(self, model: str) -> list[str]:
180 """Return capability tags from the backend selected by the ref prefix."""
181 ref = parse_model_ref(model)
182 return self._pick_backend(ref).get_capabilities(model)
184 def rerank(self, query: str, candidates: list[str]) -> list[float]:
185 """Dispatch rerank to the backend that owns ``cfg.reranker_model``.
187 Native GGUF refs go to llama-cpp; hosted refs go through the SDK
188 provider. Raises ``ProviderError`` when ``cfg.reranker_model`` is
189 empty or the selected backend does not support reranking.
190 """
191 if not cfg.reranker_model:
192 raise ProviderError("No reranker configured. Set cfg.reranker_model first.")
193 if _is_native_rerank_ref(cfg.reranker_model):
194 return self._get_llama_cpp().rerank(query, candidates)
195 sdk = self._get_sdk_provider()
196 if not sdk.supports_rerank():
197 raise ProviderError(
198 f"Cannot rerank with {cfg.reranker_model!r}: "
199 "hosted rerank backend not available. "
200 "Install the 'litellm' extra to enable hosted reranking."
201 )
202 return sdk.rerank(query, candidates)
204 def supports_rerank(self) -> bool:
205 """Capability probe: can the routed backend rerank if configured?
207 Pure capability check, NOT "a reranker is currently active". An
208 empty ``cfg.reranker_model`` returns ``True`` so the settings UI
209 keeps the picker visible; callers that need to know whether
210 reranking is actually configured must check ``bool(cfg.reranker_model)``
211 separately. Delegates to the backend that would handle the
212 configured model when one is set.
213 """
214 model = cfg.reranker_model
215 if not model:
216 return True
217 if _is_native_rerank_ref(model):
218 return self._get_llama_cpp().supports_rerank()
219 return self._get_sdk_provider().supports_rerank()
221 def shutdown(self) -> None:
222 """Shut down sub-providers to release resources."""
223 if self._llama_cpp is not None:
224 self._llama_cpp.shutdown()
225 if self._sdk_provider is not None:
226 self._sdk_provider.shutdown()
228 def invalidate_load_cache(self, model_path: Path | None = None) -> None:
229 """Forward to the native side only; the SDK side has no local cache."""
230 if self._llama_cpp is not None:
231 self._llama_cpp.invalidate_load_cache(model_path)
233 def warm_up_pool(self) -> None:
234 """Forward to the native side; the SDK side has no worker pool.
236 Lazily constructs the llama-cpp provider if it isn't already up so
237 eager-start during ``Services`` boot still warms the configured
238 native roles, even when the user hasn't issued a chat call yet.
239 """
240 self._get_llama_cpp().warm_up_pool()
243def _is_native_rerank_ref(model: str) -> bool:
244 """Return True iff *model* should route to the native llama-cpp rerank worker.
246 Two acceptance paths:
248 1. The ref resolves to a featured rerank catalog entry (the historical
249 fast path).
250 2. The ref has the native HuggingFace GGUF shape
251 ``<org>/<repo>/<filename>.gguf`` (two slashes, ``.gguf`` suffix). This
252 lets users point ``cfg.reranker_model`` at any installed native GGUF
253 reranker instead of only the ones that ship in ``FEATURED_ALL``.
254 Non-GGUF refs without a known SDK prefix still raise downstream
255 through ``parse_model_ref``.
256 """
257 if not model:
258 return False
259 if is_rerank_ref(model):
260 return True
261 return model.lower().endswith(".gguf") and model.count("/") >= _NATIVE_GGUF_REF_MIN_SLASHES