Coverage for src / lilbee / providers / routing_provider.py: 100%
113 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-28 01:01 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-28 01:01 +0000
1"""Routing provider: prefix-based dispatch between the SDK backend and llama-cpp."""
3from __future__ import annotations
5import contextlib
6import logging
7from collections.abc import Callable
8from pathlib import Path
9from typing import Any, Literal, overload
11from lilbee.catalog import is_rerank_ref
12from lilbee.core.config import cfg
13from lilbee.providers.base import ClosableIterator, LLMProvider, ProviderError
14from lilbee.providers.litellm_sdk import LitellmSdkBackend
15from lilbee.providers.model_ref import ProviderModelRef, parse_model_ref
16from lilbee.providers.sdk_llm_provider import SdkLLMProvider
17from lilbee.providers.worker.transport import OcrBackend
18from lilbee.vision import PageText
20log = logging.getLogger(__name__)
22_NATIVE_GGUF_REF_MIN_SLASHES = 2
23"""``<org>/<repo>/<filename>.gguf`` has at least two slashes."""
26class RoutingProvider(LLMProvider):
27 """Dispatches calls based on the model ref prefix.
29 ``ollama/``, ``openai/``, ``anthropic/``, ``gemini/`` go to the SDK
30 provider. Other refs (the HuggingFace ``<org>/<repo>/<file>.gguf``
31 shape) go to llama-cpp, which resolves them against the native
32 registry. A registry miss surfaces the native ProviderError
33 unchanged, rather than silently falling through to a remote backend.
34 """
36 def __init__(self) -> None:
37 self._llama_cpp: LLMProvider | None = None
38 self._sdk_provider: SdkLLMProvider | None = None
40 def _get_llama_cpp(self) -> LLMProvider:
41 if self._llama_cpp is None:
42 from lilbee.providers.llama_cpp import LlamaCppProvider
44 self._llama_cpp = LlamaCppProvider()
45 return self._llama_cpp
47 def _get_sdk_provider(self) -> SdkLLMProvider:
48 if self._sdk_provider is None:
49 self._sdk_provider = SdkLLMProvider(
50 LitellmSdkBackend(),
51 api_key=cfg.llm_api_key,
52 )
53 return self._sdk_provider
55 def _pick_backend(self, ref: ProviderModelRef) -> LLMProvider:
56 """Pick the backend for *ref* purely by prefix."""
57 if ref.is_remote:
58 return self._get_sdk_provider()
59 return self._get_llama_cpp()
61 def embed(self, texts: list[str]) -> list[list[float]]:
62 ref = parse_model_ref(cfg.embedding_model)
63 return self._pick_backend(ref).embed(texts)
65 @overload
66 def chat(
67 self,
68 messages: list[dict[str, str]],
69 *,
70 stream: Literal[False] = False,
71 options: dict[str, Any] | None = None,
72 model: str | None = None,
73 ) -> str: ...
75 @overload
76 def chat(
77 self,
78 messages: list[dict[str, str]],
79 *,
80 stream: Literal[True],
81 options: dict[str, Any] | None = None,
82 model: str | None = None,
83 ) -> ClosableIterator[str]: ...
85 def chat(
86 self,
87 messages: list[dict[str, str]],
88 *,
89 stream: bool = False,
90 options: dict[str, Any] | None = None,
91 model: str | None = None,
92 ) -> str | ClosableIterator[str]:
93 ref = parse_model_ref(model or cfg.chat_model)
94 backend = self._pick_backend(ref)
95 # Split on stream so each call resolves to a specific overload; the
96 # base impl signature accepts bool but the @overloads on the LLMProvider
97 # Protocol require Literal narrowing at the boundary.
98 if stream:
99 return backend.chat(messages, stream=True, options=options, model=model)
100 return backend.chat(messages, stream=False, options=options, model=model)
102 def vision_ocr(
103 self,
104 png_bytes: bytes,
105 model: str,
106 prompt: str = "",
107 *,
108 timeout: float | None = None,
109 ) -> str:
110 """Dispatch by ``model``'s ref prefix, same rules as :meth:`chat`."""
111 ref = parse_model_ref(model)
112 return self._pick_backend(ref).vision_ocr(png_bytes, model, prompt, timeout=timeout)
114 def pdf_ocr(
115 self,
116 path: Path,
117 *,
118 backend: OcrBackend,
119 model: str = "",
120 per_page_timeout_s: float | None = None,
121 quiet: bool = True,
122 on_progress: Callable[..., None] | None = None,
123 ) -> list[PageText]:
124 """Dispatch by ``model``'s ref prefix, same rules as :meth:`vision_ocr`.
126 Hosted refs reach :class:`SdkLLMProvider`, which raises
127 ``NotImplementedError`` for PDF OCR; native refs reach the
128 llama-cpp pool worker. ``model`` is empty when the caller wants
129 the configured ``cfg.vision_model`` to drive the dispatch.
130 """
131 ref = parse_model_ref(model or cfg.vision_model)
132 return self._pick_backend(ref).pdf_ocr(
133 path,
134 backend=backend,
135 model=model,
136 per_page_timeout_s=per_page_timeout_s,
137 quiet=quiet,
138 on_progress=on_progress,
139 )
141 def list_models(self) -> list[str]:
142 """Return the union of native and SDK-visible models.
144 Both halves are wrapped so an unreachable remote backend or a
145 missing native registry does not mask the other.
146 """
147 native: set[str] = set()
148 with contextlib.suppress(Exception):
149 native = set(self._get_llama_cpp().list_models())
150 sdk = self._get_sdk_provider()
151 if not sdk.available():
152 return sorted(native)
153 try:
154 remote = set(sdk.list_models())
155 except Exception:
156 return sorted(native)
157 return sorted(native | remote)
159 def list_chat_models(self, provider: str) -> list[str]:
160 """Delegate to the SDK backend; native llama-cpp has no catalog."""
161 sdk = self._get_sdk_provider()
162 if not sdk.available():
163 return []
164 return sdk.list_chat_models(provider)
166 def pull_model(self, model: str, *, on_progress: Callable[..., Any] | None = None) -> None:
167 """Pull via the SDK backend if installed, otherwise raise."""
168 sdk = self._get_sdk_provider()
169 if not sdk.available():
170 raise ProviderError(f"Cannot pull model {model!r}: no pull-capable backend available")
171 sdk.pull_model(model, on_progress=on_progress)
173 def show_model(self, model: str) -> dict[str, Any] | None:
174 """Show model info from the backend selected by the ref prefix."""
175 ref = parse_model_ref(model)
176 return self._pick_backend(ref).show_model(model)
178 def get_capabilities(self, model: str) -> list[str]:
179 """Return capability tags from the backend selected by the ref prefix."""
180 ref = parse_model_ref(model)
181 return self._pick_backend(ref).get_capabilities(model)
183 def rerank(self, query: str, candidates: list[str]) -> list[float]:
184 """Dispatch rerank to the backend that owns ``cfg.reranker_model``.
186 Native GGUF refs go to llama-cpp; hosted refs go through the SDK
187 provider. Raises ``ProviderError`` when ``cfg.reranker_model`` is
188 empty or the selected backend does not support reranking.
189 """
190 if not cfg.reranker_model:
191 raise ProviderError("No reranker configured. Set cfg.reranker_model first.")
192 if _is_native_rerank_ref(cfg.reranker_model):
193 return self._get_llama_cpp().rerank(query, candidates)
194 sdk = self._get_sdk_provider()
195 if not sdk.supports_rerank():
196 raise ProviderError(
197 f"Cannot rerank with {cfg.reranker_model!r}: "
198 "hosted rerank backend not available. "
199 "Install the 'litellm' extra to enable hosted reranking."
200 )
201 return sdk.rerank(query, candidates)
203 def supports_rerank(self) -> bool:
204 """Capability probe: can the routed backend rerank if configured?
206 Pure capability check, NOT "a reranker is currently active". An
207 empty ``cfg.reranker_model`` returns ``True`` so the settings UI
208 keeps the picker visible; callers that need to know whether
209 reranking is actually configured must check ``bool(cfg.reranker_model)``
210 separately. Delegates to the backend that would handle the
211 configured model when one is set.
212 """
213 model = cfg.reranker_model
214 if not model:
215 return True
216 if _is_native_rerank_ref(model):
217 return self._get_llama_cpp().supports_rerank()
218 return self._get_sdk_provider().supports_rerank()
220 def shutdown(self) -> None:
221 """Shut down sub-providers to release resources."""
222 if self._llama_cpp is not None:
223 self._llama_cpp.shutdown()
224 if self._sdk_provider is not None:
225 self._sdk_provider.shutdown()
227 def invalidate_load_cache(self, model_path: Path | None = None) -> None:
228 """Forward to the native side only; the SDK side has no local cache."""
229 if self._llama_cpp is not None:
230 self._llama_cpp.invalidate_load_cache(model_path)
232 def warm_up_pool(self) -> None:
233 """Forward to the native side; the SDK side has no worker pool.
235 Lazily constructs the llama-cpp provider if it isn't already up so
236 eager-start during ``Services`` boot still warms the configured
237 native roles, even when the user hasn't issued a chat call yet.
238 """
239 self._get_llama_cpp().warm_up_pool()
242def _is_native_rerank_ref(model: str) -> bool:
243 """Return True iff *model* should route to the native llama-cpp rerank worker.
245 Two acceptance paths:
247 1. The ref resolves to a featured rerank catalog entry (the historical
248 fast path).
249 2. The ref has the native HuggingFace GGUF shape
250 ``<org>/<repo>/<filename>.gguf`` (two slashes, ``.gguf`` suffix). This
251 lets users point ``cfg.reranker_model`` at any installed native GGUF
252 reranker instead of only the ones that ship in ``FEATURED_ALL``.
253 Non-GGUF refs without a known SDK prefix still raise downstream
254 through ``parse_model_ref``.
255 """
256 if not model:
257 return False
258 if is_rerank_ref(model):
259 return True
260 return model.lower().endswith(".gguf") and model.count("/") >= _NATIVE_GGUF_REF_MIN_SLASHES