Coverage for src / lilbee / modelhub / model_manager / discovery.py: 100%

94 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-06-28 01:01 +0000

1"""Remote model discovery and task classification.""" 

2 

3import logging 

4import os 

5from collections.abc import Callable 

6 

7import httpx 

8 

9from lilbee.app.services import get_services 

10from lilbee.catalog.types import ModelTask 

11from lilbee.core.config.model import cfg 

12from lilbee.modelhub.model_manager.types import RemoteModel 

13from lilbee.providers.backend_names import BackendName 

14from lilbee.providers.local_servers import ( 

15 LM_STUDIO, 

16 OLLAMA, 

17 LocalServerSpec, 

18 openai_models_url, 

19) 

20from lilbee.providers.local_servers.config_urls import configured_local_servers 

21from lilbee.providers.sdk_backend import PROVIDER_KEYS 

22 

23log = logging.getLogger(__name__) 

24 

25_EMBEDDING_FAMILIES = frozenset({"bert", "nomic-bert", "e5", "bge"}) 

26# Embedding detection by name, for servers (LM Studio) that report ids but no 

27# family. Trailing hyphens keep chat models that merely contain the letters out. 

28_EMBEDDING_NAME_PATTERNS = frozenset({"embed", "bge-", "e5-", "gte-"}) 

29_VISION_NAME_PATTERNS = frozenset({"llava", "vision", "moondream", "ocr", "minicpm-v"}) 

30# Reranker detection runs before embedding detection so ``bge-reranker-*`` 

31# (family "bge") is not misclassified as EMBEDDING. 

32_RERANKER_NAME_PATTERNS = frozenset({"reranker", "rerank", "cross-encoder"}) 

33 

34_CLASSIFY_DEFAULT_TIMEOUT_S = 5.0 

35 

36 

37def _classify_remote_task(name: str, family: str) -> ModelTask: 

38 """Classify a remote model as rerank, embedding, vision, or chat (in that order). 

39 

40 Embedding matches by family tag or name pattern; the name path covers 

41 servers like LM Studio that report no family. 

42 """ 

43 name_lower = name.lower() 

44 if any(rp in name_lower for rp in _RERANKER_NAME_PATTERNS): 

45 return ModelTask.RERANK 

46 family_lower = family.lower() 

47 if any(ef in family_lower for ef in _EMBEDDING_FAMILIES) or any( 

48 ep in name_lower for ep in _EMBEDDING_NAME_PATTERNS 

49 ): 

50 return ModelTask.EMBEDDING 

51 if any(vp in name_lower for vp in _VISION_NAME_PATTERNS): 

52 return ModelTask.VISION 

53 return ModelTask.CHAT 

54 

55 

56def reclassify_by_name(ref: str, declared_task: str) -> str: 

57 """Override declared_task to RERANK / VISION when ref names a known role. 

58 

59 Defends against pre-fix manifests that stored ``task="chat"`` for 

60 models whose ref obviously identifies them as rerankers (e.g. 

61 ``bge-reranker-*``) or vision loaders. The model bar uses this so a 

62 historical mis-tag does not surface a reranker in the chat picker. 

63 """ 

64 name_lower = ref.lower() 

65 if any(rp in name_lower for rp in _RERANKER_NAME_PATTERNS): 

66 return ModelTask.RERANK 

67 if any(vp in name_lower for vp in _VISION_NAME_PATTERNS): 

68 return ModelTask.VISION 

69 return declared_task 

70 

71 

72def classify_remote_models( 

73 base_url: str, 

74 spec: LocalServerSpec, 

75 *, 

76 timeout: float = _CLASSIFY_DEFAULT_TIMEOUT_S, 

77) -> list[RemoteModel]: 

78 """Discover and classify all models from one local server by task. 

79 

80 The strategy and provider label come from *spec* (Ollama ``/api/tags`` vs 

81 LM Studio ``/v1/models``), so a server reached at a non-default host is 

82 classified correctly. Returns ``[]`` on any error so read-only callers stay 

83 responsive when the backend is down. 

84 """ 

85 discover = _DISCOVERY_BY_KEY[spec.key] 

86 return discover(base_url, spec.display_name, timeout) 

87 

88 

89def classify_all_remote_models( 

90 *, 

91 timeout: float = _CLASSIFY_DEFAULT_TIMEOUT_S, 

92) -> list[RemoteModel]: 

93 """Classify models across every configured local server, source-labeled.""" 

