Coverage for src / lilbee / modelhub / model_manager / discovery.py: 100%
65 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
1"""Remote model discovery and task classification."""
3import logging
4import os
6import httpx
8from lilbee.app.services import get_services
9from lilbee.catalog.types import ModelTask
10from lilbee.core.config.model import cfg
11from lilbee.modelhub.model_manager.types import RemoteModel
12from lilbee.providers.sdk_backend import (
13 PROVIDER_KEYS,
14 detect_backend_name,
15)
17log = logging.getLogger(__name__)
19_EMBEDDING_FAMILIES = frozenset({"bert", "nomic-bert", "e5", "bge"})
20_VISION_NAME_PATTERNS = frozenset({"llava", "vision", "moondream", "ocr", "minicpm-v"})
21# Reranker detection runs BEFORE embedding detection so ``bge-reranker-*``
22# (family "bge" but clearly a reranker) does not get misclassified as
23# EMBEDDING. ``cross-encoder`` covers the SBERT-style naming convention
24# for cross-encoder rerankers.
25_RERANKER_NAME_PATTERNS = frozenset({"reranker", "rerank", "cross-encoder"})
27_CLASSIFY_DEFAULT_TIMEOUT_S = 5.0
30def _classify_remote_task(name: str, family: str) -> ModelTask:
31 """Classify a remote model as chat, embedding, vision, or rerank.
33 Reranker detection runs first so ``bge-reranker-base`` (family
34 ``bge``) does not get dragged into the embedding bucket by the
35 family check. After reranker: embedding by family tag, vision by
36 name pattern, else chat.
37 """
38 name_lower = name.lower()
39 if any(rp in name_lower for rp in _RERANKER_NAME_PATTERNS):
40 return ModelTask.RERANK
41 family_lower = family.lower()
42 if any(ef in family_lower for ef in _EMBEDDING_FAMILIES):
43 return ModelTask.EMBEDDING
44 if any(vp in name_lower for vp in _VISION_NAME_PATTERNS):
45 return ModelTask.VISION
46 return ModelTask.CHAT
49def reclassify_by_name(ref: str, declared_task: str) -> str:
50 """Override declared_task to RERANK / VISION when ref names a known role.
52 Defends against pre-fix manifests that stored ``task="chat"`` for
53 models whose ref obviously identifies them as rerankers (e.g.
54 ``bge-reranker-*``) or vision loaders. The model bar uses this so a
55 historical mis-tag does not surface a reranker in the chat picker.
56 """
57 name_lower = ref.lower()
58 if any(rp in name_lower for rp in _RERANKER_NAME_PATTERNS):
59 return ModelTask.RERANK
60 if any(vp in name_lower for vp in _VISION_NAME_PATTERNS):
61 return ModelTask.VISION
62 return declared_task
65def classify_remote_models(
66 base_url: str = "http://localhost:11434",
67 *,
68 timeout: float = _CLASSIFY_DEFAULT_TIMEOUT_S,
69) -> list[RemoteModel]:
70 """Discover and classify all models from the SDK backend by task.
72 Uses /api/tags family metadata for embedding detection and name
73 patterns for reranker and vision detection. Returns an empty list
74 on any error (including timeout) so callers in read-only code paths
75 can stay responsive when the backend is down.
76 """
77 try:
78 resp = httpx.get(f"{base_url}/api/tags", timeout=timeout)
79 resp.raise_for_status()
80 raw_models = resp.json().get("models", [])
81 except Exception:
82 log.debug("Failed to classify remote models", exc_info=True)
83 return []
85 provider = detect_backend_name(base_url)
86 result: list[RemoteModel] = []
87 for model in raw_models:
88 name = model.get("name", "")
89 details = model.get("details", {})
90 family = details.get("family", "")
91 param_size = details.get("parameter_size", "")
92 task = _classify_remote_task(name, family)
93 result.append(
94 RemoteModel(
95 name=name, task=task, family=family, parameter_size=param_size, provider=provider
96 )
97 )
98 return result
101def _has_provider_key(cfg_field: str, env_var: str) -> bool:
102 """Return True if a usable API key exists via env var or lilbee config."""
103 if os.environ.get(env_var):
104 return True
105 return bool(getattr(cfg, cfg_field, ""))
108def discover_api_models() -> dict[str, list[RemoteModel]]:
109 """Return frontier chat models grouped by provider.
111 Returns whatever the active provider's backend exposes for each
112 configured API key, no curation. Short-circuits before touching
113 the SDK when no keys are present.
114 """
115 active = [
116 (prov, cfg_f, env, label)
117 for prov, cfg_f, env, label in PROVIDER_KEYS
118 if _has_provider_key(cfg_f, env)
119 ]
120 if not active:
121 return {}
123 provider = get_services().provider
125 result: dict[str, list[RemoteModel]] = {}
126 for prov, _cfg_field, _env_var, display_name in active:
127 chat_models = [
128 RemoteModel(
129 name=model_name,
130 task=ModelTask.CHAT,
131 family="",
132 parameter_size="",
133 provider=display_name,
134 )
135 for model_name in provider.list_chat_models(prov)
136 ]
137 if chat_models:
138 result[display_name] = chat_models
139 return result
142def detect_remote_embedding_models(base_url: str = "http://localhost:11434") -> list[str]:
143 """Return names of models classified as embedding from the SDK backend."""
144 return [m.name for m in classify_remote_models(base_url) if m.task == ModelTask.EMBEDDING]