Coverage for src / lilbee / providers / routing_provider.py: 100%

113 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-06-28 01:01 +0000

1"""Routing provider: prefix-based dispatch between the SDK backend and llama-cpp.""" 

2 

3from __future__ import annotations 

4 

5import contextlib 

6import logging 

7from collections.abc import Callable 

8from pathlib import Path 

9from typing import Any, Literal, overload 

10 

11from lilbee.catalog import is_rerank_ref 

12from lilbee.core.config import cfg 

13from lilbee.providers.base import ClosableIterator, LLMProvider, ProviderError 

14from lilbee.providers.litellm_sdk import LitellmSdkBackend 

15from lilbee.providers.model_ref import ProviderModelRef, parse_model_ref 

16from lilbee.providers.sdk_llm_provider import SdkLLMProvider 

17from lilbee.providers.worker.transport import OcrBackend 

18from lilbee.vision import PageText 

19 

20log = logging.getLogger(__name__) 

21 

22_NATIVE_GGUF_REF_MIN_SLASHES = 2 

23"""``<org>/<repo>/<filename>.gguf`` has at least two slashes.""" 

24 

25 

26class RoutingProvider(LLMProvider): 

27 """Dispatches calls based on the model ref prefix. 

28 

29 ``ollama/``, ``openai/``, ``anthropic/``, ``gemini/`` go to the SDK 

30 provider. Other refs (the HuggingFace ``<org>/<repo>/<file>.gguf`` 

31 shape) go to llama-cpp, which resolves them against the native 

32 registry. A registry miss surfaces the native ProviderError 

33 unchanged, rather than silently falling through to a remote backend. 

34 """ 

35 

36 def __init__(self) -> None: 

37 self._llama_cpp: LLMProvider | None = None 

38 self._sdk_provider: SdkLLMProvider | None = None 

39 

40 def _get_llama_cpp(self) -> LLMProvider: 

41 if self._llama_cpp is None: 

42 from lilbee.providers.llama_cpp import LlamaCppProvider 

43 

44 self._llama_cpp = LlamaCppProvider() 

45 return self._llama_cpp 

46 

47 def _get_sdk_provider(self) -> SdkLLMProvider: 

48 if self._sdk_provider is None: 

49 self._sdk_provider = SdkLLMProvider( 

50 LitellmSdkBackend(), 

51 api_key=cfg.llm_api_key, 

52 ) 

53 return self._sdk_provider 

54 

55 def _pick_backend(self, ref: ProviderModelRef) -> LLMProvider: 

56 """Pick the backend for *ref* purely by prefix.""" 

57 if ref.is_remote: 

58 return self._get_sdk_provider() 

59 return self._get_llama_cpp() 

60 

61 def embed(self, texts: list[str]) -> list[list[float]]: 

62 ref = parse_model_ref(cfg.embedding_model) 

63 return self._pick_backend(ref).embed(texts) 

64 

65 @overload 

66 def chat( 

67 self, 

68 messages: list[dict[str, str]], 

69 *, 

70 stream: Literal[False] = False, 

71 options: dict[str, Any] | None = None, 

72 model: str | None = None, 

73 ) -> str: ... 

74 

75 @overload 

76 def chat( 

77 self, 

78 messages: list[dict[str, str]], 

79 *, 

80 stream: Literal[True], 

81 options: dict[str, Any] | None = None, 

82 model: str | None = None, 

83 ) -> ClosableIterator[str]: ... 

84 

85 def chat( 

86 self, 

87 messages: list[dict[str, str]], 

88 *, 

89 stream: bool = False, 

90 options: dict[str, Any] | None = None, 

91 model: str | None = None, 

92 ) -> str | ClosableIterator[str]: 

93 ref = parse_model_ref(model or cfg.chat_model) 

94 backend = self._pick_backend(ref) 

95 # Split on stream so each call resolves to a specific overload; the 

96 # base impl signature accepts bool but the @overloads on the LLMProvider 

97 # Protocol require Literal narrowing at the boundary. 

98 if stream: 

99 return backend.chat(messages, stream=True, options=options, model=model) 

100 return backend.chat(messages, stream=False, options=options, model=model) 

101 

102 def vision_ocr( 

103 self, 

104 png_bytes: bytes, 

105 model: str, 

106 prompt: str = "", 

107 *, 

108 timeout: float | None = None, 

109 ) -> str: 

110 """Dispatch by ``model``'s ref prefix, same rules as :meth:`chat`.""" 

111 ref = parse_model_ref(model) 

112 return self._pick_backend(ref).vision_ocr(png_bytes, model, prompt, timeout=timeout) 

113 

114 def pdf_ocr( 

115 self, 

116 path: Path, 

117 *, 

118 backend: OcrBackend, 

119 model: str = "", 

120 per_page_timeout_s: float | None = None, 

121 quiet: bool = True, 

122 on_progress: Callable[..., None] | None = None, 

123 ) -> list[PageText]: 

124 """Dispatch by ``model``'s ref prefix, same rules as :meth:`vision_ocr`. 

125 

126 Hosted refs reach :class:`SdkLLMProvider`, which raises 

127 ``NotImplementedError`` for PDF OCR; native refs reach the 

128 llama-cpp pool worker. ``model`` is empty when the caller wants 

129 the configured ``cfg.vision_model`` to drive the dispatch. 

130 """ 

131 ref = parse_model_ref(model or cfg.vision_model) 

132 return self._pick_backend(ref).pdf_ocr( 

133 path, 

134 backend=backend, 

135 model=model, 

136 per_page_timeout_s=per_page_timeout_s, 

137 quiet=quiet, 

138 on_progress=on_progress, 

139 ) 

