Coverage for src / lilbee / providers / routing_provider.py: 100%

113 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-15 20:55 +0000

1"""Routing provider: prefix-based dispatch between the SDK backend and llama-cpp.""" 

2 

3from __future__ import annotations 

4 

5import contextlib 

6import logging 

7from collections.abc import Callable 

8from pathlib import Path 

9from typing import Any, Literal, overload 

10 

11from lilbee.catalog import is_rerank_ref 

12from lilbee.core.config import cfg 

13from lilbee.providers.base import ClosableIterator, LLMProvider, ProviderError 

14from lilbee.providers.litellm_sdk import LitellmSdkBackend 

15from lilbee.providers.model_ref import ProviderModelRef, parse_model_ref 

16from lilbee.providers.sdk_llm_provider import SdkLLMProvider 

17from lilbee.providers.worker.transport import OcrBackend 

18from lilbee.vision import PageText 

19 

20log = logging.getLogger(__name__) 

21 

22_NATIVE_GGUF_REF_MIN_SLASHES = 2 

23"""``<org>/<repo>/<filename>.gguf`` has at least two slashes.""" 

24 

25 

26class RoutingProvider(LLMProvider): 

27 """Dispatches calls based on the model ref prefix. 

28 

29 ``ollama/``, ``openai/``, ``anthropic/``, ``gemini/`` go to the SDK 

30 provider. Other refs (the HuggingFace ``<org>/<repo>/<file>.gguf`` 

31 shape) go to llama-cpp, which resolves them against the native 

32 registry. A registry miss surfaces the native ProviderError 

33 unchanged, rather than silently falling through to a remote backend. 

34 """ 

35 

36 def __init__(self) -> None: 

37 self._llama_cpp: LLMProvider | None = None 

38 self._sdk_provider: SdkLLMProvider | None = None 

39 

40 def _get_llama_cpp(self) -> LLMProvider: 

41 if self._llama_cpp is None: 

42 from lilbee.providers.llama_cpp import LlamaCppProvider 

43 

44 self._llama_cpp = LlamaCppProvider() 

45 return self._llama_cpp 

46 

47 def _get_sdk_provider(self) -> SdkLLMProvider: 

48 if self._sdk_provider is None: 

49 self._sdk_provider = SdkLLMProvider( 

50 LitellmSdkBackend(), 

51 base_url=cfg.remote_base_url, 

52 api_key=cfg.llm_api_key, 

53 ) 

54 return self._sdk_provider 

55 

56 def _pick_backend(self, ref: ProviderModelRef) -> LLMProvider: 

57 """Pick the backend for *ref* purely by prefix.""" 

58 if ref.is_remote: 

59 return self._get_sdk_provider() 

60 return self._get_llama_cpp() 

61 

62 def embed(self, texts: list[str]) -> list[list[float]]: 

63 ref = parse_model_ref(cfg.embedding_model) 

64 return self._pick_backend(ref).embed(texts) 

65 

66 @overload 

67 def chat( 

68 self, 

69 messages: list[dict[str, str]], 

70 *, 

71 stream: Literal[False] = False, 

72 options: dict[str, Any] | None = None, 

73 model: str | None = None, 

74 ) -> str: ... 

75 

76 @overload 

77 def chat( 

78 self, 

79 messages: list[dict[str, str]], 

80 *, 

81 stream: Literal[True], 

82 options: dict[str, Any] | None = None, 

83 model: str | None = None, 

84 ) -> ClosableIterator[str]: ... 

85 

86 def chat( 

87 self, 

88 messages: list[dict[str, str]], 

89 *, 

90 stream: bool = False, 

91 options: dict[str, Any] | None = None, 

92 model: str | None = None, 

93 ) -> str | ClosableIterator[str]: 

94 ref = parse_model_ref(model or cfg.chat_model) 

95 backend = self._pick_backend(ref) 

96 # Split on stream so each call resolves to a specific overload; the 

97 # base impl signature accepts bool but the @overloads on the LLMProvider 

98 # Protocol require Literal narrowing at the boundary. 

99 if stream: 

100 return backend.chat(messages, stream=True, options=options, model=model) 

101 return backend.chat(messages, stream=False, options=options, model=model) 

102 

103 def vision_ocr( 

104 self, 

105 png_bytes: bytes, 

106 model: str, 

107 prompt: str = "", 

108 *, 

109 timeout: float | None = None, 

110 ) -> str: 

111 """Dispatch by ``model``'s ref prefix, same rules as :meth:`chat`.""" 

112 ref = parse_model_ref(model) 

113 return self._pick_backend(ref).vision_ocr(png_bytes, model, prompt, timeout=timeout) 

114 

115 def pdf_ocr( 

116 self, 

117 path: Path, 

118 *, 

119 backend: OcrBackend, 

120 model: str = "", 

121 per_page_timeout_s: float | None = None, 

122 quiet: bool = True, 

123 on_progress: Callable[..., None] | None = None, 

124 ) -> list[PageText]: 

125 """Dispatch by ``model``'s ref prefix, same rules as :meth:`vision_ocr`. 

126 

127 Hosted refs reach :class:`SdkLLMProvider`, which raises 

128 ``NotImplementedError`` for PDF OCR; native refs reach the 

129 llama-cpp pool worker. ``model`` is empty when the caller wants 

130 the configured ``cfg.vision_model`` to drive the dispatch. 

131 """ 

132 ref = parse_model_ref(model or cfg.vision_model) 

133 return self._pick_backend(ref).pdf_ocr( 

134 path, 

135 backend=backend, 

136 model=model, 

137 per_page_timeout_s=per_page_timeout_s, 

138 quiet=quiet, 

139 on_progress=on_progress, 

140 ) 