94 result: list[RemoteModel] = [] 

95 for spec, base_url in configured_local_servers(): 

96 result.extend(classify_remote_models(base_url, spec, timeout=timeout)) 

97 return result 

98 

99 

100def _discover_via_ollama_tags( 

101 base_url: str, provider: BackendName, timeout: float 

102) -> list[RemoteModel]: 

103 """Classify models from Ollama's ``/api/tags`` using family metadata.""" 

104 try: 

105 resp = httpx.get(f"{base_url}/api/tags", timeout=timeout) 

106 resp.raise_for_status() 

107 raw_models = resp.json().get("models", []) 

108 except Exception: 

109 log.debug("Failed to classify remote models", exc_info=True) 

110 return [] 

111 

112 result: list[RemoteModel] = [] 

113 for model in raw_models: 

114 name = model.get("name", "") 

115 details = model.get("details", {}) 

116 family = details.get("family", "") 

117 param_size = details.get("parameter_size", "") 

118 task = _classify_remote_task(name, family) 

119 result.append( 

120 RemoteModel( 

121 name=name, 

122 task=task, 

123 family=family, 

124 parameter_size=param_size, 

125 provider=provider, 

126 ) 

127 ) 

128 return result 

129 

130 

131def _discover_via_openai_models( 

132 base_url: str, provider: BackendName, timeout: float 

133) -> list[RemoteModel]: 

134 """Classify models from an OpenAI-compatible ``/v1/models`` endpoint. 

135 

136 These servers report only ids (no family), so task detection runs off the 

137 name patterns, which LM Studio ids usually carry. Every id is surfaced: LM 

138 Studio presents LM Link remote/cloud models here as if local, so the list 

139 is intentionally not filtered to locally-downloaded models. 

140 """ 

141 try: 

142 resp = httpx.get(openai_models_url(base_url), timeout=timeout) 

143 resp.raise_for_status() 

144 raw_models = resp.json().get("data", []) 

145 except Exception: 

146 log.debug("Failed to classify remote models", exc_info=True) 

147 return [] 

148 

149 result: list[RemoteModel] = [] 

150 for model in raw_models: 

151 name = model.get("id", "") 

152 if not name: 

153 continue 

154 task = _classify_remote_task(name, "") 

155 result.append( 

156 RemoteModel( 

157 name=name, 

158 task=task, 

159 family="", 

160 parameter_size="", 

161 provider=provider, 

162 ) 

163 ) 

164 return result 

165 

166 

167# Listing strategy per local-server routing key. Module-level so it stays a 

168# single source of truth as servers are added to the registry. 

169_DISCOVERY_BY_KEY: dict[str, Callable[[str, BackendName, float], list[RemoteModel]]] = { 

170 OLLAMA.key: _discover_via_ollama_tags, 

171 LM_STUDIO.key: _discover_via_openai_models, 

172} 

173 

174 

175def _has_provider_key(cfg_field: str, env_var: str) -> bool: 

176 """Return True if a usable API key exists via env var or lilbee config.""" 

177 if os.environ.get(env_var): 

178 return True 

179 return bool(getattr(cfg, cfg_field, "")) 

180 

181 

182def discover_api_models() -> dict[str, list[RemoteModel]]: 

183 """Return frontier chat models grouped by provider. 

184 

185 Returns whatever the active provider's backend exposes for each 

186 configured API key, no curation. Short-circuits before touching 

187 the SDK when no keys are present. 

188 """ 

189 active = [ 

190 (prov, cfg_f, env, label) 

191 for prov, cfg_f, env, label in PROVIDER_KEYS 

192 if _has_provider_key(cfg_f, env) 

193 ] 

194 if not active: 

195 return {} 

196 

197 provider = get_services().provider 

198 

199 result: dict[str, list[RemoteModel]] = {} 

200 for prov, _cfg_field, _env_var, display_name in active: 

201 chat_models = [ 

202 RemoteModel( 

203 name=model_name, 

204 task=ModelTask.CHAT, 

205 family="", 

206 parameter_size="", 

207 provider=display_name, 

208 ) 

209 for model_name in provider.list_chat_models(prov) 

210 ] 

211 if chat_models: 

212 result[display_name] = chat_models 

213 return result 

214 

215 

216def detect_remote_embedding_models() -> list[str]: 

217 """Return embedding-model names across every configured local server.""" 

218 return [m.name for m in classify_all_remote_models() if m.task == ModelTask.EMBEDDING]