140 

141 def list_models(self) -> list[str]: 

142 """Return the union of native and SDK-visible models. 

143 

144 Both halves are wrapped so an unreachable remote backend or a 

145 missing native registry does not mask the other. 

146 """ 

147 native: set[str] = set() 

148 with contextlib.suppress(Exception): 

149 native = set(self._get_llama_cpp().list_models()) 

150 sdk = self._get_sdk_provider() 

151 if not sdk.available(): 

152 return sorted(native) 

153 try: 

154 remote = set(sdk.list_models()) 

155 except Exception: 

156 return sorted(native) 

157 return sorted(native | remote) 

158 

159 def list_chat_models(self, provider: str) -> list[str]: 

160 """Delegate to the SDK backend; native llama-cpp has no catalog.""" 

161 sdk = self._get_sdk_provider() 

162 if not sdk.available(): 

163 return [] 

164 return sdk.list_chat_models(provider) 

165 

166 def pull_model(self, model: str, *, on_progress: Callable[..., Any] | None = None) -> None: 

167 """Pull via the SDK backend if installed, otherwise raise.""" 

168 sdk = self._get_sdk_provider() 

169 if not sdk.available(): 

170 raise ProviderError(f"Cannot pull model {model!r}: no pull-capable backend available") 

171 sdk.pull_model(model, on_progress=on_progress) 

172 

173 def show_model(self, model: str) -> dict[str, Any] | None: 

174 """Show model info from the backend selected by the ref prefix.""" 

175 ref = parse_model_ref(model) 

176 return self._pick_backend(ref).show_model(model) 

177 

178 def get_capabilities(self, model: str) -> list[str]: 

179 """Return capability tags from the backend selected by the ref prefix.""" 

180 ref = parse_model_ref(model) 

181 return self._pick_backend(ref).get_capabilities(model) 

182 

183 def rerank(self, query: str, candidates: list[str]) -> list[float]: 

184 """Dispatch rerank to the backend that owns ``cfg.reranker_model``. 

185 

186 Native GGUF refs go to llama-cpp; hosted refs go through the SDK 

187 provider. Raises ``ProviderError`` when ``cfg.reranker_model`` is 

188 empty or the selected backend does not support reranking. 

189 """ 

190 if not cfg.reranker_model: 

191 raise ProviderError("No reranker configured. Set cfg.reranker_model first.") 

192 if _is_native_rerank_ref(cfg.reranker_model): 

193 return self._get_llama_cpp().rerank(query, candidates) 

194 sdk = self._get_sdk_provider() 

195 if not sdk.supports_rerank(): 

196 raise ProviderError( 

197 f"Cannot rerank with {cfg.reranker_model!r}: " 

198 "hosted rerank backend not available. " 

199 "Install the 'litellm' extra to enable hosted reranking." 

200 ) 

201 return sdk.rerank(query, candidates) 

202 

203 def supports_rerank(self) -> bool: 

204 """Capability probe: can the routed backend rerank if configured? 

205 

206 Pure capability check, NOT "a reranker is currently active". An 

207 empty ``cfg.reranker_model`` returns ``True`` so the settings UI 

208 keeps the picker visible; callers that need to know whether 

209 reranking is actually configured must check ``bool(cfg.reranker_model)`` 

210 separately. Delegates to the backend that would handle the 

211 configured model when one is set. 

212 """ 

213 model = cfg.reranker_model 

214 if not model: 

215 return True 

216 if _is_native_rerank_ref(model): 

217 return self._get_llama_cpp().supports_rerank() 

218 return self._get_sdk_provider().supports_rerank() 

219 

220 def shutdown(self) -> None: 

221 """Shut down sub-providers to release resources.""" 

222 if self._llama_cpp is not None: 

223 self._llama_cpp.shutdown() 

224 if self._sdk_provider is not None: 

225 self._sdk_provider.shutdown() 

226 

227 def invalidate_load_cache(self, model_path: Path | None = None) -> None: 

228 """Forward to the native side only; the SDK side has no local cache.""" 

229 if self._llama_cpp is not None: 

230 self._llama_cpp.invalidate_load_cache(model_path) 

231 

232 def warm_up_pool(self) -> None: 

233 """Forward to the native side; the SDK side has no worker pool. 

234 

235 Lazily constructs the llama-cpp provider if it isn't already up so 

236 eager-start during ``Services`` boot still warms the configured 

237 native roles, even when the user hasn't issued a chat call yet. 

238 """ 

239 self._get_llama_cpp().warm_up_pool() 

240 

241 

242def _is_native_rerank_ref(model: str) -> bool: 

243 """Return True iff *model* should route to the native llama-cpp rerank worker. 

244 

245 Two acceptance paths: 

246 

247 1. The ref resolves to a featured rerank catalog entry (the historical 

248 fast path). 

249 2. The ref has the native HuggingFace GGUF shape 

250 ``<org>/<repo>/<filename>.gguf`` (two slashes, ``.gguf`` suffix). This 

251 lets users point ``cfg.reranker_model`` at any installed native GGUF 

252 reranker instead of only the ones that ship in ``FEATURED_ALL``. 

253 Non-GGUF refs without a known SDK prefix still raise downstream 

254 through ``parse_model_ref``. 

255 """ 

256 if not model: 

257 return False 

258 if is_rerank_ref(model): 

259 return True 

260 return model.lower().endswith(".gguf") and model.count("/") >= _NATIVE_GGUF_REF_MIN_SLASHES