141 

142 def list_models(self) -> list[str]: 

143 """Return the union of native and SDK-visible models. 

144 

145 Both halves are wrapped so an unreachable remote backend or a 

146 missing native registry does not mask the other. 

147 """ 

148 native: set[str] = set() 

149 with contextlib.suppress(Exception): 

150 native = set(self._get_llama_cpp().list_models()) 

151 sdk = self._get_sdk_provider() 

152 if not sdk.available(): 

153 return sorted(native) 

154 try: 

155 remote = set(sdk.list_models()) 

156 except Exception: 

157 return sorted(native) 

158 return sorted(native | remote) 

159 

160 def list_chat_models(self, provider: str) -> list[str]: 

161 """Delegate to the SDK backend; native llama-cpp has no catalog.""" 

162 sdk = self._get_sdk_provider() 

163 if not sdk.available(): 

164 return [] 

165 return sdk.list_chat_models(provider) 

166 

167 def pull_model(self, model: str, *, on_progress: Callable[..., Any] | None = None) -> None: 

168 """Pull via the SDK backend if installed, otherwise raise.""" 

169 sdk = self._get_sdk_provider() 

170 if not sdk.available(): 

171 raise ProviderError(f"Cannot pull model {model!r}: no pull-capable backend available") 

172 sdk.pull_model(model, on_progress=on_progress) 

173 

174 def show_model(self, model: str) -> dict[str, Any] | None: 

175 """Show model info from the backend selected by the ref prefix.""" 

176 ref = parse_model_ref(model) 

177 return self._pick_backend(ref).show_model(model) 

178 

179 def get_capabilities(self, model: str) -> list[str]: 

180 """Return capability tags from the backend selected by the ref prefix.""" 

181 ref = parse_model_ref(model) 

182 return self._pick_backend(ref).get_capabilities(model) 

183 

184 def rerank(self, query: str, candidates: list[str]) -> list[float]: 

185 """Dispatch rerank to the backend that owns ``cfg.reranker_model``. 

186 

187 Native GGUF refs go to llama-cpp; hosted refs go through the SDK 

188 provider. Raises ``ProviderError`` when ``cfg.reranker_model`` is 

189 empty or the selected backend does not support reranking. 

190 """ 

191 if not cfg.reranker_model: 

192 raise ProviderError("No reranker configured. Set cfg.reranker_model first.") 

193 if _is_native_rerank_ref(cfg.reranker_model): 

194 return self._get_llama_cpp().rerank(query, candidates) 

195 sdk = self._get_sdk_provider() 

196 if not sdk.supports_rerank(): 

197 raise ProviderError( 

198 f"Cannot rerank with {cfg.reranker_model!r}: " 

199 "hosted rerank backend not available. " 

200 "Install the 'litellm' extra to enable hosted reranking." 

201 ) 

202 return sdk.rerank(query, candidates) 

203 

204 def supports_rerank(self) -> bool: 

205 """Capability probe: can the routed backend rerank if configured? 

206 

207 Pure capability check, NOT "a reranker is currently active". An 

208 empty ``cfg.reranker_model`` returns ``True`` so the settings UI 

209 keeps the picker visible; callers that need to know whether 

210 reranking is actually configured must check ``bool(cfg.reranker_model)`` 

211 separately. Delegates to the backend that would handle the 

212 configured model when one is set. 

213 """ 

214 model = cfg.reranker_model 

215 if not model: 

216 return True 

217 if _is_native_rerank_ref(model): 

218 return self._get_llama_cpp().supports_rerank() 

219 return self._get_sdk_provider().supports_rerank() 

220 

221 def shutdown(self) -> None: 

222 """Shut down sub-providers to release resources.""" 

223 if self._llama_cpp is not None: 

224 self._llama_cpp.shutdown() 

225 if self._sdk_provider is not None: 

226 self._sdk_provider.shutdown() 

227 

228 def invalidate_load_cache(self, model_path: Path | None = None) -> None: 

229 """Forward to the native side only; the SDK side has no local cache.""" 

230 if self._llama_cpp is not None: 

231 self._llama_cpp.invalidate_load_cache(model_path) 

232 

233 def warm_up_pool(self) -> None: 

234 """Forward to the native side; the SDK side has no worker pool. 

235 

236 Lazily constructs the llama-cpp provider if it isn't already up so 

237 eager-start during ``Services`` boot still warms the configured 

238 native roles, even when the user hasn't issued a chat call yet. 

239 """ 

240 self._get_llama_cpp().warm_up_pool() 

241 

242 

243def _is_native_rerank_ref(model: str) -> bool: 

244 """Return True iff *model* should route to the native llama-cpp rerank worker. 

245 

246 Two acceptance paths: 

247 

248 1. The ref resolves to a featured rerank catalog entry (the historical 

249 fast path). 

250 2. The ref has the native HuggingFace GGUF shape 

251 ``<org>/<repo>/<filename>.gguf`` (two slashes, ``.gguf`` suffix). This 

252 lets users point ``cfg.reranker_model`` at any installed native GGUF 

253 reranker instead of only the ones that ship in ``FEATURED_ALL``. 

254 Non-GGUF refs without a known SDK prefix still raise downstream 

255 through ``parse_model_ref``. 

256 """ 

257 if not model: 

258 return False 

259 if is_rerank_ref(model): 

260 return True 

261 return model.lower().endswith(".gguf") and model.count("/") >= _NATIVE_GGUF_REF_MIN_SLASHES