Coverage for src / lilbee / modelhub / model_manager / discovery.py: 100%
94 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-28 01:01 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-28 01:01 +0000
1"""Remote model discovery and task classification."""
3import logging
4import os
5from collections.abc import Callable
7import httpx
9from lilbee.app.services import get_services
10from lilbee.catalog.types import ModelTask
11from lilbee.core.config.model import cfg
12from lilbee.modelhub.model_manager.types import RemoteModel
13from lilbee.providers.backend_names import BackendName
14from lilbee.providers.local_servers import (
15 LM_STUDIO,
16 OLLAMA,
17 LocalServerSpec,
18 openai_models_url,
19)
20from lilbee.providers.local_servers.config_urls import configured_local_servers
21from lilbee.providers.sdk_backend import PROVIDER_KEYS
23log = logging.getLogger(__name__)
25_EMBEDDING_FAMILIES = frozenset({"bert", "nomic-bert", "e5", "bge"})
26# Embedding detection by name, for servers (LM Studio) that report ids but no
27# family. Trailing hyphens keep chat models that merely contain the letters out.
28_EMBEDDING_NAME_PATTERNS = frozenset({"embed", "bge-", "e5-", "gte-"})
29_VISION_NAME_PATTERNS = frozenset({"llava", "vision", "moondream", "ocr", "minicpm-v"})
30# Reranker detection runs before embedding detection so ``bge-reranker-*``
31# (family "bge") is not misclassified as EMBEDDING.
32_RERANKER_NAME_PATTERNS = frozenset({"reranker", "rerank", "cross-encoder"})
34_CLASSIFY_DEFAULT_TIMEOUT_S = 5.0
37def _classify_remote_task(name: str, family: str) -> ModelTask:
38 """Classify a remote model as rerank, embedding, vision, or chat (in that order).
40 Embedding matches by family tag or name pattern; the name path covers
41 servers like LM Studio that report no family.
42 """
43 name_lower = name.lower()
44 if any(rp in name_lower for rp in _RERANKER_NAME_PATTERNS):
45 return ModelTask.RERANK
46 family_lower = family.lower()
47 if any(ef in family_lower for ef in _EMBEDDING_FAMILIES) or any(
48 ep in name_lower for ep in _EMBEDDING_NAME_PATTERNS
49 ):
50 return ModelTask.EMBEDDING
51 if any(vp in name_lower for vp in _VISION_NAME_PATTERNS):
52 return ModelTask.VISION
53 return ModelTask.CHAT
56def reclassify_by_name(ref: str, declared_task: str) -> str:
57 """Override declared_task to RERANK / VISION when ref names a known role.
59 Defends against pre-fix manifests that stored ``task="chat"`` for
60 models whose ref obviously identifies them as rerankers (e.g.
61 ``bge-reranker-*``) or vision loaders. The model bar uses this so a
62 historical mis-tag does not surface a reranker in the chat picker.
63 """
64 name_lower = ref.lower()
65 if any(rp in name_lower for rp in _RERANKER_NAME_PATTERNS):
66 return ModelTask.RERANK
67 if any(vp in name_lower for vp in _VISION_NAME_PATTERNS):
68 return ModelTask.VISION
69 return declared_task
72def classify_remote_models(
73 base_url: str,
74 spec: LocalServerSpec,
75 *,
76 timeout: float = _CLASSIFY_DEFAULT_TIMEOUT_S,
77) -> list[RemoteModel]:
78 """Discover and classify all models from one local server by task.
80 The strategy and provider label come from *spec* (Ollama ``/api/tags`` vs
81 LM Studio ``/v1/models``), so a server reached at a non-default host is
82 classified correctly. Returns ``[]`` on any error so read-only callers stay
83 responsive when the backend is down.
84 """
85 discover = _DISCOVERY_BY_KEY[spec.key]
86 return discover(base_url, spec.display_name, timeout)
89def classify_all_remote_models(
90 *,
91 timeout: float = _CLASSIFY_DEFAULT_TIMEOUT_S,
92) -> list[RemoteModel]:
93 """Classify models across every configured local server, source-labeled."""
94 result: list[RemoteModel] = []
95 for spec, base_url in configured_local_servers():
96 result.extend(classify_remote_models(base_url, spec, timeout=timeout))
97 return result
100def _discover_via_ollama_tags(
101 base_url: str, provider: BackendName, timeout: float
102) -> list[RemoteModel]:
103 """Classify models from Ollama's ``/api/tags`` using family metadata."""
104 try:
105 resp = httpx.get(f"{base_url}/api/tags", timeout=timeout)
106 resp.raise_for_status()
107 raw_models = resp.json().get("models", [])
108 except Exception:
109 log.debug("Failed to classify remote models", exc_info=True)
110 return []
112 result: list[RemoteModel] = []
113 for model in raw_models:
114 name = model.get("name", "")
115 details = model.get("details", {})
116 family = details.get("family", "")
117 param_size = details.get("parameter_size", "")
118 task = _classify_remote_task(name, family)
119 result.append(
120 RemoteModel(
121 name=name,
122 task=task,
123 family=family,
124 parameter_size=param_size,
125 provider=provider,
126 )
127 )
128 return result
131def _discover_via_openai_models(
132 base_url: str, provider: BackendName, timeout: float
133) -> list[RemoteModel]:
134 """Classify models from an OpenAI-compatible ``/v1/models`` endpoint.
136 These servers report only ids (no family), so task detection runs off the
137 name patterns, which LM Studio ids usually carry. Every id is surfaced: LM
138 Studio presents LM Link remote/cloud models here as if local, so the list
139 is intentionally not filtered to locally-downloaded models.
140 """
141 try:
142 resp = httpx.get(openai_models_url(base_url), timeout=timeout)
143 resp.raise_for_status()
144 raw_models = resp.json().get("data", [])
145 except Exception:
146 log.debug("Failed to classify remote models", exc_info=True)
147 return []
149 result: list[RemoteModel] = []
150 for model in raw_models:
151 name = model.get("id", "")
152 if not name:
153 continue
154 task = _classify_remote_task(name, "")
155 result.append(
156 RemoteModel(
157 name=name,
158 task=task,
159 family="",
160 parameter_size="",
161 provider=provider,
162 )
163 )
164 return result
167# Listing strategy per local-server routing key. Module-level so it stays a
168# single source of truth as servers are added to the registry.
169_DISCOVERY_BY_KEY: dict[str, Callable[[str, BackendName, float], list[RemoteModel]]] = {
170 OLLAMA.key: _discover_via_ollama_tags,
171 LM_STUDIO.key: _discover_via_openai_models,
172}
175def _has_provider_key(cfg_field: str, env_var: str) -> bool:
176 """Return True if a usable API key exists via env var or lilbee config."""
177 if os.environ.get(env_var):
178 return True
179 return bool(getattr(cfg, cfg_field, ""))
182def discover_api_models() -> dict[str, list[RemoteModel]]:
183 """Return frontier chat models grouped by provider.
185 Returns whatever the active provider's backend exposes for each
186 configured API key, no curation. Short-circuits before touching
187 the SDK when no keys are present.
188 """
189 active = [
190 (prov, cfg_f, env, label)
191 for prov, cfg_f, env, label in PROVIDER_KEYS
192 if _has_provider_key(cfg_f, env)
193 ]
194 if not active:
195 return {}
197 provider = get_services().provider
199 result: dict[str, list[RemoteModel]] = {}
200 for prov, _cfg_field, _env_var, display_name in active:
201 chat_models = [
202 RemoteModel(
203 name=model_name,
204 task=ModelTask.CHAT,
205 family="",
206 parameter_size="",
207 provider=display_name,
208 )
209 for model_name in provider.list_chat_models(prov)
210 ]
211 if chat_models:
212 result[display_name] = chat_models
213 return result
216def detect_remote_embedding_models() -> list[str]:
217 """Return embedding-model names across every configured local server."""
218 return [m.name for m in classify_all_remote_models() if m.task == ModelTask.EMBEDDING]